实现dp相关逻辑
Change-Id: I63962a6f12d9330fcd81e2fa934fffcdb4d9f3c9
This commit is contained in:
@@ -102,6 +102,11 @@ class TransformerConfig(ModelParallelConfig):
|
||||
"""Whether cross entropy loss is calculated over the actual number of non-padded tokens in the
|
||||
global batch, versus the default behavior of assuming all tokens are non-padded."""
|
||||
|
||||
num_micro_batches_gard_factor: float = 0
|
||||
"""If this is not zero, the num micro batches per dp implementation would be used.
|
||||
Defaults to 0.
|
||||
"""
|
||||
|
||||
####################
|
||||
# initialization
|
||||
####################
|
||||
|
@@ -258,6 +258,66 @@ def validate_args(args, defaults={}):
|
||||
args.data_parallel_splits = data_parallel_split
|
||||
args.micro_batch_size_per_dp = micro_batch_sizes_split
|
||||
args.num_micro_batches = None
|
||||
<<<<<<< PATCH SET (ae126d 实现dp相关逻辑)
|
||||
args.num_micro_batches_grad_factor = 0.
|
||||
|
||||
assert sum(data_parallel_split) == args.data_parallel_size, \
|
||||
'the length of micro_batch_size_per_dp (equal to sum of n0, n1, ... ) should be equal to data-parallel-size.'
|
||||
|
||||
if args.num_micro_batches_per_dp is not None:
|
||||
num_microbatches_splits = args.num_micro_batches_per_dp[1::2]
|
||||
num_microbatches_data_parallel_splits = args.num_micro_batches_per_dp[::2]
|
||||
args.num_micro_batches_per_dp = num_microbatches_splits
|
||||
|
||||
assert sum(num_microbatches_data_parallel_splits) == args.data_parallel_size , \
|
||||
"the length of num_micro_batches_per_dp (equal to sum of 'n0, n1, ...') should be equal to data-parallel-size."
|
||||
assert num_microbatches_data_parallel_splits == data_parallel_split, \
|
||||
"num micro batches' data parallel splits should be equal to micro batch sizes' data parallel splits one by one." \
|
||||
"for example: micro batch size per dp is (1 A 1 B) then num micro batches per dp should be (1 X 1 Y)."
|
||||
|
||||
total_num_microbatches_split = [num_microbatches_splits[i] for i, j in enumerate(num_microbatches_data_parallel_splits) for _ in range(j)]
|
||||
|
||||
nmbs_dict = {}
|
||||
for i in num_microbatches_splits:
|
||||
nmbs_dict[i] = 0
|
||||
assert len(nmbs_dict) <= 2, \
|
||||
"the number of heterogeneous devices in parameter num_micro_batches_per_dp should be less than or equal to 2." \
|
||||
f'but get {len(nmbs_dict)} for num micro batches.' \
|
||||
"it means there are more than 2 heterogeneous devices in parameter num_micro_batches_per_dp! that is not supported yet."
|
||||
|
||||
sum_micro_batches = sum([micro_batch_sizes_split[i] * total_num_microbatches_split[i] for i in range(len(micro_batch_sizes_split))])
|
||||
|
||||
assert args.rampup_batch_size is None, 'num_micro_batches_per_dp is not currently supported for use with rampup_batch_size.'
|
||||
|
||||
offset = args.tensor_model_parallel_size * args.pipeline_model_parallel_size
|
||||
for i in range(1, args.data_parallel_size + 1):
|
||||
if args.rank < i * offset:
|
||||
args.micro_batch_size = total_micro_batch_sizes_split[i - 1]
|
||||
if args.num_micro_batches_per_dp is not None:
|
||||
args.num_micro_batches = total_num_microbatches_split[i - 1]
|
||||
args.num_micro_batches_grad_factor = total_micro_batch_sizes_split[i - 1] * \
|
||||
total_num_microbatches_split[i - 1] / sum_micro_batches
|
||||
break
|
||||
if args.num_micro_batches_per_dp is None:
|
||||
sum_of_micro_batch_sizes = sum(map(lambda x, y : x * y,
|
||||
micro_batch_sizes_split,
|
||||
data_parallel_split))
|
||||
assert args.global_batch_size % sum_of_micro_batch_sizes == 0, \
|
||||
'global batch size should be divisible by sum of micro batch size per dp! ' \
|
||||
f'but get global batch size is {args.global_batch_size} and the sum of micro batch size per dp is {sum_of_micro_batch_sizes}.'
|
||||
else:
|
||||
sum_of_micro_batch_sizes = sum(map(lambda x, y, z : x * y * z,
|
||||
micro_batch_sizes_split,
|
||||
data_parallel_split,
|
||||
num_microbatches_splits))
|
||||
assert args.global_batch_size == sum_of_micro_batch_sizes, \
|
||||
'global batch size should be equal to sum of micro batch size per dp! ' \
|
||||
f'but get global batch size is {args.global_batch_size} and the sum of micro batch size per dp is {sum_of_micro_batch_sizes}.'
|
||||
args.sum_micro_batch_sizes = sum_of_micro_batch_sizes
|
||||
else:
|
||||
args.num_micro_batches = None
|
||||
args.data_parallel_splits = None
|
||||
=======
|
||||
args.min_num_micro_batches = None
|
||||
assert sum(data_parallel_split) == args.data_parallel_size, \
|
||||
'the length of micro_batch_size_per_dp (equal to sum of n0, n1, ... ) should be equal to data-parallel-size.'
|
||||
@@ -413,6 +473,7 @@ def validate_args(args, defaults={}):
|
||||
|
||||
args.recompute_num_layers_per_stage = recompute_num_layers_per_stage
|
||||
args.recompute_method_per_stage = recompute_method_per_stage
|
||||
>>>>>>> BASE (6f9206 实现pp相关逻辑)
|
||||
|
||||
# Batch size.
|
||||
assert args.micro_batch_size is not None
|
||||
@@ -810,7 +871,11 @@ def core_transformer_config_from_args(args, config_class=None):
|
||||
else:
|
||||
kw_args['num_query_groups'] = None
|
||||
if args.num_micro_batches_per_dp:
|
||||
<<<<<<< PATCH SET (ae126d 实现dp相关逻辑)
|
||||
kw_args['num_micro_batches_gard_factor'] = args.num_micro_batches_grad_factor
|
||||
=======
|
||||
kw_args['num_micro_batches_gard_factor'] = args.num_micro_batches / float(args.sum_num_micro_batches)
|
||||
>>>>>>> BASE (6f9206 实现pp相关逻辑)
|
||||
else:
|
||||
kw_args['num_micro_batches_gard_factor'] = 0
|
||||
|
||||
|
253
pretrain_llama.py
Normal file
253
pretrain_llama.py
Normal file
@@ -0,0 +1,253 @@
|
||||
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
|
||||
"""Pretrain GPT."""
|
||||
|
||||
import os
|
||||
import torch
|
||||
from functools import partial
|
||||
|
||||
from typing import Union
|
||||
from megatron.training import get_args
|
||||
from megatron.training import print_rank_0
|
||||
from megatron.training import get_timers
|
||||
from megatron.training import get_tokenizer
|
||||
from megatron.core import mpu
|
||||
from megatron.core.enums import ModelType
|
||||
from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder
|
||||
from megatron.core.datasets.utils import get_blend_from_list
|
||||
from megatron.core.datasets.gpt_dataset import GPTDatasetConfig
|
||||
from megatron.core.datasets.gpt_dataset import MockGPTDataset, GPTDataset
|
||||
import megatron.legacy.model
|
||||
from megatron.core.models.gpt import GPTModel
|
||||
from megatron.training import pretrain
|
||||
from megatron.core.utils import StragglerDetector
|
||||
from megatron.core.transformer.spec_utils import import_module
|
||||
from megatron.training.utils import (
|
||||
get_batch_on_this_cp_rank,
|
||||
get_batch_on_this_tp_rank,
|
||||
)
|
||||
from megatron.training.arguments import core_transformer_config_from_args
|
||||
from megatron.training.yaml_arguments import core_transformer_config_from_yaml
|
||||
from megatron.core.models.gpt.gpt_layer_specs import (
|
||||
get_gpt_layer_local_spec,
|
||||
get_gpt_layer_with_transformer_engine_spec,
|
||||
)
|
||||
|
||||
|
||||
stimer = StragglerDetector()
|
||||
|
||||
def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megatron.legacy.model.GPTModel]:
|
||||
"""Builds the model.
|
||||
|
||||
If you set the use_legacy_models to True, it will return the legacy GPT model and if not the mcore GPT model.
|
||||
|
||||
Args:
|
||||
pre_process (bool, optional): Set to true if you need to compute embedings. Defaults to True.
|
||||
post_process (bool, optional): Set to true if you need to want to compute output logits/loss. Defaults to True.
|
||||
|
||||
|
||||
Returns:
|
||||
Union[GPTModel, megatron.legacy.model.GPTModel]: The returned model
|
||||
"""
|
||||
args = get_args()
|
||||
use_te = args.transformer_impl == "transformer_engine"
|
||||
|
||||
print_rank_0('building GPT model ...')
|
||||
# Experimental loading arguments from yaml
|
||||
if args.yaml_cfg is not None:
|
||||
config = core_transformer_config_from_yaml(args, "language_model")
|
||||
else:
|
||||
config = core_transformer_config_from_args(args)
|
||||
|
||||
if args.use_legacy_models:
|
||||
model = megatron.legacy.model.GPTModel(
|
||||
config,
|
||||
num_tokentypes=0,
|
||||
parallel_output=True,
|
||||
pre_process=pre_process,
|
||||
post_process=post_process,
|
||||
)
|
||||
else: # using core models
|
||||
if args.spec is not None:
|
||||
transformer_layer_spec = import_module(args.spec)
|
||||
else:
|
||||
if use_te:
|
||||
transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec(args.num_experts, args.moe_grouped_gemm, args.qk_layernorm)
|
||||
else:
|
||||
transformer_layer_spec = get_gpt_layer_local_spec(args.num_experts, args.moe_grouped_gemm, args.qk_layernorm)
|
||||
|
||||
model = GPTModel(
|
||||
config=config,
|
||||
transformer_layer_spec=transformer_layer_spec,
|
||||
vocab_size=args.padded_vocab_size,
|
||||
max_sequence_length=args.max_position_embeddings,
|
||||
pre_process=pre_process,
|
||||
post_process=post_process,
|
||||
fp16_lm_cross_entropy=args.fp16_lm_cross_entropy,
|
||||
parallel_output=True,
|
||||
share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights,
|
||||
position_embedding_type=args.position_embedding_type,
|
||||
rotary_percent=args.rotary_percent,
|
||||
rotary_base=args.rotary_base
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def get_batch(data_iterator):
|
||||
"""Generate a batch."""
|
||||
|
||||
# TODO: this is pretty hacky, find a better way
|
||||
if (not mpu.is_pipeline_first_stage()) and (not mpu.is_pipeline_last_stage()):
|
||||
return None, None, None, None, None
|
||||
|
||||
# get batches based on the TP rank you are on
|
||||
batch = get_batch_on_this_tp_rank(data_iterator)
|
||||
|
||||
# slice batch along sequence dimension for context parallelism
|
||||
batch = get_batch_on_this_cp_rank(batch)
|
||||
|
||||
return batch.values()
|
||||
|
||||
|
||||
def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor):
|
||||
"""Loss function.
|
||||
|
||||
Args:
|
||||
loss_mask (torch.Tensor): Used to mask out some portions of the loss
|
||||
output_tensor (torch.Tensor): The tensor with the losses
|
||||
|
||||
Returns:
|
||||
the loss scalar for this micro-batch
|
||||
the number of non-padded tokens in this microbatch
|
||||
a dict containing reporting metrics on the loss and number of tokens across
|
||||
the data parallel ranks
|
||||
"""
|
||||
args = get_args()
|
||||
|
||||
losses = output_tensor.float()
|
||||
loss_mask = loss_mask.view(-1).float()
|
||||
total_tokens = loss_mask.sum()
|
||||
loss = torch.cat([torch.sum(losses.view(-1) * loss_mask).view(1), total_tokens.view(1)])
|
||||
|
||||
if args.context_parallel_size > 1:
|
||||
torch.distributed.all_reduce(loss, group=mpu.get_context_parallel_group())
|
||||
|
||||
# Check individual rank losses are not NaN prior to DP all-reduce.
|
||||
if args.check_for_nan_in_loss_and_grad:
|
||||
global_rank = torch.distributed.get_rank()
|
||||
assert not loss[0].isnan(), (
|
||||
f'Rank {global_rank}: found NaN in local forward loss calculation. '
|
||||
f'Device: {torch.cuda.current_device()}, node: {os.uname()[1]}'
|
||||
)
|
||||
|
||||
# Reduce loss for logging.
|
||||
reporting_loss = loss.clone().detach()
|
||||
if args.num_micro_batches_per_dp is None:
|
||||
torch.distributed.all_reduce(reporting_loss, group=mpu.get_data_parallel_group())
|
||||
|
||||
local_num_tokens = loss[1].clone().detach().to(torch.int)
|
||||
return (
|
||||
loss[0] * args.context_parallel_size,
|
||||
local_num_tokens,
|
||||
{'lm loss': (reporting_loss[0], reporting_loss[1])},
|
||||
)
|
||||
|
||||
|
||||
def forward_step(data_iterator, model: GPTModel):
|
||||
"""Forward training step.
|
||||
|
||||
Args:
|
||||
data_iterator : Input data iterator
|
||||
model (GPTModel): The GPT Model
|
||||
"""
|
||||
args = get_args()
|
||||
timers = get_timers()
|
||||
|
||||
# Get the batch.
|
||||
timers('batch-generator', log_level=2).start()
|
||||
global stimer
|
||||
with stimer(bdata=True):
|
||||
tokens, labels, loss_mask, attention_mask, position_ids = get_batch(
|
||||
data_iterator)
|
||||
timers('batch-generator').stop()
|
||||
|
||||
with stimer:
|
||||
output_tensor = model(tokens, position_ids, attention_mask,
|
||||
labels=labels)
|
||||
|
||||
return output_tensor, partial(loss_func, loss_mask)
|
||||
|
||||
|
||||
def is_dataset_built_on_rank():
|
||||
return (
|
||||
mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage()
|
||||
) and mpu.get_tensor_model_parallel_rank() == 0
|
||||
|
||||
|
||||
def core_gpt_dataset_config_from_args(args):
|
||||
tokenizer = get_tokenizer()
|
||||
|
||||
return GPTDatasetConfig(
|
||||
random_seed=args.seed,
|
||||
sequence_length=args.seq_length,
|
||||
blend=get_blend_from_list(args.data_path),
|
||||
blend_per_split=[
|
||||
get_blend_from_list(args.train_data_path),
|
||||
get_blend_from_list(args.valid_data_path),
|
||||
get_blend_from_list(args.test_data_path)
|
||||
],
|
||||
split=args.split,
|
||||
num_dataset_builder_threads=args.num_dataset_builder_threads,
|
||||
path_to_cache=args.data_cache_path,
|
||||
mmap_bin_files=args.mmap_bin_files,
|
||||
tokenizer=tokenizer,
|
||||
reset_position_ids=args.reset_position_ids,
|
||||
reset_attention_mask=args.reset_attention_mask,
|
||||
eod_mask_loss=args.eod_mask_loss,
|
||||
create_attention_mask=args.create_attention_mask_in_dataloader,
|
||||
s3_cache_path = args.s3_cache_path
|
||||
)
|
||||
|
||||
|
||||
def train_valid_test_datasets_provider(train_val_test_num_samples):
|
||||
"""Build the train test and validation datasets.
|
||||
|
||||
Args:
|
||||
train_val_test_num_samples : A list containing the number of samples in train test and validation.
|
||||
"""
|
||||
args = get_args()
|
||||
|
||||
config = core_gpt_dataset_config_from_args(args)
|
||||
|
||||
if args.mock_data:
|
||||
dataset_type = MockGPTDataset
|
||||
else:
|
||||
dataset_type = GPTDataset
|
||||
|
||||
print_rank_0("> building train, validation, and test datasets for GPT ...")
|
||||
|
||||
train_ds, valid_ds, test_ds = BlendedMegatronDatasetBuilder(
|
||||
dataset_type,
|
||||
train_val_test_num_samples,
|
||||
is_dataset_built_on_rank,
|
||||
config
|
||||
).build()
|
||||
|
||||
print_rank_0("> finished creating GPT datasets ...")
|
||||
|
||||
return train_ds, valid_ds, test_ds
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
# Temporary for transition to core datasets
|
||||
train_valid_test_datasets_provider.is_distributed = True
|
||||
|
||||
pretrain(
|
||||
train_valid_test_datasets_provider,
|
||||
model_provider,
|
||||
ModelType.encoder_or_decoder,
|
||||
forward_step,
|
||||
args_defaults={'tokenizer_type': 'GPT2BPETokenizer'},
|
||||
)
|
||||
|
Reference in New Issue
Block a user