Spaces:
Runtime error
Runtime error
File size: 5,247 Bytes
04fbff5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import os
import json
import torch
import numpy as np
from tqdm import tqdm
from vbench.utils import load_video, load_dimension_info
from vbench.third_party.grit_model import DenseCaptioning
import logging
logging.basicConfig(level = logging.INFO,format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def get_position_score(locality, obj1,obj2, iou_threshold=0.1):
# input obj1 and obj2 should be [x0,y0,x1,y1]
# Calculate centers of bounding boxes
box1 = {
'x_min': obj1[0],
'y_min': obj1[1],
'x_max': obj1[2],
'y_max': obj1[3],
'width': obj1[2] - obj1[0],
'height': obj1[3] - obj1[1]
}
box2 = {
'x_min': obj2[0],
'y_min': obj2[1],
'x_max': obj2[2],
'y_max': obj2[3],
'width': obj2[2] - obj2[0],
'height': obj2[3] - obj2[1]
}
# Get the object center
box1_center = ((box1['x_min'] + box1['x_max']) / 2, (box1['y_min'] + box1['y_max']) / 2)
box2_center = ((box2['x_min'] + box2['x_max']) / 2, (box2['y_min'] + box2['y_max']) / 2)
# Calculate horizontal and vertical distances
x_distance = box2_center[0] - box1_center[0]
y_distance = box2_center[1] - box1_center[1]
# Calculate IoU
x_overlap = max(0, min(box1['x_max'], box2['x_max']) - max(box1['x_min'], box2['x_min']))
y_overlap = max(0, min(box1['y_max'], box2['y_max']) - max(box1['y_min'], box2['y_min']))
intersection = x_overlap * y_overlap
box1_area = (box1['x_max'] - box1['x_min']) * (box1['y_max'] - box1['y_min'])
box2_area = (box2['x_max'] - box2['x_min']) * (box2['y_max'] - box2['y_min'])
union = box1_area + box2_area - intersection
iou = intersection / union
# get max object width and max object height
max_width = max(box1['width'], box2['width'])
max_height = max(box1['height'], box2['height'])
score=0
if locality in 'on the right of' or locality in 'on the left of':
if abs(x_distance) > abs(y_distance) and iou < iou_threshold:
score=1
elif abs(x_distance) > abs(y_distance) and iou >= iou_threshold:
score=iou_threshold/iou
else:
score=0
elif locality in 'on the bottom of' or locality in 'on the top of':
if abs(y_distance) > abs(x_distance) and iou < iou_threshold:
score=1
elif abs(y_distance) > abs(x_distance) and iou >= iou_threshold:
score=iou_threshold/iou
else:
score = 0
return score
def get_dect_from_grit(model, image_arrays):
pred = []
if type(image_arrays) is not list:
image_arrays = image_arrays.numpy()
with torch.no_grad():
for frame in image_arrays:
ret = model.run_caption_tensor(frame)
pred_cur = []
if len(ret[0])>0:
for info in ret[0]:
pred_cur.append([info[0],info[1]])
pred.append(pred_cur)
return pred
def check_generate(key_info, predictions):
key_a = key_info['object_a']
key_b = key_info['object_b']
relation = key_info['relationship']
frame_score =[]
for frame_pred in predictions:
# filter the target object
frame_obj_locats = []
cur_score = [0]
for item in frame_pred:
if (key_a == item[0]) or (key_b == item[0]):
frame_obj_locats.append(item[1])
for c_obj1 in range(len(frame_obj_locats)-1):
for c_obj2 in range(c_obj1+1 ,len(frame_obj_locats)):
score_obj1_obj2 = get_position_score(relation, frame_obj_locats[c_obj1], frame_obj_locats[c_obj2])
cur_score.append(score_obj1_obj2)
frame_score.append(max(cur_score))
return frame_score
def spatial_relationship(model, video_dict, device):
video_results = []
frame_score_overall = []
for info in tqdm(video_dict):
if 'auxiliary_info' not in info:
raise "Auxiliary info is not in json, please check your json."
object_info = info['auxiliary_info']['spatial_relationship']
for video_path in info['video_list']:
video_tensor = load_video(video_path, num_frames=16)
cur_video_pred = get_dect_from_grit(model, video_tensor.permute(0,2,3,1))
cur_video_frame_score = check_generate(object_info, cur_video_pred)
cur_success_frame_rate = np.mean(cur_video_frame_score)
frame_score_overall.extend(cur_video_frame_score)
video_results.append({'video_path': video_path, 'video_results': cur_success_frame_rate, 'frame_results':cur_video_frame_score})
success_rate = np.mean(frame_score_overall)
return success_rate, video_results
def compute_spatial_relationship(json_dir, device, submodules_dict):
dense_caption_model = DenseCaptioning(device)
dense_caption_model.initialize_model_det(**submodules_dict)
logger.info("Initialize detection model success")
_, prompt_dict_ls = load_dimension_info(json_dir, dimension='spatial_relationship', lang='en')
all_results, video_results = spatial_relationship(dense_caption_model, prompt_dict_ls, device)
return all_results, video_results
|