Files
heterogeneous-distributed-t…/tools/checkpoint/hybrid_conversion.py
tianyutong d6ce507681 Initial Commit of Megatron-LM-0.8.0
Change-Id: Ifb4c061207ee2644a21e161ad52fc6ff40564e39
2025-05-23 09:54:48 +08:00

398 lines
16 KiB
Python

# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Note (rwaleffe): This is a temporary file for hybrid mamba-transformer model checkpoint conversion.
# This functionality should be integrated with the megatron core checkpoint loader/saver.
import copy
import os
import re
import shutil
from collections import OrderedDict
import torch
import argparse
tp_split_dim = {
'word_embeddings.weight': 0,
'norm.weight': -1,
'final_norm.weight': -1,
'output_layer.weight': 0,
# mamba1/2
'A_log': 0,
'D': 0,
'dt_bias': 0,
'in_proj.weight': 0,
'conv1d.weight': 0,
'conv1d.bias': 0,
'x_proj.weight': 1,
'dt_proj.weight': 0,
'dt_proj.bias': 0,
'out_proj.weight': 1,
'mixer.norm.weight': 0,
# mlp
'linear_fc1.layer_norm_weight': -1,
'linear_fc1.weight': 0,
'linear_fc2.weight': 1,
# attention
'self_attention.linear_proj.weight': 1,
'self_attention.linear_qkv.layer_norm_weight': -1,
'self_attention.linear_qkv.weight': 0,
}
def get_split_dim(tensor_name):
# norm.weight will match tensor_name of mixer.norm.weight and norm.weight, need to distinguish
if 'norm.weight' in tensor_name:
if 'mixer.norm.weight' in tensor_name:
return tp_split_dim['mixer.norm.weight']
else:
return tp_split_dim['norm.weight']
for key in tp_split_dim.keys():
if key in tensor_name:
return tp_split_dim[key]
raise Exception("Unknown tensor name {}".format(tensor_name))
def combine_tp_tensors(params, key, dim, tensors):
tp_size = len(tensors)
if 'mixer.in_proj.weight' in key and params.mamba_version == 1:
xs = []; zs = []
for tensor in tensors:
x, z = torch.split(tensor, [params.mamba_d_inner//tp_size,
params.mamba_d_inner//tp_size], dim=dim)
xs.append(x); zs.append(z)
return torch.cat([torch.cat(xs, dim=dim), torch.cat(zs, dim=dim)], dim=dim)
elif 'mixer.in_proj.weight' in key and params.mamba_version == 2:
xs = []; zs = []; Bs = []; Cs = []; dts = []
for tensor in tensors:
x, z, B, C, dt = torch.split(tensor, [params.mamba_d_inner // tp_size,
params.mamba_d_inner // tp_size,
(params.mamba2_n_groups // tp_size) * args.mamba_d_state,
(params.mamba2_n_groups // tp_size) * args.mamba_d_state,
params.mamba2_n_heads // tp_size], dim=dim)
xs.append(x); zs.append(z); Bs.append(B); Cs.append(C); dts.append(dt)
for ii in range(len(Bs)):
Bs[ii] = torch.reshape(Bs[ii], (-1, params.mamba_d_state, Bs[ii].shape[-1]))
Cs[ii] = torch.reshape(Cs[ii], (-1, params.mamba_d_state, Cs[ii].shape[-1]))
B = torch.cat(Bs, dim=dim); C = torch.cat(Cs, dim=dim)
x = torch.cat(xs, dim=dim); z = torch.cat(zs, dim=dim); dt = torch.cat(dts, dim=dim)
return torch.cat([x, z, B.flatten(0, 1), C.flatten(0, 1), dt], dim=dim)
elif 'mixer.conv1d' in key and params.mamba_version == 2:
xs = []; Bs = []; Cs = []
for tensor in tensors:
x, B, C = torch.split(tensor, [params.mamba_d_inner//tp_size,
(params.mamba2_n_groups // tp_size) * params.mamba_d_state,
(params.mamba2_n_groups // tp_size) * params.mamba_d_state], dim=dim)
xs.append(x); Bs.append(B); Cs.append(C)
for ii in range(len(Bs)):
if 'weight' in key:
Bs[ii] = torch.reshape(Bs[ii], (-1, params.mamba_d_state, Bs[ii].shape[-2], Bs[ii].shape[-1]))
Cs[ii] = torch.reshape(Cs[ii], (-1, params.mamba_d_state, Cs[ii].shape[-2], Cs[ii].shape[-1]))
elif 'bias' in key:
Bs[ii] = torch.reshape(Bs[ii], (-1, params.mamba_d_state))
Cs[ii] = torch.reshape(Cs[ii], (-1, params.mamba_d_state))
else:
raise Exception("Unknown key")
B = torch.cat(Bs, dim=dim); C = torch.cat(Cs, dim=dim)
x = torch.cat(xs, dim=dim)
return torch.cat([x, B.flatten(0, 1), C.flatten(0, 1)], dim=dim)
else:
return torch.cat(tensors, dim=dim)
def split_tensor_for_tp(params, key, dim, tensor):
tp_size = params.target_tp_size
tensor_sliced = []
if 'mixer.in_proj.weight' in key and params.mamba_version == 1:
x, z = torch.split(tensor, [params.mamba_d_inner, params.mamba_d_inner], dim=dim)
x_sliced = torch.chunk(x, tp_size, dim=dim)
z_sliced = torch.chunk(z, tp_size, dim=dim)
for (x, z) in zip(x_sliced, z_sliced):
tensor_sliced.append(torch.cat((x, z), dim=dim))
elif 'mixer.in_proj.weight' in key and params.mamba_version == 2:
x, z, B, C, dt = torch.split(tensor, [params.mamba_d_inner, params.mamba_d_inner,
params.mamba2_n_groups * params.mamba_d_state,
params.mamba2_n_groups * params.mamba_d_state,
params.mamba2_n_heads], dim=dim)
B = torch.reshape(B, (-1, params.mamba_d_state, B.shape[-1]))
C = torch.reshape(C, (-1, params.mamba_d_state, C.shape[-1]))
B_sliced = torch.chunk(B, tp_size, dim=dim)
C_sliced = torch.chunk(C, tp_size, dim=dim)
x_sliced = torch.chunk(x, tp_size, dim=dim)
z_sliced = torch.chunk(z, tp_size, dim=dim)
dt_sliced = torch.chunk(dt, tp_size, dim=dim)
tensor_sliced = []
for (x, z, B, C, dt) in zip(x_sliced, z_sliced, B_sliced, C_sliced, dt_sliced):
tensor_sliced.append(torch.cat((x, z, B.flatten(0, 1), C.flatten(0, 1), dt), dim=dim))
elif 'mixer.conv1d' in key and params.mamba_version == 2:
x, B, C = torch.split(tensor, [params.mamba_d_inner,
params.mamba2_n_groups * params.mamba_d_state,
params.mamba2_n_groups * params.mamba_d_state], dim=dim)
if 'weight' in key:
B = torch.reshape(B, (-1, params.mamba_d_state, B.shape[-2], B.shape[-1]))
C = torch.reshape(C, (-1, params.mamba_d_state, C.shape[-2], C.shape[-1]))
elif 'bias' in key:
B = torch.reshape(B, (-1, params.mamba_d_state))
C = torch.reshape(C, (-1, params.mamba_d_state))
else:
raise Exception("Unknown key")
B_sliced = torch.chunk(B, tp_size, dim=dim)
C_sliced = torch.chunk(C, tp_size, dim=dim)
x_sliced = torch.chunk(x, tp_size, dim=dim)
tensor_sliced = []
for (x, B, C) in zip(x_sliced, B_sliced, C_sliced):
tensor_sliced.append(torch.cat((x, B.flatten(0, 1), C.flatten(0, 1)), dim=dim))
else:
tensor_sliced = torch.chunk(tensor, tp_size, dim=dim)
return tensor_sliced
def finalize_checkpoint(sample_model, model, params, verbose=False):
# make sure the rest of the checkpoint is how we want it from the original (i.e., other than the 'model')
reset_iterations = params.reset_iterations
# checkpoint 'args'
model['args'] = copy.deepcopy(sample_model['args'])
model['args'].tensor_model_parallel_size = params.target_tp_size
model['args'].pipeline_model_parallel_size = params.target_pp_size
if reset_iterations:
model['args'].iteration = 0
model['args'].consumed_valid_samples = 0
model['args'].consumed_train_samples = 0
model['args'].train_iters = 0
model['args'].train_samples = 0
# checkpoint 'checkpoint_version'
model['checkpoint_version'] = copy.deepcopy(sample_model['checkpoint_version'])
# checkpoint 'iteration'
model['iteration'] = copy.deepcopy(sample_model['iteration'])
if reset_iterations:
model['iteration'] = 0
# checkpoint 'optimizer'
# ignore
# checkpoint 'opt_param_scheduler'
if 'opt_param_scheduler' in sample_model.keys():
model['opt_param_scheduler'] = copy.deepcopy(sample_model['opt_param_scheduler'])
# checkpoint 'rng_state'
model['rng_state'] = copy.deepcopy(sample_model['rng_state'])
# report on argument difference
if verbose:
original_args = sample_model['args'].__dict__
final_args = model['args'].__dict__
for key in original_args:
if key in final_args:
if final_args[key] != original_args[key]:
print("KEY MISMATCH: {}".format(key))
print("\toriginal: {}\n\tfinal: {}".format(original_args[key], final_args[key]))
else:
print("KEY MISSING from final: {}, value {}".format(key, original_args[key]))
print("")
for key in final_args:
if key not in original_args:
print("KEY ADDED to final: {}, value {}".format(key, final_args[key]))
return model
def main(args):
print("\n====RUNNING CHECKPOINT CONVERSION====\n")
args.mamba_d_inner = args.d_model * 2
args.mamba2_n_heads = args.mamba_d_inner // args.mamba2_head_dim
# get the latest iteration
tracker_filename = os.path.join(args.load_dir, 'latest_checkpointed_iteration.txt')
with open(tracker_filename, 'r') as f:
metastring = f.read().strip()
try:
iteration = int(metastring)
except ValueError:
raise Exception("")
out_iteration = iteration if not args.reset_iterations else 0
# get model directory and model parallel ranks
input_model_dir = os.path.join(args.load_dir, 'iter_{:07d}'.format(iteration))
input_sub_models = os.listdir(input_model_dir)
# input_sub_models = sorted(input_sub_models, key=lambda x: int(re.search(r'\d+', x).group()))
# load one of the model parallel ranks to get arguments
sample_model_file = os.path.join(input_model_dir, input_sub_models[0], "model_optim_rng.pt")
sample_model = torch.load(sample_model_file)
print(f"Sample model {sample_model_file} is loaded.\n")
# input tensor and pipeline parallel size
input_tp_rank = sample_model['args'].tensor_model_parallel_size
input_pp_rank = sample_model['args'].pipeline_model_parallel_size
num_layers_per_pipeline_rank = sample_model['args'].num_layers // input_pp_rank
# construct full model
full_model = OrderedDict()
for pp in range(input_pp_rank):
print("[INFO] Processing input pipeline rank {}".format(pp))
tp_models = []
for tp in range(input_tp_rank):
dir_name = "mp_rank_{:02d}".format(tp)
if input_pp_rank > 1:
dir_name += "_{:03d}".format(pp)
model_file = os.path.join(input_model_dir, dir_name, "model_optim_rng.pt")
tp_models.append(torch.load(model_file))
print(f"Model {model_file} is loaded.")
if input_tp_rank > 1:
combined_tp_model = OrderedDict()
for ii, (key, original_tensor) in enumerate(tp_models[0]['model'].items()):
if "_extra_state" in key:
combined_tp_model[key] = original_tensor
continue
split_dim = get_split_dim(key)
original_shape = list(original_tensor.shape)
combined_shape = copy.deepcopy(original_shape)
combined_shape[split_dim] *= input_tp_rank
# print("{}, {}, {}".format(ii, key, split_dim))
if split_dim != -1:
# slice together model
# print("\tshape mismatch: original {}, combined {}".format(original_shape, combined_shape))
combined_tensor = combine_tp_tensors(args, key, split_dim,
[tp_models[jj]['model'][key].cpu() for jj in range(input_tp_rank)])
combined_tp_model[key] = combined_tensor
else:
# copy model
combined_tp_model[key] = original_tensor
else:
combined_tp_model = tp_models[0]['model']
# print("Combined tp model: {}".format(combined_tp_model.keys()))
for ii, (key, original_tensor) in enumerate(combined_tp_model.items()):
try:
layer_num = int(re.findall(r'\d+', key)[0])
new_key = key.replace(str(layer_num), str(layer_num + pp*num_layers_per_pipeline_rank), 1)
except:
new_key = key
full_model[new_key] = original_tensor
# print("Combined model: {}".format(full_model.keys()))
print("\n[INFO] Loaded combined model\n")
# sort by layer
# full_model_sorted = dict(sorted(people.items(), key=lambda item: item[1]))
# create new split model
pp_offset = 0
num_layers_per_pipeline_rank = sample_model['args'].num_layers // args.target_pp_size
for pp in range(args.target_pp_size):
print("[INFO] Processing output pipeline rank {}".format(pp))
tp_models = []
for ii in range(args.target_tp_size):
tp_models.append({'model': OrderedDict()})
for ii, (key, original_tensor) in enumerate(full_model.items()):
try:
layer_num = int(re.findall(r'\d+', key)[0])
if layer_num >= num_layers_per_pipeline_rank * (pp+1):
break
new_key = key.replace(str(layer_num), str(layer_num - (pp * num_layers_per_pipeline_rank)), 1)
except:
new_key = key
if ii < pp_offset:
continue
else:
pp_offset += 1
if "_extra_state" in new_key:
# copy
for jj in range(args.target_tp_size):
tp_models[jj]['model'][new_key] = original_tensor
continue
split_dim = get_split_dim(new_key)
original_shape = list(original_tensor.shape)
v0 = original_shape[split_dim]
split_size = v0 // args.target_tp_size
split_shape = copy.deepcopy(original_shape)
split_shape[split_dim] = split_size
# print("{}, {}, {}".format(ii, new_key, split_dim))
if split_dim != -1:
# split model
# print("\tshape mismatch: original {}, combined {}".format(original_shape, split_shape))
tensor_sliced = split_tensor_for_tp(args, new_key, split_dim, original_tensor)
for jj in range(args.target_tp_size):
tp_models[jj]['model'][new_key] = tensor_sliced[jj]
else:
# copy model
for jj in range(args.target_tp_size):
tp_models[jj]['model'][new_key] = original_tensor
# print(tp_models[0]['model'].keys())
for tp in range(args.target_tp_size):
dir_name = "mp_rank_{:02d}".format(tp)
if args.target_pp_size > 1:
dir_name += "_{:03d}".format(pp)
model = finalize_checkpoint(sample_model, tp_models[tp], args, verbose=False)
save_dir = os.path.join(args.save_dir, 'iter_{:07d}'.format(out_iteration), dir_name)
os.makedirs(save_dir, exist_ok=True)
model_file = os.path.join(save_dir, "model_optim_rng.pt")
torch.save(model, model_file)
print(f"Model {model_file} is saved.")
# shutil.copyfile(tracker_filename, os.path.join(args.save_dir, 'latest_checkpointed_iteration.txt'))
tracker_filename = os.path.join(args.save_dir, 'latest_checkpointed_iteration.txt')
with open(tracker_filename, 'w') as f:
f.write(str(out_iteration))
if __name__ == "__main__":
# example run command:
# python hybrid_conversion.py
# --load-dir mamba2-840m-test/checkpoints/
# --save-dir mamba2-840m-test-conversion/checkpoints/
# --target-pp-size 1
# --target-tp-size 1
parser = argparse.ArgumentParser()
parser.add_argument('--load-dir', type=str)
parser.add_argument('--save-dir', type=str)
parser.add_argument('--target-tp-size', type=int, default=1)
parser.add_argument('--target-pp-size', type=int, default=1)
parser.add_argument('--reset-iterations', action='store_true')
parser.add_argument('--d-model', type=int, default=4096)
parser.add_argument('--mamba-version', type=int, default=2)
parser.add_argument('--mamba-d-state', type=int, default=128)
parser.add_argument('--mamba2-n-groups', type=int, default=8)
parser.add_argument('--mamba2-head-dim', type=int, default=64)
args = parser.parse_args()
main(args)