224 lines
8.4 KiB
Python
224 lines
8.4 KiB
Python
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
|
|
|
"""Sample Generate GPT."""
|
|
import functools
|
|
import os
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../")))
|
|
|
|
import modelopt.torch.quantization as mtq
|
|
import torch
|
|
from datasets import load_dataset
|
|
from modelopt.torch.utils.distributed import set_data_parallel_group, set_tensor_parallel_group
|
|
from tqdm import tqdm
|
|
|
|
# [ModelOpt]: changing the default model provider to the ModelOpt version
|
|
from megatron.core import mpu
|
|
from megatron.inference.arguments import add_modelopt_args
|
|
from megatron.inference.checkpointing import load_modelopt_checkpoint
|
|
from megatron.inference.gpt.model_provider import model_provider
|
|
from megatron.inference.text_generation import generate_and_post_process
|
|
from megatron.training import get_args, get_model, initialize_megatron
|
|
from megatron.training.checkpointing import save_checkpoint
|
|
from megatron.training.utils import print_rank_0, unwrap_model
|
|
|
|
QUANT_CFG_CHOICES = {
|
|
"int8": mtq.INT8_DEFAULT_CFG,
|
|
"int8_sq": mtq.INT8_SMOOTHQUANT_CFG,
|
|
"fp8": mtq.FP8_DEFAULT_CFG,
|
|
"int4_awq": mtq.INT4_AWQ_CFG,
|
|
"w4a8_awq": mtq.W4A8_AWQ_BETA_CFG,
|
|
"int4": mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG,
|
|
}
|
|
|
|
|
|
def add_trtllm_ckpt_export_args(parser):
|
|
"""Add additional arguments for TensorRT-LLM."""
|
|
group = parser.add_argument_group(title="trtllm")
|
|
|
|
group.add_argument(
|
|
"--export-dir", type=str, help="The output TensorRT-LLM checkpoint.",
|
|
)
|
|
group.add_argument(
|
|
"--decoder", type=str, choices=["gptnext", 'llama'], help="The decoder type of the model.",
|
|
)
|
|
group.add_argument(
|
|
"--inference-tensor-parallel",
|
|
type=int,
|
|
help="Tensor parallel for the inference time, can be different from the training config.",
|
|
default=1,
|
|
)
|
|
|
|
|
|
def add_text_generate_ptq_args(parser):
|
|
"""Add additional arguments for ModelOpt text generation PTQ."""
|
|
group = parser.add_argument_group(title='ModelOpt text generation ptq')
|
|
group.add_argument(
|
|
"--calib-dataset",
|
|
type=str,
|
|
default="cnn_dailymail",
|
|
help="Calibration datasets from HuggingFace datasets.",
|
|
)
|
|
group.add_argument(
|
|
"--calib-batch-size", type=int, default=4, help="Batch size to use for ptq calibration."
|
|
)
|
|
group.add_argument(
|
|
"--calib-size", type=int, default=512, help="Samples to use for ptq calibration."
|
|
)
|
|
parser.add_argument(
|
|
"--prompts",
|
|
type=str,
|
|
default=(
|
|
"Born in north-east France, Soyer trained as a|Born in California, Soyer trained as a"
|
|
),
|
|
help="Input texts. Please use | to separate different batches.",
|
|
)
|
|
add_modelopt_args(parser)
|
|
add_trtllm_ckpt_export_args(parser)
|
|
return parser
|
|
|
|
|
|
def get_calib_dataloader(
|
|
data="cnn_dailymail", batch_size=4, calib_size=512, max_sequence_length=512
|
|
):
|
|
if data == "pileval":
|
|
dataset = load_dataset(
|
|
"json", data_files="https://the-eye.eu/public/AI/pile/val.jsonl.zst", split="train"
|
|
)
|
|
text_column = "text"
|
|
elif data == "wikitext":
|
|
dataset = load_dataset("wikitext", "wikitext-103-v1", split="train")
|
|
text_column = "text"
|
|
elif data == "cnn_dailymail":
|
|
dataset = load_dataset("cnn_dailymail", name="3.0.0", split="train")
|
|
text_column = "article"
|
|
|
|
calib_size = max(min(len(dataset), calib_size), batch_size)
|
|
for i in range(calib_size // batch_size):
|
|
batch = dataset[i * batch_size : (i + 1) * batch_size][text_column]
|
|
for j in range(len(batch)):
|
|
batch[j] = batch[j][:max_sequence_length]
|
|
yield batch
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
initialize_megatron(
|
|
extra_args_provider=add_text_generate_ptq_args,
|
|
args_defaults={
|
|
'tokenizer_type': 'GPT2BPETokenizer',
|
|
'no_load_rng': True,
|
|
'no_load_optim': True,
|
|
},
|
|
)
|
|
|
|
args = get_args()
|
|
if args.num_layers_per_virtual_pipeline_stage is not None:
|
|
print_rank_0("Interleaved pipeline schedule is not yet supported for text generation.")
|
|
exit()
|
|
|
|
print_rank_0("WARNING: Forcing exit_on_missing_checkpoint to True for text generation.")
|
|
args.exit_on_missing_checkpoint = True
|
|
|
|
# Set up model and load checkpoint
|
|
# [ModelOpt]: make sure that output logits are allgathered.
|
|
text_generation_model_provider = functools.partial(model_provider, parallel_output=False)
|
|
model = get_model(text_generation_model_provider, wrap_with_ddp=False)
|
|
|
|
if args.load is not None:
|
|
load_modelopt_checkpoint(model, strict=not args.untie_embeddings_and_output_weights)
|
|
print_rank_0("Done loading checkpoint")
|
|
|
|
# Removing virtual pipeline parallel and other wrapper
|
|
assert len(model) == 1, "Above condition should have caught this"
|
|
unwrapped_model = unwrap_model(model)
|
|
|
|
all_prompts = args.prompts.split("|")
|
|
|
|
def custom_prompt_forward_loop_func(model):
|
|
for prompt in tqdm(all_prompts):
|
|
if mpu.is_pipeline_first_stage() and mpu.get_tensor_model_parallel_rank() == 0:
|
|
(
|
|
prompts_plus_generations,
|
|
prompts_plus_generations_segments,
|
|
logprobs,
|
|
_,
|
|
) = generate_and_post_process(
|
|
model,
|
|
prompts=[prompt],
|
|
tokens_to_generate=128,
|
|
return_output_log_probs=True,
|
|
temperature=1.0,
|
|
)
|
|
print_rank_0(prompts_plus_generations)
|
|
else:
|
|
generate_and_post_process(model)
|
|
|
|
def hf_dataset_forword_loop_func(model):
|
|
dataloader = get_calib_dataloader(args.calib_dataset, args.calib_batch_size, args.calib_size)
|
|
for prompts in tqdm(dataloader, total=args.calib_size//args.calib_batch_size):
|
|
if mpu.is_pipeline_first_stage() and mpu.get_tensor_model_parallel_rank() == 0:
|
|
(
|
|
prompts_plus_generations,
|
|
prompts_plus_generations_segments,
|
|
logprobs,
|
|
_,
|
|
) = generate_and_post_process(
|
|
model,
|
|
prompts=prompts,
|
|
tokens_to_generate=0,
|
|
return_output_log_probs=True,
|
|
temperature=1.0,
|
|
)
|
|
else:
|
|
generate_and_post_process(model)
|
|
|
|
ptq_forward_loop_func = custom_prompt_forward_loop_func
|
|
if args.calib_dataset is not None:
|
|
ptq_forward_loop_func = hf_dataset_forword_loop_func
|
|
|
|
# Setting data parallel and tensor parallel group
|
|
set_data_parallel_group(mpu.get_data_parallel_group())
|
|
set_tensor_parallel_group(mpu.get_tensor_model_parallel_group())
|
|
|
|
if args.export_quant_cfg in QUANT_CFG_CHOICES:
|
|
mtq_config = QUANT_CFG_CHOICES[args.export_quant_cfg]
|
|
if "*output_layer*" not in mtq_config["quant_cfg"]:
|
|
mtq_config["quant_cfg"]["*output_layer*"] = {"enable": False}
|
|
if "awq" in args.export_quant_cfg:
|
|
weight_quantizer = mtq_config["quant_cfg"]["*weight_quantizer"] # type: ignore
|
|
if isinstance(weight_quantizer, list):
|
|
weight_quantizer = weight_quantizer[0]
|
|
weight_quantizer["block_sizes"][-1] = 128
|
|
print_rank_0("Quantizing the model...")
|
|
mtq.quantize(unwrapped_model[0], mtq_config, ptq_forward_loop_func)
|
|
|
|
custom_prompt_forward_loop_func(model[0])
|
|
|
|
if args.save is not None and args.export_quant_cfg in QUANT_CFG_CHOICES:
|
|
save_checkpoint(1, unwrapped_model, None, None, 0)
|
|
|
|
print_rank_0(f"Fake Quantized Model:\n {unwrapped_model[0]}")
|
|
|
|
if args.export_dir:
|
|
assert args.decoder in ["gptnext", "llama"], f"Decoder type {args.decoder} not supported."
|
|
Path(args.export_dir).mkdir(parents=True, exist_ok=True)
|
|
print_rank_0("Exporting TensorRT-LLM checkpoints.")
|
|
|
|
from modelopt.torch.export import export_tensorrt_llm_checkpoint
|
|
|
|
# In TRT LLM, squared relu activation does not support bf16. So we use fp16 by default.
|
|
export_tensorrt_llm_checkpoint(
|
|
unwrapped_model[0],
|
|
args.decoder,
|
|
torch.bfloat16 if args.bf16 else torch.float16,
|
|
export_dir=args.export_dir,
|
|
inference_tensor_parallel=args.inference_tensor_parallel,
|
|
inference_pipeline_parallel=1,
|
|
use_nfs_workspace=True,
|
|
)
|
|
|
|
print_rank_0(f"TensorRT-LLM checkpoints saved to {args.export_dir}")
|