diff --git a/megatron/core/model_parallel_config.py b/megatron/core/model_parallel_config.py index d40912c..ed07dcc 100644 --- a/megatron/core/model_parallel_config.py +++ b/megatron/core/model_parallel_config.py @@ -1,7 +1,7 @@ # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. from dataclasses import dataclass -from typing import Callable, ContextManager, Optional +from typing import Callable, ContextManager, Optional, List import torch @@ -36,6 +36,33 @@ class ModelParallelConfig: (https://arxiv.org/abs/2205.05198) for more details. """ + hetero_pipeline_stages: Optional[List[List[int]]] = None + """Incompatible with --num-layers-per-virtual-pipeline-stage. + hetero-pipeline-stages must be in the form: + n0 layers_0_0 layers_0_1 ... n1 nlayers_1_0 nlayers_1_1 ... + The order should be consistent with --hetero-device-types. + """ + + recompute_granularity_per_stage: Optional[List[int]] = None + """used with recompute-granularity=full, setting recompute granularity' + of each stage. This argument must be in the form: n0, flag0, n1, flag1,...' + the sum of n0, n1, ... should be equal to pipeline-model-parallel-size.' + granularity flag: 0 means turning off full recompute, 1 means turning on + """ + + recompute_method_per_stage: Optional[List[int]] = None + """used with recompute-granularity=full, setting recompute method + of each stage. This argument must be in the form: n0, method0, n1, method1, ... + the sum of n0, n1, ... should be equal to pipeline-model-parallel-size. + method: 0 means uniform, 1 means block + """ + + recompute_num_layers_per_stage: Optional[List[int]] = None + """used with recompute-granularity=full, setting recompute num layers + of each stage. This argument must be in the form: n0, layers0, n1, layers1, ... + the sum of n0, n1, ... should be equal to pipeline-model-parallel-size. + """ + context_parallel_size: int = 1 """Splits network input along sequence dimension across GPU ranks.""" diff --git a/megatron/core/transformer/transformer_block.py b/megatron/core/transformer/transformer_block.py index b832221..032b3d0 100644 --- a/megatron/core/transformer/transformer_block.py +++ b/megatron/core/transformer/transformer_block.py @@ -52,6 +52,8 @@ def get_num_layers_to_build(config: TransformerConfig) -> int: ) if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None: + assert config.hetero_pipeline_stages is None, \ + "Heterogenous pipeline parallelism is not supported for virtual pipeline model parallel." # Interleaved pipeline parallelism: # Number of layers in each model chunk is the number of layers in the stage, # divided by the number of model chunks in a stage. @@ -73,8 +75,13 @@ def get_num_layers_to_build(config: TransformerConfig) -> int: else: # Non-interleaved pipeline parallelism: # Each stage gets a contiguous set of layers. - - num_layers_to_build = num_layers_per_pipeline_rank + if config.hetero_pipeline_stages is None: + num_layers_to_build = num_layers_per_pipeline_rank + else: + pipeline_rank = parallel_state.get_pipeline_model_parallel_rank() + pipeline_stages = [item for sublist in config.hetero_pipeline_stages for item in sublist] + num_layers_to_build = pipeline_stages[pipeline_rank] + torch.distributed.barrier() return num_layers_to_build @@ -139,6 +146,39 @@ class TransformerBlock(MegatronModule): # required for pipeline parallel schedules self.input_tensor = None + if self.config.recompute_method_per_stage is not None: + if self.config.virtual_pipeline_model_parallel_size is not None: + if self.config.recompute_method_per_stage[ + parallel_state.get_virtual_pipeline_model_parallel_rank() * + self.config.pipeline_model_parallel_size + + parallel_state.get_pipeline_model_parallel_rank()] == 0: + self.config.recompute_method = 'uniform' + elif self.config.recompute_method_per_stage[ + parallel_state.get_virtual_pipeline_model_parallel_rank() * + self.config.pipeline_model_parallel_size + + parallel_state.get_pipeline_model_parallel_rank()] == 1: + self.config.recompute_method = 'block' + else: + if self.config.recompute_method_per_stage[parallel_state.get_pipeline_model_parallel_rank()] == 0: + self.config.recompute_method = 'uniform' + elif self.config.recompute_method_per_stage[parallel_state.get_pipeline_model_parallel_rank()] == 1: + self.config.recompute_method = 'block' + + if self.config.recompute_num_layers_per_stage is not None: + if self.config.virtual_pipeline_model_parallel_size is not None: + self.config.recompute_num_layers = self.config.recompute_num_layers_per_stage[ + parallel_state.get_virtual_pipeline_model_parallel_rank() * + self.config.pipeline_model_parallel_size + + parallel_state.get_pipeline_model_parallel_rank()] + else: + self.config.recompute_num_layers = self.config.recompute_num_layers_per_stage[ + parallel_state.get_pipeline_model_parallel_rank()] + + if self.config.recompute_granularity_per_stage is not None and self.config.recompute_granularity_per_stage[ + parallel_state.get_pipeline_model_parallel_rank()] == 0: + self.config.recompute_granularity = None + self.config.recompute_method = None + self.checkpoint_core_attention = self.config.recompute_granularity == 'selective' if get_cpu_offload_context is not None: diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index f2c5f7c..329c907 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -158,7 +158,6 @@ class TransformerConfig(ModelParallelConfig): # activation recomputation #################### recompute_granularity: str = None - recompute_granularity: str = None """Determines which type of activation recompute to use. Megatron-core supports 'selective' activation checkpointing where only the memory intensive part of attention is checkpointed. These memory intensive activations are also less compute intensive which makes activation diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index 631179e..5885f1b 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -139,6 +139,8 @@ class TransformerLayer(MegatronModule, BaseTransformerLayer): ) if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None: + assert self.config.hetero_pipeline_stages is None, \ + "Heterogenous pipeline parallelism is not supported for virtual pipeline model parallel." vp_rank = parallel_state.get_virtual_pipeline_model_parallel_rank() vp_size = parallel_state.get_virtual_pipeline_model_parallel_world_size() @@ -150,7 +152,11 @@ class TransformerLayer(MegatronModule, BaseTransformerLayer): else: # Each stage gets a contiguous set of layers. if parallel_state.get_pipeline_model_parallel_world_size() > 1: - offset = pipeline_rank * num_layers_per_pipeline_rank + if self.config.hetero_pipeline_stages is None: + offset = pipeline_rank * num_layers_per_pipeline_rank + else: + pipeline_stages = [item for sublist in self.config.hetero_pipeline_stages for item in sublist] + offset = sum(([0] + pipeline_stages)[: pipeline_rank + 1]) else: offset = 0 diff --git a/megatron/legacy/model/transformer.py b/megatron/legacy/model/transformer.py index db46a72..9737a0e 100644 --- a/megatron/legacy/model/transformer.py +++ b/megatron/legacy/model/transformer.py @@ -1360,6 +1360,19 @@ def _get_num_layers(args, model_type, is_decoder=False): num_layers = args.decoder_num_layers return num_layers +def _get_layer_info(args): + assert args.hetero_pipeline_stages is not None, "Only pipeline parallelism is supported." + pipeline_rank = mpu.get_pipeline_model_parallel_rank() + pipeline_stages = [item for sublist in args.hetero_pipeline_stages for item in sublist] + offset = sum(([0] + pipeline_stages)[: pipeline_rank + 1]) + num_layers = pipeline_stages[pipeline_rank] + torch.distributed.barrier() + for i in range(torch.distributed.get_world_size()): + if i == torch.distributed.get_rank(): + print("pipeline_rank:", pipeline_rank, "offset:", offset, "num_layers:", num_layers, flush=True) + torch.distributed.barrier() + return offset, num_layers + def _get_layer_type(model_type, default_layer_type, retro_layer_numbers, layer_number): @@ -1403,12 +1416,43 @@ class ParallelTransformer(MegatronModule): self.retro_add_retriever = args.retro_add_retriever # Store activation checkpoiting flag. - self.recompute_granularity = config.recompute_granularity - self.recompute_method = config.recompute_method - self.recompute_num_layers = config.recompute_num_layers + # self.recompute_granularity = config.recompute_granularity + # self.recompute_method = config.recompute_method + # self.recompute_num_layers = config.recompute_num_layers + if args.recompute_method_per_stage != None: + if args.virtual_pipeline_model_parallel_size != None: + if args.recompute_method_per_stage[ + mpu.get_virtual_pipeline_model_parallel_rank() * args.pipeline_model_parallel_size + mpu.get_pipeline_model_parallel_rank()] == 0: + self.recompute_method = 'uniform' + elif args.recompute_method_per_stage[ + mpu.get_virtual_pipeline_model_parallel_rank() * args.pipeline_model_parallel_size + mpu.get_pipeline_model_parallel_rank()] == 1: + self.recompute_method = 'block' + else: + if args.recompute_method_per_stage[mpu.get_pipeline_model_parallel_rank()] == 0: + self.recompute_method = 'uniform' + elif args.recompute_method_per_stage[mpu.get_pipeline_model_parallel_rank()] == 1: + self.recompute_method = 'block' + else: + self.recompute_method = config.recompute_method + + if args.recompute_num_layers_per_stage != None: + if args.virtual_pipeline_model_parallel_size != None: + self.recompute_num_layers = args.recompute_num_layers_per_stage[ + mpu.get_virtual_pipeline_model_parallel_rank() * args.pipeline_model_parallel_size + mpu.get_pipeline_model_parallel_rank()] + else: + self.recompute_num_layers = args.recompute_num_layers_per_stage[mpu.get_pipeline_model_parallel_rank()] + else: + self.recompute_num_layers = config.recompute_num_layers self.distribute_saved_activations = \ config.distribute_saved_activations and not config.sequence_parallel + if args.recompute_granularity_per_stage != None and args.recompute_granularity_per_stage[ + mpu.get_pipeline_model_parallel_rank()] == 0: + self.recompute_granularity = None + self.recompute_method = None + else: + self.recompute_granularity = config.recompute_granularity + self.sequence_parallel = config.sequence_parallel # Transformer Engine Init. @@ -1460,8 +1504,11 @@ class ParallelTransformer(MegatronModule): self.checkpoint_core_attention = config.recompute_granularity == 'selective' # Number of layers. - self.num_layers = _get_num_layers(args, model_type, - layer_type==LayerType.decoder) + if args.hetero_pipeline_stages is None: + self.num_layers = _get_num_layers(args, model_type, + layer_type == LayerType.decoder) + else: + offset, self.num_layers = _get_layer_info(args) self.drop_path_rates = [ rate.item() for rate in @@ -1541,6 +1588,8 @@ class ParallelTransformer(MegatronModule): 'num_layers_per_stage must be divisible by ' \ 'virtual_pipeline_model_parallel_size' assert args.model_type != ModelType.encoder_and_decoder + assert args.hetero_pipeline_stages is None, \ + "Heterogenous pipeline parallelism is not supported for virtual pipeline model parallel." # Number of layers in each model chunk is the number of layers in the stage, # divided by the number of model chunks in a stage. self.num_layers = self.num_layers // config.virtual_pipeline_model_parallel_size @@ -1559,6 +1608,8 @@ class ParallelTransformer(MegatronModule): # Each stage gets a contiguous set of layers. if args.model_type == ModelType.encoder_and_decoder and \ mpu.get_pipeline_model_parallel_world_size() > 1: + assert args.hetero_pipeline_stages is None, \ + "Heterogenous pipeline parallelism is not supported for encoder-decoder models." pipeline_rank = mpu.get_pipeline_model_parallel_rank() if layer_type == LayerType.encoder: offset = pipeline_rank * self.num_layers @@ -1566,7 +1617,10 @@ class ParallelTransformer(MegatronModule): num_ranks_in_enc = args.pipeline_model_parallel_split_rank offset = (pipeline_rank - num_ranks_in_enc) * self.num_layers else: - offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers + if args.hetero_pipeline_stages is None: + offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers + else: + offset, self.num_layers = _get_layer_info(args) if self.num_layers == 0: # When a standalone embedding stage is used (e.g., @@ -1579,6 +1633,7 @@ class ParallelTransformer(MegatronModule): # disconnect the input tensor from the output tensor. self.num_layers = 1 self.layers = torch.nn.ModuleList([ NoopTransformerLayer(1) ]) + self.recompute_granularity = None else: self.layers = torch.nn.ModuleList( [build_layer(i + 1 + offset) for i in range(self.num_layers)]) diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index fae7d3c..962bbca 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -318,6 +318,102 @@ def validate_args(args, defaults={}): args.num_micro_batches = None args.data_parallel_splits = None + if args.hetero_pipeline_stages is not None: + assert args.micro_batch_size_per_dp is None, \ + "micro_batch_size_per_dp should be None when use hetero_pipeline_stages" + args.hetero_data_parallel_splits = None + + stages = [] + hetero_pipeline_stages = [] + hetero_pipeline_stage_splits = [] + counter = 0 + num_layers = 0 + for item in args.hetero_pipeline_stages: + if counter == 0: + hetero_pipeline_stage_splits.append(item) + counter = item + else: + stages.append(item) + num_layers += item + counter -= 1 + if counter == 0: + hetero_pipeline_stages.append(stages) + stages = [] + args.hetero_pipeline_stages = hetero_pipeline_stages + args.hetero_pipeline_stage_splits = hetero_pipeline_stage_splits + + for split, stages in zip(args.hetero_pipeline_stage_splits, args.hetero_pipeline_stages): + assert split == len(stages), \ + f"hetero_pipeline_stage_split {split} should be equal to the length of hetero_pipeline_stage {stages}" + assert num_layers == args.num_layers, f"sum of hetero_pipeline_stages {sum} should be equal to num_layers {args.num_layers}" + assert args.pipeline_model_parallel_size == sum(args.hetero_pipeline_stage_splits), \ + f"pipeline_model_parallel_size {args.pipeline_model_parallel_size} should be equal to the sum of hetero_pipeline_stage_splits {args.hetero_pipeline_stage_splits}" + # assert len(args.hetero_pipeline_stage_splits) == len(args.hetero_device_types), \ + # f"length of hetero_pipeline_stage_splits {args.hetero_pipeline_stage_splits} should be equal to the length of hetero_device_types {args.hetero_device_types}" + + + if args.recompute_granularity_per_stage != None: + assert args.recompute_granularity == 'full', \ + 'recompute-granularity-per-stage is only'\ + 'application to full recompute granularity mode' + assert args.recompute_method is not None, \ + 'for distributed recompute activations to work you '\ + 'need to use a recompute method ' + + pipeline_size_split = args.recompute_granularity_per_stage[::2] + recompute_granularity_split = args.recompute_granularity_per_stage[1::2] + + for i in recompute_granularity_split: + assert i == 1 or i == 0, 'element of recompute-granularity-per-stage must be 0 or 1.' + assert sum(pipeline_size_split) == args.pipeline_model_parallel_size, \ + 'recompute-granularity-per-stage setting:' \ + 'the sum of n0, n1, ... should be equal to pipeline-model-parallel-size.' + args.recompute_granularity_per_stage = [recompute_granularity_split[i] for i,j in enumerate(pipeline_size_split) for _ in range(j)] + + if args.recompute_num_layers_per_stage != None: + assert args.recompute_granularity == 'full', \ + 'recompute-num-layers-per-stage is only'\ + 'application to full recompute granularity' + assert args.recompute_method_per_stage is not None, \ + 'recompute_method_per_stage must be used with '\ + 'recompute_num_layers_per_stage ' + + recompute_num_layers_stage_split = args.recompute_num_layers_per_stage[::2] + recompute_num_layers_layer_split = args.recompute_num_layers_per_stage[1::2] + recompute_methods_stage_split = args.recompute_method_per_stage[::2] + recompute_methods_method_split = args.recompute_method_per_stage[1::2] + + assert len(recompute_num_layers_stage_split) == len(recompute_num_layers_layer_split), \ + 'args.recompute_num_layers_per_stage setting must match form: n0, layers0, n1, layers1, ...' + assert len(recompute_methods_stage_split) == len(recompute_methods_method_split), \ + 'args.recompute_method_per_stage setting must match form: n0, layers0, n1, layers1, ...' + if args.virtual_pipeline_model_parallel_size != None: + assert args.pipeline_model_parallel_size * args.virtual_pipeline_model_parallel_size == sum(recompute_num_layers_stage_split), \ + 'args.recompute_num_layers_per_stage setting:' \ + 'the sum of n0, n1, ... should be equal to pipeline-model-parallel-size * virtual_pipeline_model_parallel_size' + assert args.pipeline_model_parallel_size * args.virtual_pipeline_model_parallel_size == sum(recompute_methods_stage_split), \ + 'args.recompute_method_per_stage setting:' \ + 'the sum of n0, n1, ... should be equal to pipeline-model-parallel-size * virtual_pipeline_model_parallel_size' + else: + assert args.pipeline_model_parallel_size == sum(recompute_num_layers_stage_split), \ + 'args.recompute_num_layers_per_stage setting:' \ + 'the sum of n0, n1, ... should be equal to pipeline-model-parallel-size.' + assert args.pipeline_model_parallel_size == sum(recompute_methods_stage_split), \ + 'args.recompute_method_per_stage setting:' \ + 'the sum of n0, n1, ... should be equal to pipeline-model-parallel-size.' + + recompute_num_layers_per_stage = [] + for i in range(len(recompute_num_layers_stage_split)): + for j in range(recompute_num_layers_stage_split[i]): + recompute_num_layers_per_stage.append(recompute_num_layers_layer_split[i]) + recompute_method_per_stage = [] + for i in range(len(recompute_methods_stage_split)): + for j in range(recompute_methods_stage_split[i]): + recompute_method_per_stage.append(recompute_methods_method_split[i]) + + args.recompute_num_layers_per_stage = recompute_num_layers_per_stage + args.recompute_method_per_stage = recompute_method_per_stage + # Batch size. assert args.micro_batch_size is not None assert args.micro_batch_size > 0 @@ -1124,6 +1220,25 @@ def _add_training_args(parser): 'uniformly divided recompute unit, ' '2) block: the number of individual Transformer layers ' 'to recompute within each pipeline stage.') + group.add_argument('--hetero-pipeline-stages', nargs='*', type=int, default=None, + help='Incompatible with --num-layers-per-virtual-pipeline-stage.' + 'hetero-pipeline-stages must be in the form:' + 'n0 layers_0_0 layers_0_1 ... n1 nlayers_1_0 nlayers_1_1 ...' + 'The order should be consistent with --hetero-device-types.') + group.add_argument('--recompute-granularity-per-stage', nargs='*', type=int, default=None, + help='used with recompute-granularity=full, setting recompute granularity' + 'of each stage. This argument must be in the form: n0, flag0, n1, flag1,...' + 'the sum of n0, n1, ... should be equal to pipeline-model-parallel-size.' + 'granularity flag: 0 means turning off full recompute, 1 means turning on') + group.add_argument('--recompute-method-per-stage', nargs='*', type=int, default=None, + help='used with recompute-granularity=full, setting recompute method ' + 'of each stage. This argument must be in the form: n0, method0, n1, method1, ...' + 'the sum of n0, n1, ... should be equal to pipeline-model-parallel-size.' + 'method: 0 means uniform, 1 means block') + group.add_argument('--recompute-num-layers-per-stage', nargs='*', type=int, default=None, + help='used with recompute-granularity=full, setting recompute num layers ' + 'of each stage. This argument must be in the form: n0, layers0, n1, layers1, ...' + 'the sum of n0, n1, ... should be equal to pipeline-model-parallel-size.') group.add_argument('--no-clone-scatter-output-in-embedding', action='store_false', help='If not set, clone the output of the scatter in embedding layer to GC original tensor.', dest='clone_scatter_output_in_embedding')