Long-VITA / app.py
shenyunhang's picture
6a0e740d98a0@2025-02-26_11-00-48:
f9a86ee
raw
history blame
35.9 kB
##############################################################
# copy from cognitron_vl/constants.py
##############################################################
import logging
logger = logging.getLogger(__name__)
if True:
IMG_TAG_TOKEN = "<image>"
VID_TAG_TOKEN = "<video>"
AUD_TAG_TOKEN = "<audio>"
IMG_CONTEXT_TOKEN = '<IMG_CONTEXT>'
IMG_START_TOKEN = '<img>'
IMG_END_TOKEN = '</img>'
VID_CONTEXT_TOKEN = '<VID_CONTEXT>'
VID_START_TOKEN = '<vid>'
VID_END_TOKEN = '</vid>'
PATCH_CONTEXT_TOKEN = '<PATCH_CONTEXT>'
PATCH_START_TOKEN = '<patch>'
PATCH_END_TOKEN = '</patch>'
AUD_START_TOKEN = '<|begin_of_audio|>'
AUD_END_TOKEN = '<|end_of_audio|>'
QUAD_START_TOKEN = '<quad>'
QUAD_END_TOKEN = '</quad>'
REF_START_TOKEN = '<ref>'
REF_END_TOKEN = '</ref>'
BOX_START_TOKEN = '<box>'
BOX_END_TOKEN = '</box>'
if False:
IMG_TAG_TOKEN = "<|image|>"
VID_TAG_TOKEN = "<|video|>"
AUD_TAG_TOKEN = "<|audio|>"
IMG_CONTEXT_TOKEN = '<|context_of_image|>'
IMG_START_TOKEN = '<|begin_of_image|>'
IMG_END_TOKEN = '<|end_of_image|>'
VID_CONTEXT_TOKEN = '<|context_of_video|>'
VID_START_TOKEN = '<|begin_of_video|>'
VID_END_TOKEN = '<|end_of_video|>'
PATCH_CONTEXT_TOKEN = '<|context_of_patch|>'
PATCH_START_TOKEN = '<|begin_of_patch|>'
PATCH_END_TOKEN = '<|end_of_patch|>'
AUD_START_TOKEN = '<|begin_of_audio|>'
AUD_END_TOKEN = '<|end_of_audio|>'
QUAD_START_TOKEN = '<|begin_of_quad|>'
QUAD_END_TOKEN = '<|end_of_quad|>'
REF_START_TOKEN = '<|begin_of_ref|>'
REF_END_TOKEN = '<|end_of_ref|>'
BOX_START_TOKEN = '<|begin_of_box|>'
BOX_END_TOKEN = '<|end_of_box|>'
logger.info(f"IMG_TAG_TOKEN {IMG_TAG_TOKEN}")
logger.info(f"VID_TAG_TOKEN {VID_TAG_TOKEN}")
logger.info(f"AUD_TAG_TOKEN {AUD_TAG_TOKEN}")
logger.info(f"IMG_CONTEXT_TOKEN {IMG_CONTEXT_TOKEN}")
logger.info(f"IMG_START_TOKEN {IMG_START_TOKEN}")
logger.info(f"IMG_END_TOKEN {IMG_END_TOKEN}")
logger.info(f"VID_CONTEXT_TOKEN {VID_CONTEXT_TOKEN}")
logger.info(f"VID_START_TOKEN {VID_START_TOKEN}")
logger.info(f"VID_END_TOKEN {VID_END_TOKEN}")
logger.info(f"PATCH_CONTEXT_TOKEN {PATCH_CONTEXT_TOKEN}")
logger.info(f"PATCH_START_TOKEN {PATCH_START_TOKEN}")
logger.info(f"PATCH_END_TOKEN {PATCH_END_TOKEN}")
logger.info(f"AUD_START_TOKEN {AUD_START_TOKEN}")
logger.info(f"AUD_END_TOKEN {AUD_END_TOKEN}")
# IMAGENET_MEAN = (0.485, 0.456, 0.406)
# IMAGENET_STD = (0.229, 0.224, 0.225)
# CLIP_MEAN = (0.4814546, 0.4578275, 0.40821073)
# CLIP_STD = (0.2686295, 0.2613025, 0.2757711)
# SIGLIP_MEAN = (0.5, 0.5, 0.5)
# SIGLIP_STD = (0.5, 0.5, 0.5)
IMAGENET_DEFAULT_MEAN = [0.485, 0.456, 0.406]
IMAGENET_DEFAULT_STD = [0.229, 0.224, 0.225]
IMAGENET_STANDARD_MEAN = [0.5, 0.5, 0.5]
IMAGENET_STANDARD_STD = [0.5, 0.5, 0.5]
OPENAI_CLIP_MEAN = [0.48145466, 0.4578275, 0.40821073]
OPENAI_CLIP_STD = [0.26862954, 0.26130258, 0.27577711]
# Model Constants
IGNORE_INDEX = -100
IMAGE_TOKEN_INDEX = -200
DEFAULT_IMAGE_TOKEN = IMG_CONTEXT_TOKEN
DEFAULT_IMAGE_PATCH_TOKEN = PATCH_CONTEXT_TOKEN
DEFAULT_IM_START_TOKEN = IMG_START_TOKEN
DEFAULT_IM_END_TOKEN = IMG_END_TOKEN
##############################################################
##############################################################
# copy from cognitron_vl/data/processor/image_processor.py
##############################################################
import math
import os
import cv2
import natsort
import numpy as np
import torch
from PIL import Image
import decord
# from cognitron_vl.constants import (
# IMAGENET_DEFAULT_MEAN,
# IMAGENET_DEFAULT_STD,
# IMAGENET_STANDARD_MEAN,
# IMAGENET_STANDARD_STD,
# OPENAI_CLIP_MEAN,
# OPENAI_CLIP_STD,
# )
class ImageProcessor:
def __init__(
self,
process_type,
image_size=448,
normalize_type="imagenet",
min_patch_grid=1,
max_patch_grid=6,
):
self.process_type = process_type
self.image_size = image_size
if normalize_type == "imagenet":
MEAN, STD = IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
elif normalize_type == "clip":
MEAN, STD = OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
elif normalize_type == "siglip":
MEAN, STD = IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD
else:
raise NotImplementedError
self.mean = MEAN
self.std = STD
self.patch_size = image_size
self.min_patch_grid = min_patch_grid
self.max_patch_grid = max_patch_grid
if self.process_type == "anyres":
self.grid_pinpoints = [
(i, j)
for i in range(min_patch_grid, max_patch_grid + 1)
for j in range(min_patch_grid, max_patch_grid + 1)
]
self.possible_resolutions = [
[dim * self.patch_size for dim in pair] for pair in self.grid_pinpoints
]
print(f"grid_pinpoints {self.grid_pinpoints}")
print(f"possible_resolutions {self.possible_resolutions}")
if self.process_type == "dynamic":
max_num = self.max_patch_grid
min_num = self.min_patch_grid
# calculate the existing image aspect ratio
target_ratios = set(
(i, j)
for n in range(min_num, max_num + 1)
for i in range(1, n + 1)
for j in range(1, n + 1)
if i * j <= max_num and i * j >= min_num
)
self.target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
self.possible_resolutions = [
[dim * self.patch_size for dim in pair] for pair in self.target_ratios
]
print(f"target_ratios {self.target_ratios}")
print(f"possible_resolutions {self.possible_resolutions}")
def get_frame_paths(self, frame_root, num_frames=8):
os.makedirs(frame_root, exist_ok=True)
self.frame_tmpl = "frame-{}-of-{}.jpg"
return [
os.path.join(frame_root, self.frame_tmpl.format(i, num_frames))
for i in range(1, num_frames + 1)
]
def save_video_frames(self, vid_path, max_fps=1, num_frames=8):
vid = decord.VideoReader(vid_path, num_threads=1)
step_size = len(vid) / (num_frames + 1)
# step_size = max(1, step_size)
fps = vid.get_avg_fps()
step_size = max(fps / max_fps, step_size)
# indices = [int(i * step_size) for i in range(1, num_frames + 1)]
indices = [int(i * step_size) for i in range(0, num_frames)]
indices = [i for i in indices if i < len(vid)]
num_frames = len(indices)
frame_paths = self.get_frame_paths(vid_path + ".saved_frames", num_frames)
flag = np.all([os.path.exists(p) for p in frame_paths])
if flag:
return frame_paths
images = [vid[i].asnumpy() for i in indices]
images = [Image.fromarray(arr) for arr in images]
for im, pth in zip(images, frame_paths):
# if not os.path.exists(pth):
# im.save(pth)
im.save(pth)
# print(f"save_video_frames vid_path {vid_path} fps {fps} len(vid) {len(vid)} frame_paths {frame_paths}")
return frame_paths
def get_video_frames(self, vid_path, max_fps=1, num_frames=8):
vid = decord.VideoReader(vid_path, num_threads=1)
step_size = len(vid) / (num_frames + 1)
# step_size = max(1, step_size)
fps = vid.get_avg_fps()
step_size = max(fps / max_fps, step_size)
# indices = [int(i * step_size) for i in range(1, num_frames + 1)]
indices = [int(i * step_size) for i in range(0, num_frames)]
indices = [i for i in indices if i < len(vid)]
images = [vid[i].asnumpy() for i in indices]
images = [Image.fromarray(arr) for arr in images]
# print(f"save_video_frames vid_path {vid_path} fps {fps} len(vid) {len(vid)} frame_paths {frame_paths}")
return images
def process_video(self, video_file_or_dir, max_num_frame=8, max_fps=1):
if os.path.isdir(video_file_or_dir):
all_filepath = []
for root, dirs, files in os.walk(video_file_or_dir):
for filename in files:
if (
filename.endswith("png")
or filename.endswith("jpeg")
or filename.endswith("jpg")
):
filepath = os.path.join(root, filename)
all_filepath.append(filepath)
if len(all_filepath) == 0:
return None
# all_filepath.sort()
all_filepath = natsort.natsorted(all_filepath)
total_frame = len(all_filepath)
if "ShareGPTVideo" in video_file_or_dir:
fps = 2
else:
fps = 1
target_frame = int(min(total_frame / fps * max_fps, max_num_frame))
index = [int(1.0 * total_frame / target_frame) * x for x in range(target_frame)]
selected_filepath = [all_filepath[x] for x in index]
img_or_path_list = selected_filepath
# print(f"process_video {img_or_path_list}")
elif os.path.isfile(video_file_or_dir):
# frame_paths = self.save_video_frames(
# video_file_or_dir, num_frames=max_num_frame, max_fps=max_fps
# )
# img_or_path_list = frame_paths
img_or_path_list = self.get_video_frames(
video_file_or_dir, num_frames=max_num_frame, max_fps=max_fps
)
else:
# print(f"FileNotFoundError {video_file_or_dir}")
raise NotImplementedError
return self.process_images(img_or_path_list), img_or_path_list
def process_images(self, img_or_path_list):
if isinstance(img_or_path_list[0], str):
images = [Image.open(x).convert("RGB") for x in img_or_path_list]
elif isinstance(img_or_path_list[0], Image.Image):
images = [x.convert("RGB") for x in img_or_path_list]
else:
images = img_or_path_list
def expand2square(pil_img, background_color):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
image_tensor = torch.ones([len(images), 3, self.image_size, self.image_size])
for i, image in enumerate(images):
image = expand2square(image, tuple(int(x * 255) for x in self.mean))
image = image.resize(
(self.image_size, self.image_size), resample=Image.Resampling.BICUBIC
)
image = np.array(image, dtype=np.float32)
image = image * 1.0 / 255.0
mean = np.array(self.mean, dtype=image.dtype)
std = np.array(self.std, dtype=image.dtype)
image = (image - mean) / std
image = torch.tensor(image, dtype=torch.float32)
image = image.permute(2, 0, 1)
image_tensor[i] = image
return image_tensor
def process_images_with_subpatch(self, img_or_path):
if self.process_type == "anyres":
return self.process_anyres(img_or_path)
if self.process_type == "dynamic":
return self.process_dynamic(img_or_path)
if isinstance(img_or_path, str):
image = Image.open(img_or_path).convert("RGB")
elif isinstance(img_or_path, Image.Image):
image = img_or_path.convert("RGB")
else:
image = img_or_path
return self.process_images([images])
def process_anyres(self, img_or_path):
if isinstance(img_or_path, str):
image = Image.open(img_or_path).convert("RGB")
elif isinstance(img_or_path, Image.Image):
image = img_or_path.convert("RGB")
else:
image = img_or_path
best_resolution = select_best_resolution(image.size, self.possible_resolutions)
image_padded = resize_and_pad_image(image, best_resolution)
patches = divide_to_patches(image_padded, self.patch_size)
if best_resolution == (self.patch_size, self.patch_size):
image_patches = [image]
else:
image_patches = [image] + patches
image_patches = self.process_images(image_patches)
# print(f"image {image.size} best_resolution {best_resolution} image_padded {image_padded.size} patches {len(patches)} image_patches {image_patches.size()}")
return image_patches, best_resolution
def process_dynamic(self, img_or_path):
if isinstance(img_or_path, str):
image = Image.open(img_or_path).convert("RGB")
elif isinstance(img_or_path, Image.Image):
image = img_or_path.convert("RGB")
else:
image = img_or_path
image_patches, best_resolution = dynamic_preprocess(
image,
min_num=self.min_patch_grid,
max_num=self.max_patch_grid,
image_size=self.patch_size,
use_thumbnail=True,
)
image_patches = self.process_images(image_patches)
# print(f"image {image.size} best_resolution {best_resolution} image_padded {image_padded.size} patches {len(patches)} image_patches {image_patches.size()}")
return image_patches, best_resolution
def select_best_resolution(original_size, possible_resolutions):
"""
Selects the best resolution from a list of possible resolutions based on the original size.
Args:
original_size (tuple): The original size of the image in the format (width, height).
possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
Returns:
tuple: The best fit resolution in the format (width, height).
"""
original_width, original_height = original_size
best_fit = None
max_effective_resolution = 0
min_wasted_resolution = float("inf")
for width, height in possible_resolutions:
# Calculate the downscaled size to keep the aspect ratio
scale = min(width / original_width, height / original_height)
downscaled_width, downscaled_height = int(original_width * scale), int(
original_height * scale
)
# Calculate effective and wasted resolutions
effective_resolution = min(
downscaled_width * downscaled_height, original_width * original_height
)
wasted_resolution = (width * height) - effective_resolution
if effective_resolution > max_effective_resolution or (
effective_resolution == max_effective_resolution
and wasted_resolution < min_wasted_resolution
):
max_effective_resolution = effective_resolution
min_wasted_resolution = wasted_resolution
best_fit = (width, height)
return best_fit
def resize_and_pad_image(image, target_resolution):
"""
Resize and pad an image to a target resolution while maintaining aspect ratio.
Args:
image (PIL.Image.Image): The input image.
target_resolution (tuple): The target resolution (width, height) of the image.
Returns:
PIL.Image.Image: The resized and padded image.
"""
original_width, original_height = image.size
target_width, target_height = target_resolution
# Determine which dimension (width or height) to fill
scale_w = target_width / original_width
scale_h = target_height / original_height
if scale_w < scale_h:
# Width will be filled completely
new_width = target_width
new_height = min(math.ceil(original_height * scale_w), target_height)
else:
# Height will be filled completely
new_height = target_height
new_width = min(math.ceil(original_width * scale_h), target_width)
# Resize the image
resized_image = image.resize((new_width, new_height))
# Create a new image with the target size and paste the resized image onto it
new_image = Image.new("RGB", (target_width, target_height), (0, 0, 0))
paste_x = (target_width - new_width) // 2
paste_y = (target_height - new_height) // 2
new_image.paste(resized_image, (paste_x, paste_y))
return new_image
def divide_to_patches(image, patch_size):
"""
Divides an image into patches of a specified size.
Args:
image (PIL.Image.Image): The input image.
patch_size (int): The size of each patch.
Returns:
list: A list of PIL.Image.Image objects representing the patches.
"""
patches = []
width, height = image.size
for i in range(0, height, patch_size):
for j in range(0, width, patch_size):
box = (j, i, j + patch_size, i + patch_size)
patch = image.crop(box)
patches.append(patch)
return patches
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
best_ratio_diff = float("inf")
best_ratio = (1, 1)
area = width * height
for ratio in target_ratios:
target_aspect_ratio = ratio[0] / ratio[1]
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
if ratio_diff < best_ratio_diff:
best_ratio_diff = ratio_diff
best_ratio = ratio
elif ratio_diff == best_ratio_diff:
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
best_ratio = ratio
# print(f'width: {width}, height: {height}, best_ratio: {best_ratio}')
return best_ratio
def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False):
orig_width, orig_height = image.size
aspect_ratio = orig_width / orig_height
# calculate the existing image aspect ratio
target_ratios = set(
(i, j)
for n in range(min_num, max_num + 1)
for i in range(1, n + 1)
for j in range(1, n + 1)
if i * j <= max_num and i * j >= min_num
)
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
# find the closest aspect ratio to the target
target_aspect_ratio = find_closest_aspect_ratio(
aspect_ratio, target_ratios, orig_width, orig_height, image_size
)
# calculate the target width and height
target_width = image_size * target_aspect_ratio[0]
target_height = image_size * target_aspect_ratio[1]
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
# resize the image
resized_img = image.resize((target_width, target_height))
processed_images = []
for i in range(blocks):
box = (
(i % (target_width // image_size)) * image_size,
(i // (target_width // image_size)) * image_size,
((i % (target_width // image_size)) + 1) * image_size,
((i // (target_width // image_size)) + 1) * image_size,
)
# split the image
split_img = resized_img.crop(box)
processed_images.append(split_img)
assert len(processed_images) == blocks
if use_thumbnail and len(processed_images) != 1:
thumbnail_img = image.resize((image_size, image_size))
# processed_images.append(thumbnail_img)
processed_images = [
thumbnail_img,
] + processed_images
return processed_images, (target_width, target_height)
##############################################################
##############################################################
# modify from long_vita_megatron/tasks/inference/module.py
##############################################################
def get_external_inputs(tokens, image_list=None, image_path_list=None, video_path_list=None):
print(f"get_external_inputs tokens {tokens.size()}")
tokens = tokens.tolist()
image_token_length = 256
max_num_frame = 4096
max_fps = 1
# from cognitron_vl.constants import IMG_START_TOKEN, IMG_END_TOKEN, IMG_CONTEXT_TOKEN, VID_START_TOKEN, VID_END_TOKEN, VID_CONTEXT_TOKEN, PATCH_START_TOKEN, PATCH_END_TOKEN, PATCH_CONTEXT_TOKEN, IMG_TAG_TOKEN, VID_TAG_TOKEN
image_tag = "<image>"
video_tag = "<video>"
IMG_CONTEXT_ID = tokenizer(IMG_CONTEXT_TOKEN, add_special_tokens=False).input_ids
IMG_START_ID = tokenizer(IMG_START_TOKEN, add_special_tokens=False).input_ids
IMG_END_ID = tokenizer(IMG_END_TOKEN, add_special_tokens=False).input_ids
VID_CONTEXT_ID = tokenizer(VID_CONTEXT_TOKEN, add_special_tokens=False).input_ids
VID_START_ID = tokenizer(VID_START_TOKEN, add_special_tokens=False).input_ids
VID_END_ID = tokenizer(VID_END_TOKEN, add_special_tokens=False).input_ids
PATCH_CONTEXT_ID = tokenizer(PATCH_CONTEXT_TOKEN, add_special_tokens=False).input_ids
PATCH_START_ID = tokenizer(PATCH_START_TOKEN, add_special_tokens=False).input_ids
PATCH_END_ID = tokenizer(PATCH_END_TOKEN, add_special_tokens=False).input_ids
IMG_TAG_ID = tokenizer(IMG_TAG_TOKEN, add_special_tokens=False).input_ids
VID_TAG_ID = tokenizer(VID_TAG_TOKEN, add_special_tokens=False).input_ids
assert len(IMG_CONTEXT_ID) == 1
assert len(IMG_START_ID) == 1
assert len(IMG_END_ID) == 1
assert len(VID_CONTEXT_ID) == 1
assert len(VID_START_ID) == 1
assert len(VID_END_ID) == 1
assert len(PATCH_CONTEXT_ID) == 1
assert len(PATCH_START_ID) == 1
assert len(PATCH_END_ID) == 1
IMG_CONTEXT_ID = IMG_CONTEXT_ID[0]
IMG_START_ID = IMG_START_ID[0]
IMG_END_ID = IMG_END_ID[0]
VID_CONTEXT_ID = VID_CONTEXT_ID[0]
VID_START_ID = VID_START_ID[0]
VID_END_ID = VID_END_ID[0]
PATCH_CONTEXT_ID = PATCH_CONTEXT_ID[0]
PATCH_START_ID = PATCH_START_ID[0]
PATCH_END_ID = PATCH_END_ID[0]
IMG_TAG_ID = IMG_TAG_ID[0]
VID_TAG_ID = VID_TAG_ID[0]
nl_tokens = tokenizer("\n", add_special_tokens=False).input_ids
image_indices = []
images = []
# ----------------------------------------------------------------
# image
for batch_idx, input_ids in enumerate(tokens):
# img_positions = [i for i, x in enumerate(input_ids) if x == IMG_CONTEXT_ID]
img_positions = [i for i, x in enumerate(input_ids) if x == IMG_TAG_ID]
if len(img_positions) == 0:
continue
if image_path_list is not None:
assert len(img_positions) == len(image_path_list), f"{img_positions} {image_path_list} {IMG_CONTEXT_TOKEN} {IMG_CONTEXT_ID} {tokens}"
if image_list is not None:
assert len(img_positions) == len(image_list), f"{img_positions} {image_list} {IMG_CONTEXT_TOKEN} {IMG_CONTEXT_ID} {tokens}"
new_input_ids = []
st = 0
for img_idx, img_pos in enumerate(img_positions):
if image_path_list is not None:
image_patches, (best_width, best_height) = image_processor.process_images_with_subpatch(image_path_list[img_idx])
if image_list is not None:
image_patches, (best_width, best_height) = image_processor.process_images_with_subpatch(image_list[img_idx])
images.append(image_patches)
print(f"get_external_inputs best_width {best_width} best_height {best_height}")
new_input_ids += input_ids[st:img_pos]
new_input_ids += [IMG_START_ID]
image_indice_b = torch.zeros(
1, image_token_length, dtype=torch.int64
) # This will change in collate_fn
image_indice_s = (
torch.arange(len(new_input_ids), len(new_input_ids) + image_token_length)
.unsqueeze(0)
.repeat(1, 1)
)
image_indice_b_s = torch.stack(
[image_indice_b, image_indice_s], dim=0
) # 2, num_image, image_length
image_indices.append(image_indice_b_s)
new_input_ids += [IMG_CONTEXT_ID] * image_token_length
new_input_ids += [IMG_END_ID]
if len(image_patches) > 1:
for i in range(0, best_height, image_processor.patch_size):
new_input_ids += nl_tokens
for j in range(0, best_width, image_processor.patch_size):
new_input_ids += [PATCH_START_ID]
image_indice_b = torch.zeros(
1, image_token_length, dtype=torch.int64
) # This will change in collate_fn
image_indice_s = (
torch.arange(len(new_input_ids), len(new_input_ids) + image_token_length)
.unsqueeze(0)
.repeat(1, 1)
)
image_indice_b_s = torch.stack(
[image_indice_b, image_indice_s], dim=0
) # 2, num_image, image_length
image_indices.append(image_indice_b_s)
new_input_ids += [PATCH_CONTEXT_ID] * image_token_length
new_input_ids += [PATCH_END_ID]
# print(f"get_external_dict i {i} j {j} new_input_ids {len(new_input_ids)}")
st = img_pos + 1
new_input_ids += input_ids[st:]
input_ids = new_input_ids
tokens[batch_idx] = input_ids
# ----------------------------------------------------------------
# video
for batch_idx, input_ids in enumerate(tokens):
# vid_positions = [i for i, x in enumerate(input_ids) if x == VID_CONTEXT_ID]
vid_positions = [i for i, x in enumerate(input_ids) if x == VID_TAG_ID]
if len(vid_positions) == 0:
continue
if video_path_list is not None:
assert len(vid_positions) == len(video_path_list), f"{vid_positions} {video_path_list} {VID_CONTEXT_TOKEN} {VID_CONTEXT_ID} {tokens}"
if image_path_list is not None:
assert len(vid_positions) == len(image_path_list), f"{vid_positions} {image_path_list} {VID_CONTEXT_TOKEN} {VID_CONTEXT_ID} {tokens}"
if image_list is not None:
assert len(vid_positions) == len(image_list), f"{vid_positions} {image_list} {VID_CONTEXT_TOKEN} {VID_CONTEXT_ID} {tokens}"
new_input_ids = []
st = 0
for vid_idx, vid_pos in enumerate(vid_positions):
if video_path_list is not None:
video_frames, _ = image_processor.process_video(video_path_list[vid_idx], max_num_frame, max_fps)
if image_path_list is not None:
video_frames = image_processor.process_images([image_path_list[vid_idx]])
if image_list is not None:
video_frames = image_processor.process_images([image_list[vid_idx]])
images.append(video_frames)
new_input_ids += input_ids[st:vid_pos]
for _ in video_frames:
new_input_ids += [VID_START_ID]
image_indice_b = torch.zeros(
1, image_token_length, dtype=torch.int64
) # This will change in collate_fn
image_indice_s = (
torch.arange(len(new_input_ids), len(new_input_ids) + image_token_length)
.unsqueeze(0)
.repeat(1, 1)
)
image_indice_b_s = torch.stack(
[image_indice_b, image_indice_s], dim=0
) # 2, num_image, image_length
image_indices.append(image_indice_b_s)
new_input_ids += [VID_CONTEXT_ID] * image_token_length
new_input_ids += [VID_END_ID]
st = vid_pos + 1
new_input_ids += input_ids[st:]
input_ids = new_input_ids
tokens[batch_idx] = input_ids
if len(images) > 0:
images = torch.cat(images, dim=0)
image_indices = torch.cat(image_indices, dim=1)
image_indices = image_indices.contiguous().to(torch.cuda.current_device())
if True:
images = torch.tensor(images, dtype=torch.bfloat16).contiguous().to(torch.cuda.current_device())
else:
images = torch.tensor(images, dtype=torch.float16).contiguous().to(torch.cuda.current_device())
print(f"get_external_inputs images {images.size()}")
print(f"get_external_inputs image_indices {image_indices.size()}")
else:
images = None
image_indices = None
print(f"get_external_inputs images {images}")
print(f"get_external_inputs image_indices {image_indices}")
tokens = torch.tensor(tokens, dtype=torch.long, device='cuda')
print(f"get_external_inputs tokens {tokens.size()}")
return tokens, images, image_indices
##############################################################
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation import GenerationConfig
import torch
import importlib
if importlib.util.find_spec("torch_npu") is not None:
print("Loading torch_npu")
import torch_npu
from torch_npu.contrib import transfer_to_npu
# torch.npu.set_compile_mode(jit_compile=True)
import sys
import os
import natsort
import gradio as gr
import spaces
torch.manual_seed(1234)
model_name_or_path = "VITA-MLLM/Long-VITA-128K_HF"
device_map = "auto"
# device_map = "npu:0"
# torch_dtype=torch.float16
torch_dtype=torch.bfloat16
# torch_dtype=torch.float32
tokenizer = AutoTokenizer.from_pretrained(
model_name_or_path,
trust_remote_code=True
)
print("tokenizer", tokenizer)
model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
trust_remote_code=True,
device_map=device_map,
torch_dtype=torch_dtype,
attn_implementation="flash_attention_2",
).eval()
# print("model", model)
model.generation_config = GenerationConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
model.generation_config.max_new_tokens = 1024
model.generation_config.chat_format = "chatml"
model.generation_config.max_window_size = 1310720
model.generation_config.do_sample = False
model.generation_config.use_cache = True
model.generation_config.pad_token_id = tokenizer.pad_token_id
# from cognitron_vl.data.processor.image_processor import ImageProcessor
image_processor = ImageProcessor(
process_type="dynamic",
image_size=448,
normalize_type="imagenet",
min_patch_grid=1,
max_patch_grid=12,
)
@spaces.GPU(duration=120)
def inference_model(messages, image_path_list, video_path_list):
default_system_message = [
{
"role": "system",
"content": "You are a helpful AI assistant.",
}
]
messages = default_system_message + messages
inputs = tokenizer.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_tensors="pt",
)
# .to("cuda")
print("input", tokenizer.decode(inputs[0], skip_special_tokens=False), flush=True)
inputs, images, image_indices = get_external_inputs(inputs, image_path_list=image_path_list, video_path_list=video_path_list)
# inputs = inputs.to("cuda")
# images = images.to("cuda")
# image_indices = image_indices.to("cuda")
outputs = model.generate(inputs=inputs, images=images, image_indices=image_indices)
# output = tokenizer.decode(outputs[0], skip_special_tokens=False)
output = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True)
print(f"output {output}", flush=True)
return output
import time
import filetype
font_size = "2.5em"
html = f"""
<p align="center" style="font-size: {font_size}; line-height: 1;">
<span style="display: inline-block; vertical-align: middle;">{model_name_or_path.split('/')[-1]}</span>
</p>
<center>
<font size=3>
<b>Long-VITA</b> has been fully open-sourced on <a href='https://huggingface.co./VITA-MLLM'>😊 Huggingface</a> and <a href='https://github.com/VITA-MLLM/Long-VITA'>🌟 GitHub</a>. If you find Long-VITA useful, a like❤️ or a star🌟 would be appreciated.
</font>
</center>
"""
def add_message(history, message):
for x in message["files"]:
history.append({"role": "user", "content": {"path": x}})
if message["text"] is not None:
history.append({"role": "user", "content": message["text"]})
return history, gr.MultimodalTextbox(value=None, interactive=False)
def bot(history: list):
print("#" * 100)
messages = []
image_path_list = []
video_path_list = []
for message in history:
# print(f"message {message}")
role = message["role"]
content = message["content"]
if isinstance(content, str):
if len(messages) == 0 or messages[-1]["role"] != role:
messages.append(
{
"role": role,
"content": "",
}
)
messages[-1]["content"] = messages[-1]["content"] + content
else:
for filepath in content:
if filetype.is_image(filepath):
# print(f"{filepath} is a valid image...")
if len(messages) == 0 or messages[-1]["role"] != role:
messages.append(
{
"role": role,
"content": "",
}
)
messages[-1]["content"] = "<image>" + messages[-1]["content"]
image_path_list.append(filepath)
elif filetype.is_video(filepath):
# print(f"{filepath} is a valid video...")
if len(messages) == 0 or messages[-1]["role"] != role:
messages.append(
{
"role": role,
"content": "",
}
)
messages[-1]["content"] = "<video>" + messages[-1]["content"]
video_path_list.append(filepath)
print(f"messages {messages}")
print(f"image_path_list {image_path_list}")
print(f"video_path_list {video_path_list}")
if len(image_path_list) == 0:
image_path_list = None
if len(video_path_list) == 0:
video_path_list = None
output = inference_model(messages, image_path_list, video_path_list)
history.append({"role": "assistant", "content": output})
return history
with gr.Blocks(title=model_name_or_path.split('/')[-1] + "🔥🚀🔥", theme=gr.themes.Ocean()) as demo:
gr.HTML(html)
with gr.Row():
chatbot = gr.Chatbot(type="messages", elem_id="chatbot", bubble_full_width=False, height=600)
with gr.Row():
chat_input = gr.MultimodalTextbox(
interactive=True,
file_count="multiple",
file_types=['image', 'video'],
placeholder="Enter message or upload file...",
show_label=False,
# sources=["microphone", "upload"],
sources=["upload"],
)
chat_msg = chat_input.submit(
add_message, [chatbot, chat_input], [chatbot, chat_input]
)
bot_msg = chat_msg.then(bot, chatbot, chatbot, api_name="bot_response")
bot_msg.then(lambda: gr.MultimodalTextbox(interactive=True), None, [chat_input])
demo.launch()