Tarsier2-7b / dataset /utils.py
omni-research's picture
update to tarsier2-7b-0115
dcd4560
# Copyright (2024) Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
import os
from PIL import Image, ImageSequence
import decord
VALID_DATA_FORMAT_STRING = "Input data must be {'.jpg', '.jpeg', '.png', '.tif'} for image; or {'.mp4', '.avi', '.webm', '.mov', '.mkv', '.wmv', '.gif'} for videos!"
# 均匀抽帧,必采样首尾帧。
def sample_frame_indices(start_frame, total_frames: int, n_frames: int):
if n_frames == 1:
return [0] # sample first frame in default
sample_ids = [round(i * (total_frames - 1) / (n_frames - 1)) for i in range(n_frames)]
sample_ids = [i + start_frame for i in sample_ids]
return sample_ids
def sample_video(
video_path: str,
n_frames: int = None,
start_time: int = 0,
end_time: int = -1
) -> List[Image.Image]:
assert os.path.exists(video_path), f"File not found: {video_path}"
vr = decord.VideoReader(video_path, num_threads=1, ctx=decord.cpu(0))
vr.seek(0)
total_frames = len(vr)
fps = vr.get_avg_fps()
start_frame = 0
end_frame = total_frames - 1
if start_time > 0:
start_frame = min((total_frames-1), int(fps*start_time))
if end_time > 0:
end_frame = max(start_frame, int(fps*end_time))
end_frame = min(end_frame, (total_frames-1))
frame_indices = sample_frame_indices(
start_frame=start_frame,
total_frames=end_frame - start_frame + 1,
n_frames=n_frames,
)
frames = vr.get_batch(frame_indices).asnumpy()
frames = [Image.fromarray(f).convert('RGB') for f in frames]
return frames
def sample_gif(
gif_path: str,
n_frames:int = None,
start_time: int = 0,
end_time: int = -1
) -> List[Image.Image]:
assert os.path.exists(gif_path), f"File not found: {gif_path}"
gif_frames = Image.open(gif_path)
start_frame = 0
end_frame = gif_frames.n_frames - 1
frame_indices = sample_frame_indices(
start_frame=start_frame,
total_frames=end_frame - start_frame + 1,
n_frames=n_frames,
)
frames = []
i = 0
for frame in ImageSequence.Iterator(gif_frames):
if i in frame_indices:
frames.append(frame.convert('RGB'))
i += 1
return frames
def sample_image(
image_path: str,
n_frames: int = None,
start_time: int = 0,
end_time: int = -1
):
assert os.path.exists(image_path), f"File not found: {image_path}"
image = Image.open(image_path).convert('RGB')
return [image]
def get_visual_type(input_file):
ext = os.path.splitext(input_file)[-1]
if ext in {'.gif'}:
return 'gif'
elif ext in {'.mp4', '.avi', '.webm', '.mov', '.mkv', '.wmv'}:
return 'video'
elif ext in {'.jpg', '.jpeg', '.png', '.tif'}:
return 'image'
else:
print(f"{VALID_DATA_FORMAT_STRING} But found {ext}!")
return 'unk'
def get_benchmarks(benchmarks):
final_benchmarks = []
type2bm = {
'dream': ['dream'],
'caption': ['msvd-caption', 'msr-vtt-caption', 'vatex-caption'],
'mc_qa': ['next-qa', 'egoschema', 'mvbench', 'video-mme'],
'oe_qa': ['msvd-qa', 'msr-vtt-qa', 'tgif-qa', 'anet-qa'],
}
for bm in benchmarks:
bm = bm.lower()
if bm in final_benchmarks:
continue
if bm == 'all':
for v in type2bm.values():
final_benchmarks.extend(v)
return final_benchmarks
if bm in type2bm:
final_benchmarks.extend(type2bm[bm])
else:
final_benchmarks.append(bm)
return final_benchmarks
def check_data_format(data):
for msg in data['messages']:
if isinstance(msg['content'], dict):
msg['content'] = [msg['content']]
for content in msg['content']:
assert content['type'] in {'image', 'video', 'text'}, f"content['type']={content['type']} MUST be one of ['image', 'video', 'text']"
if content['type'] != "text":
media_path_key = f"{content['type']}_file"
meida_paths = content[content['type']][media_path_key]
if isinstance(meida_paths, str):
meida_paths = [meida_paths]
for path in meida_paths:
assert os.path.exists(path), f"File not found: {path}"
def format_one_sample(media_file=None, prompt="Describe the video in detail."):
sample = {
"messages": []
}
user_content = {
"role": "user",
"content": []
}
if media_file is not None:
media_type = get_visual_type(media_file)
if media_type in ("video", "gif"):
media_type = "video"
media_path_key = f"{media_type}_file"
user_content["content"].append({
"type": media_type,
media_type: {
media_path_key: media_file,
}
})
user_content["content"].append({
"type": "text",
"text": prompt
})
assistant_content = {
"role": "assistant",
"content": []
}
sample["messages"].append(user_content)
sample["messages"].append(assistant_content)
if media_file is not None:
sample["task"] = f"{media_type}/QA"
else:
sample["task"] = 'text-only'
check_data_format(sample)
return sample
class DictToObject(object):
def __init__(self, dictionary):
for key, value in dictionary.items():
setattr(self, key, value)