Spaces:
Running
on
Zero
Running
on
Zero
import copy | |
import json | |
import math | |
import os | |
import random | |
from dataclasses import dataclass, field | |
from typing import Dict, Optional, Sequence | |
import numpy as np | |
import torch | |
import transformers | |
from PIL import Image | |
from torch.nn.utils.rnn import pad_sequence | |
from torch.utils.data import Dataset | |
from decord import VideoReader, cpu | |
from vita import conversation as conversation_lib | |
from vita.config import AudioFolder, DataConfig, FolderDict, NoPatchSets | |
from vita.constants import ( | |
DEFAULT_AUDIO_TOKEN, | |
DEFAULT_DATA_RATIO, | |
DEFAULT_IMAGE_TOKEN, | |
DEFAULT_VIDEO_TOKEN, | |
IGNORE_INDEX, | |
MAX_IMAGE_LENGTH, | |
MIN_IMAGE_LENGTH, | |
) | |
from vita.util.mm_utils import tokenizer_image_audio_token, tokenizer_image_token | |
class DataArguments: | |
lazy_preprocess: bool = False | |
is_multimodal: bool = True | |
image_folder: Optional[str] = field(default=None) | |
image_aspect_ratio: str = field(default=None) | |
dataset_use: str = field(default="temp") | |
min_dynamic_patch: int = 1 | |
max_dynamic_patch: int = 12 | |
use_thumbnail: bool = True | |
def preprocess_multimodal( | |
sources: Sequence[str], | |
data_args: DataArguments, | |
image_token_num=1, | |
patch_num=[1], | |
audio_lens: int = 0, | |
inserted_id=None, | |
) -> Dict: | |
is_multimodal = data_args.is_multimodal | |
if not is_multimodal: | |
return sources | |
k_img_ph = 0 | |
for source in sources: | |
if inserted_id is not None: | |
assert source[inserted_id]["from"] == "gpt" | |
for i, sentence in enumerate(source): | |
if DEFAULT_IMAGE_TOKEN in sentence["value"] or DEFAULT_VIDEO_TOKEN in sentence["value"]: | |
sentence["value"] = ( | |
sentence["value"] | |
.replace(DEFAULT_IMAGE_TOKEN + "\n", DEFAULT_IMAGE_TOKEN) | |
.strip() | |
) | |
sentence["value"] = ( | |
sentence["value"] | |
.replace("\n" + DEFAULT_IMAGE_TOKEN, DEFAULT_IMAGE_TOKEN) | |
.strip() | |
) | |
if sentence["value"].endswith(DEFAULT_IMAGE_TOKEN): | |
IMAGE_TOKEN_NUM = sentence["value"].count(DEFAULT_IMAGE_TOKEN) | |
sentence["value"] = ( | |
sentence["value"].replace(DEFAULT_IMAGE_TOKEN * IMAGE_TOKEN_NUM, "").strip() | |
) | |
sentence["value"] = DEFAULT_IMAGE_TOKEN * IMAGE_TOKEN_NUM + sentence["value"] | |
sentence["value"] = sentence["value"].strip() | |
if sentence["value"].endswith(DEFAULT_VIDEO_TOKEN): | |
VIDEO_TOKEN_NUM = sentence["value"].count(DEFAULT_VIDEO_TOKEN) | |
sentence["value"] = ( | |
sentence["value"].replace(DEFAULT_VIDEO_TOKEN * VIDEO_TOKEN_NUM, "").strip() | |
) | |
sentence["value"] = DEFAULT_VIDEO_TOKEN * VIDEO_TOKEN_NUM + sentence["value"] | |
sentence["value"] = sentence["value"].strip() | |
if "mmtag" in conversation_lib.default_conversation.version: | |
sentence["value"] = sentence["value"].replace( | |
DEFAULT_IMAGE_TOKEN, "<Image>" + DEFAULT_IMAGE_TOKEN + "</Image>" | |
) | |
IMAGE_TOKEN_NUM = sentence["value"].count(DEFAULT_IMAGE_TOKEN) | |
if IMAGE_TOKEN_NUM > 99999:#MAX_IMAGE_LENGTH: | |
sentence["value"] = ( | |
sentence["value"] | |
.replace( | |
DEFAULT_IMAGE_TOKEN * IMAGE_TOKEN_NUM, | |
DEFAULT_IMAGE_TOKEN * MAX_IMAGE_LENGTH, | |
) | |
.strip() | |
) | |
replace_token, vid_replace_token, audio_replace_token = ( | |
DEFAULT_IMAGE_TOKEN, | |
DEFAULT_IMAGE_TOKEN * image_token_num, | |
DEFAULT_AUDIO_TOKEN, | |
) # * audio_lens | |
if DEFAULT_IMAGE_TOKEN in sentence["value"]: | |
replace_token = DEFAULT_IMAGE_TOKEN * patch_num[k_img_ph] | |
k_img_ph += 1 | |
sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token + "\n") | |
sentence["value"] = sentence["value"].replace( | |
DEFAULT_VIDEO_TOKEN, vid_replace_token + "\n" | |
) | |
sentence["value"] = sentence["value"].replace( | |
DEFAULT_AUDIO_TOKEN + "\n", audio_replace_token | |
) | |
sentence["value"] = sentence["value"].replace("\n\n", "\n") | |
#if i == inserted_id: | |
# assert sentence["from"] == "gpt" | |
# sentence["value"] = "<2>" + sentence["value"] | |
#elif sentence["from"] == "gpt": | |
# if "<audio>" in source[i - 1]["value"]: | |
# sentence["value"] = "<1>" + sentence["value"] | |
# else: | |
# sentence["value"] = "<3>" + sentence["value"] | |
if i == inserted_id:# ☞ ☟ ☜ | |
assert sentence["from"] == "gpt" | |
sentence["value"] = "☟" + sentence["value"] | |
elif sentence["from"] == "gpt": | |
if "<audio>" in source[i - 1]["value"]: | |
sentence["value"] = "☞" + sentence["value"] | |
else: | |
sentence["value"] = "☜" + sentence["value"] | |
# print(patch_num) | |
# print(sum(patch_num)) | |
#print(sources) | |
#import pdb; pdb.set_trace() | |
return sources | |
def preprocess_mixtral_zh( | |
sources, | |
tokenizer: transformers.PreTrainedTokenizer, | |
has_image: bool = False, | |
has_audio: bool = False, | |
end_tag: bool = True, | |
) -> Dict: | |
conv = conversation_lib.default_conversation.copy() | |
roles = {"human": conv.roles[0], "gpt": conv.roles[1]} | |
# Apply prompt templates | |
conversations = [] | |
for i, source in enumerate(sources): | |
if roles[source[0]["from"]] != conv.roles[0]: | |
# Skip the first one if it is not from human | |
source = source[1:] | |
conv.messages = [] | |
for j, sentence in enumerate(source): | |
role = roles[sentence["from"]] | |
assert role == conv.roles[j % 2], f"{i}" | |
conv.append_message(role, sentence["value"]) | |
conversations.append(conv.get_prompt()) | |
# Tokenize conversations | |
if not end_tag: | |
conversations[0] = conversations[0][:-4] | |
if has_image and not has_audio: | |
input_ids = torch.stack( | |
[ | |
tokenizer_image_token(prompt, tokenizer, return_tensors="pt") | |
for prompt in conversations | |
], | |
dim=0, | |
) | |
elif has_image and has_audio: | |
input_ids = torch.stack( | |
[ | |
tokenizer_image_audio_token(prompt, tokenizer, return_tensors="pt") | |
for prompt in conversations | |
], | |
dim=0, | |
) | |
elif not has_image and has_audio: | |
input_ids = torch.stack( | |
[ | |
tokenizer_image_audio_token(prompt, tokenizer, return_tensors="pt") | |
for prompt in conversations | |
], | |
dim=0, | |
) | |
else: | |
input_ids = tokenizer( | |
conversations, | |
return_tensors="pt", | |
padding="longest", | |
max_length=tokenizer.model_max_length, | |
truncation=True, | |
).input_ids | |
#print(f'end_tag: {end_tag}') | |
#print(conversations) | |
#print(input_ids) | |
#import pdb; pdb.set_trace() | |
targets = input_ids.clone() | |
assert conv.sep_style == conversation_lib.SeparatorStyle.MixtralZh | |
# Mask targets | |
sep = conv.sep + "\n" + conv.roles[1] + ":" | |
sep2_2 = "\n" + conv.roles[0] + ":" | |
sep2 = conv.sep2 + sep2_2 | |
for conversation, target in zip(conversations, targets): | |
total_len = int(target.ne(tokenizer.pad_token_id).sum()) | |
rounds = conversation.split(sep2) | |
rounds = [rounds[0] + sep2 + rounds[1]] + rounds[2:] | |
cur_len = 1 | |
end_token_cnt = 0 | |
target[:cur_len] = IGNORE_INDEX | |
for i, rou in enumerate(rounds): | |
if rou == "": | |
break | |
if i > 0: | |
rou = sep2_2 + rou | |
parts = rou.split(sep) | |
if len(parts) != 2: | |
break | |
parts[0] += sep | |
if has_image and not has_audio: | |
round_len = len(tokenizer_image_token(rou, tokenizer)) | |
instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 1 | |
elif has_image and has_audio: | |
round_len = len(tokenizer_image_audio_token(rou, tokenizer)) | |
instruction_len = len(tokenizer_image_audio_token(parts[0], tokenizer)) - 1 | |
elif not has_image and has_audio: | |
round_len = len(tokenizer_image_audio_token(rou, tokenizer)) | |
instruction_len = len(tokenizer_image_audio_token(parts[0], tokenizer)) - 1 | |
else: | |
round_len = len(tokenizer(rou).input_ids) | |
instruction_len = len(tokenizer(parts[0]).input_ids) - 2 | |
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX | |
end_token_cnt += 1 | |
cur_len += round_len | |
cur_len = cur_len - 1 | |
target[cur_len:] = IGNORE_INDEX | |
if tokenizer.pad_token_id == tokenizer.eos_token_id: | |
cur_len -= end_token_cnt | |
if cur_len < tokenizer.model_max_length: | |
if cur_len != total_len: | |
target[:] = IGNORE_INDEX | |
print(f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." f" (ignored)") | |
# print(f"YOU NEED GO TO DEBUG THIS DATA ITEM: {conversations}") | |
return dict( | |
input_ids=input_ids, | |
labels=targets, | |
) | |
def preprocess_mixtral_two( | |
sources, | |
tokenizer: transformers.PreTrainedTokenizer, | |
has_image: bool = False, | |
has_audio: bool = False, | |
end_tag: bool = True, | |
modality: str = "lang", | |
) -> Dict: | |
conv = conversation_lib.default_conversation.copy() | |
roles = {"human": conv.roles[0], "gpt": conv.roles[1]} | |
# Apply prompt templates | |
conversations = [] | |
for i, source in enumerate(sources): | |
if roles[source[0]["from"]] != conv.roles[0]: | |
# Skip the first one if it is not from human | |
source = source[1:] | |
conv.messages = [] | |
for j, sentence in enumerate(source): | |
role = roles[sentence["from"]] | |
assert role == conv.roles[j % 2], f"{i}" | |
conv.append_message(role, sentence["value"]) | |
conversations.append(conv.get_prompt(modality)) | |
# print(conversations) | |
# import pdb; pdb.set_trace() | |
# Tokenize conversations | |
if not end_tag: | |
conversations[0] = conversations[0][:-4] | |
if has_image and not has_audio: | |
input_ids = torch.stack( | |
[ | |
tokenizer_image_token(prompt, tokenizer, return_tensors="pt") | |
for prompt in conversations | |
], | |
dim=0, | |
) | |
elif has_image and has_audio: | |
input_ids = torch.stack( | |
[ | |
tokenizer_image_audio_token(prompt, tokenizer, return_tensors="pt") | |
for prompt in conversations | |
], | |
dim=0, | |
) | |
elif not has_image and has_audio: | |
input_ids = torch.stack( | |
[ | |
tokenizer_image_audio_token(prompt, tokenizer, return_tensors="pt") | |
for prompt in conversations | |
], | |
dim=0, | |
) | |
else: | |
input_ids = tokenizer( | |
conversations, | |
return_tensors="pt", | |
padding="longest", | |
max_length=tokenizer.model_max_length, | |
truncation=True, | |
).input_ids | |
#print(f'end_tag: {end_tag}') | |
print(conversations) | |
print(input_ids) | |
import pdb; pdb.set_trace() | |
targets = input_ids.clone() | |
assert conv.sep_style == conversation_lib.SeparatorStyle.MixtralTwo | |
# Mask targets | |
sep = conv.sep + "\n" + conv.roles[1] + ":" | |
sep2_2 = "\n" + conv.roles[0] + ":" | |
sep2 = conv.sep2 + sep2_2 | |
for conversation, target in zip(conversations, targets): | |
total_len = int(target.ne(tokenizer.pad_token_id).sum()) | |
rounds = conversation.split(sep2) | |
rounds = [rounds[0] + sep2 + rounds[1]] + rounds[2:] | |
cur_len = 1 | |
end_token_cnt = 0 | |
target[:cur_len] = IGNORE_INDEX | |
for i, rou in enumerate(rounds): | |
if rou == "": | |
break | |
if i > 0: | |
rou = sep2_2 + rou | |
parts = rou.split(sep) | |
if len(parts) != 2: | |
break | |
parts[0] += sep | |
import pdb; pdb.set_trace() | |
if has_image and not has_audio: | |
round_len = len(tokenizer_image_token(rou, tokenizer)) | |
instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 1 | |
elif has_image and has_audio: | |
round_len = len(tokenizer_image_audio_token(rou, tokenizer)) | |
instruction_len = len(tokenizer_image_audio_token(parts[0], tokenizer)) - 1 | |
elif not has_image and has_audio: | |
round_len = len(tokenizer_image_audio_token(rou, tokenizer)) | |
instruction_len = len(tokenizer_image_audio_token(parts[0], tokenizer)) - 1 | |
else: | |
round_len = len(tokenizer(rou).input_ids) | |
instruction_len = len(tokenizer(parts[0]).input_ids) - 2 | |
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX | |
end_token_cnt += 1 | |
cur_len += round_len | |
cur_len = cur_len - 1 | |
target[cur_len:] = IGNORE_INDEX | |
if tokenizer.pad_token_id == tokenizer.eos_token_id: | |
cur_len -= end_token_cnt | |
if cur_len < tokenizer.model_max_length: | |
if cur_len != total_len: | |
target[:] = IGNORE_INDEX | |
print(f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." f" (ignored)") | |
# print(f"YOU NEED GO TO DEBUG THIS DATA ITEM: {conversations}") | |
import pdb; pdb.set_trace() | |
return dict( | |
input_ids=input_ids, | |
labels=targets, | |
) | |
def preprocess_nemo( | |
sources, | |
tokenizer: transformers.PreTrainedTokenizer, | |
has_image: bool = False, | |
has_audio: bool = False, | |
end_tag: bool = True, | |
modality: str = "lang", | |
) -> Dict: | |
conv = conversation_lib.default_conversation.copy() | |
roles = {"human": conv.roles[0], "gpt": conv.roles[1]} | |
# Apply prompt templates | |
conversations = [] | |
for i, source in enumerate(sources): | |
if roles[source[0]["from"]] != conv.roles[0]: | |
# Skip the first one if it is not from human | |
source = source[1:] | |
conv.messages = [] | |
for j, sentence in enumerate(source): | |
role = roles[sentence["from"]] | |
assert role == conv.roles[j % 2], f"{i}" | |
conv.append_message(role, sentence["value"]) | |
conversations.append(conv.get_prompt(modality)) | |
#print(conversations) | |
#import pdb; pdb.set_trace() | |
if not end_tag: | |
conversations[0] = conversations[0][:-4] | |
# Tokenize conversations | |
if has_image and not has_audio: | |
input_ids = torch.stack( | |
[ | |
tokenizer_image_token(prompt, tokenizer, return_tensors="pt") | |
for prompt in conversations | |
], | |
dim=0, | |
) | |
elif has_image and has_audio: | |
input_ids = torch.stack( | |
[ | |
tokenizer_image_audio_token(prompt, tokenizer, return_tensors="pt") | |
for prompt in conversations | |
], | |
dim=0, | |
) | |
elif not has_image and has_audio: | |
input_ids = torch.stack( | |
[ | |
tokenizer_image_audio_token(prompt, tokenizer, return_tensors="pt") | |
for prompt in conversations | |
], | |
dim=0, | |
) | |
else: | |
input_ids = tokenizer( | |
conversations, | |
return_tensors="pt", | |
padding="longest", | |
max_length=tokenizer.model_max_length, | |
truncation=True, | |
).input_ids | |
print(f'end_tag: {end_tag}') | |
print(conversations) | |
print(input_ids) | |
import pdb; pdb.set_trace() | |
targets = input_ids.clone() | |
assert conv.sep_style == conversation_lib.SeparatorStyle.Nemo | |
# Mask targets | |
sep = conv.sep | |
sep2 = conv.sep2 | |
for conversation, target in zip(conversations, targets): | |
total_len = int(target.ne(tokenizer.pad_token_id).sum()) | |
rounds = conversation.split(sep2) | |
cur_len = 1 | |
end_token_cnt = 0 | |
target[:cur_len] = IGNORE_INDEX | |
for i, rou in enumerate(rounds): | |
if rou == "": | |
break | |
parts = rou.split(sep) | |
if len(parts) != 2: | |
break | |
parts[0] += sep | |
if has_image and not has_audio: | |
round_len = len(tokenizer_image_token(rou, tokenizer)) | |
instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 1 | |
elif has_image and has_audio: | |
round_len = len(tokenizer_image_audio_token(rou, tokenizer)) | |
instruction_len = len(tokenizer_image_audio_token(parts[0], tokenizer)) - 1 | |
elif not has_image and has_audio: | |
round_len = len(tokenizer_image_audio_token(rou, tokenizer)) | |
instruction_len = len(tokenizer_image_audio_token(parts[0], tokenizer)) - 1 | |
else: | |
round_len = len(tokenizer(rou).input_ids) | |
instruction_len = len(tokenizer(parts[0]).input_ids) - 2 | |
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX | |
end_token_cnt += 1 | |
cur_len += round_len | |
target[cur_len:] = IGNORE_INDEX | |
if tokenizer.pad_token_id == tokenizer.eos_token_id: | |
cur_len -= end_token_cnt | |
if cur_len < tokenizer.model_max_length: | |
if cur_len != total_len: | |
target[:] = IGNORE_INDEX | |
print(f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." f" (ignored)") | |
# print(f"YOU NEED GO TO DEBUG THIS DATA ITEM: {conversations}") | |
#import pdb; pdb.set_trace() | |
return dict( | |
input_ids=input_ids, | |
labels=targets, | |
) | |
def preprocess_qwen2p5_instruct( | |
sources, | |
tokenizer: transformers.PreTrainedTokenizer, | |
has_image: bool = False, | |
has_audio: bool = False, | |
end_tag: bool = True, | |
modality: str = "lang", | |
) -> Dict: | |
conv = conversation_lib.default_conversation.copy() | |
roles = {"human": conv.roles[0], "gpt": conv.roles[1]} | |
# Apply prompt templates | |
conversations = [] | |
for i, source in enumerate(sources): | |
if roles[source[0]["from"]] != conv.roles[0]: | |
# Skip the first one if it is not from human | |
source = source[1:] | |
conv.messages = [] | |
for j, sentence in enumerate(source): | |
role = roles[sentence["from"]] | |
assert role == conv.roles[j % 2], f"{i}" | |
conv.append_message(role, sentence["value"]) | |
conversations.append(conv.get_prompt(modality)) | |
#print(conversations) | |
#import pdb; pdb.set_trace() | |
# Tokenize conversations | |
if not end_tag: | |
conversations[0] = conversations[0][:-10] | |
if has_image and not has_audio: | |
input_ids = torch.stack( | |
[ | |
tokenizer_image_token(prompt, tokenizer, return_tensors="pt") | |
for prompt in conversations | |
], | |
dim=0, | |
) | |
elif has_image and has_audio: | |
input_ids = torch.stack( | |
[ | |
tokenizer_image_audio_token(prompt, tokenizer, return_tensors="pt") | |
for prompt in conversations | |
], | |
dim=0, | |
) | |
elif not has_image and has_audio: | |
input_ids = torch.stack( | |
[ | |
tokenizer_image_audio_token(prompt, tokenizer, return_tensors="pt") | |
for prompt in conversations | |
], | |
dim=0, | |
) | |
else: | |
input_ids = tokenizer( | |
conversations, | |
return_tensors="pt", | |
padding="longest", | |
max_length=tokenizer.model_max_length, | |
truncation=True, | |
).input_ids | |
#print(f'end_tag: {end_tag}') | |
#print(conversations) | |
#print(input_ids) | |
#import pdb; pdb.set_trace() | |
targets = input_ids.clone() | |
assert conv.sep_style == conversation_lib.SeparatorStyle.Qwen2p5Instruct | |
# Mask targets | |
sep = '\n' + conv.sep + conv.roles[1] + "\n" #\n<|im_start|>assistant\n | |
sep2 = '\n' + conv.sep2 + conv.roles[0] + "\n" #\n<|im_start|>user\n | |
for conversation, target in zip(conversations, targets): | |
total_len = int(target.ne(tokenizer.pad_token_id).sum()) | |
rounds = conversation.split(sep2) | |
rounds = [rounds[0] + sep2 + rounds[1]] + rounds[2:] | |
cur_len = 0 | |
end_token_cnt = 0 | |
for i, rou in enumerate(rounds): | |
if rou == "": | |
break | |
if i > 0: | |
rou = sep2 + rou | |
parts = rou.split(sep) | |
if len(parts) != 2: | |
break | |
parts[0] += sep | |
#import pdb; pdb.set_trace() | |
if has_image and not has_audio: | |
round_len = len(tokenizer_image_token(rou, tokenizer)) | |
instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) | |
elif has_image and has_audio: | |
round_len = len(tokenizer_image_audio_token(rou, tokenizer)) | |
instruction_len = len(tokenizer_image_audio_token(parts[0], tokenizer)) | |
elif not has_image and has_audio: | |
round_len = len(tokenizer_image_audio_token(rou, tokenizer)) | |
instruction_len = len(tokenizer_image_audio_token(parts[0], tokenizer)) | |
else: | |
round_len = len(tokenizer(rou).input_ids) | |
instruction_len = len(tokenizer(parts[0]).input_ids) | |
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX | |
end_token_cnt += 1 | |
cur_len += round_len | |
target[cur_len:] = IGNORE_INDEX | |
if tokenizer.pad_token_id == tokenizer.eos_token_id: | |
cur_len -= end_token_cnt | |
if cur_len < tokenizer.model_max_length: | |
if cur_len != total_len: | |
target[:] = IGNORE_INDEX | |
print(f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." f" (ignored)") | |
# print(f"YOU NEED GO TO DEBUG THIS DATA ITEM: {conversations}") | |
#print(targets) | |
#import pdb; pdb.set_trace() | |
return dict( | |
input_ids=input_ids, | |
labels=targets, | |
) | |
def preprocess_plain( | |
sources: Sequence[str], | |
tokenizer: transformers.PreTrainedTokenizer, | |
) -> Dict: | |
# add end signal and concatenate together | |
conversations = [] | |
for source in sources: | |
assert len(source) == 2 | |
assert DEFAULT_IMAGE_TOKEN in source[0]["value"] | |
source[0]["value"] = DEFAULT_IMAGE_TOKEN | |
conversation = ( | |
source[0]["value"] + source[1]["value"] + conversation_lib.default_conversation.sep | |
) | |
conversations.append(conversation) | |
# tokenize conversations | |
input_ids = [ | |
tokenizer_image_token(prompt, tokenizer, return_tensors="pt") for prompt in conversations | |
] | |
targets = copy.deepcopy(input_ids) | |
for target, source in zip(targets, sources): | |
tokenized_len = len(tokenizer_image_token(source[0]["value"], tokenizer)) | |
target[:tokenized_len] = IGNORE_INDEX | |
return dict(input_ids=input_ids, labels=targets) | |
def preprocess( | |
sources: Sequence[str], | |
tokenizer: transformers.PreTrainedTokenizer, | |
has_image: bool = False, | |
has_audio: bool = False, | |
end_tag: bool = True, | |
modality: str = "lang", | |
) -> Dict: | |
if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.PLAIN: | |
return preprocess_plain(sources, tokenizer) | |
if conversation_lib.default_conversation.version == "nemo": | |
return preprocess_nemo( | |
sources, | |
tokenizer, | |
has_image=has_image, | |
has_audio=has_audio, | |
end_tag=end_tag, | |
modality=modality, | |
) | |
elif conversation_lib.default_conversation.version == "qwen2p5_instruct": | |
return preprocess_qwen2p5_instruct( | |
sources, | |
tokenizer, | |
has_image=has_image, | |
has_audio=has_audio, | |
end_tag=end_tag, | |
modality=modality, | |
) | |
elif conversation_lib.default_conversation.version == "mixtral_two": | |
return preprocess_mixtral_two( | |
sources, | |
tokenizer, | |
has_image=has_image, | |
has_audio=has_audio, | |
end_tag=end_tag, | |
modality=modality, | |
) | |
def _get_rawvideo_dec( | |
video_path, | |
image_processor, | |
max_frames=32, | |
min_frames=4, | |
image_resolution=384, | |
video_framerate=3, | |
s=None, | |
e=None, | |
image_aspect_ratio="pad", | |
): | |
# speed up video decode via decord. | |
video_mask = np.zeros(max_frames, dtype=np.int64) | |
max_video_length = 0 | |
# T x 3 x H x W | |
video = np.zeros((max_frames, 3, image_resolution, image_resolution), dtype=np.float64) | |
if s is None: | |
start_time, end_time = None, None | |
else: | |
start_time = int(s) | |
end_time = int(e) | |
start_time = start_time if start_time >= 0.0 else 0.0 | |
end_time = end_time if end_time >= 0.0 else 0.0 | |
if start_time > end_time: | |
start_time, end_time = end_time, start_time | |
elif start_time == end_time: | |
end_time = start_time + 1 | |
if os.path.exists(video_path): | |
vreader = VideoReader(video_path, ctx=cpu(0)) | |
else: | |
print(video_path) | |
raise FileNotFoundError | |
fps = vreader.get_avg_fps() | |
f_start = 0 if start_time is None else int(start_time * fps) | |
f_end = int(min(1000000000 if end_time is None else end_time * fps, len(vreader) - 1)) | |
num_frames = f_end - f_start + 1 | |
if num_frames > 0: | |
# T x 3 x H x W | |
sample_fps = int(video_framerate) | |
t_stride = int(round(float(fps) / sample_fps)) | |
all_pos = list(range(f_start, f_end + 1, t_stride)) | |
if len(all_pos) > max_frames: | |
sample_pos = [ | |
all_pos[_] for _ in np.linspace(0, len(all_pos) - 1, num=max_frames, dtype=int) | |
] | |
elif len(all_pos) < min_frames: | |
sample_pos = [ | |
all_pos[_] for _ in np.linspace(0, len(all_pos) - 1, num=min_frames, dtype=int) | |
] | |
else: | |
sample_pos = all_pos | |
patch_images = [Image.fromarray(f) for f in vreader.get_batch(sample_pos).asnumpy()] | |
if image_aspect_ratio == "pad": | |
def expand2square(pil_img, background_color): | |
width, height = pil_img.size | |
if width == height: | |
return pil_img | |
elif width > height: | |
result = Image.new(pil_img.mode, (width, width), background_color) | |
result.paste(pil_img, (0, (width - height) // 2)) | |
return result | |
else: | |
result = Image.new(pil_img.mode, (height, height), background_color) | |
result.paste(pil_img, ((height - width) // 2, 0)) | |
return result | |
patch_images = [ | |
expand2square(i, tuple(int(x * 255) for x in image_processor.image_mean)) | |
for i in patch_images | |
] | |
patch_images = [ | |
image_processor.preprocess(i, return_tensors="pt")["pixel_values"][0] | |
for i in patch_images | |
] | |
else: | |
patch_images = [ | |
image_processor.preprocess(i, return_tensors="pt")["pixel_values"][0] | |
for i in patch_images | |
] | |
# patch_images = [image_processor.preprocess(img, return_tensors='pt')['pixel_values'][0] for img in patch_images] | |
slice_len = len(patch_images) | |
return patch_images, slice_len | |
max_video_length = max_video_length if max_video_length > slice_len else slice_len | |
if slice_len < 1: | |
pass | |
else: | |
while len(patch_images) < max_frames: | |
patch_images.append(torch.zeros((3, image_resolution, image_resolution))) | |
# video[:slice_len, ...] = patch_images | |
else: | |
print("video path: {} error.".format(video_path)) | |
video_mask[:max_video_length] = [1] * max_video_length | |
return patch_images, video_mask | |
class LazySupervisedDataset(Dataset): | |
"""Dataset for supervised fine-tuning.""" | |
def __init__(self, tokenizer: transformers.PreTrainedTokenizer, data_args: DataArguments): | |
super(LazySupervisedDataset, self).__init__() | |
dataset_list = DataConfig[str(data_args.dataset_use)] | |
print(dataset_list) | |
self.max_length = MAX_IMAGE_LENGTH | |
list_data_dict = [] | |
self.folder_dict = {} | |
for i in dataset_list: | |
# list_data_dict += json.load(open(i["chat_path"], "r")) | |
data_ratio = i.get("data_ratio", DEFAULT_DATA_RATIO) | |
data_i = json.load(open(i["chat_path"], "r")) | |
len_data_i = len(data_i) | |
data_i = random.sample(data_i, int(len_data_i * data_ratio)) | |
list_data_dict += data_i | |
image_folder = [folder for folder in i if folder is not "chat_path"] | |
for folder in image_folder: | |
if folder not in self.folder_dict: | |
self.folder_dict[folder] = i[folder] | |
for key in FolderDict.keys(): | |
if key not in self.folder_dict: | |
self.folder_dict[key] = FolderDict[key] | |
random.shuffle(list_data_dict) | |
self.tokenizer = tokenizer | |
self.list_data_dict = list_data_dict | |
self.data_args = data_args | |
def __len__(self): | |
return len(self.list_data_dict) | |
# @property | |
# def lengths(self): | |
# length_list = [] | |
# for sample in self.list_data_dict: | |
# img_tokens = 128 if 'image' in sample else 0 | |
# length_list.append(sum(len(conv['value'].split()) for conv in sample['conversations']) + img_tokens) | |
# return length_list | |
def modality_lengths(self): | |
length_list = [] | |
for sample in self.list_data_dict: | |
cur_len = sum(len(conv["value"].split()) for conv in sample["conversations"]) | |
cur_len = cur_len if ("image" in sample or "video" in sample) else -cur_len | |
length_list.append(cur_len) | |
return length_list | |
def __getitem__(self, i) -> Dict[str, torch.Tensor]: | |
sources = self.list_data_dict[i] | |
if isinstance(i, int): | |
sources = [sources] | |
assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME | |
if "image" in sources[0] and "audio" not in sources[0]: | |
image_file = self.list_data_dict[i]["image"] | |
set_id = self.list_data_dict[i].get("set", None) | |
file = image_file[0] if type(image_file) is list else image_file | |
processor = self.data_args.image_processor | |
if "height" in processor.size.keys(): | |
image_size = processor.size["height"] | |
elif "shortest_edge" in processor.size.keys(): | |
image_size = processor.size["shortest_edge"] | |
else: | |
raise NotImplementedError(f"Please use correct key to use processor size!") | |
if type(image_file) is list: | |
assert type(set_id) is list | |
if len(image_file) != len(set_id): | |
assert len(set(set_id)) == 1 | |
image = [ | |
Image.open( | |
os.path.join(self.folder_dict[set_id[k]], file.replace("\\", "/")) | |
).convert("RGB") | |
for k, file in enumerate(image_file) | |
] | |
if self.data_args.image_aspect_ratio == "pad": | |
def expand2square(pil_img, background_color): | |
width, height = pil_img.size | |
if width == height: | |
return pil_img | |
elif width > height: | |
result = Image.new(pil_img.mode, (width, width), background_color) | |
result.paste(pil_img, (0, (width - height) // 2)) | |
return result | |
else: | |
result = Image.new(pil_img.mode, (height, height), background_color) | |
result.paste(pil_img, ((height - width) // 2, 0)) | |
return result | |
image = [ | |
expand2square(i, tuple(int(x * 255) for x in processor.image_mean)) | |
for i in image | |
] | |
image_patches, patch_num = [], [] | |
for k, img in enumerate(image): | |
if set_id[k] not in NoPatchSets: | |
img, p_num = dynamic_preprocess( | |
img, | |
min_num=self.data_args.min_dynamic_patch, | |
max_num=self.data_args.max_dynamic_patch, | |
image_size=image_size, | |
use_thumbnail=self.data_args.use_thumbnail, | |
) | |
else: | |
img, p_num = [img], [1] | |
image_patches += img | |
patch_num += p_num | |
assert len(image_patches) == sum(patch_num) | |
image = image_patches | |
image = [ | |
processor.preprocess(i, return_tensors="pt")["pixel_values"][0] | |
for i in image | |
] | |
else: | |
image_patches, patch_num = [], [] | |
for k, img in enumerate(image): | |
if set_id[k] not in NoPatchSets: | |
img, p_num = dynamic_preprocess( | |
img, | |
min_num=self.data_args.min_dynamic_patch, | |
max_num=self.data_args.max_dynamic_patch, | |
image_size=image_size, | |
use_thumbnail=self.data_args.use_thumbnail, | |
) | |
else: | |
img, p_num = [img], [1] | |
image_patches += img | |
patch_num += p_num | |
assert len(image_patches) == sum(patch_num) | |
image = image_patches | |
image = [ | |
processor.preprocess(i, return_tensors="pt")["pixel_values"][0] | |
for i in image | |
] | |
else: | |
image_folder = self.folder_dict[set_id] | |
image = Image.open( | |
os.path.join(image_folder, image_file.replace("\\", "/")) | |
).convert("RGB") | |
if self.data_args.image_aspect_ratio == "pad": | |
def expand2square(pil_img, background_color): | |
width, height = pil_img.size | |
if width == height: | |
return pil_img | |
elif width > height: | |
result = Image.new(pil_img.mode, (width, width), background_color) | |
result.paste(pil_img, (0, (width - height) // 2)) | |
return result | |
else: | |
result = Image.new(pil_img.mode, (height, height), background_color) | |
result.paste(pil_img, ((height - width) // 2, 0)) | |
return result | |
image = expand2square(image, tuple(int(x * 255) for x in processor.image_mean)) | |
image, patch_num = dynamic_preprocess( | |
image, | |
min_num=self.data_args.min_dynamic_patch, | |
max_num=self.data_args.max_dynamic_patch, | |
image_size=image_size, | |
use_thumbnail=self.data_args.use_thumbnail, | |
) | |
image = [ | |
processor.preprocess(i, return_tensors="pt")["pixel_values"][0] | |
for i in image | |
] | |
else: | |
image, patch_num = dynamic_preprocess( | |
image, | |
min_num=self.data_args.min_dynamic_patch, | |
max_num=self.data_args.max_dynamic_patch, | |
image_size=image_size, | |
use_thumbnail=self.data_args.use_thumbnail, | |
) | |
image = [ | |
processor.preprocess(i, return_tensors="pt")["pixel_values"][0] | |
for i in image | |
] | |
inserted_id = self.list_data_dict[i].get("inserted_id", None) | |
end_tag = self.list_data_dict[i].get("end_tag", True) | |
assert inserted_id is None | |
assert end_tag is True | |
sources = preprocess_multimodal( | |
copy.deepcopy([e["conversations"] for e in sources]), | |
self.data_args, | |
patch_num=patch_num, | |
inserted_id=inserted_id, | |
) | |
data_dict = preprocess( | |
sources, self.tokenizer, has_image=True, end_tag=end_tag, modality="image" | |
) | |
elif "image" in sources[0] and "audio" in sources[0]: | |
image_file = self.list_data_dict[i]["image"] | |
set_id = self.list_data_dict[i].get("set", None) | |
file = image_file[0] if type(image_file) is list else image_file | |
audio_file = self.list_data_dict[i]["audio"] | |
processor = self.data_args.image_processor | |
if "height" in processor.size.keys(): | |
image_size = processor.size["height"] | |
elif "shortest_edge" in processor.size.keys(): | |
image_size = processor.size["shortest_edge"] | |
else: | |
raise NotImplementedError(f"Please use correct key to use processor size!") | |
if type(image_file) is list: | |
assert type(set_id) is list | |
if len(image_file) != len(set_id): # 多图数据 | |
assert len(set(set_id)) == 1 | |
image = [ | |
Image.open( | |
os.path.join(self.folder_dict[set_id[k]], file.replace("\\", "/")) | |
).convert("RGB") | |
for k, file in enumerate(image_file) | |
] | |
if self.data_args.image_aspect_ratio == "pad": | |
def expand2square(pil_img, background_color): | |
width, height = pil_img.size | |
if width == height: | |
return pil_img | |
elif width > height: | |
result = Image.new(pil_img.mode, (width, width), background_color) | |
result.paste(pil_img, (0, (width - height) // 2)) | |
return result | |
else: | |
result = Image.new(pil_img.mode, (height, height), background_color) | |
result.paste(pil_img, ((height - width) // 2, 0)) | |
return result | |
image = [ | |
expand2square(i, tuple(int(x * 255) for x in processor.image_mean)) | |
for i in image | |
] | |
image_patches, patch_num = [], [] | |
for k, img in enumerate(image): | |
if set_id[k] not in NoPatchSets: | |
img, p_num = dynamic_preprocess( | |
img, | |
min_num=self.data_args.min_dynamic_patch, | |
max_num=self.data_args.max_dynamic_patch, | |
image_size=image_size, | |
use_thumbnail=self.data_args.use_thumbnail, | |
) | |
else: | |
img, p_num = [img], [1] | |
image_patches += img | |
patch_num += p_num | |
assert len(image_patches) == sum(patch_num) | |
image = image_patches | |
image = [ | |
processor.preprocess(i, return_tensors="pt")["pixel_values"][0] | |
for i in image | |
] | |
else: | |
image_patches, patch_num = [], [] | |
for k, img in enumerate(image): | |
if set_id[k] not in NoPatchSets: | |
img, p_num = dynamic_preprocess( | |
img, | |
min_num=self.data_args.min_dynamic_patch, | |
max_num=self.data_args.max_dynamic_patch, | |
image_size=image_size, | |
use_thumbnail=self.data_args.use_thumbnail, | |
) | |
else: | |
img, p_num = [img], [1] | |
image_patches += img | |
patch_num += p_num | |
assert len(image_patches) == sum(patch_num) | |
image = image_patches | |
image = [ | |
processor.preprocess(i, return_tensors="pt")["pixel_values"][0] | |
for i in image | |
] | |
else: | |
image_folder = self.folder_dict[set_id] | |
image = Image.open( | |
os.path.join(image_folder, image_file.replace("\\", "/")) | |
).convert("RGB") | |
if self.data_args.image_aspect_ratio == "pad": | |
def expand2square(pil_img, background_color): | |
width, height = pil_img.size | |
if width == height: | |
return pil_img | |
elif width > height: | |
result = Image.new(pil_img.mode, (width, width), background_color) | |
result.paste(pil_img, (0, (width - height) // 2)) | |
return result | |
else: | |
result = Image.new(pil_img.mode, (height, height), background_color) | |
result.paste(pil_img, ((height - width) // 2, 0)) | |
return result | |
image = expand2square(image, tuple(int(x * 255) for x in processor.image_mean)) | |
image, patch_num = dynamic_preprocess( | |
image, | |
min_num=self.data_args.min_dynamic_patch, | |
max_num=self.data_args.max_dynamic_patch, | |
image_size=image_size, | |
use_thumbnail=self.data_args.use_thumbnail, | |
) | |
image = [ | |
processor.preprocess(i, return_tensors="pt")["pixel_values"][0] | |
for i in image | |
] | |
else: | |
image, patch_num = dynamic_preprocess( | |
image, | |
min_num=self.data_args.min_dynamic_patch, | |
max_num=self.data_args.max_dynamic_patch, | |
image_size=image_size, | |
use_thumbnail=self.data_args.use_thumbnail, | |
) | |
image = [ | |
processor.preprocess(i, return_tensors="pt")["pixel_values"][0] | |
for i in image | |
] | |
if type(audio_file) is list: | |
# if type(set_id) is list: | |
# audio_folder = self.folder_dict[set_id[0]+'_audio'] | |
# else: | |
# audio_folder = self.folder_dict[set_id+'_audio'] | |
audio_folder = AudioFolder | |
assert len(audio_file) > 0, "audio_file为列表时不能为空" | |
audio = [] | |
audio_for_llm_lens = [] | |
audio_length = [] | |
for file in audio_file: | |
try: | |
a, a_llm = self.data_args.audio_processor.process( | |
os.path.join(audio_folder, "audio", file) | |
) | |
except: | |
print(f"File {os.path.join(audio_folder, 'audio', file)} not OK!!!!!") | |
audio.append(a) | |
audio_for_llm_lens.append(a_llm) | |
audio_length.append(a.shape[0]) | |
else: | |
# audio_folder = self.folder_dict[set_id+'_audio'] | |
audio_folder = AudioFolder | |
assert audio_file, "audio_file不能为空" | |
audio, audio_for_llm_lens = self.data_args.audio_processor.process( | |
os.path.join(audio_folder, "audio", audio_file) | |
) | |
audio_length = audio.shape[0] | |
inserted_id = self.list_data_dict[i].get("inserted_id", None) | |
end_tag = self.list_data_dict[i].get("end_tag", True) | |
sources = preprocess_multimodal( | |
copy.deepcopy([e["conversations"] for e in sources]), | |
self.data_args, | |
patch_num=patch_num, | |
audio_lens=audio_for_llm_lens, | |
inserted_id=inserted_id, | |
) | |
data_dict = preprocess( | |
sources, | |
self.tokenizer, | |
has_image=True, | |
has_audio=True, | |
end_tag=end_tag, | |
modality="image", | |
) | |
data_dict["audio_lengths"] = audio_length | |
data_dict["audio_lengths_for_llm"] = audio_for_llm_lens | |
elif "video" in sources[0] and "audio" not in sources[0]: | |
video_file = self.list_data_dict[i]["video"] | |
video_id = self.list_data_dict[i]["id"] | |
set_id = self.list_data_dict[i].get("set", None) | |
processor = self.data_args.image_processor | |
if "height" in processor.size.keys(): | |
image_size = processor.size["height"] | |
elif "shortest_edge" in processor.size.keys(): | |
image_size = processor.size["shortest_edge"] | |
else: | |
raise NotImplementedError(f"Please use correct key to use processor size!") | |
video_folder = self.folder_dict[set_id] | |
image, image_token_num = _get_rawvideo_dec( | |
os.path.join(video_folder, video_file), | |
processor, | |
max_frames=MAX_IMAGE_LENGTH, | |
min_frames=MIN_IMAGE_LENGTH, | |
image_resolution=image_size, | |
image_aspect_ratio=self.data_args.image_aspect_ratio, | |
) | |
inserted_id = self.list_data_dict[i].get("inserted_id", None) | |
end_tag = self.list_data_dict[i].get("end_tag", True) | |
assert inserted_id is None | |
assert end_tag is True | |
sources = preprocess_multimodal( | |
copy.deepcopy([e["conversations"] for e in sources]), | |
self.data_args, | |
image_token_num=image_token_num, | |
inserted_id=inserted_id, | |
) | |
data_dict = preprocess( | |
sources, | |
self.tokenizer, | |
has_image=True, | |
has_audio=False, | |
end_tag=end_tag, | |
modality="video", | |
) | |
elif "video" in sources[0] and "audio" in sources[0]: | |
video_file = self.list_data_dict[i]["video"] | |
video_id = self.list_data_dict[i]["id"] | |
set_id = self.list_data_dict[i].get("set", None) | |
audio_file = self.list_data_dict[i]["audio"] | |
processor = self.data_args.image_processor | |
if "height" in processor.size.keys(): | |
image_size = processor.size["height"] | |
elif "shortest_edge" in processor.size.keys(): | |
image_size = processor.size["shortest_edge"] | |
else: | |
raise NotImplementedError(f"Please use correct key to use processor size!") | |
video_folder = self.folder_dict[set_id] | |
# audio_folder = self.folder_dict[set_id+'_audio'] | |
audio_folder = AudioFolder | |
image, image_token_num = _get_rawvideo_dec( | |
os.path.join(video_folder, video_file), | |
processor, | |
max_frames=MAX_IMAGE_LENGTH, | |
min_frames=MIN_IMAGE_LENGTH, | |
image_resolution=image_size, | |
image_aspect_ratio=self.data_args.image_aspect_ratio, | |
) | |
if type(audio_file) is list: | |
assert len(audio_file) > 0, "audio_file为列表时不能为空" | |
audio = [] | |
audio_for_llm_lens = [] | |
audio_length = [] | |
for file in audio_file: | |
a, a_llm = self.data_args.audio_processor.process( | |
os.path.join(audio_folder, "audio", file) | |
) | |
audio.append(a) | |
audio_for_llm_lens.append(a_llm) | |
audio_length.append(a.shape[0]) | |
else: | |
assert audio_file, "audio_file不能为空" | |
audio, audio_for_llm_lens = self.data_args.audio_processor.process( | |
os.path.join(audio_folder, "audio", audio_file) | |
) | |
audio_length = audio.shape[0] | |
inserted_id = self.list_data_dict[i].get("inserted_id", None) | |
end_tag = self.list_data_dict[i].get("end_tag", True) | |
sources = preprocess_multimodal( | |
copy.deepcopy([e["conversations"] for e in sources]), | |
self.data_args, | |
image_token_num=image_token_num, | |
audio_lens=audio_for_llm_lens, | |
inserted_id=inserted_id, | |
) | |
data_dict = preprocess( | |
sources, | |
self.tokenizer, | |
has_image=True, | |
has_audio=True, | |
end_tag=end_tag, | |
modality="video", | |
) | |
data_dict["audio_lengths"] = audio_length | |
data_dict["audio_lengths_for_llm"] = audio_for_llm_lens | |
elif "audio" in sources[0]: | |
audio_file = self.list_data_dict[i]["audio"] | |
audio_folder = AudioFolder | |
if type(audio_file) is list: | |
assert len(audio_file) > 0, "audio_file为列表时不能为空" | |
audio = [] | |
audio_for_llm_lens = [] | |
audio_length = [] | |
for file in audio_file: | |
a, a_llm = self.data_args.audio_processor.process( | |
os.path.join(audio_folder, "audio", file) | |
) | |
audio.append(a) | |
audio_for_llm_lens.append(a_llm) | |
audio_length.append(a.shape[0]) | |
else: | |
assert audio_file, "audio_file不能为空" | |
audio, audio_for_llm_lens = self.data_args.audio_processor.process( | |
os.path.join(audio_folder, "audio", audio_file) | |
) | |
audio_length = audio.shape[0] | |
inserted_id = self.list_data_dict[i].get("inserted_id", None) | |
end_tag = self.list_data_dict[i].get("end_tag", True) | |
sources = preprocess_multimodal( | |
copy.deepcopy([e["conversations"] for e in sources]), | |
self.data_args, | |
image_token_num=0, | |
audio_lens=audio_for_llm_lens, | |
inserted_id=inserted_id, | |
) | |
data_dict = preprocess( | |
sources, | |
self.tokenizer, | |
has_image=False, | |
has_audio=True, | |
end_tag=end_tag, | |
modality="lang", | |
) | |
data_dict["audio_lengths"] = audio_length | |
data_dict["audio_lengths_for_llm"] = audio_for_llm_lens | |
else: | |
sources = copy.deepcopy([e["conversations"] for e in sources]) | |
sources = preprocess_multimodal( | |
sources, | |
self.data_args, | |
image_token_num=0, | |
) | |
data_dict = preprocess(sources, self.tokenizer, has_image=False, modality="lang") | |
if isinstance(i, int): | |
if "audio" in self.list_data_dict[i]: | |
data_dict = dict( | |
input_ids=data_dict["input_ids"][0], | |
labels=data_dict["labels"][0], | |
audio_lengths=data_dict["audio_lengths"], | |
audio_lengths_for_llm=data_dict["audio_lengths_for_llm"], | |
) | |
else: | |
data_dict = dict(input_ids=data_dict["input_ids"][0], labels=data_dict["labels"][0]) | |
# image exist in the data | |
if "image" in self.list_data_dict[i] or "video" in self.list_data_dict[i]: | |
data_dict["image"] = image | |
elif self.data_args.is_multimodal: | |
# image does not exist in the data, but the model is multimodal | |
crop_size = self.data_args.image_processor.crop_size | |
data_dict["image"] = torch.zeros(3, crop_size["height"], crop_size["width"]) | |
if "audio" in self.list_data_dict[i]: | |
data_dict["audio"] = audio | |
elif self.data_args.is_multimodal: | |
data_dict["audio"] = torch.zeros(400, 80) | |
data_dict["audio_lengths"] = 400 | |
data_dict["audio_lengths_for_llm"] = 60 | |
return data_dict | |
class DataCollatorForSupervisedDataset(object): | |
"""Collate examples for supervised fine-tuning.""" | |
tokenizer: transformers.PreTrainedTokenizer | |
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: | |
input_ids, labels = tuple( | |
[instance[key] for instance in instances] for key in ("input_ids", "labels") | |
) | |
if self.tokenizer.pad_token_id == self.tokenizer.eos_token_id: | |
for input_id in input_ids: | |
input_id[input_id == self.tokenizer.eos_token_id] = -300 | |
input_ids = torch.nn.utils.rnn.pad_sequence( | |
input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id | |
) | |
labels = torch.nn.utils.rnn.pad_sequence( | |
labels, batch_first=True, padding_value=IGNORE_INDEX | |
) | |
input_ids = input_ids[:, : self.tokenizer.model_max_length] | |
attention_mask = input_ids.ne(self.tokenizer.pad_token_id) | |
labels = labels[:, : self.tokenizer.model_max_length] | |
if self.tokenizer.pad_token_id == self.tokenizer.eos_token_id: | |
for input_id in input_ids: | |
input_id[input_id == -300] = self.tokenizer.eos_token_id | |
batch = dict( | |
input_ids=input_ids, | |
labels=labels, | |
attention_mask=attention_mask, | |
) | |
if "image" in instances[0]: | |
images = [instance["image"] for instance in instances] | |
new_images = [] | |
for image in images: | |
if type(image) is list: | |
for i in image: | |
new_images.append(i) | |
else: | |
new_images.append(image) | |
images = new_images | |
if all(x is not None and x.shape == images[0].shape for x in images): | |
batch["images"] = torch.stack(images) | |
else: | |
batch["images"] = images | |
batch["audios"] = {} | |
if "audio" in instances[0]: | |
audios = [instance["audio"] for instance in instances] | |
audio_lengths = [instance["audio_lengths"] for instance in instances] | |
audio_lengths_for_llm = [instance["audio_lengths_for_llm"] for instance in instances] | |
new_audios = [] | |
new_audio_lengths = [] | |
new_audio_lengths_for_llm = [] | |
for i, audio in enumerate(audios): | |
length = audio_lengths[i] | |
length_for_llm = audio_lengths_for_llm[i] | |
if type(audio) is list: | |
for j, a in enumerate(audio): | |
new_audios.append(a) | |
new_audio_lengths.append(length[j]) | |
new_audio_lengths_for_llm.append(length_for_llm[j]) | |
else: | |
new_audios.append(audio) | |
new_audio_lengths.append(length) | |
new_audio_lengths_for_llm.append(length_for_llm) | |
audios = new_audios | |
audios = pad_sequence(audios, batch_first=True, padding_value=0) | |
batch["audios"]["audios"] = audios | |
batch["audios"]["lengths"] = torch.tensor(new_audio_lengths) | |
batch["audios"]["lengths_for_llm"] = torch.tensor(new_audio_lengths_for_llm) | |
return batch | |
def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict: | |
"""Make dataset and collator for supervised fine-tuning.""" | |
train_dataset = LazySupervisedDataset(tokenizer=tokenizer, data_args=data_args) | |
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) | |
return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator) | |
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): | |
best_ratio_diff = float("inf") | |
best_ratio = (1, 1) | |
area = width * height | |
for ratio in target_ratios: | |
target_aspect_ratio = ratio[0] / ratio[1] | |
ratio_diff = abs(aspect_ratio - target_aspect_ratio) | |
if ratio_diff < best_ratio_diff: | |
best_ratio_diff = ratio_diff | |
best_ratio = ratio | |
elif ratio_diff == best_ratio_diff: | |
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: | |
best_ratio = ratio | |
# print(f'width: {width}, height: {height}, best_ratio: {best_ratio}') | |
return best_ratio | |
def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False): | |
orig_width, orig_height = image.size | |
aspect_ratio = orig_width / orig_height | |
# calculate the existing image aspect ratio | |
target_ratios = set( | |
(i, j) | |
for n in range(min_num, max_num + 1) | |
for i in range(1, n + 1) | |
for j in range(1, n + 1) | |
if i * j <= max_num and i * j >= min_num | |
) | |
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) | |
# find the closest aspect ratio to the target | |
target_aspect_ratio = find_closest_aspect_ratio( | |
aspect_ratio, target_ratios, orig_width, orig_height, image_size | |
) | |
# calculate the target width and height | |
target_width = image_size * target_aspect_ratio[0] | |
target_height = image_size * target_aspect_ratio[1] | |
blocks = target_aspect_ratio[0] * target_aspect_ratio[1] | |
# resize the image | |
resized_img = image.resize((target_width, target_height)) | |
processed_images = [] | |
for i in range(blocks): | |
box = ( | |
(i % (target_width // image_size)) * image_size, | |
(i // (target_width // image_size)) * image_size, | |
((i % (target_width // image_size)) + 1) * image_size, | |
((i // (target_width // image_size)) + 1) * image_size, | |
) | |
# split the image | |
split_img = resized_img.crop(box) | |
processed_images.append(split_img) | |
assert len(processed_images) == blocks | |
if use_thumbnail and len(processed_images) != 1: | |
thumbnail_img = image.resize((image_size, image_size)) | |
processed_images.append(thumbnail_img) | |
return processed_images, [len(processed_images)] | |