Spaces:
Runtime error
Runtime error
import os | |
import json | |
import torch | |
import numpy as np | |
from tqdm import tqdm | |
from vbench.utils import load_video, load_dimension_info, tag2text_transform | |
from vbench.third_party.tag2Text.tag2text import tag2text_caption | |
import logging | |
logging.basicConfig(level = logging.INFO,format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
logger = logging.getLogger(__name__) | |
def get_caption(model, image_arrays): | |
caption, tag_predict = model.generate(image_arrays, tag_input = None, return_tag_predict = True) | |
return caption | |
def check_generate(key_info, predictions): | |
cur_cnt = 0 | |
key = key_info['scene'] | |
for pred in predictions: | |
q_flag = [q in pred for q in key.split(' ')] | |
if len(q_flag) == sum(q_flag): | |
cur_cnt +=1 | |
return cur_cnt | |
def scene(model, video_dict, device): | |
success_frame_count, frame_count = 0,0 | |
video_results = [] | |
transform = tag2text_transform(384) | |
for info in tqdm(video_dict): | |
if 'auxiliary_info' not in info: | |
raise "Auxiliary info is not in json, please check your json." | |
scene_info = info['auxiliary_info']['scene'] | |
for video_path in info['video_list']: | |
video_array = load_video(video_path, num_frames=16, return_tensor=False, width=384, height=384) | |
video_tensor_list = [] | |
for i in video_array: | |
video_tensor_list.append(transform(i).to(device).unsqueeze(0)) | |
video_tensor = torch.cat(video_tensor_list) | |
cur_video_pred = get_caption(model, video_tensor) | |
cur_success_frame_count = check_generate(scene_info, cur_video_pred) | |
cur_success_frame_rate = cur_success_frame_count/len(cur_video_pred) | |
success_frame_count += cur_success_frame_count | |
frame_count += len(cur_video_pred) | |
video_results.append({'video_path': video_path, 'video_results': cur_success_frame_rate}) | |
success_rate = success_frame_count / frame_count | |
return success_rate, video_results | |
def compute_scene(json_dir, device, submodules_dict): | |
model = tag2text_caption(**submodules_dict) | |
model.eval() | |
model = model.to(device) | |
logger.info("Initialize caption model success") | |
_, prompt_dict_ls = load_dimension_info(json_dir, dimension='scene', lang='en') | |
all_results, video_results = scene(model, prompt_dict_ls, device) | |
return all_results, video_results | |