Spaces:
Runtime error
Runtime error
# ------------------------------------------------------------------------ | |
# HOTR official code : hotr/models/post_process.py | |
# Copyright (c) Kakao Brain, Inc. and its affiliates. All Rights Reserved | |
# ------------------------------------------------------------------------ | |
import time | |
import copy | |
import torch | |
import torch.nn.functional as F | |
from torch import nn | |
from hotr.util import box_ops | |
class PostProcess(nn.Module): | |
""" This module converts the model's output into the format expected by the coco api""" | |
def __init__(self, HOIDet): | |
super().__init__() | |
self.HOIDet = HOIDet | |
def forward(self, outputs, target_sizes, threshold=0, dataset='coco',args=None): | |
""" Perform the computation | |
Parameters: | |
outputs: raw outputs of the model | |
target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch | |
For evaluation, this must be the original image size (before any data augmentation) | |
For visualization, this should be the image size after data augment, but before padding | |
""" | |
out_logits, out_bbox = outputs['pred_logits'], outputs['pred_boxes'] | |
num_path = 1+len(args.augpath_name) | |
path_id = args.path_id | |
assert len(out_logits) == len(target_sizes) | |
assert target_sizes.shape[1] == 2 | |
prob = F.softmax(out_logits, -1) | |
scores, labels = prob[..., :-1].max(-1) | |
boxes = box_ops.box_cxcywh_to_xyxy(out_bbox) | |
img_h, img_w = target_sizes.unbind(1) | |
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1) | |
boxes = boxes * scale_fct[:, None, :] | |
# Preidction Branch for HOI detection | |
if self.HOIDet: | |
if dataset == 'vcoco': | |
""" Compute HOI triplet prediction score for V-COCO. | |
Our scoring function follows the implementation details of UnionDet. | |
""" | |
out_time = outputs['hoi_recognition_time'] | |
bss,q,hd=outputs['pred_hidx'].shape | |
start_time = time.time() | |
pair_actions = torch.sigmoid(outputs['pred_actions'][:,path_id,...]) | |
h_prob = F.softmax(outputs['pred_hidx'].view(num_path,bss//num_path,q,hd)[path_id], -1) | |
h_idx_score, h_indices = h_prob.max(-1) | |
o_prob = F.softmax(outputs['pred_oidx'].view(num_path,bss//num_path,q,hd)[path_id], -1) | |
o_idx_score, o_indices = o_prob.max(-1) | |
hoi_recognition_time = (time.time() - start_time) + out_time | |
# import pdb;pdb.set_trace() | |
results = [] | |
# iterate for batch size | |
for batch_idx, (s, l, b) in enumerate(zip(scores, labels, boxes)): | |
h_inds = (l == 1) & (s > threshold) | |
o_inds = (s > threshold) | |
h_box, h_cat = b[h_inds], s[h_inds] | |
o_box, o_cat = b[o_inds], s[o_inds] | |
# for scenario 1 in v-coco dataset | |
o_inds = torch.cat((o_inds, torch.ones(1).type(torch.bool).to(o_inds.device))) | |
o_box = torch.cat((o_box, torch.Tensor([0, 0, 0, 0]).unsqueeze(0).to(o_box.device))) | |
result_dict = { | |
'h_box': h_box, 'h_cat': h_cat, | |
'o_box': o_box, 'o_cat': o_cat, | |
'scores': s, 'labels': l, 'boxes': b | |
} | |
h_inds_lst = (h_inds == True).nonzero(as_tuple=False).squeeze(-1) | |
o_inds_lst = (o_inds == True).nonzero(as_tuple=False).squeeze(-1) | |
K = boxes.shape[1] | |
n_act = pair_actions[batch_idx][:, :-1].shape[-1] | |
score = torch.zeros((n_act, K, K+1)).to(pair_actions[batch_idx].device) | |
sorted_score = torch.zeros((n_act, K, K+1)).to(pair_actions[batch_idx].device) | |
id_score = torch.zeros((K, K+1)).to(pair_actions[batch_idx].device) | |
# import pdb;pdb.set_trace() | |
# Score function | |
for hs, h_idx, os, o_idx, pair_action in zip(h_idx_score[batch_idx], h_indices[batch_idx], o_idx_score[batch_idx], o_indices[batch_idx], pair_actions[batch_idx]): | |
matching_score = (1-pair_action[-1]) # no interaction score | |
if h_idx == o_idx: o_idx = -1 | |
if matching_score > id_score[h_idx, o_idx]: | |
id_score[h_idx, o_idx] = matching_score | |
sorted_score[:, h_idx, o_idx] = matching_score * pair_action[:-1] | |
score[:, h_idx, o_idx] += matching_score * pair_action[:-1] | |
score += sorted_score | |
score = score[:, h_inds, :] | |
score = score[:, :, o_inds] | |
result_dict.update({ | |
'pair_score': score, | |
'hoi_recognition_time': hoi_recognition_time, | |
}) | |
results.append(result_dict) | |
elif dataset == 'hico-det': | |
""" Compute HOI triplet prediction score for HICO-DET. | |
For HICO-DET, we follow the same scoring function but do not accumulate the results. | |
""" | |
bss,q,hd=outputs['pred_hidx'].shape | |
out_time = outputs['hoi_recognition_time'] | |
a,b,c=outputs['pred_obj_logits'].shape | |
start_time = time.time() | |
out_obj_logits, out_verb_logits = outputs['pred_obj_logits'].view(-1,num_path,b,c)[:,path_id,...], outputs['pred_actions'][:,path_id,...] | |
out_verb_logits = outputs['pred_actions'][:,path_id,...] | |
# actions | |
matching_scores = (1-out_verb_logits.sigmoid()[..., -1:]) #* (1-out_verb_logits.sigmoid()[..., 57:58]) | |
verb_scores = out_verb_logits.sigmoid()[..., :-1] * matching_scores | |
# hbox, obox | |
outputs_hrepr, outputs_orepr = outputs['pred_hidx'].view(num_path,bss//num_path,q,hd)[path_id], outputs['pred_oidx'].view(num_path,bss//num_path,q,hd)[path_id] | |
obj_scores, obj_labels = F.softmax(out_obj_logits, -1)[..., :-1].max(-1) | |
h_prob = F.softmax(outputs_hrepr, -1) | |
h_idx_score, h_indices = h_prob.max(-1) | |
# targets | |
o_prob = F.softmax(outputs_orepr, -1) | |
o_idx_score, o_indices = o_prob.max(-1) | |
hoi_recognition_time = (time.time() - start_time) + out_time | |
# hidx, oidx | |
sub_boxes, obj_boxes = [], [] | |
for batch_id, (box, h_idx, o_idx) in enumerate(zip(boxes, h_indices, o_indices)): | |
sub_boxes.append(box[h_idx, :]) | |
obj_boxes.append(box[o_idx, :]) | |
sub_boxes = torch.stack(sub_boxes, dim=0) | |
obj_boxes = torch.stack(obj_boxes, dim=0) | |
# accumulate results (iterate through interaction queries) | |
results = [] | |
for os, ol, vs, ms, sb, ob in zip(obj_scores, obj_labels, verb_scores, matching_scores, sub_boxes, obj_boxes): | |
sl = torch.full_like(ol, 0) # self.subject_category_id = 0 in HICO-DET | |
l = torch.cat((sl, ol)) | |
b = torch.cat((sb, ob)) | |
results.append({'labels': l.to('cpu'), 'boxes': b.to('cpu')}) | |
vs = vs * os.unsqueeze(1) | |
ids = torch.arange(b.shape[0]) | |
res_dict = { | |
'verb_scores': vs.to('cpu'), | |
'sub_ids': ids[:ids.shape[0] // 2], | |
'obj_ids': ids[ids.shape[0] // 2:], | |
'hoi_recognition_time': hoi_recognition_time | |
} | |
results[-1].update(res_dict) | |
else: | |
results = [{'scores': s, 'labels': l, 'boxes': b} for s, l, b in zip(scores, labels, boxes)] | |
return results | |