import torch, os from scipy.optimize import linear_sum_assignment from torch import nn from .utils import OKSLoss import numpy as np from util.box_ops import box_cxcywh_to_xyxy, generalized_box_iou class HungarianMatcher(nn.Module): def __init__(self, cost_class: float = 1, cost_bbox: float = 1, cost_giou: float = 1, focal_alpha=0.25, cost_keypoints=1.0, cost_kpvis=0.1, cost_oks=0.01, num_body_points=17): super().__init__() self.cost_class = cost_class self.cost_bbox = cost_bbox self.cost_giou = cost_giou assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, 'all costs cant be 0' self.cost_keypoints = cost_keypoints self.cost_kpvis = cost_kpvis self.cost_oks = cost_oks self.focal_alpha = focal_alpha self.num_body_points = num_body_points if num_body_points == 17: self.sigmas = np.array([ .26, .25, .25, .35, .35, .79, .79, .72, .72, .62, .62, 1.07, 1.07, .87, .87, .89, .89 ], dtype=np.float32) / 10.0 elif num_body_points == 14: self.sigmas = np.array([ .79, .79, .72, .72, .62, .62, 1.07, 1.07, .87, .87, .89, .89, .79, .79 ]) / 10.0 else: raise ValueError(f'Unsupported keypoints number {num_keypoints}') @torch.no_grad() def forward(self, outputs, targets, data_batch=None): bs, num_queries = outputs['pred_logits'].shape[:2] out_prob = outputs['pred_logits'].flatten(0, 1).sigmoid() out_bbox = outputs['pred_boxes'].flatten(0, 1) out_keypoints = outputs['pred_keypoints'].flatten(0, 1) # Also concat the target labels and boxes tgt_ids = torch.cat([v['labels'] for v in targets]) tgt_bbox = torch.cat([v['boxes'] for v in targets]) tgt_keypoints = torch.cat([v['keypoints'] for v in targets]) tgt_area = torch.cat([v['area'] for v in targets]) # Compute the classification cost. alpha = self.focal_alpha gamma = 2.0 neg_cost_class = (1 - alpha) * (out_prob** gamma) * (-(1 - out_prob + 1e-8).log()) pos_cost_class = alpha * ( (1 - out_prob)**gamma) * (-(out_prob + 1e-8).log()) cost_class = pos_cost_class[:, tgt_ids] - neg_cost_class[:, tgt_ids] # Compute the L1 cost between boxes cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1) # Compute the giou cost betwen boxes cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox), data_batch) # compute the keypoint costs Z_pred = out_keypoints[:, 0:(self.num_body_points * 2)] V_pred = out_keypoints[:, (self.num_body_points * 2):] Z_gt = tgt_keypoints[:, 0:(self.num_body_points * 2)] V_gt: torch.Tensor = tgt_keypoints[:, (self.num_body_points * 2):] if Z_pred.sum() > 0: sigmas = Z_pred.new_tensor(self.sigmas) variances = (sigmas * 2)**2 kpt_preds = Z_pred.reshape(-1, Z_pred.size(-1) // 2, 2) kpt_gts = Z_gt.reshape(-1, Z_gt.size(-1) // 2, 2) squared_distance = (kpt_preds[:, None, :, 0] - kpt_gts[None, :, :, 0]) ** 2 + \ (kpt_preds[:, None, :, 1] - kpt_gts[None, :, :, 1]) ** 2 squared_distance0 = squared_distance / (tgt_area[:, None] * variances[None, :] * 2) squared_distance1 = torch.exp(-squared_distance0) squared_distance1 = squared_distance1 * V_gt oks = squared_distance1.sum(dim=-1) / (V_gt.sum(dim=-1) + 1e-6) oks = oks.clamp(min=1e-6) cost_oks = 1 - oks # import pdb; pdb.set_trace() cost_keypoints = torch.abs(Z_pred[:, None, :] - Z_gt[None]) cost_keypoints = cost_keypoints * V_gt.repeat_interleave( 2, dim=1)[None] cost_keypoints = cost_keypoints.sum(-1) cost_bbox = torch.zeros_like(cost_keypoints) cost_giou = torch.zeros_like( cost_keypoints) # [bs*query, instance_num] C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou + self.cost_keypoints * cost_keypoints + self.cost_oks * cost_oks C = C.view(bs, num_queries, -1).cpu() else: cost_oks = torch.zeros_like(cost_bbox) cost_keypoints = torch.zeros_like(cost_bbox) C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou + self.cost_keypoints * cost_keypoints + self.cost_oks * cost_oks C = C.view(bs, num_queries, -1).cpu() sizes = [len(v['boxes']) for v in targets] indices = [ linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1)) ] # import mmcv # import numpy as np # import cv2 # from detrsmpl.core.visualization.visualize_keypoints2d import visualize_kp2d # img = mmcv.imdenormalize( # img=(data_batch['img'][0].cpu().numpy()).transpose(1, 2, 0), # mean=np.array([123.675, 116.28, 103.53]), # std=np.array([58.395, 57.12, 57.375]), # to_bgr=True).astype(np.uint8) # visualize_kp2d( # (gt_2d).reshape(-1,2)[None], # output_path='./figs/gt2d', # image_array=img.copy()[None], # # data_source='smplx_137', # disable_limbs = True, # overwrite=True) # from util import box_ops # idx = [0, 1, 83] # pred_bbox_body = (box_ops.box_cxcywh_to_xyxy(outputs['pred_boxes'][idx]).reshape(-1,2).detach().cpu().numpy()*data_batch['img_shape'].cpu().numpy()[1, ::-1]).reshape(-1,4) # pred_bbox_lhand = (box_ops.box_cxcywh_to_xyxy(outputs['pred_lhand_boxes'][idx]).reshape(-1,2).detach().cpu().numpy()*data_batch['img_shape'].cpu().numpy()[1, ::-1]).reshape(-1,4) # pred_bbox_rhand = (box_ops.box_cxcywh_to_xyxy(outputs['pred_rhand_boxes'][idx]).reshape(-1,2).detach().cpu().numpy()*data_batch['img_shape'].cpu().numpy()[1, ::-1]).reshape(-1,4) # pred_bbox_face = (box_ops.box_cxcywh_to_xyxy(outputs['pred_face_boxes'][idx]).reshape(-1,2).detach().cpu().numpy()*data_batch['img_shape'].cpu().numpy()[1, ::-1]).reshape(-1,4) # pred_bbox = np.concatenate([pred_bbox_body,pred_bbox_face,pred_bbox_rhand,pred_bbox_lhand],axis=0) # img = mmcv.imshow_bboxes(img.copy(), pred_bbox, show=False) # cv2.imwrite('test1.png', img) if tgt_ids.shape[0] > 0: cost_mean_dict = { 'class': cost_class.mean(), 'bbox': cost_bbox.mean(), 'giou': cost_giou.mean(), 'keypoints': cost_keypoints.mean() } else: cost_mean_dict = { 'class': torch.zeros_like(cost_class.mean()), 'bbox': torch.zeros_like(cost_bbox.mean()), 'giou': torch.zeros_like(cost_giou.mean()), 'keypoints': torch.zeros_like(cost_keypoints.mean()), } return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices], cost_mean_dict def build_matcher(args): if args.matcher_type == 'HungarianMatcher': return HungarianMatcher(cost_class=args.set_cost_class, cost_bbox=args.set_cost_bbox, cost_giou=args.set_cost_giou, focal_alpha=args.focal_alpha, cost_keypoints=args.set_cost_keypoints, cost_kpvis=args.set_cost_kpvis, cost_oks=args.set_cost_oks, num_body_points=args.num_body_points) elif args.matcher_type == 'HungarianMatcherBox': return HungarianMatcherBox(cost_class=args.set_cost_class, cost_bbox=args.set_cost_bbox, cost_giou=args.set_cost_giou, focal_alpha=args.focal_alpha) else: raise NotImplementedError('Unknown args.matcher_type: {}'.format( args.matcher_type)) class HungarianMatcherBox(nn.Module): def __init__(self, cost_class: float = 1, cost_bbox: float = 1, cost_giou: float = 1, focal_alpha=0.25, cost_keypoints=1.0, cost_kpvis=0.1, cost_oks=0.01, num_body_points=17): super().__init__() self.cost_class = cost_class self.cost_bbox = cost_bbox self.cost_giou = cost_giou assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, 'all costs cant be 0' self.cost_keypoints = cost_keypoints self.cost_kpvis = cost_kpvis self.cost_oks = cost_oks self.focal_alpha = focal_alpha self.num_body_points = num_body_points if num_body_points == 17: self.sigmas = np.array([ .26, .25, .25, .35, .35, .79, .79, .72, .72, .62, .62, 1.07, 1.07, .87, .87, .89, .89 ], dtype=np.float32) / 10.0 elif num_body_points == 14: self.sigmas = np.array([ .79, .79, .72, .72, .62, .62, 1.07, 1.07, .87, .87, .89, .89, .79, .79 ]) / 10.0 else: raise ValueError(f'Unsupported keypoints number {num_keypoints}') @torch.no_grad() def forward(self, outputs, targets): bs, num_queries = outputs['pred_logits'].shape[:2] out_prob = outputs['pred_logits'].flatten(0, 1).sigmoid() out_bbox = outputs['pred_boxes'].flatten(0, 1) # Also concat the target labels and boxes tgt_ids = torch.cat([v['labels'] for v in targets]) tgt_bbox = torch.cat([v['boxes'] for v in targets]) # Compute the classification cost. alpha = self.focal_alpha gamma = 2.0 neg_cost_class = (1 - alpha) * (out_prob** gamma) * (-(1 - out_prob + 1e-8).log()) pos_cost_class = alpha * ( (1 - out_prob)**gamma) * (-(out_prob + 1e-8).log()) cost_class = pos_cost_class[:, tgt_ids] - neg_cost_class[:, tgt_ids] # Compute the L1 cost between boxes cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1) # Compute the giou cost betwen boxes cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox)) cost_oks = torch.zeros_like(cost_bbox) cost_keypoints = torch.zeros_like(cost_bbox) C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou C = C.view(bs, num_queries, -1).cpu() sizes = [len(v['boxes']) for v in targets] indices = [ linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1)) ] if tgt_ids.shape[0] > 0: cost_mean_dict = { 'class': cost_class.mean(), 'bbox': cost_bbox.mean(), 'giou': cost_giou.mean(), } else: cost_mean_dict = { 'class': torch.zeros_like(cost_class.mean()), 'bbox': torch.zeros_like(cost_bbox.mean()), 'giou': torch.zeros_like(cost_giou.mean()), } return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices], cost_mean_dict