Spaces:
Runtime error
Runtime error
import os | |
import json | |
import torch | |
import numpy as np | |
from tqdm import tqdm | |
from vbench.utils import load_video, load_dimension_info | |
from vbench.third_party.grit_model import DenseCaptioning | |
import logging | |
logging.basicConfig(level = logging.INFO,format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
logger = logging.getLogger(__name__) | |
def get_position_score(locality, obj1,obj2, iou_threshold=0.1): | |
# input obj1 and obj2 should be [x0,y0,x1,y1] | |
# Calculate centers of bounding boxes | |
box1 = { | |
'x_min': obj1[0], | |
'y_min': obj1[1], | |
'x_max': obj1[2], | |
'y_max': obj1[3], | |
'width': obj1[2] - obj1[0], | |
'height': obj1[3] - obj1[1] | |
} | |
box2 = { | |
'x_min': obj2[0], | |
'y_min': obj2[1], | |
'x_max': obj2[2], | |
'y_max': obj2[3], | |
'width': obj2[2] - obj2[0], | |
'height': obj2[3] - obj2[1] | |
} | |
# Get the object center | |
box1_center = ((box1['x_min'] + box1['x_max']) / 2, (box1['y_min'] + box1['y_max']) / 2) | |
box2_center = ((box2['x_min'] + box2['x_max']) / 2, (box2['y_min'] + box2['y_max']) / 2) | |
# Calculate horizontal and vertical distances | |
x_distance = box2_center[0] - box1_center[0] | |
y_distance = box2_center[1] - box1_center[1] | |
# Calculate IoU | |
x_overlap = max(0, min(box1['x_max'], box2['x_max']) - max(box1['x_min'], box2['x_min'])) | |
y_overlap = max(0, min(box1['y_max'], box2['y_max']) - max(box1['y_min'], box2['y_min'])) | |
intersection = x_overlap * y_overlap | |
box1_area = (box1['x_max'] - box1['x_min']) * (box1['y_max'] - box1['y_min']) | |
box2_area = (box2['x_max'] - box2['x_min']) * (box2['y_max'] - box2['y_min']) | |
union = box1_area + box2_area - intersection | |
iou = intersection / union | |
# get max object width and max object height | |
max_width = max(box1['width'], box2['width']) | |
max_height = max(box1['height'], box2['height']) | |
score=0 | |
if locality in 'on the right of' or locality in 'on the left of': | |
if abs(x_distance) > abs(y_distance) and iou < iou_threshold: | |
score=1 | |
elif abs(x_distance) > abs(y_distance) and iou >= iou_threshold: | |
score=iou_threshold/iou | |
else: | |
score=0 | |
elif locality in 'on the bottom of' or locality in 'on the top of': | |
if abs(y_distance) > abs(x_distance) and iou < iou_threshold: | |
score=1 | |
elif abs(y_distance) > abs(x_distance) and iou >= iou_threshold: | |
score=iou_threshold/iou | |
else: | |
score = 0 | |
return score | |
def get_dect_from_grit(model, image_arrays): | |
pred = [] | |
if type(image_arrays) is not list: | |
image_arrays = image_arrays.numpy() | |
with torch.no_grad(): | |
for frame in image_arrays: | |
ret = model.run_caption_tensor(frame) | |
pred_cur = [] | |
if len(ret[0])>0: | |
for info in ret[0]: | |
pred_cur.append([info[0],info[1]]) | |
pred.append(pred_cur) | |
return pred | |
def check_generate(key_info, predictions): | |
key_a = key_info['object_a'] | |
key_b = key_info['object_b'] | |
relation = key_info['relationship'] | |
frame_score =[] | |
for frame_pred in predictions: | |
# filter the target object | |
frame_obj_locats = [] | |
cur_score = [0] | |
for item in frame_pred: | |
if (key_a == item[0]) or (key_b == item[0]): | |
frame_obj_locats.append(item[1]) | |
for c_obj1 in range(len(frame_obj_locats)-1): | |
for c_obj2 in range(c_obj1+1 ,len(frame_obj_locats)): | |
score_obj1_obj2 = get_position_score(relation, frame_obj_locats[c_obj1], frame_obj_locats[c_obj2]) | |
cur_score.append(score_obj1_obj2) | |
frame_score.append(max(cur_score)) | |
return frame_score | |
def spatial_relationship(model, video_dict, device): | |
video_results = [] | |
frame_score_overall = [] | |
for info in tqdm(video_dict): | |
if 'auxiliary_info' not in info: | |
raise "Auxiliary info is not in json, please check your json." | |
object_info = info['auxiliary_info']['spatial_relationship'] | |
for video_path in info['video_list']: | |
video_tensor = load_video(video_path, num_frames=16) | |
cur_video_pred = get_dect_from_grit(model, video_tensor.permute(0,2,3,1)) | |
cur_video_frame_score = check_generate(object_info, cur_video_pred) | |
cur_success_frame_rate = np.mean(cur_video_frame_score) | |
frame_score_overall.extend(cur_video_frame_score) | |
video_results.append({'video_path': video_path, 'video_results': cur_success_frame_rate, 'frame_results':cur_video_frame_score}) | |
success_rate = np.mean(frame_score_overall) | |
return success_rate, video_results | |
def compute_spatial_relationship(json_dir, device, submodules_dict): | |
dense_caption_model = DenseCaptioning(device) | |
dense_caption_model.initialize_model_det(**submodules_dict) | |
logger.info("Initialize detection model success") | |
_, prompt_dict_ls = load_dimension_info(json_dir, dimension='spatial_relationship', lang='en') | |
all_results, video_results = spatial_relationship(dense_caption_model, prompt_dict_ls, device) | |
return all_results, video_results | |