实现dp相关逻辑

Change-Id: I63962a6f12d9330fcd81e2fa934fffcdb4d9f3c9
This commit is contained in:
tianyutong
2025-06-23 23:11:32 +08:00
parent 6f920623fc
commit b2e940a01b
3 changed files with 323 additions and 0 deletions

View File

@@ -102,6 +102,11 @@ class TransformerConfig(ModelParallelConfig):
"""Whether cross entropy loss is calculated over the actual number of non-padded tokens in the
global batch, versus the default behavior of assuming all tokens are non-padded."""
num_micro_batches_gard_factor: float = 0
"""If this is not zero, the num micro batches per dp implementation would be used.
Defaults to 0.
"""
####################
# initialization
####################

View File

@@ -258,6 +258,66 @@ def validate_args(args, defaults={}):
args.data_parallel_splits = data_parallel_split
args.micro_batch_size_per_dp = micro_batch_sizes_split
args.num_micro_batches = None
<<<<<<< PATCH SET (ae126d 实现dp相关逻辑)
args.num_micro_batches_grad_factor = 0.
assert sum(data_parallel_split) == args.data_parallel_size, \
'the length of micro_batch_size_per_dp (equal to sum of n0, n1, ... ) should be equal to data-parallel-size.'
if args.num_micro_batches_per_dp is not None:
num_microbatches_splits = args.num_micro_batches_per_dp[1::2]
num_microbatches_data_parallel_splits = args.num_micro_batches_per_dp[::2]
args.num_micro_batches_per_dp = num_microbatches_splits
assert sum(num_microbatches_data_parallel_splits) == args.data_parallel_size , \
"the length of num_micro_batches_per_dp (equal to sum of 'n0, n1, ...') should be equal to data-parallel-size."
assert num_microbatches_data_parallel_splits == data_parallel_split, \
"num micro batches' data parallel splits should be equal to micro batch sizes' data parallel splits one by one." \
"for example: micro batch size per dp is (1 A 1 B) then num micro batches per dp should be (1 X 1 Y)."
total_num_microbatches_split = [num_microbatches_splits[i] for i, j in enumerate(num_microbatches_data_parallel_splits) for _ in range(j)]
nmbs_dict = {}
for i in num_microbatches_splits:
nmbs_dict[i] = 0
assert len(nmbs_dict) <= 2, \
"the number of heterogeneous devices in parameter num_micro_batches_per_dp should be less than or equal to 2." \
f'but get {len(nmbs_dict)} for num micro batches.' \
"it means there are more than 2 heterogeneous devices in parameter num_micro_batches_per_dp! that is not supported yet."
sum_micro_batches = sum([micro_batch_sizes_split[i] * total_num_microbatches_split[i] for i in range(len(micro_batch_sizes_split))])
assert args.rampup_batch_size is None, 'num_micro_batches_per_dp is not currently supported for use with rampup_batch_size.'
offset = args.tensor_model_parallel_size * args.pipeline_model_parallel_size
for i in range(1, args.data_parallel_size + 1):
if args.rank < i * offset:
args.micro_batch_size = total_micro_batch_sizes_split[i - 1]
if args.num_micro_batches_per_dp is not None:
args.num_micro_batches = total_num_microbatches_split[i - 1]
args.num_micro_batches_grad_factor = total_micro_batch_sizes_split[i - 1] * \
total_num_microbatches_split[i - 1] / sum_micro_batches
break
if args.num_micro_batches_per_dp is None:
sum_of_micro_batch_sizes = sum(map(lambda x, y : x * y,
micro_batch_sizes_split,
data_parallel_split))
assert args.global_batch_size % sum_of_micro_batch_sizes == 0, \
'global batch size should be divisible by sum of micro batch size per dp! ' \
f'but get global batch size is {args.global_batch_size} and the sum of micro batch size per dp is {sum_of_micro_batch_sizes}.'
else:
sum_of_micro_batch_sizes = sum(map(lambda x, y, z : x * y * z,
micro_batch_sizes_split,
data_parallel_split,
num_microbatches_splits))
assert args.global_batch_size == sum_of_micro_batch_sizes, \
'global batch size should be equal to sum of micro batch size per dp! ' \
f'but get global batch size is {args.global_batch_size} and the sum of micro batch size per dp is {sum_of_micro_batch_sizes}.'
args.sum_micro_batch_sizes = sum_of_micro_batch_sizes
else:
args.num_micro_batches = None
args.data_parallel_splits = None
=======
args.min_num_micro_batches = None
assert sum(data_parallel_split) == args.data_parallel_size, \
'the length of micro_batch_size_per_dp (equal to sum of n0, n1, ... ) should be equal to data-parallel-size.'
@@ -413,6 +473,7 @@ def validate_args(args, defaults={}):
args.recompute_num_layers_per_stage = recompute_num_layers_per_stage
args.recompute_method_per_stage = recompute_method_per_stage
>>>>>>> BASE (6f9206 实现pp相关逻辑)
# Batch size.
assert args.micro_batch_size is not None
@@ -810,7 +871,11 @@ def core_transformer_config_from_args(args, config_class=None):
else:
kw_args['num_query_groups'] = None
if args.num_micro_batches_per_dp:
<<<<<<< PATCH SET (ae126d 实现dp相关逻辑)
kw_args['num_micro_batches_gard_factor'] = args.num_micro_batches_grad_factor
=======
kw_args['num_micro_batches_gard_factor'] = args.num_micro_batches / float(args.sum_num_micro_batches)
>>>>>>> BASE (6f9206 实现pp相关逻辑)
else:
kw_args['num_micro_batches_gard_factor'] = 0

