import copy import json import math import os import random from dataclasses import dataclass, field from typing import Dict, Optional, Sequence import numpy as np import torch import transformers from PIL import Image from torch.nn.utils.rnn import pad_sequence from torch.utils.data import Dataset from decord import VideoReader, cpu from vita import conversation as conversation_lib from vita.config import AudioFolder, DataConfig, FolderDict, NoPatchSets from vita.constants import ( DEFAULT_AUDIO_TOKEN, DEFAULT_DATA_RATIO, DEFAULT_IMAGE_TOKEN, DEFAULT_VIDEO_TOKEN, IGNORE_INDEX, MAX_IMAGE_LENGTH, MIN_IMAGE_LENGTH, ) from vita.util.mm_utils import tokenizer_image_audio_token, tokenizer_image_token @dataclass class DataArguments: lazy_preprocess: bool = False is_multimodal: bool = True image_folder: Optional[str] = field(default=None) image_aspect_ratio: str = field(default=None) dataset_use: str = field(default="temp") min_dynamic_patch: int = 1 max_dynamic_patch: int = 12 use_thumbnail: bool = True def preprocess_multimodal( sources: Sequence[str], data_args: DataArguments, image_token_num=1, patch_num=[1], audio_lens: int = 0, inserted_id=None, ) -> Dict: is_multimodal = data_args.is_multimodal if not is_multimodal: return sources k_img_ph = 0 for source in sources: if inserted_id is not None: assert source[inserted_id]["from"] == "gpt" for i, sentence in enumerate(source): if DEFAULT_IMAGE_TOKEN in sentence["value"] or DEFAULT_VIDEO_TOKEN in sentence["value"]: sentence["value"] = ( sentence["value"] .replace(DEFAULT_IMAGE_TOKEN + "\n", DEFAULT_IMAGE_TOKEN) .strip() ) sentence["value"] = ( sentence["value"] .replace("\n" + DEFAULT_IMAGE_TOKEN, DEFAULT_IMAGE_TOKEN) .strip() ) VIDEO_TOKEN_NUM = sentence["value"].count(DEFAULT_VIDEO_TOKEN) if VIDEO_TOKEN_NUM == 1 and sentence["value"].endswith(DEFAULT_IMAGE_TOKEN): IMAGE_TOKEN_NUM = sentence["value"].count(DEFAULT_IMAGE_TOKEN) sentence["value"] = ( sentence["value"].replace(DEFAULT_IMAGE_TOKEN * IMAGE_TOKEN_NUM, "").strip() ) sentence["value"] = DEFAULT_IMAGE_TOKEN * IMAGE_TOKEN_NUM + sentence["value"] sentence["value"] = sentence["value"].strip() if sentence["value"].endswith(DEFAULT_VIDEO_TOKEN): VIDEO_TOKEN_NUM = sentence["value"].count(DEFAULT_VIDEO_TOKEN) sentence["value"] = ( sentence["value"].replace(DEFAULT_VIDEO_TOKEN * VIDEO_TOKEN_NUM, "").strip() ) sentence["value"] = DEFAULT_VIDEO_TOKEN * VIDEO_TOKEN_NUM + sentence["value"] sentence["value"] = sentence["value"].strip() if "mmtag" in conversation_lib.default_conversation.version: sentence["value"] = sentence["value"].replace( DEFAULT_IMAGE_TOKEN, "" + DEFAULT_IMAGE_TOKEN + "" ) IMAGE_TOKEN_NUM = sentence["value"].count(DEFAULT_IMAGE_TOKEN) if IMAGE_TOKEN_NUM > MAX_IMAGE_LENGTH: sentence["value"] = ( sentence["value"] .replace( DEFAULT_IMAGE_TOKEN * IMAGE_TOKEN_NUM, DEFAULT_IMAGE_TOKEN * MAX_IMAGE_LENGTH, ) .strip() ) replace_token, vid_replace_token, audio_replace_token = ( DEFAULT_IMAGE_TOKEN, DEFAULT_IMAGE_TOKEN * image_token_num, DEFAULT_AUDIO_TOKEN, ) # * audio_lens if DEFAULT_IMAGE_TOKEN in sentence["value"]: replace_token = DEFAULT_IMAGE_TOKEN * patch_num[k_img_ph] k_img_ph += 1 sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token + "\n") sentence["value"] = sentence["value"].replace( DEFAULT_VIDEO_TOKEN, vid_replace_token + "\n" ) sentence["value"] = sentence["value"].replace( DEFAULT_AUDIO_TOKEN + "\n", audio_replace_token ) sentence["value"] = sentence["value"].replace("\n\n", "\n") # if i == inserted_id: # assert sentence["from"] == "gpt" # sentence["value"] = "<2>" + sentence["value"] # elif sentence["from"] == "gpt": # if "