实现pp相关逻辑

Change-Id: If203d8b02ebe3b4669a4333398661f476cc3abd4
This commit is contained in:
tianyutong
2025-05-23 10:09:01 +08:00
parent 36ec2b5d10
commit 6f920623fc
6 changed files with 253 additions and 11 deletions

View File

@@ -1,7 +1,7 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from dataclasses import dataclass
from typing import Callable, ContextManager, Optional
from typing import Callable, ContextManager, Optional, List
import torch
@@ -36,6 +36,33 @@ class ModelParallelConfig:
(https://arxiv.org/abs/2205.05198) for more details.
"""
hetero_pipeline_stages: Optional[List[List[int]]] = None
"""Incompatible with --num-layers-per-virtual-pipeline-stage.
hetero-pipeline-stages must be in the form:
n0 layers_0_0 layers_0_1 ... n1 nlayers_1_0 nlayers_1_1 ...
The order should be consistent with --hetero-device-types.
"""
recompute_granularity_per_stage: Optional[List[int]] = None
"""used with recompute-granularity=full, setting recompute granularity'
of each stage. This argument must be in the form: n0, flag0, n1, flag1,...'
the sum of n0, n1, ... should be equal to pipeline-model-parallel-size.'
granularity flag: 0 means turning off full recompute, 1 means turning on
"""
recompute_method_per_stage: Optional[List[int]] = None
"""used with recompute-granularity=full, setting recompute method
of each stage. This argument must be in the form: n0, method0, n1, method1, ...
the sum of n0, n1, ... should be equal to pipeline-model-parallel-size.
method: 0 means uniform, 1 means block
"""
recompute_num_layers_per_stage: Optional[List[int]] = None
"""used with recompute-granularity=full, setting recompute num layers
of each stage. This argument must be in the form: n0, layers0, n1, layers1, ...
the sum of n0, n1, ... should be equal to pipeline-model-parallel-size.
"""
context_parallel_size: int = 1
"""Splits network input along sequence dimension across GPU ranks."""

View File

@@ -52,6 +52,8 @@ def get_num_layers_to_build(config: TransformerConfig) -> int:
)
if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None:
assert config.hetero_pipeline_stages is None, \
"Heterogenous pipeline parallelism is not supported for virtual pipeline model parallel."
# Interleaved pipeline parallelism:
# Number of layers in each model chunk is the number of layers in the stage,
# divided by the number of model chunks in a stage.
@@ -73,8 +75,13 @@ def get_num_layers_to_build(config: TransformerConfig) -> int:
else:
# Non-interleaved pipeline parallelism:
# Each stage gets a contiguous set of layers.
num_layers_to_build = num_layers_per_pipeline_rank
if config.hetero_pipeline_stages is None:
num_layers_to_build = num_layers_per_pipeline_rank
else:
pipeline_rank = parallel_state.get_pipeline_model_parallel_rank()
pipeline_stages = [item for sublist in config.hetero_pipeline_stages for item in sublist]
num_layers_to_build = pipeline_stages[pipeline_rank]
torch.distributed.barrier()
return num_layers_to_build
@@ -139,6 +146,39 @@ class TransformerBlock(MegatronModule):
# required for pipeline parallel schedules
self.input_tensor = None
if self.config.recompute_method_per_stage is not None:
if self.config.virtual_pipeline_model_parallel_size is not None:
if self.config.recompute_method_per_stage[
parallel_state.get_virtual_pipeline_model_parallel_rank() *
self.config.pipeline_model_parallel_size +
parallel_state.get_pipeline_model_parallel_rank()] == 0:
self.config.recompute_method = 'uniform'
elif self.config.recompute_method_per_stage[
parallel_state.get_virtual_pipeline_model_parallel_rank() *
self.config.pipeline_model_parallel_size +
parallel_state.get_pipeline_model_parallel_rank()] == 1:
self.config.recompute_method = 'block'
else:
if self.config.recompute_method_per_stage[parallel_state.get_pipeline_model_parallel_rank()] == 0:
self.config.recompute_method = 'uniform'
elif self.config.recompute_method_per_stage[parallel_state.get_pipeline_model_parallel_rank()] == 1:
self.config.recompute_method = 'block'
if self.config.recompute_num_layers_per_stage is not None:
if self.config.virtual_pipeline_model_parallel_size is not None:
self.config.recompute_num_layers = self.config.recompute_num_layers_per_stage[
parallel_state.get_virtual_pipeline_model_parallel_rank() *
self.config.pipeline_model_parallel_size +
parallel_state.get_pipeline_model_parallel_rank()]
else:
self.config.recompute_num_layers = self.config.recompute_num_layers_per_stage[
parallel_state.get_pipeline_model_parallel_rank()]
if self.config.recompute_granularity_per_stage is not None and self.config.recompute_granularity_per_stage[
parallel_state.get_pipeline_model_parallel_rank()] == 0:
self.config.recompute_granularity = None
self.config.recompute_method = None
self.checkpoint_core_attention = self.config.recompute_granularity == 'selective'
if get_cpu_offload_context is not None:

View File

@@ -158,7 +158,6 @@ class TransformerConfig(ModelParallelConfig):
# activation recomputation
####################
recompute_granularity: str = None
recompute_granularity: str = None
"""Determines which type of activation recompute to use. Megatron-core supports 'selective'
activation checkpointing where only the memory intensive part of attention is checkpointed.
These memory intensive activations are also less compute intensive which makes activation

View File

@@ -139,6 +139,8 @@ class TransformerLayer(MegatronModule, BaseTransformerLayer):
)
if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None:
assert self.config.hetero_pipeline_stages is None, \
"Heterogenous pipeline parallelism is not supported for virtual pipeline model parallel."
vp_rank = parallel_state.get_virtual_pipeline_model_parallel_rank()
vp_size = parallel_state.get_virtual_pipeline_model_parallel_world_size()
@@ -150,7 +152,11 @@ class TransformerLayer(MegatronModule, BaseTransformerLayer):
else:
# Each stage gets a contiguous set of layers.
if parallel_state.get_pipeline_model_parallel_world_size() > 1:
offset = pipeline_rank * num_layers_per_pipeline_rank
if self.config.hetero_pipeline_stages is None:
offset = pipeline_rank * num_layers_per_pipeline_rank
else:
pipeline_stages = [item for sublist in self.config.hetero_pipeline_stages for item in sublist]
offset = sum(([0] + pipeline_stages)[: pipeline_rank + 1])
else:
offset = 0

