398 lines
16 KiB
Python
398 lines
16 KiB
Python
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
|
|
|
# Note (rwaleffe): This is a temporary file for hybrid mamba-transformer model checkpoint conversion.
|
|
# This functionality should be integrated with the megatron core checkpoint loader/saver.
|
|
|
|
|
|
import copy
|
|
import os
|
|
import re
|
|
import shutil
|
|
from collections import OrderedDict
|
|
|
|
import torch
|
|
import argparse
|
|
|
|
|
|
tp_split_dim = {
|
|
'word_embeddings.weight': 0,
|
|
'norm.weight': -1,
|
|
'final_norm.weight': -1,
|
|
'output_layer.weight': 0,
|
|
# mamba1/2
|
|
'A_log': 0,
|
|
'D': 0,
|
|
'dt_bias': 0,
|
|
'in_proj.weight': 0,
|
|
'conv1d.weight': 0,
|
|
'conv1d.bias': 0,
|
|
'x_proj.weight': 1,
|
|
'dt_proj.weight': 0,
|
|
'dt_proj.bias': 0,
|
|
'out_proj.weight': 1,
|
|
'mixer.norm.weight': 0,
|
|
# mlp
|
|
'linear_fc1.layer_norm_weight': -1,
|
|
'linear_fc1.weight': 0,
|
|
'linear_fc2.weight': 1,
|
|
# attention
|
|
'self_attention.linear_proj.weight': 1,
|
|
'self_attention.linear_qkv.layer_norm_weight': -1,
|
|
'self_attention.linear_qkv.weight': 0,
|
|
}
|
|
|
|
|
|
def get_split_dim(tensor_name):
|
|
# norm.weight will match tensor_name of mixer.norm.weight and norm.weight, need to distinguish
|
|
if 'norm.weight' in tensor_name:
|
|
if 'mixer.norm.weight' in tensor_name:
|
|
return tp_split_dim['mixer.norm.weight']
|
|
else:
|
|
return tp_split_dim['norm.weight']
|
|
|
|
for key in tp_split_dim.keys():
|
|
if key in tensor_name:
|
|
return tp_split_dim[key]
|
|
raise Exception("Unknown tensor name {}".format(tensor_name))
|
|
|
|
|
|
def combine_tp_tensors(params, key, dim, tensors):
|
|
tp_size = len(tensors)
|
|
|
|
if 'mixer.in_proj.weight' in key and params.mamba_version == 1:
|
|
xs = []; zs = []
|
|
for tensor in tensors:
|
|
x, z = torch.split(tensor, [params.mamba_d_inner//tp_size,
|
|
params.mamba_d_inner//tp_size], dim=dim)
|
|
xs.append(x); zs.append(z)
|
|
return torch.cat([torch.cat(xs, dim=dim), torch.cat(zs, dim=dim)], dim=dim)
|
|
|
|
elif 'mixer.in_proj.weight' in key and params.mamba_version == 2:
|
|
xs = []; zs = []; Bs = []; Cs = []; dts = []
|
|
for tensor in tensors:
|
|
x, z, B, C, dt = torch.split(tensor, [params.mamba_d_inner // tp_size,
|
|
params.mamba_d_inner // tp_size,
|
|
(params.mamba2_n_groups // tp_size) * args.mamba_d_state,
|
|
(params.mamba2_n_groups // tp_size) * args.mamba_d_state,
|
|
params.mamba2_n_heads // tp_size], dim=dim)
|
|
xs.append(x); zs.append(z); Bs.append(B); Cs.append(C); dts.append(dt)
|
|
|
|
for ii in range(len(Bs)):
|
|
Bs[ii] = torch.reshape(Bs[ii], (-1, params.mamba_d_state, Bs[ii].shape[-1]))
|
|
Cs[ii] = torch.reshape(Cs[ii], (-1, params.mamba_d_state, Cs[ii].shape[-1]))
|
|
B = torch.cat(Bs, dim=dim); C = torch.cat(Cs, dim=dim)
|
|
x = torch.cat(xs, dim=dim); z = torch.cat(zs, dim=dim); dt = torch.cat(dts, dim=dim)
|
|
|
|
return torch.cat([x, z, B.flatten(0, 1), C.flatten(0, 1), dt], dim=dim)
|
|
|
|
elif 'mixer.conv1d' in key and params.mamba_version == 2:
|
|
xs = []; Bs = []; Cs = []
|
|
for tensor in tensors:
|
|
x, B, C = torch.split(tensor, [params.mamba_d_inner//tp_size,
|
|
(params.mamba2_n_groups // tp_size) * params.mamba_d_state,
|
|
(params.mamba2_n_groups // tp_size) * params.mamba_d_state], dim=dim)
|
|
xs.append(x); Bs.append(B); Cs.append(C)
|
|
|
|
for ii in range(len(Bs)):
|
|
if 'weight' in key:
|
|
Bs[ii] = torch.reshape(Bs[ii], (-1, params.mamba_d_state, Bs[ii].shape[-2], Bs[ii].shape[-1]))
|
|
Cs[ii] = torch.reshape(Cs[ii], (-1, params.mamba_d_state, Cs[ii].shape[-2], Cs[ii].shape[-1]))
|
|
elif 'bias' in key:
|
|
Bs[ii] = torch.reshape(Bs[ii], (-1, params.mamba_d_state))
|
|
Cs[ii] = torch.reshape(Cs[ii], (-1, params.mamba_d_state))
|
|
else:
|
|
raise Exception("Unknown key")
|
|
B = torch.cat(Bs, dim=dim); C = torch.cat(Cs, dim=dim)
|
|
x = torch.cat(xs, dim=dim)
|
|
|
|
return torch.cat([x, B.flatten(0, 1), C.flatten(0, 1)], dim=dim)
|
|
|
|
else:
|
|
return torch.cat(tensors, dim=dim)
|
|
|
|
|
|
def split_tensor_for_tp(params, key, dim, tensor):
|
|
tp_size = params.target_tp_size
|
|
tensor_sliced = []
|
|
|
|
if 'mixer.in_proj.weight' in key and params.mamba_version == 1:
|
|
x, z = torch.split(tensor, [params.mamba_d_inner, params.mamba_d_inner], dim=dim)
|
|
x_sliced = torch.chunk(x, tp_size, dim=dim)
|
|
z_sliced = torch.chunk(z, tp_size, dim=dim)
|
|
for (x, z) in zip(x_sliced, z_sliced):
|
|
tensor_sliced.append(torch.cat((x, z), dim=dim))
|
|
|
|
elif 'mixer.in_proj.weight' in key and params.mamba_version == 2:
|
|
x, z, B, C, dt = torch.split(tensor, [params.mamba_d_inner, params.mamba_d_inner,
|
|
params.mamba2_n_groups * params.mamba_d_state,
|
|
params.mamba2_n_groups * params.mamba_d_state,
|
|
params.mamba2_n_heads], dim=dim)
|
|
B = torch.reshape(B, (-1, params.mamba_d_state, B.shape[-1]))
|
|
C = torch.reshape(C, (-1, params.mamba_d_state, C.shape[-1]))
|
|
|
|
B_sliced = torch.chunk(B, tp_size, dim=dim)
|
|
C_sliced = torch.chunk(C, tp_size, dim=dim)
|
|
x_sliced = torch.chunk(x, tp_size, dim=dim)
|
|
z_sliced = torch.chunk(z, tp_size, dim=dim)
|
|
dt_sliced = torch.chunk(dt, tp_size, dim=dim)
|
|
|
|
tensor_sliced = []
|
|
for (x, z, B, C, dt) in zip(x_sliced, z_sliced, B_sliced, C_sliced, dt_sliced):
|
|
tensor_sliced.append(torch.cat((x, z, B.flatten(0, 1), C.flatten(0, 1), dt), dim=dim))
|
|
|
|
elif 'mixer.conv1d' in key and params.mamba_version == 2:
|
|
x, B, C = torch.split(tensor, [params.mamba_d_inner,
|
|
params.mamba2_n_groups * params.mamba_d_state,
|
|
params.mamba2_n_groups * params.mamba_d_state], dim=dim)
|
|
if 'weight' in key:
|
|
B = torch.reshape(B, (-1, params.mamba_d_state, B.shape[-2], B.shape[-1]))
|
|
C = torch.reshape(C, (-1, params.mamba_d_state, C.shape[-2], C.shape[-1]))
|
|
elif 'bias' in key:
|
|
B = torch.reshape(B, (-1, params.mamba_d_state))
|
|
C = torch.reshape(C, (-1, params.mamba_d_state))
|
|
else:
|
|
raise Exception("Unknown key")
|
|
|
|
B_sliced = torch.chunk(B, tp_size, dim=dim)
|
|
C_sliced = torch.chunk(C, tp_size, dim=dim)
|
|
x_sliced = torch.chunk(x, tp_size, dim=dim)
|
|
|
|
tensor_sliced = []
|
|
for (x, B, C) in zip(x_sliced, B_sliced, C_sliced):
|
|
tensor_sliced.append(torch.cat((x, B.flatten(0, 1), C.flatten(0, 1)), dim=dim))
|
|
|
|
else:
|
|
tensor_sliced = torch.chunk(tensor, tp_size, dim=dim)
|
|
|
|
return tensor_sliced
|
|
|
|
|
|
def finalize_checkpoint(sample_model, model, params, verbose=False):
|
|
# make sure the rest of the checkpoint is how we want it from the original (i.e., other than the 'model')
|
|
reset_iterations = params.reset_iterations
|
|
|
|
# checkpoint 'args'
|
|
model['args'] = copy.deepcopy(sample_model['args'])
|
|
model['args'].tensor_model_parallel_size = params.target_tp_size
|
|
model['args'].pipeline_model_parallel_size = params.target_pp_size
|
|
if reset_iterations:
|
|
model['args'].iteration = 0
|
|
model['args'].consumed_valid_samples = 0
|
|
model['args'].consumed_train_samples = 0
|
|
model['args'].train_iters = 0
|
|
model['args'].train_samples = 0
|
|
|
|
# checkpoint 'checkpoint_version'
|
|
model['checkpoint_version'] = copy.deepcopy(sample_model['checkpoint_version'])
|
|
|
|
# checkpoint 'iteration'
|
|
model['iteration'] = copy.deepcopy(sample_model['iteration'])
|
|
if reset_iterations:
|
|
model['iteration'] = 0
|
|
|
|
# checkpoint 'optimizer'
|
|
# ignore
|
|
|
|
# checkpoint 'opt_param_scheduler'
|
|
if 'opt_param_scheduler' in sample_model.keys():
|
|
model['opt_param_scheduler'] = copy.deepcopy(sample_model['opt_param_scheduler'])
|
|
|
|
# checkpoint 'rng_state'
|
|
model['rng_state'] = copy.deepcopy(sample_model['rng_state'])
|
|
|
|
# report on argument difference
|
|
if verbose:
|
|
original_args = sample_model['args'].__dict__
|
|
final_args = model['args'].__dict__
|
|
for key in original_args:
|
|
if key in final_args:
|
|
if final_args[key] != original_args[key]:
|
|
print("KEY MISMATCH: {}".format(key))
|
|
print("\toriginal: {}\n\tfinal: {}".format(original_args[key], final_args[key]))
|
|
else:
|
|
print("KEY MISSING from final: {}, value {}".format(key, original_args[key]))
|
|
print("")
|
|
for key in final_args:
|
|
if key not in original_args:
|
|
print("KEY ADDED to final: {}, value {}".format(key, final_args[key]))
|
|
|
|
return model
|
|
|
|
|
|
def main(args):
|
|
print("\n====RUNNING CHECKPOINT CONVERSION====\n")
|
|
|
|
args.mamba_d_inner = args.d_model * 2
|
|
args.mamba2_n_heads = args.mamba_d_inner // args.mamba2_head_dim
|
|
|
|
# get the latest iteration
|
|
tracker_filename = os.path.join(args.load_dir, 'latest_checkpointed_iteration.txt')
|
|
with open(tracker_filename, 'r') as f:
|
|
metastring = f.read().strip()
|
|
try:
|
|
iteration = int(metastring)
|
|
except ValueError:
|
|
raise Exception("")
|
|
out_iteration = iteration if not args.reset_iterations else 0
|
|
|
|
# get model directory and model parallel ranks
|
|
input_model_dir = os.path.join(args.load_dir, 'iter_{:07d}'.format(iteration))
|
|
input_sub_models = os.listdir(input_model_dir)
|
|
# input_sub_models = sorted(input_sub_models, key=lambda x: int(re.search(r'\d+', x).group()))
|
|
|
|
# load one of the model parallel ranks to get arguments
|
|
sample_model_file = os.path.join(input_model_dir, input_sub_models[0], "model_optim_rng.pt")
|
|
sample_model = torch.load(sample_model_file)
|
|
print(f"Sample model {sample_model_file} is loaded.\n")
|
|
|
|
# input tensor and pipeline parallel size
|
|
input_tp_rank = sample_model['args'].tensor_model_parallel_size
|
|
input_pp_rank = sample_model['args'].pipeline_model_parallel_size
|
|
num_layers_per_pipeline_rank = sample_model['args'].num_layers // input_pp_rank
|
|
|
|
# construct full model
|
|
full_model = OrderedDict()
|
|
for pp in range(input_pp_rank):
|
|
print("[INFO] Processing input pipeline rank {}".format(pp))
|
|
tp_models = []
|
|
for tp in range(input_tp_rank):
|
|
dir_name = "mp_rank_{:02d}".format(tp)
|
|
if input_pp_rank > 1:
|
|
dir_name += "_{:03d}".format(pp)
|
|
model_file = os.path.join(input_model_dir, dir_name, "model_optim_rng.pt")
|
|
|
|
tp_models.append(torch.load(model_file))
|
|
print(f"Model {model_file} is loaded.")
|
|
|
|
if input_tp_rank > 1:
|
|
combined_tp_model = OrderedDict()
|
|
for ii, (key, original_tensor) in enumerate(tp_models[0]['model'].items()):
|
|
if "_extra_state" in key:
|
|
combined_tp_model[key] = original_tensor
|
|
continue
|
|
|
|
split_dim = get_split_dim(key)
|
|
original_shape = list(original_tensor.shape)
|
|
combined_shape = copy.deepcopy(original_shape)
|
|
combined_shape[split_dim] *= input_tp_rank
|
|
# print("{}, {}, {}".format(ii, key, split_dim))
|
|
|
|
if split_dim != -1:
|
|
# slice together model
|
|
# print("\tshape mismatch: original {}, combined {}".format(original_shape, combined_shape))
|
|
combined_tensor = combine_tp_tensors(args, key, split_dim,
|
|
[tp_models[jj]['model'][key].cpu() for jj in range(input_tp_rank)])
|
|
combined_tp_model[key] = combined_tensor
|
|
else:
|
|
# copy model
|
|
combined_tp_model[key] = original_tensor
|
|
else:
|
|
combined_tp_model = tp_models[0]['model']
|
|
# print("Combined tp model: {}".format(combined_tp_model.keys()))
|
|
|
|
for ii, (key, original_tensor) in enumerate(combined_tp_model.items()):
|
|
try:
|
|
layer_num = int(re.findall(r'\d+', key)[0])
|
|
new_key = key.replace(str(layer_num), str(layer_num + pp*num_layers_per_pipeline_rank), 1)
|
|
except:
|
|
new_key = key
|
|
full_model[new_key] = original_tensor
|
|
# print("Combined model: {}".format(full_model.keys()))
|
|
print("\n[INFO] Loaded combined model\n")
|
|
|
|
# sort by layer
|
|
# full_model_sorted = dict(sorted(people.items(), key=lambda item: item[1]))
|
|
|
|
# create new split model
|
|
pp_offset = 0
|
|
num_layers_per_pipeline_rank = sample_model['args'].num_layers // args.target_pp_size
|
|
|
|
for pp in range(args.target_pp_size):
|
|
print("[INFO] Processing output pipeline rank {}".format(pp))
|
|
tp_models = []
|
|
for ii in range(args.target_tp_size):
|
|
tp_models.append({'model': OrderedDict()})
|
|
|
|
for ii, (key, original_tensor) in enumerate(full_model.items()):
|
|
try:
|
|
layer_num = int(re.findall(r'\d+', key)[0])
|
|
if layer_num >= num_layers_per_pipeline_rank * (pp+1):
|
|
break
|
|
new_key = key.replace(str(layer_num), str(layer_num - (pp * num_layers_per_pipeline_rank)), 1)
|
|
except:
|
|
new_key = key
|
|
|
|
if ii < pp_offset:
|
|
continue
|
|
else:
|
|
pp_offset += 1
|
|
|
|
if "_extra_state" in new_key:
|
|
# copy
|
|
for jj in range(args.target_tp_size):
|
|
tp_models[jj]['model'][new_key] = original_tensor
|
|
continue
|
|
|
|
split_dim = get_split_dim(new_key)
|
|
original_shape = list(original_tensor.shape)
|
|
v0 = original_shape[split_dim]
|
|
split_size = v0 // args.target_tp_size
|
|
split_shape = copy.deepcopy(original_shape)
|
|
split_shape[split_dim] = split_size
|
|
# print("{}, {}, {}".format(ii, new_key, split_dim))
|
|
|
|
if split_dim != -1:
|
|
# split model
|
|
# print("\tshape mismatch: original {}, combined {}".format(original_shape, split_shape))
|
|
tensor_sliced = split_tensor_for_tp(args, new_key, split_dim, original_tensor)
|
|
for jj in range(args.target_tp_size):
|
|
tp_models[jj]['model'][new_key] = tensor_sliced[jj]
|
|
else:
|
|
# copy model
|
|
for jj in range(args.target_tp_size):
|
|
tp_models[jj]['model'][new_key] = original_tensor
|
|
# print(tp_models[0]['model'].keys())
|
|
|
|
for tp in range(args.target_tp_size):
|
|
dir_name = "mp_rank_{:02d}".format(tp)
|
|
if args.target_pp_size > 1:
|
|
dir_name += "_{:03d}".format(pp)
|
|
|
|
model = finalize_checkpoint(sample_model, tp_models[tp], args, verbose=False)
|
|
|
|
save_dir = os.path.join(args.save_dir, 'iter_{:07d}'.format(out_iteration), dir_name)
|
|
os.makedirs(save_dir, exist_ok=True)
|
|
model_file = os.path.join(save_dir, "model_optim_rng.pt")
|
|
torch.save(model, model_file)
|
|
print(f"Model {model_file} is saved.")
|
|
|
|
# shutil.copyfile(tracker_filename, os.path.join(args.save_dir, 'latest_checkpointed_iteration.txt'))
|
|
tracker_filename = os.path.join(args.save_dir, 'latest_checkpointed_iteration.txt')
|
|
with open(tracker_filename, 'w') as f:
|
|
f.write(str(out_iteration))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# example run command:
|
|
# python hybrid_conversion.py
|
|
# --load-dir mamba2-840m-test/checkpoints/
|
|
# --save-dir mamba2-840m-test-conversion/checkpoints/
|
|
# --target-pp-size 1
|
|
# --target-tp-size 1
|
|
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--load-dir', type=str)
|
|
parser.add_argument('--save-dir', type=str)
|
|
parser.add_argument('--target-tp-size', type=int, default=1)
|
|
parser.add_argument('--target-pp-size', type=int, default=1)
|
|
parser.add_argument('--reset-iterations', action='store_true')
|
|
|
|
parser.add_argument('--d-model', type=int, default=4096)
|
|
parser.add_argument('--mamba-version', type=int, default=2)
|
|
parser.add_argument('--mamba-d-state', type=int, default=128)
|
|
parser.add_argument('--mamba2-n-groups', type=int, default=8)
|
|
parser.add_argument('--mamba2-head-dim', type=int, default=64)
|
|
|
|
args = parser.parse_args()
|
|
|
|
main(args) |