253
pretrain_llama.py Normal file
View File

@@ -0,0 +1,253 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
"""Pretrain GPT."""
import os
import torch
from functools import partial
from typing import Union
from megatron.training import get_args
from megatron.training import print_rank_0
from megatron.training import get_timers
from megatron.training import get_tokenizer
from megatron.core import mpu
from megatron.core.enums import ModelType
from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder
from megatron.core.datasets.utils import get_blend_from_list
from megatron.core.datasets.gpt_dataset import GPTDatasetConfig
from megatron.core.datasets.gpt_dataset import MockGPTDataset, GPTDataset
import megatron.legacy.model
from megatron.core.models.gpt import GPTModel
from megatron.training import pretrain
from megatron.core.utils import StragglerDetector
from megatron.core.transformer.spec_utils import import_module
from megatron.training.utils import (
get_batch_on_this_cp_rank,
get_batch_on_this_tp_rank,
)
from megatron.training.arguments import core_transformer_config_from_args
from megatron.training.yaml_arguments import core_transformer_config_from_yaml
from megatron.core.models.gpt.gpt_layer_specs import (
get_gpt_layer_local_spec,
get_gpt_layer_with_transformer_engine_spec,
)
stimer = StragglerDetector()
def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megatron.legacy.model.GPTModel]:
"""Builds the model.
If you set the use_legacy_models to True, it will return the legacy GPT model and if not the mcore GPT model.
Args:
pre_process (bool, optional): Set to true if you need to compute embedings. Defaults to True.
post_process (bool, optional): Set to true if you need to want to compute output logits/loss. Defaults to True.
Returns:
Union[GPTModel, megatron.legacy.model.GPTModel]: The returned model
"""
args = get_args()
use_te = args.transformer_impl == "transformer_engine"
print_rank_0('building GPT model ...')
# Experimental loading arguments from yaml
if args.yaml_cfg is not None:
config = core_transformer_config_from_yaml(args, "language_model")
else:
config = core_transformer_config_from_args(args)
if args.use_legacy_models:
model = megatron.legacy.model.GPTModel(
config,
num_tokentypes=0,
parallel_output=True,
pre_process=pre_process,
post_process=post_process,
)
else: # using core models
if args.spec is not None:
transformer_layer_spec = import_module(args.spec)
else:
if use_te:
transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec(args.num_experts, args.moe_grouped_gemm, args.qk_layernorm)
else:
transformer_layer_spec = get_gpt_layer_local_spec(args.num_experts, args.moe_grouped_gemm, args.qk_layernorm)
model = GPTModel(
config=config,
transformer_layer_spec=transformer_layer_spec,
vocab_size=args.padded_vocab_size,
max_sequence_length=args.max_position_embeddings,
pre_process=pre_process,
post_process=post_process,
fp16_lm_cross_entropy=args.fp16_lm_cross_entropy,
parallel_output=True,
share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights,
position_embedding_type=args.position_embedding_type,
rotary_percent=args.rotary_percent,
rotary_base=args.rotary_base
)
return model
def get_batch(data_iterator):
"""Generate a batch."""
# TODO: this is pretty hacky, find a better way
if (not mpu.is_pipeline_first_stage()) and (not mpu.is_pipeline_last_stage()):
return None, None, None, None, None
# get batches based on the TP rank you are on
batch = get_batch_on_this_tp_rank(data_iterator)
# slice batch along sequence dimension for context parallelism
batch = get_batch_on_this_cp_rank(batch)
return batch.values()
def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor):
"""Loss function.
Args:
loss_mask (torch.Tensor): Used to mask out some portions of the loss
output_tensor (torch.Tensor): The tensor with the losses
Returns:
the loss scalar for this micro-batch
the number of non-padded tokens in this microbatch
a dict containing reporting metrics on the loss and number of tokens across
the data parallel ranks
"""
args = get_args()
losses = output_tensor.float()
loss_mask = loss_mask.view(-1).float()
total_tokens = loss_mask.sum()
loss = torch.cat([torch.sum(losses.view(-1) * loss_mask).view(1), total_tokens.view(1)])
if args.context_parallel_size > 1:
torch.distributed.all_reduce(loss, group=mpu.get_context_parallel_group())
# Check individual rank losses are not NaN prior to DP all-reduce.
if args.check_for_nan_in_loss_and_grad:
global_rank = torch.distributed.get_rank()
assert not loss[0].isnan(), (
f'Rank {global_rank}: found NaN in local forward loss calculation. '
f'Device: {torch.cuda.current_device()}, node: {os.uname()[1]}'
)
# Reduce loss for logging.
reporting_loss = loss.clone().detach()
if args.num_micro_batches_per_dp is None:
torch.distributed.all_reduce(reporting_loss, group=mpu.get_data_parallel_group())
local_num_tokens = loss[1].clone().detach().to(torch.int)
return (
loss[0] * args.context_parallel_size,
local_num_tokens,
{'lm loss': (reporting_loss[0], reporting_loss[1])},
)
def forward_step(data_iterator, model: GPTModel):
"""Forward training step.
Args:
data_iterator : Input data iterator
model (GPTModel): The GPT Model
"""
args = get_args()
timers = get_timers()
# Get the batch.
timers('batch-generator', log_level=2).start()
global stimer
with stimer(bdata=True):
tokens, labels, loss_mask, attention_mask, position_ids = get_batch(
data_iterator)
timers('batch-generator').stop()
with stimer:
output_tensor = model(tokens, position_ids, attention_mask,
labels=labels)
return output_tensor, partial(loss_func, loss_mask)
def is_dataset_built_on_rank():
return (
mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage()
) and mpu.get_tensor_model_parallel_rank() == 0
def core_gpt_dataset_config_from_args(args):
tokenizer = get_tokenizer()
return GPTDatasetConfig(
random_seed=args.seed,
sequence_length=args.seq_length,
blend=get_blend_from_list(args.data_path),
blend_per_split=[
get_blend_from_list(args.train_data_path),
get_blend_from_list(args.valid_data_path),
get_blend_from_list(args.test_data_path)
],
split=args.split,
num_dataset_builder_threads=args.num_dataset_builder_threads,
path_to_cache=args.data_cache_path,
mmap_bin_files=args.mmap_bin_files,
tokenizer=tokenizer,
reset_position_ids=args.reset_position_ids,
reset_attention_mask=args.reset_attention_mask,
eod_mask_loss=args.eod_mask_loss,
create_attention_mask=args.create_attention_mask_in_dataloader,
s3_cache_path = args.s3_cache_path
)
def train_valid_test_datasets_provider(train_val_test_num_samples):
"""Build the train test and validation datasets.
Args:
train_val_test_num_samples : A list containing the number of samples in train test and validation.
"""
args = get_args()
config = core_gpt_dataset_config_from_args(args)
if args.mock_data:
dataset_type = MockGPTDataset
else:
dataset_type = GPTDataset
print_rank_0("> building train, validation, and test datasets for GPT ...")
train_ds, valid_ds, test_ds = BlendedMegatronDatasetBuilder(
dataset_type,
train_val_test_num_samples,
is_dataset_built_on_rank,
config
).build()
print_rank_0("> finished creating GPT datasets ...")
return train_ds, valid_ds, test_ds
if __name__ == "__main__":
# Temporary for transition to core datasets
train_valid_test_datasets_provider.is_distributed = True
pretrain(
train_valid_test_datasets_provider,
model_provider,
ModelType.encoder_or_decoder,
forward_step,
args_defaults={'tokenizer_type': 'GPT2BPETokenizer'},
)