View File

@@ -1360,6 +1360,19 @@ def _get_num_layers(args, model_type, is_decoder=False):
num_layers = args.decoder_num_layers
return num_layers
def _get_layer_info(args):
assert args.hetero_pipeline_stages is not None, "Only pipeline parallelism is supported."
pipeline_rank = mpu.get_pipeline_model_parallel_rank()
pipeline_stages = [item for sublist in args.hetero_pipeline_stages for item in sublist]
offset = sum(([0] + pipeline_stages)[: pipeline_rank + 1])
num_layers = pipeline_stages[pipeline_rank]
torch.distributed.barrier()
for i in range(torch.distributed.get_world_size()):
if i == torch.distributed.get_rank():
print("pipeline_rank:", pipeline_rank, "offset:", offset, "num_layers:", num_layers, flush=True)
torch.distributed.barrier()
return offset, num_layers
def _get_layer_type(model_type, default_layer_type, retro_layer_numbers,
layer_number):
@@ -1403,12 +1416,43 @@ class ParallelTransformer(MegatronModule):
self.retro_add_retriever = args.retro_add_retriever
# Store activation checkpoiting flag.
self.recompute_granularity = config.recompute_granularity
self.recompute_method = config.recompute_method
self.recompute_num_layers = config.recompute_num_layers
# self.recompute_granularity = config.recompute_granularity
# self.recompute_method = config.recompute_method
# self.recompute_num_layers = config.recompute_num_layers
if args.recompute_method_per_stage != None:
if args.virtual_pipeline_model_parallel_size != None:
if args.recompute_method_per_stage[
mpu.get_virtual_pipeline_model_parallel_rank() * args.pipeline_model_parallel_size + mpu.get_pipeline_model_parallel_rank()] == 0:
self.recompute_method = 'uniform'
elif args.recompute_method_per_stage[
mpu.get_virtual_pipeline_model_parallel_rank() * args.pipeline_model_parallel_size + mpu.get_pipeline_model_parallel_rank()] == 1:
self.recompute_method = 'block'
else:
if args.recompute_method_per_stage[mpu.get_pipeline_model_parallel_rank()] == 0:
self.recompute_method = 'uniform'
elif args.recompute_method_per_stage[mpu.get_pipeline_model_parallel_rank()] == 1:
self.recompute_method = 'block'
else:
self.recompute_method = config.recompute_method
if args.recompute_num_layers_per_stage != None:
if args.virtual_pipeline_model_parallel_size != None:
self.recompute_num_layers = args.recompute_num_layers_per_stage[
mpu.get_virtual_pipeline_model_parallel_rank() * args.pipeline_model_parallel_size + mpu.get_pipeline_model_parallel_rank()]
else:
self.recompute_num_layers = args.recompute_num_layers_per_stage[mpu.get_pipeline_model_parallel_rank()]
else:
self.recompute_num_layers = config.recompute_num_layers
self.distribute_saved_activations = \
config.distribute_saved_activations and not config.sequence_parallel
if args.recompute_granularity_per_stage != None and args.recompute_granularity_per_stage[
mpu.get_pipeline_model_parallel_rank()] == 0:
self.recompute_granularity = None
self.recompute_method = None
else:
self.recompute_granularity = config.recompute_granularity
self.sequence_parallel = config.sequence_parallel
# Transformer Engine Init.
@@ -1460,8 +1504,11 @@ class ParallelTransformer(MegatronModule):
self.checkpoint_core_attention = config.recompute_granularity == 'selective'
# Number of layers.
self.num_layers = _get_num_layers(args, model_type,
layer_type==LayerType.decoder)
if args.hetero_pipeline_stages is None:
self.num_layers = _get_num_layers(args, model_type,
layer_type == LayerType.decoder)
else:
offset, self.num_layers = _get_layer_info(args)
self.drop_path_rates = [
rate.item() for rate in
@@ -1541,6 +1588,8 @@ class ParallelTransformer(MegatronModule):
'num_layers_per_stage must be divisible by ' \
'virtual_pipeline_model_parallel_size'
assert args.model_type != ModelType.encoder_and_decoder
assert args.hetero_pipeline_stages is None, \
"Heterogenous pipeline parallelism is not supported for virtual pipeline model parallel."
# Number of layers in each model chunk is the number of layers in the stage,
# divided by the number of model chunks in a stage.
self.num_layers = self.num_layers // config.virtual_pipeline_model_parallel_size
@@ -1559,6 +1608,8 @@ class ParallelTransformer(MegatronModule):
# Each stage gets a contiguous set of layers.
if args.model_type == ModelType.encoder_and_decoder and \
mpu.get_pipeline_model_parallel_world_size() > 1:
assert args.hetero_pipeline_stages is None, \
"Heterogenous pipeline parallelism is not supported for encoder-decoder models."
pipeline_rank = mpu.get_pipeline_model_parallel_rank()
if layer_type == LayerType.encoder:
offset = pipeline_rank * self.num_layers
@@ -1566,7 +1617,10 @@ class ParallelTransformer(MegatronModule):
num_ranks_in_enc = args.pipeline_model_parallel_split_rank
offset = (pipeline_rank - num_ranks_in_enc) * self.num_layers
else:
offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers
if args.hetero_pipeline_stages is None:
offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers
else:
offset, self.num_layers = _get_layer_info(args)
if self.num_layers == 0:
# When a standalone embedding stage is used (e.g.,
@@ -1579,6 +1633,7 @@ class ParallelTransformer(MegatronModule):
# disconnect the input tensor from the output tensor.
self.num_layers = 1
self.layers = torch.nn.ModuleList([ NoopTransformerLayer(1) ])
self.recompute_granularity = None
else:
self.layers = torch.nn.ModuleList(
[build_layer(i + 1 + offset) for i in range(self.num_layers)])

View File

@@ -318,6 +318,102 @@ def validate_args(args, defaults={}):
args.num_micro_batches = None
args.data_parallel_splits = None
if args.hetero_pipeline_stages is not None:
assert args.micro_batch_size_per_dp is None, \
"micro_batch_size_per_dp should be None when use hetero_pipeline_stages"
args.hetero_data_parallel_splits = None
stages = []
hetero_pipeline_stages = []
hetero_pipeline_stage_splits = []
counter = 0
num_layers = 0
for item in args.hetero_pipeline_stages:
if counter == 0:
hetero_pipeline_stage_splits.append(item)
counter = item
else:
stages.append(item)
num_layers += item
counter -= 1
if counter == 0:
hetero_pipeline_stages.append(stages)
stages = []
args.hetero_pipeline_stages = hetero_pipeline_stages
args.hetero_pipeline_stage_splits = hetero_pipeline_stage_splits
for split, stages in zip(args.hetero_pipeline_stage_splits, args.hetero_pipeline_stages):
assert split == len(stages), \
f"hetero_pipeline_stage_split {split} should be equal to the length of hetero_pipeline_stage {stages}"
assert num_layers == args.num_layers, f"sum of hetero_pipeline_stages {sum} should be equal to num_layers {args.num_layers}"
assert args.pipeline_model_parallel_size == sum(args.hetero_pipeline_stage_splits), \
f"pipeline_model_parallel_size {args.pipeline_model_parallel_size} should be equal to the sum of hetero_pipeline_stage_splits {args.hetero_pipeline_stage_splits}"
# assert len(args.hetero_pipeline_stage_splits) == len(args.hetero_device_types), \
# f"length of hetero_pipeline_stage_splits {args.hetero_pipeline_stage_splits} should be equal to the length of hetero_device_types {args.hetero_device_types}"
if args.recompute_granularity_per_stage != None:
assert args.recompute_granularity == 'full', \
'recompute-granularity-per-stage is only'\
'application to full recompute granularity mode'
assert args.recompute_method is not None, \
'for distributed recompute activations to work you '\
'need to use a recompute method '
pipeline_size_split = args.recompute_granularity_per_stage[::2]
recompute_granularity_split = args.recompute_granularity_per_stage[1::2]
for i in recompute_granularity_split:
assert i == 1 or i == 0, 'element of recompute-granularity-per-stage must be 0 or 1.'
assert sum(pipeline_size_split) == args.pipeline_model_parallel_size, \
'recompute-granularity-per-stage setting:' \
'the sum of n0, n1, ... should be equal to pipeline-model-parallel-size.'
args.recompute_granularity_per_stage = [recompute_granularity_split[i] for i,j in enumerate(pipeline_size_split) for _ in range(j)]
if args.recompute_num_layers_per_stage != None:
assert args.recompute_granularity == 'full', \
'recompute-num-layers-per-stage is only'\
'application to full recompute granularity'
assert args.recompute_method_per_stage is not None, \
'recompute_method_per_stage must be used with '\
'recompute_num_layers_per_stage '
recompute_num_layers_stage_split = args.recompute_num_layers_per_stage[::2]
recompute_num_layers_layer_split = args.recompute_num_layers_per_stage[1::2]
recompute_methods_stage_split = args.recompute_method_per_stage[::2]
recompute_methods_method_split = args.recompute_method_per_stage[1::2]
assert len(recompute_num_layers_stage_split) == len(recompute_num_layers_layer_split), \
'args.recompute_num_layers_per_stage setting must match form: n0, layers0, n1, layers1, ...'
assert len(recompute_methods_stage_split) == len(recompute_methods_method_split), \
'args.recompute_method_per_stage setting must match form: n0, layers0, n1, layers1, ...'
if args.virtual_pipeline_model_parallel_size != None:
assert args.pipeline_model_parallel_size * args.virtual_pipeline_model_parallel_size == sum(recompute_num_layers_stage_split), \
'args.recompute_num_layers_per_stage setting:' \
'the sum of n0, n1, ... should be equal to pipeline-model-parallel-size * virtual_pipeline_model_parallel_size'
assert args.pipeline_model_parallel_size * args.virtual_pipeline_model_parallel_size == sum(recompute_methods_stage_split), \
'args.recompute_method_per_stage setting:' \
'the sum of n0, n1, ... should be equal to pipeline-model-parallel-size * virtual_pipeline_model_parallel_size'
else:
assert args.pipeline_model_parallel_size == sum(recompute_num_layers_stage_split), \
'args.recompute_num_layers_per_stage setting:' \
'the sum of n0, n1, ... should be equal to pipeline-model-parallel-size.'
assert args.pipeline_model_parallel_size == sum(recompute_methods_stage_split), \
'args.recompute_method_per_stage setting:' \
'the sum of n0, n1, ... should be equal to pipeline-model-parallel-size.'
recompute_num_layers_per_stage = []
for i in range(len(recompute_num_layers_stage_split)):
for j in range(recompute_num_layers_stage_split[i]):
recompute_num_layers_per_stage.append(recompute_num_layers_layer_split[i])
recompute_method_per_stage = []
for i in range(len(recompute_methods_stage_split)):
for j in range(recompute_methods_stage_split[i]):
recompute_method_per_stage.append(recompute_methods_method_split[i])
args.recompute_num_layers_per_stage = recompute_num_layers_per_stage
args.recompute_method_per_stage = recompute_method_per_stage
# Batch size.
assert args.micro_batch_size is not None
assert args.micro_batch_size > 0
@@ -1124,6 +1220,25 @@ def _add_training_args(parser):
'uniformly divided recompute unit, '
'2) block: the number of individual Transformer layers '
'to recompute within each pipeline stage.')
group.add_argument('--hetero-pipeline-stages', nargs='*', type=int, default=None,
help='Incompatible with --num-layers-per-virtual-pipeline-stage.'
'hetero-pipeline-stages must be in the form:'
'n0 layers_0_0 layers_0_1 ... n1 nlayers_1_0 nlayers_1_1 ...'
'The order should be consistent with --hetero-device-types.')
group.add_argument('--recompute-granularity-per-stage', nargs='*', type=int, default=None,
help='used with recompute-granularity=full, setting recompute granularity'
'of each stage. This argument must be in the form: n0, flag0, n1, flag1,...'
'the sum of n0, n1, ... should be equal to pipeline-model-parallel-size.'
'granularity flag: 0 means turning off full recompute, 1 means turning on')
group.add_argument('--recompute-method-per-stage', nargs='*', type=int, default=None,
help='used with recompute-granularity=full, setting recompute method '
'of each stage. This argument must be in the form: n0, method0, n1, method1, ...'
'the sum of n0, n1, ... should be equal to pipeline-model-parallel-size.'
'method: 0 means uniform, 1 means block')
group.add_argument('--recompute-num-layers-per-stage', nargs='*', type=int, default=None,
help='used with recompute-granularity=full, setting recompute num layers '
'of each stage. This argument must be in the form: n0, layers0, n1, layers1, ...'
'the sum of n0, n1, ... should be equal to pipeline-model-parallel-size.')
group.add_argument('--no-clone-scatter-output-in-embedding', action='store_false',
help='If not set, clone the output of the scatter in embedding layer to GC original tensor.',
dest='clone_scatter_output_in_embedding')