实现pp相关逻辑
Change-Id: If203d8b02ebe3b4669a4333398661f476cc3abd4
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, ContextManager, Optional
|
||||
from typing import Callable, ContextManager, Optional, List
|
||||
|
||||
import torch
|
||||
|
||||
@@ -36,6 +36,33 @@ class ModelParallelConfig:
|
||||
(https://arxiv.org/abs/2205.05198) for more details.
|
||||
"""
|
||||
|
||||
hetero_pipeline_stages: Optional[List[List[int]]] = None
|
||||
"""Incompatible with --num-layers-per-virtual-pipeline-stage.
|
||||
hetero-pipeline-stages must be in the form:
|
||||
n0 layers_0_0 layers_0_1 ... n1 nlayers_1_0 nlayers_1_1 ...
|
||||
The order should be consistent with --hetero-device-types.
|
||||
"""
|
||||
|
||||
recompute_granularity_per_stage: Optional[List[int]] = None
|
||||
"""used with recompute-granularity=full, setting recompute granularity'
|
||||
of each stage. This argument must be in the form: n0, flag0, n1, flag1,...'
|
||||
the sum of n0, n1, ... should be equal to pipeline-model-parallel-size.'
|
||||
granularity flag: 0 means turning off full recompute, 1 means turning on
|
||||
"""
|
||||
|
||||
recompute_method_per_stage: Optional[List[int]] = None
|
||||
"""used with recompute-granularity=full, setting recompute method
|
||||
of each stage. This argument must be in the form: n0, method0, n1, method1, ...
|
||||
the sum of n0, n1, ... should be equal to pipeline-model-parallel-size.
|
||||
method: 0 means uniform, 1 means block
|
||||
"""
|
||||
|
||||
recompute_num_layers_per_stage: Optional[List[int]] = None
|
||||
"""used with recompute-granularity=full, setting recompute num layers
|
||||
of each stage. This argument must be in the form: n0, layers0, n1, layers1, ...
|
||||
the sum of n0, n1, ... should be equal to pipeline-model-parallel-size.
|
||||
"""
|
||||
|
||||
context_parallel_size: int = 1
|
||||
"""Splits network input along sequence dimension across GPU ranks."""
|
||||
|
||||
|
@@ -52,6 +52,8 @@ def get_num_layers_to_build(config: TransformerConfig) -> int:
|
||||
)
|
||||
|
||||
if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None:
|
||||
assert config.hetero_pipeline_stages is None, \
|
||||
"Heterogenous pipeline parallelism is not supported for virtual pipeline model parallel."
|
||||
# Interleaved pipeline parallelism:
|
||||
# Number of layers in each model chunk is the number of layers in the stage,
|
||||
# divided by the number of model chunks in a stage.
|
||||
@@ -73,8 +75,13 @@ def get_num_layers_to_build(config: TransformerConfig) -> int:
|
||||
else:
|
||||
# Non-interleaved pipeline parallelism:
|
||||
# Each stage gets a contiguous set of layers.
|
||||
|
||||
num_layers_to_build = num_layers_per_pipeline_rank
|
||||
if config.hetero_pipeline_stages is None:
|
||||
num_layers_to_build = num_layers_per_pipeline_rank
|
||||
else:
|
||||
pipeline_rank = parallel_state.get_pipeline_model_parallel_rank()
|
||||
pipeline_stages = [item for sublist in config.hetero_pipeline_stages for item in sublist]
|
||||
num_layers_to_build = pipeline_stages[pipeline_rank]
|
||||
torch.distributed.barrier()
|
||||
|
||||
return num_layers_to_build
|
||||
|
||||
@@ -139,6 +146,39 @@ class TransformerBlock(MegatronModule):
|
||||
# required for pipeline parallel schedules
|
||||
self.input_tensor = None
|
||||
|
||||
if self.config.recompute_method_per_stage is not None:
|
||||
if self.config.virtual_pipeline_model_parallel_size is not None:
|
||||
if self.config.recompute_method_per_stage[
|
||||
parallel_state.get_virtual_pipeline_model_parallel_rank() *
|
||||
self.config.pipeline_model_parallel_size +
|
||||
parallel_state.get_pipeline_model_parallel_rank()] == 0:
|
||||
self.config.recompute_method = 'uniform'
|
||||
elif self.config.recompute_method_per_stage[
|
||||
parallel_state.get_virtual_pipeline_model_parallel_rank() *
|
||||
self.config.pipeline_model_parallel_size +
|
||||
parallel_state.get_pipeline_model_parallel_rank()] == 1:
|
||||
self.config.recompute_method = 'block'
|
||||
else:
|
||||
if self.config.recompute_method_per_stage[parallel_state.get_pipeline_model_parallel_rank()] == 0:
|
||||
self.config.recompute_method = 'uniform'
|
||||
elif self.config.recompute_method_per_stage[parallel_state.get_pipeline_model_parallel_rank()] == 1:
|
||||
self.config.recompute_method = 'block'
|
||||
|
||||
if self.config.recompute_num_layers_per_stage is not None:
|
||||
if self.config.virtual_pipeline_model_parallel_size is not None:
|
||||
self.config.recompute_num_layers = self.config.recompute_num_layers_per_stage[
|
||||
parallel_state.get_virtual_pipeline_model_parallel_rank() *
|
||||
self.config.pipeline_model_parallel_size +
|
||||
parallel_state.get_pipeline_model_parallel_rank()]
|
||||
else:
|
||||
self.config.recompute_num_layers = self.config.recompute_num_layers_per_stage[
|
||||
parallel_state.get_pipeline_model_parallel_rank()]
|
||||
|
||||
if self.config.recompute_granularity_per_stage is not None and self.config.recompute_granularity_per_stage[
|
||||
parallel_state.get_pipeline_model_parallel_rank()] == 0:
|
||||
self.config.recompute_granularity = None
|
||||
self.config.recompute_method = None
|
||||
|
||||
self.checkpoint_core_attention = self.config.recompute_granularity == 'selective'
|
||||
|
||||
if get_cpu_offload_context is not None:
|
||||
|
@@ -158,7 +158,6 @@ class TransformerConfig(ModelParallelConfig):
|
||||
# activation recomputation
|
||||
####################
|
||||
recompute_granularity: str = None
|
||||
recompute_granularity: str = None
|
||||
"""Determines which type of activation recompute to use. Megatron-core supports 'selective'
|
||||
activation checkpointing where only the memory intensive part of attention is checkpointed.
|
||||
These memory intensive activations are also less compute intensive which makes activation
|
||||
|
@@ -139,6 +139,8 @@ class TransformerLayer(MegatronModule, BaseTransformerLayer):
|
||||
)
|
||||
|
||||
if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None:
|
||||
assert self.config.hetero_pipeline_stages is None, \
|
||||
"Heterogenous pipeline parallelism is not supported for virtual pipeline model parallel."
|
||||
vp_rank = parallel_state.get_virtual_pipeline_model_parallel_rank()
|
||||
vp_size = parallel_state.get_virtual_pipeline_model_parallel_world_size()
|
||||
|
||||
@@ -150,7 +152,11 @@ class TransformerLayer(MegatronModule, BaseTransformerLayer):
|
||||
else:
|
||||
# Each stage gets a contiguous set of layers.
|
||||
if parallel_state.get_pipeline_model_parallel_world_size() > 1:
|
||||
offset = pipeline_rank * num_layers_per_pipeline_rank
|
||||
if self.config.hetero_pipeline_stages is None:
|
||||
offset = pipeline_rank * num_layers_per_pipeline_rank
|
||||
else:
|
||||
pipeline_stages = [item for sublist in self.config.hetero_pipeline_stages for item in sublist]
|
||||
offset = sum(([0] + pipeline_stages)[: pipeline_rank + 1])
|
||||
else:
|
||||
offset = 0
|
||||
|
||||
|
@@ -1360,6 +1360,19 @@ def _get_num_layers(args, model_type, is_decoder=False):
|
||||
num_layers = args.decoder_num_layers
|
||||
return num_layers
|
||||
|
||||
def _get_layer_info(args):
|
||||
assert args.hetero_pipeline_stages is not None, "Only pipeline parallelism is supported."
|
||||
pipeline_rank = mpu.get_pipeline_model_parallel_rank()
|
||||
pipeline_stages = [item for sublist in args.hetero_pipeline_stages for item in sublist]
|
||||
offset = sum(([0] + pipeline_stages)[: pipeline_rank + 1])
|
||||
num_layers = pipeline_stages[pipeline_rank]
|
||||
torch.distributed.barrier()
|
||||
for i in range(torch.distributed.get_world_size()):
|
||||
if i == torch.distributed.get_rank():
|
||||
print("pipeline_rank:", pipeline_rank, "offset:", offset, "num_layers:", num_layers, flush=True)
|
||||
torch.distributed.barrier()
|
||||
return offset, num_layers
|
||||
|
||||
|
||||
def _get_layer_type(model_type, default_layer_type, retro_layer_numbers,
|
||||
layer_number):
|
||||
@@ -1403,12 +1416,43 @@ class ParallelTransformer(MegatronModule):
|
||||
self.retro_add_retriever = args.retro_add_retriever
|
||||
|
||||
# Store activation checkpoiting flag.
|
||||
self.recompute_granularity = config.recompute_granularity
|
||||
self.recompute_method = config.recompute_method
|
||||
self.recompute_num_layers = config.recompute_num_layers
|
||||
# self.recompute_granularity = config.recompute_granularity
|
||||
# self.recompute_method = config.recompute_method
|
||||
# self.recompute_num_layers = config.recompute_num_layers
|
||||
if args.recompute_method_per_stage != None:
|
||||
if args.virtual_pipeline_model_parallel_size != None:
|
||||
if args.recompute_method_per_stage[
|
||||
mpu.get_virtual_pipeline_model_parallel_rank() * args.pipeline_model_parallel_size + mpu.get_pipeline_model_parallel_rank()] == 0:
|
||||
self.recompute_method = 'uniform'
|
||||
elif args.recompute_method_per_stage[
|
||||
mpu.get_virtual_pipeline_model_parallel_rank() * args.pipeline_model_parallel_size + mpu.get_pipeline_model_parallel_rank()] == 1:
|
||||
self.recompute_method = 'block'
|
||||
else:
|
||||
if args.recompute_method_per_stage[mpu.get_pipeline_model_parallel_rank()] == 0:
|
||||
self.recompute_method = 'uniform'
|
||||
elif args.recompute_method_per_stage[mpu.get_pipeline_model_parallel_rank()] == 1:
|
||||
self.recompute_method = 'block'
|
||||
else:
|
||||
self.recompute_method = config.recompute_method
|
||||
|
||||
if args.recompute_num_layers_per_stage != None:
|
||||
if args.virtual_pipeline_model_parallel_size != None:
|
||||
self.recompute_num_layers = args.recompute_num_layers_per_stage[
|
||||
mpu.get_virtual_pipeline_model_parallel_rank() * args.pipeline_model_parallel_size + mpu.get_pipeline_model_parallel_rank()]
|
||||
else:
|
||||
self.recompute_num_layers = args.recompute_num_layers_per_stage[mpu.get_pipeline_model_parallel_rank()]
|
||||
else:
|
||||
self.recompute_num_layers = config.recompute_num_layers
|
||||
self.distribute_saved_activations = \
|
||||
config.distribute_saved_activations and not config.sequence_parallel
|
||||
|
||||
if args.recompute_granularity_per_stage != None and args.recompute_granularity_per_stage[
|
||||
mpu.get_pipeline_model_parallel_rank()] == 0:
|
||||
self.recompute_granularity = None
|
||||
self.recompute_method = None
|
||||
else:
|
||||
self.recompute_granularity = config.recompute_granularity
|
||||
|
||||
self.sequence_parallel = config.sequence_parallel
|
||||
|
||||
# Transformer Engine Init.
|
||||
@@ -1460,8 +1504,11 @@ class ParallelTransformer(MegatronModule):
|
||||
self.checkpoint_core_attention = config.recompute_granularity == 'selective'
|
||||
|
||||
# Number of layers.
|
||||
self.num_layers = _get_num_layers(args, model_type,
|
||||
layer_type==LayerType.decoder)
|
||||
if args.hetero_pipeline_stages is None:
|
||||
self.num_layers = _get_num_layers(args, model_type,
|
||||
layer_type == LayerType.decoder)
|
||||
else:
|
||||
offset, self.num_layers = _get_layer_info(args)
|
||||
|
||||
self.drop_path_rates = [
|
||||
rate.item() for rate in
|
||||
@@ -1541,6 +1588,8 @@ class ParallelTransformer(MegatronModule):
|
||||
'num_layers_per_stage must be divisible by ' \
|
||||
'virtual_pipeline_model_parallel_size'
|
||||
assert args.model_type != ModelType.encoder_and_decoder
|
||||
assert args.hetero_pipeline_stages is None, \
|
||||
"Heterogenous pipeline parallelism is not supported for virtual pipeline model parallel."
|
||||
# Number of layers in each model chunk is the number of layers in the stage,
|
||||
# divided by the number of model chunks in a stage.
|
||||
self.num_layers = self.num_layers // config.virtual_pipeline_model_parallel_size
|
||||
@@ -1559,6 +1608,8 @@ class ParallelTransformer(MegatronModule):
|
||||
# Each stage gets a contiguous set of layers.
|
||||
if args.model_type == ModelType.encoder_and_decoder and \
|
||||
mpu.get_pipeline_model_parallel_world_size() > 1:
|
||||
assert args.hetero_pipeline_stages is None, \
|
||||
"Heterogenous pipeline parallelism is not supported for encoder-decoder models."
|
||||
pipeline_rank = mpu.get_pipeline_model_parallel_rank()
|
||||
if layer_type == LayerType.encoder:
|
||||
offset = pipeline_rank * self.num_layers
|
||||
@@ -1566,7 +1617,10 @@ class ParallelTransformer(MegatronModule):
|
||||
num_ranks_in_enc = args.pipeline_model_parallel_split_rank
|
||||
offset = (pipeline_rank - num_ranks_in_enc) * self.num_layers
|
||||
else:
|
||||
offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers
|
||||
if args.hetero_pipeline_stages is None:
|
||||
offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers
|
||||
else:
|
||||
offset, self.num_layers = _get_layer_info(args)
|
||||
|
||||
if self.num_layers == 0:
|
||||
# When a standalone embedding stage is used (e.g.,
|
||||
@@ -1579,6 +1633,7 @@ class ParallelTransformer(MegatronModule):
|
||||
# disconnect the input tensor from the output tensor.
|
||||
self.num_layers = 1
|
||||
self.layers = torch.nn.ModuleList([ NoopTransformerLayer(1) ])
|
||||
self.recompute_granularity = None
|
||||
else:
|
||||
self.layers = torch.nn.ModuleList(
|
||||
[build_layer(i + 1 + offset) for i in range(self.num_layers)])
|
||||
|
@@ -318,6 +318,102 @@ def validate_args(args, defaults={}):
|
||||
args.num_micro_batches = None
|
||||
args.data_parallel_splits = None
|
||||
|
||||
if args.hetero_pipeline_stages is not None:
|
||||
assert args.micro_batch_size_per_dp is None, \
|
||||
"micro_batch_size_per_dp should be None when use hetero_pipeline_stages"
|
||||
args.hetero_data_parallel_splits = None
|
||||
|
||||
stages = []
|
||||
hetero_pipeline_stages = []
|
||||
hetero_pipeline_stage_splits = []
|
||||
counter = 0
|
||||
num_layers = 0
|
||||
for item in args.hetero_pipeline_stages:
|
||||
if counter == 0:
|
||||
hetero_pipeline_stage_splits.append(item)
|
||||
counter = item
|
||||
else:
|
||||
stages.append(item)
|
||||
num_layers += item
|
||||
counter -= 1
|
||||
if counter == 0:
|
||||
hetero_pipeline_stages.append(stages)
|
||||
stages = []
|
||||
args.hetero_pipeline_stages = hetero_pipeline_stages
|
||||
args.hetero_pipeline_stage_splits = hetero_pipeline_stage_splits
|
||||
|
||||
for split, stages in zip(args.hetero_pipeline_stage_splits, args.hetero_pipeline_stages):
|
||||
assert split == len(stages), \
|
||||
f"hetero_pipeline_stage_split {split} should be equal to the length of hetero_pipeline_stage {stages}"
|
||||
assert num_layers == args.num_layers, f"sum of hetero_pipeline_stages {sum} should be equal to num_layers {args.num_layers}"
|
||||
assert args.pipeline_model_parallel_size == sum(args.hetero_pipeline_stage_splits), \
|
||||
f"pipeline_model_parallel_size {args.pipeline_model_parallel_size} should be equal to the sum of hetero_pipeline_stage_splits {args.hetero_pipeline_stage_splits}"
|
||||
# assert len(args.hetero_pipeline_stage_splits) == len(args.hetero_device_types), \
|
||||
# f"length of hetero_pipeline_stage_splits {args.hetero_pipeline_stage_splits} should be equal to the length of hetero_device_types {args.hetero_device_types}"
|
||||
|
||||
|
||||
if args.recompute_granularity_per_stage != None:
|
||||
assert args.recompute_granularity == 'full', \
|
||||
'recompute-granularity-per-stage is only'\
|
||||
'application to full recompute granularity mode'
|
||||
assert args.recompute_method is not None, \
|
||||
'for distributed recompute activations to work you '\
|
||||
'need to use a recompute method '
|
||||
|
||||
pipeline_size_split = args.recompute_granularity_per_stage[::2]
|
||||
recompute_granularity_split = args.recompute_granularity_per_stage[1::2]
|
||||
|
||||
for i in recompute_granularity_split:
|
||||
assert i == 1 or i == 0, 'element of recompute-granularity-per-stage must be 0 or 1.'
|
||||
assert sum(pipeline_size_split) == args.pipeline_model_parallel_size, \
|
||||
'recompute-granularity-per-stage setting:' \
|
||||
'the sum of n0, n1, ... should be equal to pipeline-model-parallel-size.'
|
||||
args.recompute_granularity_per_stage = [recompute_granularity_split[i] for i,j in enumerate(pipeline_size_split) for _ in range(j)]
|
||||
|
||||
if args.recompute_num_layers_per_stage != None:
|
||||
assert args.recompute_granularity == 'full', \
|
||||
'recompute-num-layers-per-stage is only'\
|
||||
'application to full recompute granularity'
|
||||
assert args.recompute_method_per_stage is not None, \
|
||||
'recompute_method_per_stage must be used with '\
|
||||
'recompute_num_layers_per_stage '
|
||||
|
||||
recompute_num_layers_stage_split = args.recompute_num_layers_per_stage[::2]
|
||||
recompute_num_layers_layer_split = args.recompute_num_layers_per_stage[1::2]
|
||||
recompute_methods_stage_split = args.recompute_method_per_stage[::2]
|
||||
recompute_methods_method_split = args.recompute_method_per_stage[1::2]
|
||||
|
||||
assert len(recompute_num_layers_stage_split) == len(recompute_num_layers_layer_split), \
|
||||
'args.recompute_num_layers_per_stage setting must match form: n0, layers0, n1, layers1, ...'
|
||||
assert len(recompute_methods_stage_split) == len(recompute_methods_method_split), \
|
||||
'args.recompute_method_per_stage setting must match form: n0, layers0, n1, layers1, ...'
|
||||
if args.virtual_pipeline_model_parallel_size != None:
|
||||
assert args.pipeline_model_parallel_size * args.virtual_pipeline_model_parallel_size == sum(recompute_num_layers_stage_split), \
|
||||
'args.recompute_num_layers_per_stage setting:' \
|
||||
'the sum of n0, n1, ... should be equal to pipeline-model-parallel-size * virtual_pipeline_model_parallel_size'
|
||||
assert args.pipeline_model_parallel_size * args.virtual_pipeline_model_parallel_size == sum(recompute_methods_stage_split), \
|
||||
'args.recompute_method_per_stage setting:' \
|
||||
'the sum of n0, n1, ... should be equal to pipeline-model-parallel-size * virtual_pipeline_model_parallel_size'
|
||||
else:
|
||||
assert args.pipeline_model_parallel_size == sum(recompute_num_layers_stage_split), \
|
||||
'args.recompute_num_layers_per_stage setting:' \
|
||||
'the sum of n0, n1, ... should be equal to pipeline-model-parallel-size.'
|
||||
assert args.pipeline_model_parallel_size == sum(recompute_methods_stage_split), \
|
||||
'args.recompute_method_per_stage setting:' \
|
||||
'the sum of n0, n1, ... should be equal to pipeline-model-parallel-size.'
|
||||
|
||||
recompute_num_layers_per_stage = []
|
||||
for i in range(len(recompute_num_layers_stage_split)):
|
||||
for j in range(recompute_num_layers_stage_split[i]):
|
||||
recompute_num_layers_per_stage.append(recompute_num_layers_layer_split[i])
|
||||
recompute_method_per_stage = []
|
||||
for i in range(len(recompute_methods_stage_split)):
|
||||
for j in range(recompute_methods_stage_split[i]):
|
||||
recompute_method_per_stage.append(recompute_methods_method_split[i])
|
||||
|
||||
args.recompute_num_layers_per_stage = recompute_num_layers_per_stage
|
||||
args.recompute_method_per_stage = recompute_method_per_stage
|
||||
|
||||
# Batch size.
|
||||
assert args.micro_batch_size is not None
|
||||
assert args.micro_batch_size > 0
|
||||
@@ -1124,6 +1220,25 @@ def _add_training_args(parser):
|
||||
'uniformly divided recompute unit, '
|
||||
'2) block: the number of individual Transformer layers '
|
||||
'to recompute within each pipeline stage.')
|
||||
group.add_argument('--hetero-pipeline-stages', nargs='*', type=int, default=None,
|
||||
help='Incompatible with --num-layers-per-virtual-pipeline-stage.'
|
||||
'hetero-pipeline-stages must be in the form:'
|
||||
'n0 layers_0_0 layers_0_1 ... n1 nlayers_1_0 nlayers_1_1 ...'
|
||||
'The order should be consistent with --hetero-device-types.')
|
||||
group.add_argument('--recompute-granularity-per-stage', nargs='*', type=int, default=None,
|
||||
help='used with recompute-granularity=full, setting recompute granularity'
|
||||
'of each stage. This argument must be in the form: n0, flag0, n1, flag1,...'
|
||||
'the sum of n0, n1, ... should be equal to pipeline-model-parallel-size.'
|
||||
'granularity flag: 0 means turning off full recompute, 1 means turning on')
|
||||
group.add_argument('--recompute-method-per-stage', nargs='*', type=int, default=None,
|
||||
help='used with recompute-granularity=full, setting recompute method '
|
||||
'of each stage. This argument must be in the form: n0, method0, n1, method1, ...'
|
||||
'the sum of n0, n1, ... should be equal to pipeline-model-parallel-size.'
|
||||
'method: 0 means uniform, 1 means block')
|
||||
group.add_argument('--recompute-num-layers-per-stage', nargs='*', type=int, default=None,
|
||||
help='used with recompute-granularity=full, setting recompute num layers '
|
||||
'of each stage. This argument must be in the form: n0, layers0, n1, layers1, ...'
|
||||
'the sum of n0, n1, ... should be equal to pipeline-model-parallel-size.')
|
||||
group.add_argument('--no-clone-scatter-output-in-embedding', action='store_false',
|
||||
help='If not set, clone the output of the scatter in embedding layer to GC original tensor.',
|
||||
dest='clone_scatter_output_in_embedding')
|
||||
|
Reference in New Issue
Block a user