import copy
import json
import math
import os
import random
from dataclasses import dataclass, field
from typing import Dict, Optional, Sequence
import numpy as np
import torch
import transformers
from PIL import Image
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset
from decord import VideoReader, cpu
from vita import conversation as conversation_lib
from vita.config import AudioFolder, DataConfig, FolderDict, NoPatchSets
from vita.constants import (
DEFAULT_AUDIO_TOKEN,
DEFAULT_DATA_RATIO,
DEFAULT_IMAGE_TOKEN,
DEFAULT_VIDEO_TOKEN,
IGNORE_INDEX,
MAX_IMAGE_LENGTH,
MIN_IMAGE_LENGTH,
)
from vita.util.mm_utils import tokenizer_image_audio_token, tokenizer_image_token
@dataclass
class DataArguments:
lazy_preprocess: bool = False
is_multimodal: bool = True
image_folder: Optional[str] = field(default=None)
image_aspect_ratio: str = field(default=None)
dataset_use: str = field(default="temp")
min_dynamic_patch: int = 1
max_dynamic_patch: int = 12
use_thumbnail: bool = True
def preprocess_multimodal(
sources: Sequence[str],
data_args: DataArguments,
image_token_num=1,
patch_num=[1],
audio_lens: int = 0,
inserted_id=None,
) -> Dict:
is_multimodal = data_args.is_multimodal
if not is_multimodal:
return sources
k_img_ph = 0
for source in sources:
if inserted_id is not None:
assert source[inserted_id]["from"] == "gpt"
for i, sentence in enumerate(source):
if DEFAULT_IMAGE_TOKEN in sentence["value"] or DEFAULT_VIDEO_TOKEN in sentence["value"]:
sentence["value"] = (
sentence["value"]
.replace(DEFAULT_IMAGE_TOKEN + "\n", DEFAULT_IMAGE_TOKEN)
.strip()
)
sentence["value"] = (
sentence["value"]
.replace("\n" + DEFAULT_IMAGE_TOKEN, DEFAULT_IMAGE_TOKEN)
.strip()
)
VIDEO_TOKEN_NUM = sentence["value"].count(DEFAULT_VIDEO_TOKEN)
if VIDEO_TOKEN_NUM == 1 and sentence["value"].endswith(DEFAULT_IMAGE_TOKEN):
IMAGE_TOKEN_NUM = sentence["value"].count(DEFAULT_IMAGE_TOKEN)
sentence["value"] = (
sentence["value"].replace(DEFAULT_IMAGE_TOKEN * IMAGE_TOKEN_NUM, "").strip()
)
sentence["value"] = DEFAULT_IMAGE_TOKEN * IMAGE_TOKEN_NUM + sentence["value"]
sentence["value"] = sentence["value"].strip()
if sentence["value"].endswith(DEFAULT_VIDEO_TOKEN):
VIDEO_TOKEN_NUM = sentence["value"].count(DEFAULT_VIDEO_TOKEN)
sentence["value"] = (
sentence["value"].replace(DEFAULT_VIDEO_TOKEN * VIDEO_TOKEN_NUM, "").strip()
)
sentence["value"] = DEFAULT_VIDEO_TOKEN * VIDEO_TOKEN_NUM + sentence["value"]
sentence["value"] = sentence["value"].strip()
if "mmtag" in conversation_lib.default_conversation.version:
sentence["value"] = sentence["value"].replace(
DEFAULT_IMAGE_TOKEN, "" + DEFAULT_IMAGE_TOKEN + ""
)
IMAGE_TOKEN_NUM = sentence["value"].count(DEFAULT_IMAGE_TOKEN)
if IMAGE_TOKEN_NUM > MAX_IMAGE_LENGTH:
sentence["value"] = (
sentence["value"]
.replace(
DEFAULT_IMAGE_TOKEN * IMAGE_TOKEN_NUM,
DEFAULT_IMAGE_TOKEN * MAX_IMAGE_LENGTH,
)
.strip()
)
replace_token, vid_replace_token, audio_replace_token = (
DEFAULT_IMAGE_TOKEN,
DEFAULT_IMAGE_TOKEN * image_token_num,
DEFAULT_AUDIO_TOKEN,
) # * audio_lens
if DEFAULT_IMAGE_TOKEN in sentence["value"]:
replace_token = DEFAULT_IMAGE_TOKEN * patch_num[k_img_ph]
k_img_ph += 1
sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token + "\n")
sentence["value"] = sentence["value"].replace(
DEFAULT_VIDEO_TOKEN, vid_replace_token + "\n"
)
sentence["value"] = sentence["value"].replace(
DEFAULT_AUDIO_TOKEN + "\n", audio_replace_token
)
sentence["value"] = sentence["value"].replace("\n\n", "\n")
# if i == inserted_id:
# assert sentence["from"] == "gpt"
# sentence["value"] = "<2>" + sentence["value"]
# elif sentence["from"] == "gpt":
# if "