Files
heterogeneous-distributed-t…/examples/multimodal/dataset_helpers.py
tianyutong d6ce507681 Initial Commit of Megatron-LM-0.8.0
Change-Id: Ifb4c061207ee2644a21e161ad52fc6ff40564e39
2025-05-23 09:54:48 +08:00

522 lines
19 KiB
Python

# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import dataclasses
import json
import random
import re
import sys
import traceback
from dataclasses import dataclass
from typing import Any, List, Dict, Optional, Tuple, Union
import numpy as np
import torch
from PIL import Image, ImageDraw
from torchvision import transforms as T
from torchvision.transforms import Compose, RandAugment, RandomResizedCrop, Resize, ToPILImage
from megatron.core import mpu
from megatron.energon import Batch, CaptioningSample, DefaultTaskEncoder, OCRSample, VQASample
from megatron.energon.transforms import CustomTransform, MergeTransform
from megatron.training import get_args
from megatron.training.tokenizer import build_tokenizer
try:
from torchvision.transforms import InterpolationMode
BICUBIC = InterpolationMode.BICUBIC
except ImportError:
BICUBIC = Image.BICUBIC
# Imagenet's mean and std.
pixel_mean = [123.675, 116.28, 103.53]
pixel_std = [58.395, 57.12, 57.375]
def convert_to_rgb(image):
return image.convert("RGB")
def _transform_train(img_h, img_w):
return Compose([
ToPILImage(),
RandomResizedCrop((img_h, img_w), scale=(0.5, 1.0)),
convert_to_rgb,
])
def _transform_train_aug(img_h, img_w):
return Compose([
ToPILImage(),
RandomResizedCrop((img_h, img_w), scale=(0.5, 1.0)),
convert_to_rgb,
RandAugment(2, 5, isPIL=True, augs=['Identity', 'AutoContrast', 'Brightness', 'Sharpness', 'Equalize',
'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']),
])
def _transform_test(img_h, img_w):
return Compose([
ToPILImage(),
Resize((img_h, img_w)),
convert_to_rgb,
])
class RandomResize(CustomTransform):
"""Resizes the image by a random scale factor in the given interval, but at most max_size"""
def __init__(self, min_scale: float, max_scale: float, max_size: int):
self._min_scale = min_scale
self._max_scale = max_scale
self._max_size = max_size
def apply_transform(self, matrix: np.ndarray, dst_size: np.ndarray) -> Tuple[Any, Any, Any]:
scale = random.uniform(self._min_scale, self._max_scale)
new_size = tuple(int(x * scale) for x in dst_size)
if max(new_size) > self._max_size:
scale = self._max_size / max(new_size)
new_size = tuple(int(x * scale) for x in dst_size)
matrix = self.scale(scale, scale) @ matrix
dst_size = np.array(new_size, dtype=dst_size.dtype)
return matrix, dst_size, (self.__class__.__name__, scale)
class RandomResizeLongEdge(CustomTransform):
"""Resizes the image's longer edge to a random length between min_size and max_size pixels."""
def __init__(self, min_size: int, max_size: int):
self._min_size = min_size
self._max_size = max_size
def apply_transform(self, matrix: np.ndarray, dst_size: np.ndarray) -> Tuple[Any, Any, Any]:
new_long = random.randint(self._min_size, self._max_size)
if dst_size[0] > dst_size[1]: # h > w
new_w, new_h = int(new_long * dst_size[1] / dst_size[0]), new_long
else: # w > h
new_w, new_h = new_long, int(new_long * dst_size[0] / dst_size[1])
new_size = (new_h, new_w)
matrix = self.scale(new_w / dst_size[1], new_h / dst_size[0]) @ matrix
dst_size = np.array(new_size, dtype=dst_size.dtype)
return matrix, dst_size, (self.__class__.__name__, new_size)
class RandomPad(CustomTransform):
"""Pads the image to the given size, randomly choosing the position of the image within the new larger image.
If the image is already larger than the given size, it will not be padded in that direction(s)."""
def __init__(self, size: Tuple[int, int]):
self._new_size = size # h, w
def apply_transform(self, matrix: np.ndarray, dst_size: np.ndarray) -> Tuple[Any, Any, Any]:
h_pad = max(self._new_size[0] - dst_size[0], 0)
w_pad = max(self._new_size[1] - dst_size[1], 0)
if h_pad == 0 and w_pad == 0:
return matrix, dst_size, (self.__class__.__name__, None)
else:
# TODO: fix me
# top = random.randint(0, h_pad)
# left = random.randint(0, w_pad)
top = 0
left = 0
matrix = self.translate(left, top) @ matrix
dst_size = np.array(self._new_size, dtype=dst_size.dtype)
return matrix, dst_size, (self.__class__.__name__, (top, left))
def _get_ocr_document_visual_transform(IMG_H=1024, IMG_W=1024):
document_visual_transform = T.Compose(
[
MergeTransform(
[
# T.RandomResizedCrop(size=FINAL_SIZE, scale=(0.5, 1.0), ratio=(0.8, 1.2)),
RandomResizeLongEdge(960, 1008), # Note: 1008 comes from list(range(960, 1024, 16))[-1]
T.RandomRotation(5, interpolation=T.InterpolationMode.BILINEAR),
T.RandomPerspective(distortion_scale=0.1, p=0.1),
RandomPad((IMG_H, IMG_W)),
]
),
T.ColorJitter(brightness=(0.8, 1.2), contrast=(0.7, 1.0)),
T.RandomGrayscale(p=0.5),
T.RandomInvert(p=0.5),
T.RandomAdjustSharpness(sharpness_factor=0.0, p=0.5),
T.RandomAdjustSharpness(sharpness_factor=2.0, p=0.5),
# LogImage(),
# T.ToTensor(),
# T.Normalize(IMAGE_MEAN, IMAGE_STD),
]
)
return document_visual_transform
def _get_ocr_document_identity_transform(IMG_H=1024, IMG_W=1024):
long_edge = max(IMG_H, IMG_W)
document_identity_transform = T.Compose(
[
MergeTransform(
[
RandomResizeLongEdge(long_edge, long_edge),
RandomPad((long_edge, long_edge)),
]
)
]
)
return document_identity_transform
def _get_ocr_paragraph_visual_transform(IMG_H=1024, IMG_W=1024):
paragraph_visual_transform = T.Compose(
[
MergeTransform(
[
# T.RandomResizedCrop(size=FINAL_SIZE, scale=(0.5, 1.0), ratio=(0.8, 1.2)),
RandomResize(0.5, 2.0, min(IMG_H, IMG_W)), #FINAL_SIZE),
T.RandomRotation(1, interpolation=T.InterpolationMode.BILINEAR),
T.RandomPerspective(distortion_scale=0.1, p=0.1),
RandomPad((IMG_H, IMG_W)),
]
),
T.ColorJitter(brightness=(0.8, 1.2), contrast=(0.7, 1.0)),
T.RandomGrayscale(p=0.5),
T.RandomInvert(p=0.5),
# T.RandomAdjustSharpness(sharpness_factor=0.0, p=0.5),
# T.RandomAdjustSharpness(sharpness_factor=2.0, p=0.5),
# LogImage(),
# T.ToTensor(),
# T.Normalize(IMAGE_MEAN, IMAGE_STD),
]
)
return paragraph_visual_transform
# Type for intermediate batch, after batch()
@dataclass
class ImageTaskSample:
__key__: str
__subflavors__: Dict
# (c, h, w)
img: torch.Tensor
text: np.ndarray
prompt_len: np.int64
img_clip: Optional[torch.Tensor] = None
# Typing for the resulting batch data after encode_batch()
@dataclass
class ImageTaskBatch(Batch):
__keys__: List[str]
__subflavors__: List[Dict]
# (n, c, h, w)
img: torch.Tensor
# (n, seq_len)
text: torch.Tensor
# (n, 1)
prompt_len: torch.Tensor
# (n, c, h, w)
img_clip: Optional[torch.Tensor] = None
class IdentitySplitter(object):
def tokenize(self, *text):
return text
class Tokenizer:
def __init__(self):
args = get_args()
self.args = args
self.IMAGE_TOKEN_INDEX = -200
self.initializer()
def initializer(self):
# Use Encoder class as a container for global data
Tokenizer.tokenizer = build_tokenizer(self.args)
if hasattr(Tokenizer.tokenizer, 'eod'):
self.eod_token = Tokenizer.tokenizer.eod
elif hasattr(Tokenizer.tokenizer, 'eos_id'):
self.eod_token = Tokenizer.tokenizer.eos_id
else:
raise AttributeError('No eod token found in Tokenizer')
self.split_token = 313131
if (
hasattr(self.args, "split_sentences") and self.args.split_sentences
): # default false
if not nltk_available:
print("NLTK is not available to split sentences.")
exit()
library = "tokenizers/punkt/{}.pickle".format("english")
# print("loading: " + library)
splitter = nltk.load(library)
if self.args.keep_newlines:
# this prevents punkt from eating newlines after sentences
Tokenizer.splitter = nltk.tokenize.punkt.PunktSentenceTokenizer(
train_text=splitter._params, lang_vars=CustomLanguageVars()
)
else:
Tokenizer.splitter = splitter
else:
Tokenizer.splitter = IdentitySplitter()
def __call__(self, text: str, padded: bool = True): # -> torch.Tensor:
sentence = Tokenizer.splitter.tokenize(text)[0]
sentence = Tokenizer.tokenizer.tokenize(sentence)
return sentence
def pad(self, content, seq_len=1024):
out = np.pad(content, pad_width=(0,max(0,seq_len-len(content))), mode='constant', constant_values=self.eod_token)
return out
class TaskEncoder(DefaultTaskEncoder[OCRSample, OCRSample, ImageTaskBatch, dict]):
"""A simple task encoder for captioning."""
def __init__(
self
):
# Specify the batch_type for default batching (batching is performed here "manually" by
# overwriting the `batch` method)
super().__init__()
self.args = get_args()
self.tokenizer = Tokenizer()
self.manual_prompts = json.load(open(self.args.prompt_path))
self.seq_len = self.args.seq_length
self.txt_to_token_dict = {}
self.img_h, self.img_w = self.args.img_h, self.args.img_w
self.pixel_mean = torch.Tensor(pixel_mean).view(-1, 1, 1)
self.pixel_std = torch.Tensor(pixel_std).view(-1, 1, 1)
self.ocr_document_visual_transform = _get_ocr_document_visual_transform(self.img_h, self.img_w)
self.ocr_document_identity_transform = _get_ocr_document_identity_transform(self.img_h, self.img_w)
self.ocr_paragraph_visual_transform = _get_ocr_paragraph_visual_transform(self.img_h, self.img_w)
def get_visual_transform(self, img_sample, sample_augmentation=False):
raw_h, raw_w = img_sample.shape[0], img_sample.shape[1]
ratio = float(max(self.img_h, self.img_w)) / max(raw_h, raw_w)
scaled_h, scaled_w = int(raw_h * ratio + 0.5), int(raw_w * ratio + 0.5)
# if the sample needs augmentation or not
if sample_augmentation:
# further check if augmentation is a global flag in args
if self.args.aug:
visual_transform = _transform_train_aug(scaled_h, scaled_w)
else:
visual_transform = _transform_train(scaled_h, scaled_w)
else:
visual_transform = _transform_test(scaled_h, scaled_w)
img = visual_transform(img_sample)
# Normalize pixel values.
img = (torch.Tensor(np.array(img)).permute(2, 0, 1) - self.pixel_mean) / self.pixel_std
# Pad to target image size.
delta_h, delta_w = self.img_h - scaled_h, self.img_w - scaled_w
img = torch.nn.functional.pad(img, (0, delta_w, 0, delta_h))
return img
def encode_sample(self, sample: Union[
CaptioningSample, OCRSample, VQASample]
):
if isinstance(sample, OCRSample):
yield self.encode_ocr(sample)
elif isinstance(sample, CaptioningSample):
yield self.encode_captioning(sample)
elif isinstance(sample, VQASample):
yield self.encode_vqa(sample)
else:
raise NotImplementedError('Sample format not supported')
yield None
def encode_captioning(self, sample: CaptioningSample):
sample_augmentation = sample.__subflavors__["augmentation"] == True
img = self.get_visual_transform(np.array(sample.image), sample_augmentation=sample_augmentation)
# randomly select a prompt
if 'CaptioningDetailed' in sample.__subflavors__["type"]:
prompt_idx = np.random.randint(len(self.manual_prompts["CaptioningDetailed"]["raw"]))
cur_prompt = self.manual_prompts["CaptioningDetailed"]["raw"][prompt_idx]
else:
prompt_idx = np.random.randint(len(self.manual_prompts["Captioning"]["raw"]))
cur_prompt = self.manual_prompts["Captioning"]["raw"][prompt_idx]
if cur_prompt not in self.txt_to_token_dict:
self.txt_to_token_dict[cur_prompt] = self.tokenizer(cur_prompt)
cur_prompt = self.txt_to_token_dict[cur_prompt]
prompt_len = len(cur_prompt)
caption = sample.caption
if 'SplitByLine' in sample.__subflavors__["type"]:
# caption = re.sub(r"\n+", "\n", caption)
caption_list = caption.split('\n')
caption_list = [caption for caption in caption_list if caption.strip() != '']
caption = np.random.choice(caption_list)
caption_token = self.tokenizer(caption.strip())
if len(caption.strip()) == 0:
raise RuntimeError('Empty string in caption!')
seq_len = self.seq_len + 4
text_sample = np.concatenate([[self.tokenizer.IMAGE_TOKEN_INDEX], cur_prompt, caption_token])
text_sample = self.tokenizer.pad(text_sample, seq_len)
text_sample = text_sample[:seq_len]
return ImageTaskSample(
__key__=sample.__key__,
__subflavors__=sample.__subflavors__,
img=img,
text=text_sample,
prompt_len=prompt_len
)
def encode_vqa(self, sample: VQASample):
task_name = None
no_image_flag = True if '-noimage' in sample.__key__ else False
if 'pretrain' in sample.__key__:
task_name = 'pretrain'
else:
task_name = sample.__key__.split("/")[0]
sample_augmentation = sample.__subflavors__["augmentation"] == True
if no_image_flag:
img = torch.from_numpy(np.array([0]).astype(np.float32))
else:
img = self.get_visual_transform(np.array(sample.image), sample_augmentation=sample_augmentation)
if "<image>" in sample.context:
sample.context = sample.context.replace("<image>","")
if task_name != 'pretrain' and sample.context[-1:] != "\n":
sample.context = sample.context + "\n"
question = sample.context
if isinstance(sample.answers, list):
answer_list = sample.answers
weight_list = np.array(sample.answer_weights).astype(np.float32)
weight_list = weight_list / np.sum(weight_list)
answer_idx = np.random.choice(weight_list.shape[0], 1, p=weight_list)[0]
answer = answer_list[answer_idx]
else:
answer = sample.answers
question_token = self.tokenizer.tokenizer.instruct_tokenize(question)
answer_token = self.tokenizer(answer)
prompt_len = len(question_token)
seq_len = self.seq_len + 4
text_sample = np.concatenate([[self.tokenizer.IMAGE_TOKEN_INDEX], question_token, answer_token])
text_sample = self.tokenizer.pad(text_sample, seq_len)
return ImageTaskSample(
__key__=sample.__key__,
__subflavors__=sample.__subflavors__,
img=img,
text=text_sample,
prompt_len=prompt_len
)
def encode_ocr(self, sample: OCRSample) -> ImageTaskSample:
if sample.__subflavors__["type"] == "document":
visual_transform = self.ocr_document_visual_transform
elif sample.__subflavors__["type"] == "paragraph":
visual_transform = self.ocr_paragraph_visual_transform
elif sample.__subflavors__["augmentation"] == False:
visual_transform = self.ocr_document_identity_transform
else:
raise ValueError(f"Unknown subflavor {sample.__subflavors__}")
if sample.words_boxes is not None and sample.words_boxes.shape[1] >= 5:
# Boxes with conf below 0.9 are skipped
filter_words_mask = sample.words_boxes[:, 4] < 0.9
filter_boxes = sample.words_boxes[filter_words_mask, :4]
for x, y, x2, y2 in filter_boxes:
if isinstance(sample.image, Image.Image):
draw = ImageDraw.Draw(sample.image)
draw.rectangle([int(x), int(y), (int(x2), int(y2))], fill=0)
else:
sample.image[:, int(y) : int(y2) + 1, int(x) : int(x2) + 1] = 0
text = " ".join(
text for skip, text in zip(filter_words_mask, sample.words_text) if not skip
)
else:
text = " ".join(sample.text.splitlines())
match = re.search(r'"text_sequence": "(.*?)"', text)
if match:
text = match.group(1)
img = visual_transform(sample.image)
img_clip = None
img = (torch.Tensor(np.array(img)).permute(2, 0, 1) - self.pixel_mean) / self.pixel_std
img = torch.nn.functional.pad(img, (0, self.img_w - img.shape[2], 0, self.img_h - img.shape[1]))
# randomly select a prompt
prompt_idx = np.random.randint(len(self.manual_prompts["OCR"]["raw"]))
cur_prompt = self.manual_prompts["OCR"]["raw"][prompt_idx]
if cur_prompt not in self.txt_to_token_dict:
self.txt_to_token_dict[cur_prompt] = self.tokenizer(cur_prompt)
cur_prompt = self.txt_to_token_dict[cur_prompt]
text_sample = self.tokenizer(text)
prompt_len = len(cur_prompt)
seq_len = self.seq_len + 4
text_sample = np.concatenate([cur_prompt, text_sample])
text_sample = self.tokenizer.pad(text_sample, seq_len=seq_len)
text_sample = text_sample[:seq_len]
return ImageTaskSample(
__key__=sample.__key__,
__subflavors__=sample.__subflavors__,
img=img,
img_clip=img_clip,
text=text_sample,
prompt_len=prompt_len
)
def batch(self, samples: List[ImageTaskSample]) -> ImageTaskBatch:
batch = ImageTaskBatch(
__keys__=[s.__key__ for s in samples],
__subflavors__=[s.__subflavors__ for s in samples],
img=torch.stack([s.img for s in samples]),
text=torch.from_numpy(np.stack([s.text for s in samples], axis=0).astype(np.int64)),
prompt_len=torch.from_numpy(np.array([s.prompt_len for s in samples], dtype=np.int64))
)
return batch
def encode_batch(self, batch: ImageTaskBatch) -> dict:
raw = dataclasses.asdict(batch)
del raw["__subflavors__"]
return raw
def print_error_handler(exc: Exception, key: Optional[str]):
print(
f"The following exception occurred in the dataloader for sample {key} and is skipped",
file=sys.stderr,
)
traceback.print_exc()