Spaces:
Runtime error
Runtime error
# ------------------------------------------------------------------------ | |
# HOTR official code : hotr/models/hotr_matcher.py | |
# Copyright (c) Kakao Brain, Inc. and its affiliates. All Rights Reserved | |
# ------------------------------------------------------------------------ | |
import torch | |
from scipy.optimize import linear_sum_assignment | |
from torch import nn | |
from hotr.util.box_ops import box_cxcywh_to_xyxy, generalized_box_iou | |
import hotr.util.misc as utils | |
import wandb | |
class HungarianPairMatcher(nn.Module): | |
def __init__(self, args): | |
"""Creates the matcher | |
Params: | |
cost_action: This is the relative weight of the multi-label action classification error in the matching cost | |
cost_hbox: This is the relative weight of the classification error for human idx in the matching cost | |
cost_obox: This is the relative weight of the classification error for object idx in the matching cost | |
""" | |
super().__init__() | |
self.cost_action = args.set_cost_act | |
self.cost_hbox = self.cost_obox = args.set_cost_idx | |
self.cost_target = args.set_cost_tgt | |
self.log_printer = args.wandb | |
self.is_vcoco = (args.dataset_file == 'vcoco') | |
self.is_hico = (args.dataset_file == 'hico-det') | |
if self.is_vcoco: | |
self.valid_ids = args.valid_ids | |
self.invalid_ids = args.invalid_ids | |
assert self.cost_action != 0 or self.cost_hbox != 0 or self.cost_obox != 0, "all costs cant be 0" | |
def reduce_redundant_gt_box(self, tgt_bbox, indices): | |
"""Filters redundant Ground-Truth Bounding Boxes | |
Due to random crop augmentation, there exists cases where there exists | |
multiple redundant labels for the exact same bounding box and object class. | |
This function deals with the redundant labels for smoother HOTR training. | |
""" | |
tgt_bbox_unique, map_idx, idx_cnt = torch.unique(tgt_bbox, dim=0, return_inverse=True, return_counts=True) | |
k_idx, bbox_idx = indices | |
triggered = False | |
if (len(tgt_bbox) != len(tgt_bbox_unique)): | |
map_dict = {k: v for k, v in enumerate(map_idx)} | |
map_bbox2kidx = {int(bbox_id): k_id for bbox_id, k_id in zip(bbox_idx, k_idx)} | |
bbox_lst, k_lst = [], [] | |
for bbox_id in bbox_idx: | |
if map_dict[int(bbox_id)] not in bbox_lst: | |
bbox_lst.append(map_dict[int(bbox_id)]) | |
k_lst.append(map_bbox2kidx[int(bbox_id)]) | |
bbox_idx = torch.tensor(bbox_lst) | |
k_idx = torch.tensor(k_lst) | |
tgt_bbox_res = tgt_bbox_unique | |
else: | |
tgt_bbox_res = tgt_bbox | |
bbox_idx = bbox_idx.to(tgt_bbox.device) | |
return tgt_bbox_res, k_idx, bbox_idx | |
def forward(self, outputs, targets, indices, log=False): | |
assert "pred_actions" in outputs, "There is no action output for pair matching" | |
num_obj_queries = outputs["pred_boxes"].shape[1] | |
bs,num_path, num_queries = outputs["pred_actions"].shape[:3] | |
detr_query_num = outputs["pred_logits"].shape[1] \ | |
if (outputs["pred_oidx"].shape[-1] == (outputs["pred_logits"].shape[1] + 1)) else -1 | |
return_list = [] | |
if self.log_printer and log: | |
log_dict = {'h_cost': [], 'o_cost': [], 'act_cost': []} | |
if self.is_hico: log_dict['tgt_cost'] = [] | |
for batch_idx in range(bs): | |
tgt_bbox = targets[batch_idx]["boxes"] # (num_boxes, 4) | |
tgt_cls = targets[batch_idx]["labels"] # (num_boxes) | |
if self.is_vcoco: | |
targets[batch_idx]["pair_actions"][:, self.invalid_ids] = 0 | |
keep_idx = (targets[batch_idx]["pair_actions"].sum(dim=-1) != 0) | |
targets[batch_idx]["pair_boxes"] = targets[batch_idx]["pair_boxes"][keep_idx] | |
targets[batch_idx]["pair_actions"] = targets[batch_idx]["pair_actions"][keep_idx] | |
targets[batch_idx]["pair_targets"] = targets[batch_idx]["pair_targets"][keep_idx] | |
tgt_pbox = targets[batch_idx]["pair_boxes"] # (num_pair_boxes, 8) | |
tgt_act = targets[batch_idx]["pair_actions"] # (num_pair_boxes, 29) | |
tgt_tgt = targets[batch_idx]["pair_targets"] # (num_pair_boxes) | |
tgt_hbox = tgt_pbox[:, :4] # (num_pair_boxes, 4) | |
tgt_obox = tgt_pbox[:, 4:] # (num_pair_boxes, 4) | |
elif self.is_hico: | |
tgt_act = targets[batch_idx]["pair_actions"] # (num_pair_boxes, 117) | |
tgt_tgt = targets[batch_idx]["pair_targets"] # (num_pair_boxes) | |
tgt_hbox = targets[batch_idx]["sub_boxes"] # (num_pair_boxes, 4) | |
tgt_obox = targets[batch_idx]["obj_boxes"] # (num_pair_boxes, 4) | |
# find which gt boxes match the h, o boxes in the pair | |
if self.is_vcoco: | |
hbox_with_cls = torch.cat([tgt_hbox, torch.ones((tgt_hbox.shape[0], 1)).to(tgt_hbox.device)], dim=1) | |
elif self.is_hico: | |
hbox_with_cls = torch.cat([tgt_hbox, torch.zeros((tgt_hbox.shape[0], 1)).to(tgt_hbox.device)], dim=1) | |
obox_with_cls = torch.cat([tgt_obox, tgt_tgt.unsqueeze(-1)], dim=1) | |
obox_with_cls[obox_with_cls[:, :4].sum(dim=1) == -4, -1] = -1 # turn the class of occluded objects to -1 | |
bbox_with_cls = torch.cat([tgt_bbox, tgt_cls.unsqueeze(-1)], dim=1) | |
bbox_with_cls, k_idx, bbox_idx = self.reduce_redundant_gt_box(bbox_with_cls, indices[batch_idx]) | |
bbox_with_cls = torch.cat((bbox_with_cls, torch.as_tensor([-1.]*5).unsqueeze(0).to(tgt_cls.device)), dim=0) | |
cost_hbox = torch.cdist(hbox_with_cls, bbox_with_cls, p=1) | |
cost_obox = torch.cdist(obox_with_cls, bbox_with_cls, p=1) | |
# find which gt boxes matches which prediction in K | |
h_match_indices = torch.nonzero(cost_hbox == 0, as_tuple=False) # (num_hbox, num_boxes) | |
o_match_indices = torch.nonzero(cost_obox == 0, as_tuple=False) # (num_obox, num_boxes) | |
tgt_hids, tgt_oids = [], [] | |
# obtain ground truth indices for h | |
if len(h_match_indices) != len(o_match_indices): | |
import pdb; pdb.set_trace() | |
for h_match_idx, o_match_idx in zip(h_match_indices, o_match_indices): | |
hbox_idx, H_bbox_idx = h_match_idx | |
obox_idx, O_bbox_idx = o_match_idx | |
if O_bbox_idx == (len(bbox_with_cls)-1): # if the object class is -1 | |
O_bbox_idx = H_bbox_idx # happens in V-COCO, the target object may not appear | |
GT_idx_for_H = (bbox_idx == H_bbox_idx).nonzero(as_tuple=False).squeeze(-1) | |
query_idx_for_H = k_idx[GT_idx_for_H] | |
tgt_hids.append(query_idx_for_H) | |
GT_idx_for_O = (bbox_idx == O_bbox_idx).nonzero(as_tuple=False).squeeze(-1) | |
query_idx_for_O = k_idx[GT_idx_for_O] | |
tgt_oids.append(query_idx_for_O) | |
# check if empty | |
if len(tgt_hids) == 0: tgt_hids.append(torch.as_tensor([-1])) # we later ignore the label -1 | |
if len(tgt_oids) == 0: tgt_oids.append(torch.as_tensor([-1])) # we later ignore the label -1 | |
tgt_sum = (tgt_act.sum(dim=-1)).unsqueeze(0) | |
flag = False | |
if tgt_act.shape[0] == 0: | |
tgt_act = torch.zeros((1, tgt_act.shape[1])).to(targets[batch_idx]["pair_actions"].device) | |
targets[batch_idx]["pair_actions"] = torch.zeros((1, targets[batch_idx]["pair_actions"].shape[1])).to(targets[batch_idx]["pair_actions"].device) | |
if self.is_hico: | |
pad_tgt = -1 # outputs["pred_obj_logits"].shape[-1]-1 | |
tgt_tgt = torch.tensor([pad_tgt]).to(targets[batch_idx]["pair_targets"]) | |
targets[batch_idx]["pair_targets"] = torch.tensor([pad_tgt]).to(targets[batch_idx]["pair_targets"].device) | |
tgt_sum = (tgt_act.sum(dim=-1) + 1).unsqueeze(0) | |
# Concat target label | |
tgt_hids = torch.cat(tgt_hids).repeat(num_path) | |
tgt_oids = torch.cat(tgt_oids).repeat(num_path) | |
# import pdb;pdb.set_trace() | |
outputs_hidx=outputs["pred_hidx"].view(num_path,bs,num_queries,-1).transpose(0,1).flatten(1,2) | |
outputs_oidx=outputs["pred_oidx"].view(num_path,bs,num_queries,-1).transpose(0,1).flatten(1,2) | |
outputs_action=outputs["pred_actions"].view(bs,num_path*num_queries,-1) | |
out_hprob = outputs_hidx[batch_idx].softmax(-1) | |
out_oprob = outputs_oidx[batch_idx].softmax(-1) | |
out_act = outputs_action[batch_idx].clone() | |
if self.is_vcoco: out_act[..., self.invalid_ids] = 0 | |
if self.is_hico: | |
outputs_obj_logits=outputs["pred_obj_logits"].view(bs,num_path,num_queries,-1).view(bs,num_path*num_queries,-1) | |
out_tgt = outputs_obj_logits[batch_idx].softmax(-1) | |
out_tgt[..., -1] = 0 # don't get cost for no-object | |
tgt_act = torch.cat([tgt_act, torch.zeros(tgt_act.shape[0]).unsqueeze(-1).to(tgt_act.device)], dim=-1).repeat(num_path,1) | |
cost_hclass = -out_hprob[:, tgt_hids] # [batch_size * num_queries, detr.num_queries+1] | |
cost_oclass = -out_oprob[:, tgt_oids] # [batch_size * num_queries, detr.num_queries+1] | |
# import pdb;pdb.set_trace() | |
cost_pos_act = (-torch.matmul(out_act, tgt_act.t().float())) / tgt_sum.repeat(1,num_path) | |
cost_neg_act = (torch.matmul(out_act, (~tgt_act.bool()).type(torch.int64).t().float())) / (~tgt_act.bool()).type(torch.int64).sum(dim=-1).unsqueeze(0) | |
cost_action = cost_pos_act + cost_neg_act | |
h_cost = self.cost_hbox * cost_hclass | |
o_cost = self.cost_obox * cost_oclass | |
act_cost = self.cost_action * cost_action | |
C = h_cost + o_cost + act_cost | |
if self.is_hico: | |
cost_target = -out_tgt[:, tgt_tgt.repeat(num_path)] | |
tgt_cost = self.cost_target * cost_target | |
C += tgt_cost | |
C = C.view(num_path,num_queries, -1).cpu() | |
sizes = [len(tgt_hids)//num_path]*num_path | |
hoi_indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))] | |
return_list.append([(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in hoi_indices]) | |
# import pdb;pdb.set_trace() | |
targets[batch_idx]["h_labels"] = tgt_hids.to(tgt_hbox.device) | |
targets[batch_idx]["o_labels"] = tgt_oids.to(tgt_obox.device) | |
log_act_cost = torch.zeros([1]).to(tgt_act.device) if tgt_act.shape[0] == 0 else act_cost.min(dim=0)[0].mean() | |
if self.log_printer and log: | |
log_dict['h_cost'].append(h_cost[:num_queries].min(dim=0)[0].mean()) | |
log_dict['o_cost'].append(o_cost[:num_queries].min(dim=0)[0].mean()) | |
log_dict['act_cost'].append(act_cost[:num_queries].min(dim=0)[0].mean()) | |
if self.is_hico: log_dict['tgt_cost'].append(tgt_cost[:num_queries].min(dim=0)[0].mean()) | |
if self.log_printer and log: | |
log_dict['h_cost'] = torch.stack(log_dict['h_cost']).mean() | |
log_dict['o_cost'] = torch.stack(log_dict['o_cost']).mean() | |
log_dict['act_cost'] = torch.stack(log_dict['act_cost']).mean() | |
if self.is_hico: log_dict['tgt_cost'] = torch.stack(log_dict['tgt_cost']).mean() | |
if utils.get_rank() == 0: wandb.log(log_dict) | |
return return_list, targets | |
def build_hoi_matcher(args): | |
return HungarianPairMatcher(args) | |