From b2e940a01b2ec020873c96cb2f24edb7cef8f465 Mon Sep 17 00:00:00 2001 From: tianyutong Date: Mon, 23 Jun 2025 23:11:32 +0800 Subject: [PATCH] =?UTF-8?q?=E5=AE=9E=E7=8E=B0dp=E7=9B=B8=E5=85=B3=E9=80=BB?= =?UTF-8?q?=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Change-Id: I63962a6f12d9330fcd81e2fa934fffcdb4d9f3c9 --- .../core/transformer/transformer_config.py | 5 + megatron/training/arguments.py | 65 +++++ pretrain_llama.py | 253 ++++++++++++++++++ 3 files changed, 323 insertions(+) create mode 100644 pretrain_llama.py diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index 329c907..0b05731 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -102,6 +102,11 @@ class TransformerConfig(ModelParallelConfig): """Whether cross entropy loss is calculated over the actual number of non-padded tokens in the global batch, versus the default behavior of assuming all tokens are non-padded.""" + num_micro_batches_gard_factor: float = 0 + """If this is not zero, the num micro batches per dp implementation would be used. + Defaults to 0. + """ + #################### # initialization #################### diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 962bbca..1debb5a 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -258,6 +258,66 @@ def validate_args(args, defaults={}): args.data_parallel_splits = data_parallel_split args.micro_batch_size_per_dp = micro_batch_sizes_split args.num_micro_batches = None +<<<<<<< PATCH SET (ae126d 实现dp相关逻辑) + args.num_micro_batches_grad_factor = 0. + + assert sum(data_parallel_split) == args.data_parallel_size, \ + 'the length of micro_batch_size_per_dp (equal to sum of n0, n1, ... ) should be equal to data-parallel-size.' + + if args.num_micro_batches_per_dp is not None: + num_microbatches_splits = args.num_micro_batches_per_dp[1::2] + num_microbatches_data_parallel_splits = args.num_micro_batches_per_dp[::2] + args.num_micro_batches_per_dp = num_microbatches_splits + + assert sum(num_microbatches_data_parallel_splits) == args.data_parallel_size , \ + "the length of num_micro_batches_per_dp (equal to sum of 'n0, n1, ...') should be equal to data-parallel-size." + assert num_microbatches_data_parallel_splits == data_parallel_split, \ + "num micro batches' data parallel splits should be equal to micro batch sizes' data parallel splits one by one." \ + "for example: micro batch size per dp is (1 A 1 B) then num micro batches per dp should be (1 X 1 Y)." + + total_num_microbatches_split = [num_microbatches_splits[i] for i, j in enumerate(num_microbatches_data_parallel_splits) for _ in range(j)] + + nmbs_dict = {} + for i in num_microbatches_splits: + nmbs_dict[i] = 0 + assert len(nmbs_dict) <= 2, \ + "the number of heterogeneous devices in parameter num_micro_batches_per_dp should be less than or equal to 2." \ + f'but get {len(nmbs_dict)} for num micro batches.' \ + "it means there are more than 2 heterogeneous devices in parameter num_micro_batches_per_dp! that is not supported yet." + + sum_micro_batches = sum([micro_batch_sizes_split[i] * total_num_microbatches_split[i] for i in range(len(micro_batch_sizes_split))]) + + assert args.rampup_batch_size is None, 'num_micro_batches_per_dp is not currently supported for use with rampup_batch_size.' + + offset = args.tensor_model_parallel_size * args.pipeline_model_parallel_size + for i in range(1, args.data_parallel_size + 1): + if args.rank < i * offset: + args.micro_batch_size = total_micro_batch_sizes_split[i - 1] + if args.num_micro_batches_per_dp is not None: + args.num_micro_batches = total_num_microbatches_split[i - 1] + args.num_micro_batches_grad_factor = total_micro_batch_sizes_split[i - 1] * \ + total_num_microbatches_split[i - 1] / sum_micro_batches + break + if args.num_micro_batches_per_dp is None: + sum_of_micro_batch_sizes = sum(map(lambda x, y : x * y, + micro_batch_sizes_split, + data_parallel_split)) + assert args.global_batch_size % sum_of_micro_batch_sizes == 0, \ + 'global batch size should be divisible by sum of micro batch size per dp! ' \ + f'but get global batch size is {args.global_batch_size} and the sum of micro batch size per dp is {sum_of_micro_batch_sizes}.' + else: + sum_of_micro_batch_sizes = sum(map(lambda x, y, z : x * y * z, + micro_batch_sizes_split, + data_parallel_split, + num_microbatches_splits)) + assert args.global_batch_size == sum_of_micro_batch_sizes, \ + 'global batch size should be equal to sum of micro batch size per dp! ' \ + f'but get global batch size is {args.global_batch_size} and the sum of micro batch size per dp is {sum_of_micro_batch_sizes}.' + args.sum_micro_batch_sizes = sum_of_micro_batch_sizes + else: + args.num_micro_batches = None + args.data_parallel_splits = None +======= args.min_num_micro_batches = None assert sum(data_parallel_split) == args.data_parallel_size, \ 'the length of micro_batch_size_per_dp (equal to sum of n0, n1, ... ) should be equal to data-parallel-size.' @@ -413,6 +473,7 @@ def validate_args(args, defaults={}): args.recompute_num_layers_per_stage = recompute_num_layers_per_stage args.recompute_method_per_stage = recompute_method_per_stage +>>>>>>> BASE (6f9206 实现pp相关逻辑) # Batch size. assert args.micro_batch_size is not None @@ -810,7 +871,11 @@ def core_transformer_config_from_args(args, config_class=None): else: kw_args['num_query_groups'] = None if args.num_micro_batches_per_dp: +<<<<<<< PATCH SET (ae126d 实现dp相关逻辑) + kw_args['num_micro_batches_gard_factor'] = args.num_micro_batches_grad_factor +======= kw_args['num_micro_batches_gard_factor'] = args.num_micro_batches / float(args.sum_num_micro_batches) +>>>>>>> BASE (6f9206 实现pp相关逻辑) else: kw_args['num_micro_batches_gard_factor'] = 0 diff --git a/pretrain_llama.py b/pretrain_llama.py new file mode 100644 index 0000000..da129cd --- /dev/null +++ b/pretrain_llama.py @@ -0,0 +1,253 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +"""Pretrain GPT.""" + +import os +import torch +from functools import partial + +from typing import Union +from megatron.training import get_args +from megatron.training import print_rank_0 +from megatron.training import get_timers +from megatron.training import get_tokenizer +from megatron.core import mpu +from megatron.core.enums import ModelType +from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder +from megatron.core.datasets.utils import get_blend_from_list +from megatron.core.datasets.gpt_dataset import GPTDatasetConfig +from megatron.core.datasets.gpt_dataset import MockGPTDataset, GPTDataset +import megatron.legacy.model +from megatron.core.models.gpt import GPTModel +from megatron.training import pretrain +from megatron.core.utils import StragglerDetector +from megatron.core.transformer.spec_utils import import_module +from megatron.training.utils import ( + get_batch_on_this_cp_rank, + get_batch_on_this_tp_rank, +) +from megatron.training.arguments import core_transformer_config_from_args +from megatron.training.yaml_arguments import core_transformer_config_from_yaml +from megatron.core.models.gpt.gpt_layer_specs import ( + get_gpt_layer_local_spec, + get_gpt_layer_with_transformer_engine_spec, +) + + +stimer = StragglerDetector() + +def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megatron.legacy.model.GPTModel]: + """Builds the model. + + If you set the use_legacy_models to True, it will return the legacy GPT model and if not the mcore GPT model. + + Args: + pre_process (bool, optional): Set to true if you need to compute embedings. Defaults to True. + post_process (bool, optional): Set to true if you need to want to compute output logits/loss. Defaults to True. + + + Returns: + Union[GPTModel, megatron.legacy.model.GPTModel]: The returned model + """ + args = get_args() + use_te = args.transformer_impl == "transformer_engine" + + print_rank_0('building GPT model ...') + # Experimental loading arguments from yaml + if args.yaml_cfg is not None: + config = core_transformer_config_from_yaml(args, "language_model") + else: + config = core_transformer_config_from_args(args) + + if args.use_legacy_models: + model = megatron.legacy.model.GPTModel( + config, + num_tokentypes=0, + parallel_output=True, + pre_process=pre_process, + post_process=post_process, + ) + else: # using core models + if args.spec is not None: + transformer_layer_spec = import_module(args.spec) + else: + if use_te: + transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec(args.num_experts, args.moe_grouped_gemm, args.qk_layernorm) + else: + transformer_layer_spec = get_gpt_layer_local_spec(args.num_experts, args.moe_grouped_gemm, args.qk_layernorm) + + model = GPTModel( + config=config, + transformer_layer_spec=transformer_layer_spec, + vocab_size=args.padded_vocab_size, + max_sequence_length=args.max_position_embeddings, + pre_process=pre_process, + post_process=post_process, + fp16_lm_cross_entropy=args.fp16_lm_cross_entropy, + parallel_output=True, + share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights, + position_embedding_type=args.position_embedding_type, + rotary_percent=args.rotary_percent, + rotary_base=args.rotary_base + ) + + return model + + +def get_batch(data_iterator): + """Generate a batch.""" + + # TODO: this is pretty hacky, find a better way + if (not mpu.is_pipeline_first_stage()) and (not mpu.is_pipeline_last_stage()): + return None, None, None, None, None + + # get batches based on the TP rank you are on + batch = get_batch_on_this_tp_rank(data_iterator) + + # slice batch along sequence dimension for context parallelism + batch = get_batch_on_this_cp_rank(batch) + + return batch.values() + + +def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor): + """Loss function. + + Args: + loss_mask (torch.Tensor): Used to mask out some portions of the loss + output_tensor (torch.Tensor): The tensor with the losses + + Returns: + the loss scalar for this micro-batch + the number of non-padded tokens in this microbatch + a dict containing reporting metrics on the loss and number of tokens across + the data parallel ranks + """ + args = get_args() + + losses = output_tensor.float() + loss_mask = loss_mask.view(-1).float() + total_tokens = loss_mask.sum() + loss = torch.cat([torch.sum(losses.view(-1) * loss_mask).view(1), total_tokens.view(1)]) + + if args.context_parallel_size > 1: + torch.distributed.all_reduce(loss, group=mpu.get_context_parallel_group()) + + # Check individual rank losses are not NaN prior to DP all-reduce. + if args.check_for_nan_in_loss_and_grad: + global_rank = torch.distributed.get_rank() + assert not loss[0].isnan(), ( + f'Rank {global_rank}: found NaN in local forward loss calculation. ' + f'Device: {torch.cuda.current_device()}, node: {os.uname()[1]}' + ) + + # Reduce loss for logging. + reporting_loss = loss.clone().detach() + if args.num_micro_batches_per_dp is None: + torch.distributed.all_reduce(reporting_loss, group=mpu.get_data_parallel_group()) + + local_num_tokens = loss[1].clone().detach().to(torch.int) + return ( + loss[0] * args.context_parallel_size, + local_num_tokens, + {'lm loss': (reporting_loss[0], reporting_loss[1])}, + ) + + +def forward_step(data_iterator, model: GPTModel): + """Forward training step. + + Args: + data_iterator : Input data iterator + model (GPTModel): The GPT Model + """ + args = get_args() + timers = get_timers() + + # Get the batch. + timers('batch-generator', log_level=2).start() + global stimer + with stimer(bdata=True): + tokens, labels, loss_mask, attention_mask, position_ids = get_batch( + data_iterator) + timers('batch-generator').stop() + + with stimer: + output_tensor = model(tokens, position_ids, attention_mask, + labels=labels) + + return output_tensor, partial(loss_func, loss_mask) + + +def is_dataset_built_on_rank(): + return ( + mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage() + ) and mpu.get_tensor_model_parallel_rank() == 0 + + +def core_gpt_dataset_config_from_args(args): + tokenizer = get_tokenizer() + + return GPTDatasetConfig( + random_seed=args.seed, + sequence_length=args.seq_length, + blend=get_blend_from_list(args.data_path), + blend_per_split=[ + get_blend_from_list(args.train_data_path), + get_blend_from_list(args.valid_data_path), + get_blend_from_list(args.test_data_path) + ], + split=args.split, + num_dataset_builder_threads=args.num_dataset_builder_threads, + path_to_cache=args.data_cache_path, + mmap_bin_files=args.mmap_bin_files, + tokenizer=tokenizer, + reset_position_ids=args.reset_position_ids, + reset_attention_mask=args.reset_attention_mask, + eod_mask_loss=args.eod_mask_loss, + create_attention_mask=args.create_attention_mask_in_dataloader, + s3_cache_path = args.s3_cache_path + ) + + +def train_valid_test_datasets_provider(train_val_test_num_samples): + """Build the train test and validation datasets. + + Args: + train_val_test_num_samples : A list containing the number of samples in train test and validation. + """ + args = get_args() + + config = core_gpt_dataset_config_from_args(args) + + if args.mock_data: + dataset_type = MockGPTDataset + else: + dataset_type = GPTDataset + + print_rank_0("> building train, validation, and test datasets for GPT ...") + + train_ds, valid_ds, test_ds = BlendedMegatronDatasetBuilder( + dataset_type, + train_val_test_num_samples, + is_dataset_built_on_rank, + config + ).build() + + print_rank_0("> finished creating GPT datasets ...") + + return train_ds, valid_ds, test_ds + + +if __name__ == "__main__": + + # Temporary for transition to core datasets + train_valid_test_datasets_provider.is_distributed = True + + pretrain( + train_valid_test_datasets_provider, + model_provider, + ModelType.encoder_or_decoder, + forward_step, + args_defaults={'tokenizer_type': 'GPT2BPETokenizer'}, + ) +