import copy import os import math from typing import List import torch import torch.nn.functional as F from torch import nn from torchvision.ops.boxes import nms from torch import Tensor from util import box_ops from util.misc import (NestedTensor, nested_tensor_from_tensor_list, accuracy, get_world_size, interpolate, is_dist_avail_and_initialized, inverse_sigmoid) from .utils import PoseProjector, sigmoid_focal_loss, MLP, OKSLoss from typing import Optional, Union from detrsmpl.core.conventions.keypoints_mapping import (get_keypoint_idx, convert_kps) from detrsmpl.utils.geometry import (batch_rodrigues, project_points_new) from config.config import cfg from util.human_models import smpl_x from detrsmpl.utils.transforms import rotmat_to_aa class SetCriterion(nn.Module): def __init__(self, num_classes, matcher, weight_dict, focal_alpha, losses, num_box_decoder_layers=2, num_hand_face_decoder_layers=4, num_body_points=17, num_hand_points=6, num_face_points=6, smpl_loss_config=None, convention='smplx_137'): super().__init__() self.num_classes = num_classes self.matcher = matcher self.weight_dict = weight_dict self.losses = losses self.focal_alpha = focal_alpha self.vis = 0.1 self.abs = 1 self.num_body_points = num_body_points self.num_hand_points = num_hand_points self.num_face_points = num_face_points self.num_box_decoder_layers = num_box_decoder_layers self.num_hand_face_decoder_layers = num_hand_face_decoder_layers self.convention = convention self.body_oks = OKSLoss(linear=True, num_keypoints=num_body_points, eps=1e-6, reduction='mean', loss_weight=1.0) self.hand_oks = OKSLoss(linear=True, num_keypoints=num_hand_points, eps=1e-6, reduction='mean', loss_weight=1.0) self.face_oks = OKSLoss(linear=True, num_keypoints=num_face_points, eps=1e-6, reduction='mean', loss_weight=1.0) def loss_labels(self, outputs, targets, indices, idx, num_boxes, data_batch, log=True): """Classification loss (Binary focal loss) targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]""" indices = indices[0] assert 'pred_logits' in outputs src_logits = outputs['pred_logits'] target_classes_o = torch.cat( [t['labels'][J] for t, (_, J) in zip(targets, indices)]) target_classes = torch.full(src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device) target_classes[idx] = target_classes_o target_classes_onehot = torch.zeros([ src_logits.shape[0], src_logits.shape[1], src_logits.shape[2] + 1 ], dtype=src_logits.dtype, layout=src_logits.layout, device=src_logits.device) target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1) target_classes_onehot = target_classes_onehot[:, :, :-1] loss_ce = sigmoid_focal_loss(src_logits, target_classes_onehot, num_boxes, alpha=self.focal_alpha, gamma=2) * src_logits.shape[1] losses = {'loss_ce': loss_ce} if log: # TODO this should probably be a separate loss, not hacked in this one here losses['class_error'] = 100 - accuracy(src_logits[idx], target_classes_o)[0] return losses @torch.no_grad() def loss_cardinality(self, outputs, targets, indices, num_boxes, data_batch): """Compute the cardinality error, ie the absolute error in the number of predicted non-empty boxes This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients """ pred_logits = outputs['pred_logits'] device = pred_logits.device tgt_lengths = torch.as_tensor([len(v['labels']) for v in targets], device=device) if tgt_lengths == 0: return {'cardinality_error': pred_logits.sum()*0} # Count the number of predictions that are NOT "no-object" (which is the last class) card_pred = (pred_logits.argmax(-1) != pred_logits.shape[-1] - 1).sum(1) card_err = F.l1_loss(card_pred.float(), tgt_lengths.float()) losses = {'cardinality_error': card_err} return losses def loss_keypoints(self, outputs, targets, indices, idx, num_boxes, data_batch, face_hand_kpt=False): """Compute the losses related to the keypoints.""" indices = indices[0] losses = {} device = outputs['pred_logits'].device ############################################################ # body ############################################################ src_body_keypoints = outputs['pred_keypoints'][idx] # xyxyvv if len(src_body_keypoints) == 0: losses.append({ 'loss_keypoints': src_body_keypoints.sum() * 0 + \ outputs['pred_smpl_cam'][idx].float().sum()*0, 'loss_oks': src_body_keypoints.sum() * torch.as_tensor(0., device=device), }) else: Z_pred = src_body_keypoints[:, 0:(self.num_body_points * 2)] # [2, 2*14] V_pred = src_body_keypoints[:, (self.num_body_points * 2):] targets_body_keypoints = torch.cat( [t['keypoints'][i] for t, (_, i) in zip(targets, indices)], dim=0) targets_area = torch.cat( [t['area'][i] for t, (_, i) in zip(targets, indices)], dim=0) target_body_boxes_conf = torch.cat( [t[i] for t, (_, i) in zip(data_batch['body_bbox_valid'], indices)], dim=0) Z_gt = targets_body_keypoints[:, 0:(self.num_body_points * 2)] V_gt: torch.Tensor = targets_body_keypoints[:, (self.num_body_points * 2):] body_kps_conf = V_gt.sum(-1)>0 body_num_boxes = (body_kps_conf * target_body_boxes_conf).sum() oks_loss = self.body_oks(Z_pred, Z_gt, V_gt, targets_area, weight=None, avg_factor=None, reduction_override=None) oks_loss*= body_kps_conf * target_body_boxes_conf pose_loss = F.l1_loss(Z_pred, Z_gt, reduction='none') pose_loss = pose_loss * V_gt.repeat_interleave(2, dim=1) pose_loss = pose_loss.sum(-1) * target_body_boxes_conf if body_num_boxes>0: losses['loss_keypoints'] = pose_loss.sum() / body_num_boxes losses['loss_oks'] = oks_loss.sum() / body_num_boxes else: losses['loss_keypoints'] = src_body_keypoints.sum() * torch.as_tensor(0., device=device) losses['loss_oks'] = src_body_keypoints.sum() * torch.as_tensor(0., device=device) ############################################################ # lhand ############################################################ if 'pred_lhand_keypoints' in outputs and face_hand_kpt: src_lhand_keypoints = outputs['pred_lhand_keypoints'][idx] # xyxyvv if len(src_lhand_keypoints) == 0: losses.update({ 'loss_lhand_keypoints': src_lhand_keypoints.sum() * torch.as_tensor(0., device=device), 'loss_lhand_oks':src_lhand_keypoints.sum() * torch.as_tensor(0., device=device), }) else: Z_pred = src_lhand_keypoints[:, 0:(self.num_hand_points * 2)] # [2, 2*14] V_pred = src_lhand_keypoints[:, (self.num_hand_points * 2):] targets_lhand_keypoints = torch.cat( [t['lhand_keypoints'][i] for t, (_, i) in zip(targets, indices)], dim=0) # i is batch_size targets_area = torch.cat( [t['area'][i] for t, (_, i) in zip(targets, indices)], dim=0) target_lhand_boxes_conf = torch.cat( [t[i] for t, (_, i) in zip(data_batch['lhand_bbox_valid'], indices)], dim=0) Z_gt = targets_lhand_keypoints[:, 0:(self.num_hand_points * 2)] V_gt: torch.Tensor = targets_lhand_keypoints[:, (self.num_hand_points * 2):] lhand_kps_conf = V_gt.sum(-1)>0 lhand_num_boxes = (lhand_kps_conf*target_lhand_boxes_conf).sum() oks_loss = self.hand_oks(Z_pred, Z_gt, V_gt, targets_area, weight=None, avg_factor=None, reduction_override=None) oks_loss = oks_loss*lhand_kps_conf*target_lhand_boxes_conf pose_loss = F.l1_loss(Z_pred, Z_gt, reduction='none') pose_loss = pose_loss * V_gt.repeat_interleave(2, dim=1) pose_loss = pose_loss.sum(-1)*target_lhand_boxes_conf if lhand_num_boxes>0: losses['loss_lhand_keypoints'] = pose_loss.sum() / lhand_num_boxes losses['loss_lhand_oks'] = oks_loss.sum() / lhand_num_boxes else: losses['loss_lhand_keypoints'] = src_lhand_keypoints.sum() * torch.as_tensor(0., device=device) losses['loss_lhand_oks'] = src_lhand_keypoints.sum() * torch.as_tensor(0., device=device) ############################################################ # rhand ############################################################ if 'pred_rhand_keypoints' in outputs and face_hand_kpt: src_rhand_keypoints = outputs['pred_rhand_keypoints'][idx] # xyxyvv if len(src_rhand_keypoints) == 0: losses.update({ 'loss_rhand_keypoints': src_rhand_keypoints.sum() * torch.as_tensor(0., device=device), 'loss_rhand_oks': src_rhand_keypoints.sum() * torch.as_tensor(0., device=device), }) else: Z_pred = src_rhand_keypoints[:, 0:(self.num_hand_points * 2)] # [2, 2*14] V_pred = src_rhand_keypoints[:, (self.num_hand_points * 2):] targets_rhand_keypoints = torch.cat( [t['rhand_keypoints'][i] for t, (_, i) in zip(targets, indices)], dim=0) targets_area = torch.cat( [t['area'][i] for t, (_, i) in zip(targets, indices)], dim=0) target_rhand_boxes_conf = torch.cat( [t[i] for t, (_, i) in zip(data_batch['rhand_bbox_valid'], indices)], dim=0) Z_gt = targets_rhand_keypoints[:, 0:(self.num_hand_points * 2)] V_gt: torch.Tensor = targets_rhand_keypoints[:, (self.num_hand_points * 2):] rhand_kps_conf = V_gt.sum(-1)>0 rhand_num_boxes = (rhand_kps_conf*target_rhand_boxes_conf).sum() oks_loss = self.hand_oks(Z_pred, Z_gt, V_gt, targets_area, weight=None, avg_factor=None, reduction_override=None) oks_loss = oks_loss*rhand_kps_conf*target_rhand_boxes_conf pose_loss = F.l1_loss(Z_pred, Z_gt, reduction='none') pose_loss = pose_loss * V_gt.repeat_interleave(2, dim=1) pose_loss = pose_loss.sum(-1)*target_rhand_boxes_conf if rhand_num_boxes>0: losses['loss_rhand_keypoints'] = pose_loss.sum() / rhand_num_boxes losses['loss_rhand_oks'] = oks_loss.sum() / rhand_num_boxes else: losses['loss_rhand_keypoints'] = src_rhand_keypoints.sum() * torch.as_tensor(0., device=device) losses['loss_rhand_oks'] = src_rhand_keypoints.sum() * torch.as_tensor(0., device=device) ############################################################ # face ############################################################ if 'pred_face_keypoints' in outputs and face_hand_kpt: src_face_keypoints = outputs['pred_face_keypoints'][idx] # xyxyvv if len(src_face_keypoints) == 0: losses.update({ 'loss_face_keypoints': src_face_keypoints.sum() * 0, 'loss_face_oks': src_face_keypoints.sum() * 0, }) else: Z_pred = src_face_keypoints[:, 0:(self.num_face_points * 2)] # [2, 2*14] V_pred = src_face_keypoints[:, (self.num_face_points * 2):] targets_face_keypoints = torch.cat( [t['face_keypoints'][i] for t, (_, i) in zip(targets, indices)], dim=0) targets_area = torch.cat( [t['area'][i] for t, (_, i) in zip(targets, indices)], dim=0) target_face_boxes_conf = torch.cat( [t[i] for t, (_, i) in zip(data_batch['face_bbox_valid'], indices)], dim=0) Z_gt = targets_face_keypoints[:, 0:(self.num_face_points * 2)] V_gt: torch.Tensor = targets_face_keypoints[:, (self.num_face_points * 2):] face_kps_conf = V_gt.sum(-1)>0 face_num_boxes = (lhand_kps_conf*target_face_boxes_conf).sum() oks_loss = self.face_oks(Z_pred, Z_gt, V_gt, targets_area, weight=None, avg_factor=None, reduction_override=None) oks_loss = oks_loss*face_kps_conf*target_face_boxes_conf pose_loss = F.l1_loss(Z_pred, Z_gt, reduction='none') pose_loss = pose_loss * V_gt.repeat_interleave(2, dim=1) pose_loss = pose_loss.sum(-1)*target_face_boxes_conf if face_num_boxes>0: losses['loss_face_keypoints'] = pose_loss.sum() / face_num_boxes losses['loss_face_oks'] = oks_loss.sum() / face_num_boxes else: losses['loss_face_keypoints'] = src_face_keypoints.sum() * torch.as_tensor(0., device=device) losses['loss_face_oks'] = src_face_keypoints.sum() * torch.as_tensor(0., device=device) return losses def loss_smpl_pose(self, outputs, targets, indices, idx, num_boxes, data_batch, face_hand_kpt=False): device = outputs['pred_logits'].device indices = indices[0] pred_smpl_body_pose = outputs['pred_smpl_pose'][idx] # 22 pred_smpl_lhand_pose = outputs['pred_smpl_lhand_pose'][idx] # 15 pred_smpl_rhand_pose = outputs['pred_smpl_rhand_pose'][idx] # 15 pred_smpl_jaw_pose = outputs['pred_smpl_jaw_pose'][idx] pred_smplx_pose = torch.cat((pred_smpl_body_pose, pred_smpl_lhand_pose, pred_smpl_rhand_pose, pred_smpl_jaw_pose), dim=1) targets_smpl_pose = torch.cat( [t[i] for t, (_, i) in zip(data_batch['smplx_pose'], indices)], dim=0) targets_smpl_pose = batch_rodrigues(targets_smpl_pose.view( -1, 3)).view(-1, 53, 3, 3) conf = torch.cat([ t[i] for t, (_, i) in zip(data_batch['smplx_pose_valid'], indices) ], dim=0) # conf = (conf.reshape(-1,53,3)[:,:,:,None]).repeat(1,1,1,3) body_pose_valid = conf[:, :22].sum(-1) > 0 lhand_pose_valid = conf[:, 22:37].sum(-1) > 0 rhand_pose_valid = conf[:, 37:52].sum(-1) > 0 face_pose_valid = conf[:, 52].sum(-1) > 0 losses = {} loss_smpl_pose = \ F.l1_loss( pred_smplx_pose, targets_smpl_pose, reduction='none' ) loss_smpl_pose = loss_smpl_pose.sum([-1,-2]) * conf if face_hand_kpt: losses = { 'loss_smpl_pose_root': loss_smpl_pose[:, 0].sum() / (body_pose_valid.sum() + 1e-6), 'loss_smpl_pose_body': loss_smpl_pose[:, 1:22].sum() / (body_pose_valid.sum() + 1e-6), 'loss_smpl_pose_lhand': loss_smpl_pose[:, 22:37].sum() / (lhand_pose_valid.sum() + 1e-6), 'loss_smpl_pose_rhand': loss_smpl_pose[:, 37:52].sum() / (rhand_pose_valid.sum() + 1e-6), 'loss_smpl_pose_jaw': loss_smpl_pose[:, 52].sum() / (face_pose_valid.sum() + 1e-6), } else: losses = { 'loss_smpl_pose_root': loss_smpl_pose[:, 0].sum() / (body_pose_valid.sum() + 1e-6), 'loss_smpl_pose_body': loss_smpl_pose[:, 1:22].sum() / (body_pose_valid.sum() + 1e-6), 'loss_smpl_pose_lhand': torch.as_tensor(0., device=device) * loss_smpl_pose[:, 22:37].sum()/(lhand_pose_valid.sum() + 1e-6), 'loss_smpl_pose_rhand': torch.as_tensor(0., device=device) * loss_smpl_pose[:, 37:52].sum() / (rhand_pose_valid.sum() + 1e-6), 'loss_smpl_pose_jaw': torch.as_tensor(0., device=device)*loss_smpl_pose[:, 52].sum() / (face_pose_valid.sum() + 1e-6), } return losses def loss_smpl_beta(self, outputs, targets, indices, idx, num_boxes, data_batch, face_hand_kpt=False): indices = indices[0] device = outputs['pred_logits'].device pred_smpl_betas = outputs['pred_smpl_beta'][idx] targets_smpl_betas = torch.cat( [t[i] for t, (_, i) in zip(data_batch['smplx_shape'], indices)], dim=0) losses = {} conf = torch.cat([t[i] for t, (_, i) in zip(data_batch['smplx_shape_valid'], indices)], dim=0) if conf.sum() == 0: return { 'loss_smpl_beta': pred_smpl_betas.sum() * 0 } loss_smpl_betas = \ F.l1_loss( pred_smpl_betas, targets_smpl_betas, reduction='none' ) loss_smpl_betas = loss_smpl_betas.sum(-1) * conf losses = {'loss_smpl_beta': loss_smpl_betas.sum() / (conf.sum() + 1e-6)} return losses def loss_smpl_expr(self, outputs, targets, indices, idx, num_boxes, data_batch, face_hand_kpt=False): indices = indices[0] device = outputs['pred_logits'].device pred_smpl_expr = outputs['pred_smpl_expr'][idx] targets_smpl_expr = torch.cat([t[i] for t, (_, i) in zip(data_batch['smplx_expr'], indices)], dim=0) conf = torch.cat([t[i] for t, (_, i) in zip(data_batch['smplx_expr_valid'], indices)], dim=0) if conf.sum() == 0: return { 'loss_smpl_expr': pred_smpl_expr.sum() * torch.as_tensor(0., device=device) } loss_smpl_expr = \ F.l1_loss( pred_smpl_expr, targets_smpl_expr, reduction='none' ) loss_smpl_expr = loss_smpl_expr.sum(-1) * conf losses = {} if face_hand_kpt: losses = {'loss_smpl_expr': loss_smpl_expr.sum() / (conf.sum() + 1e-6)} else: losses = {'loss_smpl_expr': torch.as_tensor(0., device=device)*loss_smpl_expr.sum() / (conf.sum() + 1e-6) } return losses def loss_smpl_kp3d(self, outputs, targets, indices, idx, num_boxes, data_batch, has_keypoints3d=None, face_hand_kpt=False): # supervision for keypoints3d wo/ ra device = outputs['pred_logits'].device indices = indices[0] pred_smpl_kp3d = outputs['pred_smpl_kp3d'][idx].float() # meta_info['joint_valid'] * meta_info['is_3D'][:, None, None]) targets_smpl_kp3d = torch.cat( [t[i] for t, (_, i) in zip(data_batch['joint_cam'], indices)], dim=0) losses = {} targets_kp3d_conf = targets_smpl_kp3d[:,:,3:].clone() targets_smpl_kp3d = targets_smpl_kp3d[:,:,:3] targets_is_3d = torch.cat([ t[None, None].repeat(len(i), 1, 1) for t, (_, i) in zip(data_batch['is_3D'], indices) ], dim=0) targets_kp3d_conf = (targets_kp3d_conf * targets_is_3d) pelvis_idx = get_keypoint_idx('pelvis', self.convention) targets_pelvis = targets_smpl_kp3d[..., pelvis_idx, :] pred_pelvis = pred_smpl_kp3d[..., pelvis_idx, :] targets_smpl_kp3d = targets_smpl_kp3d - targets_pelvis[:, None, :] pred_smpl_kp3d = pred_smpl_kp3d - pred_pelvis[:, None, :] losses = {} body_idx = smpl_x.joint_part['body'] face_idx = smpl_x.joint_part['face'] lhand_idx = smpl_x.joint_part['lhand'] rhand_idx = smpl_x.joint_part['rhand'] loss_smpl_kp3d = F.l1_loss(pred_smpl_kp3d, targets_smpl_kp3d, reduction='none') body_kp3d_valid = targets_kp3d_conf[:, body_idx].sum([-1,-2]) > 0 lhand_kp3d_valid = targets_kp3d_conf[:, lhand_idx].sum([-1,-2]) > 0 rhand_kp3d_valid = targets_kp3d_conf[:, rhand_idx].sum([-1,-2]) > 0 face_kp3d_valid = targets_kp3d_conf[:, face_idx].sum([-1,-2]) > 0 loss_smpl_kp3d = loss_smpl_kp3d * targets_kp3d_conf # + outputs['pred_smpl_cam'][idx].float().sum()*0 if face_hand_kpt: losses['loss_smpl_body_kp3d'] = torch.sum(loss_smpl_kp3d[:, body_idx, :]) / (body_kp3d_valid.sum() + 1e-6) losses['loss_smpl_lhand_kp3d'] = torch.sum(loss_smpl_kp3d[:, lhand_idx, :]) / (lhand_kp3d_valid.sum() + 1e-6) losses['loss_smpl_rhand_kp3d'] = torch.sum(loss_smpl_kp3d[:, rhand_idx, :]) / (rhand_kp3d_valid.sum() + 1e-6) losses['loss_smpl_face_kp3d'] = torch.sum(loss_smpl_kp3d[:, face_idx, :]) / (face_kp3d_valid.sum() + 1e-6) else: losses['loss_smpl_body_kp3d'] = torch.sum(loss_smpl_kp3d[:, body_idx, :]) / (body_kp3d_valid.sum() + 1e-6) losses['loss_smpl_lhand_kp3d'] = torch.as_tensor(0., device=device)*torch.sum(loss_smpl_kp3d[:, lhand_idx, :]) / (lhand_kp3d_valid.sum() + 1e-6) losses['loss_smpl_rhand_kp3d'] = torch.as_tensor(0., device=device)*torch.sum(loss_smpl_kp3d[:, rhand_idx, :]) / (rhand_kp3d_valid.sum() + 1e-6) losses['loss_smpl_face_kp3d'] = torch.as_tensor(0., device=device)*torch.sum(loss_smpl_kp3d[:, face_idx, :]) / (face_kp3d_valid.sum() + 1e-6) return losses def loss_smpl_kp3d_ra(self, outputs, targets, indices, idx, num_boxes, data_batch, has_keypoints3d=None, face_hand_kpt=False): # supervision for keypoints3d w/ ra device = outputs['pred_logits'].device indices = indices[0] pred_smpl_kp3d = outputs['pred_smpl_kp3d'][idx].float() # meta_info['joint_valid'] * meta_info['is_3D'][:, None, None]) targets_smpl_kp3d = torch.cat([ t[i] for t, (_, i) in zip(data_batch['smplx_joint_cam'], indices)], dim=0) losses = {} # if valid_num == 0: # losses['loss_smpl_rhand_kp3d_ra'] = torch.as_tensor(0., device=device) + pred_smpl_kp3d.sum() * 0 # losses['loss_smpl_body_kp3d_ra'] = torch.as_tensor(0., device=device) + pred_smpl_kp3d.sum() * 0 # losses['loss_smpl_face_kp3d_ra'] = torch.as_tensor(0., device=device) + pred_smpl_kp3d.sum() * 0 # losses['loss_smpl_lhand_kp3d_ra'] = torch.as_tensor(0., device=device) + pred_smpl_kp3d.sum() * 0 # return losses targets_kp3d_conf = targets_smpl_kp3d[:,:,3:].clone() targets_smpl_kp3d = targets_smpl_kp3d[:,:,:3] targets_is_3d = torch.cat([ t[None, None].repeat(len(i), 1, 1) for t, (_, i) in zip(data_batch['is_3D'], indices)],dim=0) targets_kp3d_conf = (targets_kp3d_conf * targets_is_3d).repeat(1, 1, 3) # targets_smpl_kp3d = targets_smpl_kp3d[..., :3].float() pelvis_idx = get_keypoint_idx('pelvis', self.convention) targets_pelvis = targets_smpl_kp3d[..., pelvis_idx, :] pred_pelvis = pred_smpl_kp3d[..., pelvis_idx, :] targets_smpl_kp3d = targets_smpl_kp3d - targets_pelvis[:, None, :] pred_smpl_kp3d = pred_smpl_kp3d - pred_pelvis[:, None, :] # calculate body, face and hand loss separately: losses = {} body_idx = smpl_x.joint_part['body'] face_idx = smpl_x.joint_part['face'] lhand_idx = smpl_x.joint_part['lhand'] rhand_idx = smpl_x.joint_part['rhand'] body_kp3d_valid = targets_kp3d_conf[:, body_idx].sum([-1,-2]) > 0 lhand_kp3d_valid = targets_kp3d_conf[:, lhand_idx].sum([-1,-2]) > 0 rhand_kp3d_valid = targets_kp3d_conf[:, rhand_idx].sum([-1,-2]) > 0 face_kp3d_valid = targets_kp3d_conf[:, face_idx].sum([-1,-2]) > 0 loss_smpl_body_kp3d = F.l1_loss(pred_smpl_kp3d[:, body_idx, :], targets_smpl_kp3d[:, body_idx, :], reduction='none') loss_smpl_body_kp3d = torch.sum( loss_smpl_body_kp3d * targets_kp3d_conf[:, body_idx, :]) losses['loss_smpl_body_kp3d_ra'] = loss_smpl_body_kp3d / (body_kp3d_valid.sum() + 1e-6) face_cam = pred_smpl_kp3d[:, face_idx, :] neck_cam = pred_smpl_kp3d[:, smpl_x.neck_idx, None, :] face_cam = face_cam - neck_cam loss_smpl_face_kp3d = F.l1_loss(face_cam, targets_smpl_kp3d[:, face_idx, :], reduction='none') loss_smpl_face_kp3d = torch.sum( loss_smpl_face_kp3d * targets_kp3d_conf[:, face_idx, :]) if face_hand_kpt: losses['loss_smpl_face_kp3d_ra'] = (loss_smpl_face_kp3d / (face_kp3d_valid.sum() + 1e-6)) else: losses['loss_smpl_face_kp3d_ra'] = 0 * (loss_smpl_face_kp3d / (face_kp3d_valid.sum() + 1e-6)) lhand_cam = pred_smpl_kp3d[:, lhand_idx, :] lwrist_cam = pred_smpl_kp3d[:, smpl_x.lwrist_idx, None, :] lhand_cam = lhand_cam - lwrist_cam loss_smpl_lhand_kp3d = F.l1_loss(lhand_cam, targets_smpl_kp3d[:, lhand_idx, :], reduction='none') loss_smpl_lhand_kp3d = torch.sum( loss_smpl_lhand_kp3d * targets_kp3d_conf[:, lhand_idx, :]) if face_hand_kpt: losses['loss_smpl_lhand_kp3d_ra'] = (loss_smpl_lhand_kp3d / (lhand_kp3d_valid.sum() + 1e-6)) else: losses['loss_smpl_lhand_kp3d_ra'] = 0*(loss_smpl_lhand_kp3d / (lhand_kp3d_valid.sum() + 1e-6)) rhand_cam = pred_smpl_kp3d[:, rhand_idx, :] rwrist_cam = pred_smpl_kp3d[:, smpl_x.rwrist_idx, None, :] rhand_cam = rhand_cam - rwrist_cam loss_smpl_rhand_kp3d = F.l1_loss(rhand_cam, targets_smpl_kp3d[:, rhand_idx, :], reduction='none') loss_smpl_rhand_kp3d = torch.sum( loss_smpl_rhand_kp3d * targets_kp3d_conf[:, rhand_idx, :]) if face_hand_kpt: losses['loss_smpl_rhand_kp3d_ra'] = (loss_smpl_rhand_kp3d / (rhand_kp3d_valid.sum() + 1e-6)) else: losses['loss_smpl_rhand_kp3d_ra'] = 0*(loss_smpl_rhand_kp3d / (rhand_kp3d_valid.sum() + 1e-6)) return losses def loss_smpl_kp2d(self, outputs, targets, indices, idx, num_boxes, data_batch, focal_length=5000., has_keypoints2d=None, face_hand_kpt=False): """Compute loss for 2d keypoints.""" device = outputs['pred_logits'].device indices = indices[0] pred_smpl_kp3d = outputs['pred_smpl_kp3d'][idx].float()#.detach() pred_cam = outputs['pred_smpl_cam'][idx].float() targets_kp2d = torch.cat([t[i] for t, (_, i) in zip(data_batch['joint_img'], indices)], dim=0) keypoints2d_conf = targets_kp2d[:,:,2:].clone() targets_kp2d = targets_kp2d[:, :, :2].float() targets_kp2d[:,:,0] = targets_kp2d[:,:,0]/cfg.output_hm_shape[2] targets_kp2d[:,:,1] = targets_kp2d[:,:,1]/cfg.output_hm_shape[1] # targets_kp2d = targets_kp2d*2-1 img_wh = torch.cat([data_batch['img_shape'][i][None] for i in idx[0]], dim=0).flip(-1) pred_smpl_kp2d = project_points_new( points_3d=pred_smpl_kp3d, pred_cam=pred_cam, focal_length=focal_length, camera_center=img_wh/2 ) pred_smpl_kp2d = pred_smpl_kp2d / img_wh[:, None] losses = {} body_idx = smpl_x.joint_part['body'] face_idx = smpl_x.joint_part['face'] lhand_idx = smpl_x.joint_part['lhand'] rhand_idx = smpl_x.joint_part['rhand'] body_kp2d_valid = keypoints2d_conf[:, body_idx].sum([-1,-2]) > 0 lhand_kp2d_valid = keypoints2d_conf[:, lhand_idx].sum([-1,-2]) > 0 rhand_kp2d_valid = keypoints2d_conf[:, rhand_idx].sum([-1,-2]) > 0 face_kp2d_valid = keypoints2d_conf[:, face_idx].sum([-1,-2]) > 0 loss_smpl_kp2d = F.l1_loss(pred_smpl_kp2d, targets_kp2d, reduction='none') loss_smpl_kp2d = loss_smpl_kp2d * keypoints2d_conf # import mmcv # import cv2 # img = (data_batch['img'][0]*255).permute(1,2,0).int().detach().cpu().numpy() if face_hand_kpt: losses['loss_smpl_body_kp2d'] = torch.sum(loss_smpl_kp2d[:, body_idx, :]) / (body_kp2d_valid.sum() + 1e-6) losses['loss_smpl_lhand_kp2d'] = torch.sum(loss_smpl_kp2d[:, lhand_idx, :]) / (lhand_kp2d_valid.sum() + 1e-6) losses['loss_smpl_rhand_kp2d'] = torch.sum(loss_smpl_kp2d[:, rhand_idx, :]) / (rhand_kp2d_valid.sum() + 1e-6) losses['loss_smpl_face_kp2d'] = torch.sum(loss_smpl_kp2d[:, face_idx, :]) / (face_kp2d_valid.sum() + 1e-6) else: losses['loss_smpl_body_kp2d'] = torch.sum(loss_smpl_kp2d[:, body_idx, :]) / (body_kp2d_valid.sum() + 1e-6) losses['loss_smpl_lhand_kp2d'] = 0*torch.sum(loss_smpl_kp2d[:, lhand_idx, :]) / (lhand_kp2d_valid.sum() + 1e-6) losses['loss_smpl_rhand_kp2d'] = 0*torch.sum(loss_smpl_kp2d[:, rhand_idx, :]) / (rhand_kp2d_valid.sum() + 1e-6) losses['loss_smpl_face_kp2d'] = 0*torch.sum(loss_smpl_kp2d[:, face_idx, :]) / (face_kp2d_valid.sum() + 1e-6) return losses def loss_smpl_kp2d_ba(self, outputs, targets, indices, idx, num_boxes, data_batch, focal_length=5000., has_keypoints2d=None, face_hand_kpt=False): """Compute loss for 2d keypoints.""" device = outputs['pred_logits'].device indices = indices[0] # pdb.set_trace() pred_smpl_kp3d = outputs['pred_smpl_kp3d'][idx].float()#.detach() pred_cam = outputs['pred_smpl_cam'][idx].float() valid_num=0 for indice in indices[0]: valid_num+=len(indice) targets_kp2d = torch.cat( [t[i] for t, (_, i) in zip(data_batch['joint_img'], indices)], dim=0) losses = {} keypoints2d_conf = targets_kp2d[:,:,2:].clone() targets_kp2d = targets_kp2d[:,:,:2] keypoints2d_conf = keypoints2d_conf.repeat(1, 1, 2) targets_kp2d = targets_kp2d[:, :, :2].float() targets_kp2d[:, :, 0] = targets_kp2d[:, :, 0] / cfg.output_hm_shape[2] targets_kp2d[:, :, 1] = targets_kp2d[:, :, 1] / cfg.output_hm_shape[1] # targets_kp2d = targets_kp2d * 2 - 1 img_wh = torch.cat([data_batch['img_shape'][i][None] for i in idx[0]], dim=0).flip(-1) pred_smpl_kp2d = project_points_new( points_3d=pred_smpl_kp3d, pred_cam=pred_cam, focal_length=focal_length, camera_center=img_wh/2 ) pred_smpl_kp2d = pred_smpl_kp2d / img_wh[:, None] if valid_num == 0: losses['loss_smpl_body_kp2d_ba'] = torch.as_tensor(0., device=device) + pred_smpl_kp2d.sum()*0 losses['loss_smpl_lhand_kp2d_ba'] = torch.as_tensor(0., device=device) + pred_smpl_kp2d.sum()*0 losses['loss_smpl_rhand_kp2d_ba'] = torch.as_tensor(0., device=device) + pred_smpl_kp2d.sum()*0 losses['loss_smpl_face_kp2d_ba'] = torch.as_tensor(0., device=device) + pred_smpl_kp2d.sum()*0 return losses # rhand bbox rhand_bbox_valid = torch.cat( [t[i] for t, (_, i) in zip(data_batch['rhand_bbox_valid'], indices) ], dim=0) rhand_bbox_gt = torch.cat( [t['rhand_boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0) rhand_bbox_gt = (box_ops.box_cxcywh_to_xyxy(rhand_bbox_gt). reshape(-1,2,2)*img_wh[:, None]).reshape(-1, 4) num_rhand_bbox = rhand_bbox_valid.sum() # lhand bbox lhand_bbox_valid = torch.cat([ t[i] for t, (_, i) in zip(data_batch['lhand_bbox_valid'], indices)], dim=0) lhand_bbox_gt = torch.cat( [t['lhand_boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0) lhand_bbox_gt = (box_ops.box_cxcywh_to_xyxy(lhand_bbox_gt). reshape(-1,2,2)*img_wh[:, None]).reshape(-1, 4) num_lhand_bbox = lhand_bbox_valid.sum() # face bbox face_bbox_valid = torch.cat( [t[i] for t, (_, i) in zip(data_batch['face_bbox_valid'], indices)], dim=0) face_bbox_gt = torch.cat( [t['face_boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0) face_bbox_gt = (box_ops.box_cxcywh_to_xyxy(face_bbox_gt). reshape(-1,2,2)*img_wh[:, None]).reshape(-1, 4) num_face_bbox = face_bbox_valid.sum() img_shape = torch.cat( [t[None].repeat(len(i), 1) for t, (_, i) in zip(data_batch['img_shape'], indices)], dim=0) # joint_proj = (joint_proj / 2 + 0.5) # joint_proj[:, :, 0] = joint_proj[:, :, 0] * img_shape[:, 1:] # joint_proj[:, :, 1] = joint_proj[:, :, 1] * img_shape[:, :1] if not (lhand_bbox_valid + rhand_bbox_valid + face_bbox_valid == 0).all(): for part_name, bbox in ( ('lhand', lhand_bbox_gt), ('rhand', rhand_bbox_gt), ('face', face_bbox_gt)): x = targets_kp2d[:, smpl_x.joint_part[part_name], 0] y = targets_kp2d[:, smpl_x.joint_part[part_name], 1] # trunc = joint_trunc[:, smpl_x.joint_part[part_name], 0] trunc = keypoints2d_conf[:, smpl_x.joint_part[part_name], 0].clone() # x in [0, 1]? bbox in [0, 1]. x -= (bbox[:, None, 0] / img_shape[:, 1:]) # x x *= (img_shape[:, 1:] / (bbox[:, None, 2] - bbox[:, None, 0] + 1e-6)) y -= (bbox[:, None, 1] / img_shape[:, :1]) y *= (img_shape[:, :1] / (bbox[:, None, 3] - bbox[:, None, 1] + 1e-6)) # transformed to 0-1 bbox space trunc *= ((x >= 0) * (x <= 1) * (y >= 0) * (y <= 1)) coord = torch.stack((x, y), 2) targets_kp2d = torch.cat( (targets_kp2d[:, :smpl_x.joint_part[part_name][0], :], coord, targets_kp2d[:, smpl_x.joint_part[part_name][-1] + 1:, :]), 1) x_pred = pred_smpl_kp2d[:, smpl_x.joint_part[part_name], 0] y_pred = pred_smpl_kp2d[:, smpl_x.joint_part[part_name], 1] # bbox: xyxy img_shape: hw x_pred -= (bbox[:, None, 0] / img_shape[:, 1:]) x_pred *= (img_shape[:, 1:] / (bbox[:, None, 2] - bbox[:, None, 0] + 1e-6)) y_pred -= (bbox[:, None, 1] / img_shape[:, :1]) y_pred *= (img_shape[:, :1] / (bbox[:, None, 3] - bbox[:, None, 1] + 1e-6)) coord_pred = torch.stack((x_pred, y_pred), 2) trans = [] for bid in range(coord_pred.shape[0]): mask = trunc[bid] == 1 if torch.sum(mask) == 0: trans.append(torch.zeros((2)).float().cuda()) else: trans.append( (-coord_pred[bid, mask, :2] + targets_kp2d[:, smpl_x.joint_part[part_name], :][bid, mask, :2]).mean(0)) trans = torch.stack(trans)[:, None, :] coord_pred = coord_pred + trans # global translation alignment pred_smpl_kp2d = torch.cat( (pred_smpl_kp2d[:, :smpl_x.joint_part[part_name][0], :], coord_pred, pred_smpl_kp2d[:, smpl_x.joint_part[part_name][-1] + 1:, :]), 1) loss_smpl_kp2d_ba = F.l1_loss(pred_smpl_kp2d, targets_kp2d[:, :, :2], reduction='none') valid_pos = keypoints2d_conf > 0 losses = {} if keypoints2d_conf[valid_pos].numel() == 0: return { 'loss_smpl_body_kp2d_ba': loss_smpl_kp2d_ba.sum()*0, 'loss_smpl_lhand_kp2d_ba': loss_smpl_kp2d_ba.sum()*0, 'loss_smpl_rhand_kp2d_ba': loss_smpl_kp2d_ba.sum()*0, 'loss_smpl_face_kp2d_ba': loss_smpl_kp2d_ba.sum()*0, } # loss /= targets_kp3d_conf[valid_pos].numel() # 要改 loss_smpl_kp2d_ba = loss_smpl_kp2d_ba * keypoints2d_conf losses['loss_smpl_body_kp2d_ba'] = torch.sum(loss_smpl_kp2d_ba[:, smpl_x.joint_part['body'], :]) / num_boxes if face_hand_kpt: if num_lhand_bbox>0: losses['loss_smpl_lhand_kp2d_ba'] = torch.sum(loss_smpl_kp2d_ba[:, smpl_x.joint_part['lhand'], :]) / num_lhand_bbox else: losses['loss_smpl_lhand_kp2d_ba'] = torch.as_tensor(0., device=device) + loss_smpl_kp2d_ba.sum()*0 if num_rhand_bbox>0: losses['loss_smpl_rhand_kp2d_ba'] = torch.sum(loss_smpl_kp2d_ba[:, smpl_x.joint_part['rhand'], :]) / num_rhand_bbox else: losses['loss_smpl_rhand_kp2d_ba'] = torch.as_tensor(0., device=device) + loss_smpl_kp2d_ba.sum()*0 if num_face_bbox>0: losses['loss_smpl_face_kp2d_ba'] = torch.sum(loss_smpl_kp2d_ba[:, smpl_x.joint_part['face'], :]) / num_face_bbox else: losses['loss_smpl_face_kp2d_ba'] = torch.as_tensor(0., device=device) + loss_smpl_kp2d_ba.sum()*0 else: losses['loss_smpl_lhand_kp2d_ba'] = 0*torch.sum(loss_smpl_kp2d_ba[:, smpl_x.joint_part['lhand'], :]) / num_lhand_bbox losses['loss_smpl_rhand_kp2d_ba'] = 0*torch.sum(loss_smpl_kp2d_ba[:, smpl_x.joint_part['rhand'], :]) / num_rhand_bbox losses['loss_smpl_face_kp2d_ba'] = 0*torch.sum(loss_smpl_kp2d_ba[:, smpl_x.joint_part['face'], :]) / num_face_bbox return losses def loss_boxes(self, outputs, targets, indices, idx, num_boxes, data_batch, face_hand_box=False): """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4] The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size.""" indices = indices[0] device = outputs['pred_logits'].device assert 'pred_boxes' in outputs src_body_boxes = outputs['pred_boxes'][idx] target_body_boxes = torch.cat( [t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0) target_body_boxes_conf = torch.cat( [t[i] for t, (_, i) in zip(data_batch['body_bbox_valid'], indices)], dim=0) loss_body_bbox = F.l1_loss(src_body_boxes, target_body_boxes, reduction='none') loss_body_bbox = loss_body_bbox * target_body_boxes_conf[:,None] losses = {} losses['loss_body_bbox'] = loss_body_bbox.sum() / num_boxes loss_body_giou = 1 - torch.diag( box_ops.generalized_box_iou( box_ops.box_cxcywh_to_xyxy(src_body_boxes), box_ops.box_cxcywh_to_xyxy(target_body_boxes))) loss_body_giou = loss_body_giou * target_body_boxes_conf losses['loss_body_giou'] = loss_body_giou.sum() / num_boxes if 'pred_lhand_boxes' in outputs and face_hand_box: src_lhand_boxes = outputs['pred_lhand_boxes'][idx] target_lhand_boxes = torch.cat( [t['lhand_boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0) target_lhand_boxes_conf = torch.cat( [t[i] for t, (_, i) in zip(data_batch['lhand_bbox_valid'], indices)], dim=0) # print(target_lhand_boxes_conf) loss_lhand_bbox = F.l1_loss(src_lhand_boxes, target_lhand_boxes, reduction='none') loss_lhand_bbox = loss_lhand_bbox * target_lhand_boxes_conf[:,None] num_lhand_boxes = (target_lhand_boxes_conf>0).sum() loss_lhand_giou = 1 - torch.diag( box_ops.generalized_box_iou( box_ops.box_cxcywh_to_xyxy(src_lhand_boxes), box_ops.box_cxcywh_to_xyxy(target_lhand_boxes))) loss_lhand_giou = loss_lhand_giou * target_lhand_boxes_conf if num_lhand_boxes > 0: losses['loss_lhand_bbox'] = loss_lhand_bbox.sum() / num_lhand_boxes losses['loss_lhand_giou'] = loss_lhand_giou.sum() / num_lhand_boxes else: losses['loss_lhand_bbox'] = loss_lhand_bbox.sum() * 0 losses['loss_lhand_giou'] = loss_lhand_giou.sum() * 0 if 'pred_rhand_boxes' in outputs and face_hand_box: src_rhand_boxes = outputs['pred_rhand_boxes'][idx] target_rhand_boxes = torch.cat( [t['rhand_boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0) target_rhand_boxes_conf = torch.cat( [t[i] for t, (_, i) in zip(data_batch['rhand_bbox_valid'], indices)], dim=0) loss_rhand_bbox = F.l1_loss(src_rhand_boxes, target_rhand_boxes, reduction='none') loss_rhand_bbox = loss_rhand_bbox * target_rhand_boxes_conf[:,None] num_rhand_boxes = (target_rhand_boxes_conf>0).sum() loss_rhand_giou = 1 - torch.diag( box_ops.generalized_box_iou( box_ops.box_cxcywh_to_xyxy(src_rhand_boxes), box_ops.box_cxcywh_to_xyxy(target_rhand_boxes))) loss_rhand_giou = loss_rhand_giou * target_rhand_boxes_conf if num_rhand_boxes > 0: losses['loss_rhand_bbox'] = loss_rhand_bbox.sum() / num_rhand_boxes losses['loss_rhand_giou'] = loss_rhand_giou.sum() / num_rhand_boxes else: losses['loss_rhand_bbox'] = loss_rhand_bbox.sum() * 0 losses['loss_rhand_giou'] = loss_rhand_giou.sum() * 0 if 'pred_face_boxes' in outputs and face_hand_box: src_face_boxes = outputs['pred_face_boxes'][idx] target_face_boxes = torch.cat( [t['face_boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0) target_face_boxes_conf = torch.cat( [t[i] for t, (_, i) in zip(data_batch['face_bbox_valid'], indices)], dim=0) loss_face_bbox = F.l1_loss(src_face_boxes, target_face_boxes, reduction='none') loss_face_bbox = loss_face_bbox * target_face_boxes_conf[:,None] num_face_boxes = (target_face_boxes_conf>0).sum() loss_face_giou = 1 - torch.diag( box_ops.generalized_box_iou( box_ops.box_cxcywh_to_xyxy(src_face_boxes), box_ops.box_cxcywh_to_xyxy(target_face_boxes))) loss_face_giou = loss_face_giou * target_face_boxes_conf if num_face_boxes > 0: losses['loss_face_bbox'] = loss_face_bbox.sum() / num_face_boxes losses['loss_face_giou'] = loss_face_giou.sum() / num_face_boxes else: losses['loss_face_bbox'] = loss_face_bbox.sum() * 0 losses['loss_face_giou'] = loss_face_giou.sum() * 0 return losses def loss_dn_boxes(self, outputs, targets, indices, idx, num_boxes, data_batch): """ Input: - src_boxes: bs, num_dn, 4 - tgt_boxes: bs, num_dn, 4 """ indices = indices[0] num_tgt = outputs['num_tgt'] src_boxes = outputs['dn_bbox_pred'] tgt_boxes = outputs['dn_bbox_input'] if 'num_tgt' not in outputs: device = outputs['pred_logits'].device losses = { 'dn_loss_bbox': src_boxes.sum()*0, 'dn_loss_giou': src_boxes.sum()*0, } return losses if 'num_tgt' not in outputs: device = outputs['pred_logits'].device losses = { 'dn_loss_bbox': src_boxes.sum()*0, 'dn_loss_giou': src_boxes.sum()*0, } return losses return self.tgt_loss_boxes(src_boxes, tgt_boxes, num_tgt) def loss_dn_labels(self, outputs, targets, indices, idx, num_boxes, data_batch): """ Input: - src_logits: bs, num_dn, num_classes - tgt_labels: bs, num_dn """ indices = indices[0] if 'num_tgt' not in outputs: device = outputs['pred_logits'].device losses = { 'dn_loss_ce': outputs['pred_logits'].sum()*0, } return losses num_tgt = outputs['num_tgt'] src_logits = outputs['dn_class_pred'] # bs, num_dn, text_len tgt_labels = outputs['dn_class_input'] return self.tgt_loss_labels(src_logits, tgt_labels, num_tgt) @torch.no_grad() def loss_matching_cost(self, outputs, targets, indices, idx, num_boxes, data_batch): """ Input: - src_logits: bs, num_dn, num_classes - tgt_labels: bs, num_dn """ cost_mean_dict = indices[1] losses = {'set_{}'.format(k): v for k, v in cost_mean_dict.items()} return losses def _get_src_permutation_idx(self, indices): # permute predictions following indices batch_idx = torch.cat( [torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) src_idx = torch.cat([src for (src, _) in indices]) return batch_idx, src_idx def _get_tgt_permutation_idx(self, indices): # permute targets following indices batch_idx = torch.cat( [torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) tgt_idx = torch.cat([tgt for (_, tgt) in indices]) return batch_idx, tgt_idx def get_loss(self, loss, outputs, targets, data_batch, indices, num_boxes, **kwargs): loss_map = { 'smpl_pose': self.loss_smpl_pose, 'smpl_beta': self.loss_smpl_beta, 'smpl_expr': self.loss_smpl_expr, 'smpl_kp2d': self.loss_smpl_kp2d, 'smpl_kp2d_ba': self.loss_smpl_kp2d_ba, 'smpl_kp3d_ra': self.loss_smpl_kp3d_ra, 'smpl_kp3d': self.loss_smpl_kp3d, 'labels': self.loss_labels, 'cardinality': self.loss_cardinality, 'keypoints': self.loss_keypoints, 'boxes': self.loss_boxes, 'dn_label': self.loss_dn_labels, 'dn_bbox': self.loss_dn_boxes, 'matching': self.loss_matching_cost, } idx = self._get_src_permutation_idx(indices[0]) # pdb.set_trace() assert loss in loss_map, f'do you really want to compute {loss} loss?' return loss_map[loss](outputs, targets, indices, idx, num_boxes, data_batch, **kwargs) def prep_for_dn2(self, mask_dict): known_bboxs = mask_dict['known_bboxs'] known_labels = mask_dict['known_labels'] output_known_coord = mask_dict['output_known_coord'] output_known_class = mask_dict['output_known_class'] num_tgt = mask_dict['pad_size'] return known_labels, known_bboxs, output_known_class, output_known_coord, num_tgt ## SMPL losses def forward(self, outputs, targets, data_batch, return_indices=False): """ This performs the loss computation. Parameters: outputs: dict of tensors, see the output specification of the model for the format targets: list of dicts, such that len(targets) == batch_size. The expected keys in each dict depends on the losses applied, see each loss' doc return_indices: used for vis. if True, the layer0-5 indices will be returned as well. """ # import pdb; pdb.set_trace() outputs_without_aux = { k: v for k, v in outputs.items() if k != 'aux_outputs' } device = next(iter(outputs.values())).device # Compute the average number of target boxes accross all nodes, for normalization purposes num_boxes = sum(len(t['boxes']) for t in targets) num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=device) if is_dist_avail_and_initialized(): torch.distributed.all_reduce(num_boxes) num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item() # loss for final layer # pdb.set_trace() indices = self.matcher(outputs_without_aux, targets, data_batch) if return_indices: indices0_copy = indices indices_list = [] losses = {} smpl_loss = ['smpl_pose', 'smpl_beta', 'smpl_expr', 'smpl_kp2d', 'smpl_kp2d_ba', 'smpl_kp3d', 'smpl_kp3d_ra'] # import pdb; pdb.set_trace() for loss in self.losses: # print(loss) # print(self.get_loss(loss, outputs, targets, indices, num_boxes)) kwargs = {} if loss == 'keypoints' or loss in smpl_loss: kwargs.update({'face_hand_kpt': True}) if loss == 'boxes': kwargs.update({'face_hand_box': True}) losses.update( self.get_loss( loss, outputs, targets, data_batch, indices, num_boxes, **kwargs )) # In case of auxiliary losses, we repeat this process with the output of each intermediate layer. if 'aux_outputs' in outputs: for idx, aux_outputs in enumerate(outputs['aux_outputs']): indices = self.matcher(aux_outputs, targets, data_batch) if return_indices: indices_list.append(indices) for loss in self.losses: kwargs = {} if loss == 'boxes': kwargs.update({'face_hand_box': False}) if idx >= self.num_box_decoder_layers: kwargs.update({'face_hand_box': True}) if loss == 'masks': continue if loss == 'keypoints': if idx < self.num_box_decoder_layers: continue elif idx < self.num_hand_face_decoder_layers: kwargs.update({'face_hand_kpt': False}) else: kwargs.update({'face_hand_kpt': True}) if loss in smpl_loss: if idx < self.num_box_decoder_layers: continue elif idx < self.num_hand_face_decoder_layers: kwargs.update({'face_hand_kpt': False}) else: kwargs.update({'face_hand_kpt': True}) if loss == 'labels': # Logging is enabled only for the last layer kwargs = {'log': False} # if loss == 'smpl_expr' and idx < self.num_box_decoder_layers: # continue # import pdb;pdb.set_trace() l_dict = self.get_loss(loss, aux_outputs, targets, data_batch, indices, num_boxes, **kwargs) l_dict = {k + f'_{idx}': v for k, v in l_dict.items()} losses.update(l_dict) # interm_outputs loss if 'interm_outputs' in outputs: interm_outputs = outputs['interm_outputs'] indices = self.matcher(interm_outputs, targets) if return_indices: indices_list.append(indices) for loss in self.losses: if loss in ['dn_bbox', 'dn_label', 'keypoints']: continue if loss in [ 'smpl_pose', 'smpl_beta', 'smpl_kp2d_ba', 'smpl_kp2d', 'smpl_kp3d_ra', 'smpl_kp3d', 'smpl_expr' ]: continue kwargs = {} if loss == 'labels': kwargs = {'log': False} l_dict = self.get_loss(loss, interm_outputs, targets, data_batch, indices, num_boxes, **kwargs) l_dict = {k + f'_interm': v for k, v in l_dict.items()} losses.update(l_dict) # aux_init loss if 'query_expand' in outputs: interm_outputs = outputs['query_expand'] indices = self.matcher(interm_outputs, targets) if return_indices: indices_list.append(indices) for loss in self.losses: if loss in ['dn_bbox', 'dn_label']: continue kwargs = {} if loss == 'labels': kwargs = {'log': False} l_dict = self.get_loss(loss, interm_outputs, targets, data_batch, indices, num_boxes, **kwargs) l_dict = {k + f'_query_expand': v for k, v in l_dict.items()} losses.update(l_dict) if return_indices: indices_list.append(indices0_copy) return losses, indices_list return losses def tgt_loss_boxes( self, src_boxes, tgt_boxes, num_tgt, ): """ Input: - src_boxes: bs, num_dn, 4 - tgt_boxes: bs, num_dn, 4 """ loss_bbox = F.l1_loss(src_boxes, tgt_boxes, reduction='none') losses = {} losses['dn_loss_bbox'] = loss_bbox.sum() / num_tgt loss_giou = 1 - torch.diag( box_ops.generalized_box_iou( box_ops.box_cxcywh_to_xyxy(src_boxes.flatten(0, 1)), box_ops.box_cxcywh_to_xyxy(tgt_boxes.flatten(0, 1)))) losses['dn_loss_giou'] = loss_giou.sum() / num_tgt return losses def tgt_loss_labels(self, src_logits: Tensor, tgt_labels: Tensor, num_tgt: int, log: bool = True): """ Input: - src_logits: bs, num_dn, num_classes - tgt_labels: bs, num_dn """ target_classes_onehot = torch.zeros([ src_logits.shape[0], src_logits.shape[1], src_logits.shape[2] + 1 ], dtype=src_logits.dtype, layout=src_logits.layout, device=src_logits.device) target_classes_onehot.scatter_(2, tgt_labels.unsqueeze(-1), 1) target_classes_onehot = target_classes_onehot[:, :, :-1] loss_ce = sigmoid_focal_loss(src_logits, target_classes_onehot, num_tgt, alpha=self.focal_alpha, gamma=2) * src_logits.shape[1] losses = {'dn_loss_ce': loss_ce} return losses class SetCriterion_Box(nn.Module): def __init__(self, num_classes, matcher, weight_dict, focal_alpha, losses, num_box_decoder_layers=2, num_hand_face_decoder_layers=4, num_body_points=17, num_hand_points=6, num_face_points=6, smpl_loss_config=None, convention='smplx_137'): super().__init__() self.num_classes = num_classes self.matcher = matcher self.weight_dict = weight_dict self.losses = losses self.focal_alpha = focal_alpha self.vis = 0.1 self.abs = 1 self.num_body_points = 0 self.num_hand_points = 0 self.num_face_points = 0 self.num_box_decoder_layers = num_box_decoder_layers self.num_hand_face_decoder_layers = num_hand_face_decoder_layers self.convention = convention def loss_labels(self, outputs, targets, indices, idx, num_boxes, data_batch, log=True): """Classification loss (Binary focal loss) targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]""" indices = indices[0] valid_num = 0 for indice in indices[0]: valid_num+=len(indice) assert 'pred_logits' in outputs src_logits = outputs['pred_logits'] target_classes_o = torch.cat( [t['labels'][J] for t, (_, J) in zip(targets, indices)]) target_classes = torch.full(src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device) if valid_num == 0: return {'loss_ce': src_logits.sum()*0} target_classes[idx] = target_classes_o target_classes_onehot = torch.zeros([ src_logits.shape[0], src_logits.shape[1], src_logits.shape[2] + 1 ], dtype=src_logits.dtype, layout=src_logits.layout, device=src_logits.device) target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1) target_classes_onehot = target_classes_onehot[:, :, :-1] loss_ce = sigmoid_focal_loss(src_logits, target_classes_onehot, num_boxes, alpha=self.focal_alpha, gamma=2) * src_logits.shape[1] losses = {'loss_ce': loss_ce} if log: # TODO this should probably be a separate loss, not hacked in this one here losses['class_error'] = 100 - accuracy(src_logits[idx], target_classes_o)[0] return losses @torch.no_grad() def loss_cardinality(self, outputs, targets, indices, num_boxes, data_batch): """Compute the cardinality error, ie the absolute error in the number of predicted non-empty boxes This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients """ pred_logits = outputs['pred_logits'] device = pred_logits.device tgt_lengths = torch.as_tensor([len(v['labels']) for v in targets], device=device) if tgt_lengths == 0: return {'cardinality_error': pred_logits.sum()*0} # Count the number of predictions that are NOT "no-object" (which is the last class) card_pred = (pred_logits.argmax(-1) != pred_logits.shape[-1] - 1).sum(1) card_err = F.l1_loss(card_pred.float(), tgt_lengths.float()) losses = {'cardinality_error': card_err} return losses def loss_smpl_pose(self, outputs, targets, indices, idx, num_boxes, data_batch, face_hand_kpt=False): indices = indices[0] device = outputs['pred_logits'].device # import pdb # pdb.set_trace() valid_num=0 for indice in indices[0]: valid_num+=len(indice) pred_smpl_body_pose = outputs['pred_smpl_pose'][idx] # 22 pred_smpl_lhand_pose = outputs['pred_smpl_lhand_pose'][idx] # 15 pred_smpl_rhand_pose = outputs['pred_smpl_rhand_pose'][idx] # 15 pred_smpl_jaw_pose = outputs['pred_smpl_jaw_pose'][idx] pred_smplx_pose = torch.cat((pred_smpl_body_pose, pred_smpl_lhand_pose, pred_smpl_rhand_pose, pred_smpl_jaw_pose), dim=1) targets_smpl_pose = torch.cat( [t[i] for t, (_, i) in zip(data_batch['smplx_pose'], indices)], dim=0) targets_smpl_pose = batch_rodrigues(targets_smpl_pose.view( -1, 3)).view(-1, 53, 3, 3) conf = torch.cat([ t[i] for t, (_, i) in zip(data_batch['smplx_pose_valid'], indices) ], dim=0) conf = (conf.reshape(-1,53,3)[:,:,:,None]).repeat(1,1,1,3) losses = {} if valid_num == 0: losses['loss_smpl_pose_root'] = torch.as_tensor(0., device=device) + pred_smplx_pose.sum() * 0 losses['loss_smpl_pose_body'] = torch.as_tensor(0., device=device) + pred_smplx_pose.sum() * 0 losses['loss_smpl_pose_lhand'] = torch.as_tensor(0., device=device) + pred_smplx_pose.sum() * 0 losses['loss_smpl_pose_rhand'] = torch.as_tensor(0., device=device) + pred_smplx_pose.sum() * 0 losses['loss_smpl_pose_jaw'] = torch.as_tensor(0., device=device) + pred_smplx_pose.sum() * 0 return losses # valid_pos = conf > 0 if conf.sum() == 0: losses['loss_smpl_pose_root'] = torch.as_tensor(0., device=device) + pred_smplx_pose.sum() * 0 losses['loss_smpl_pose_body'] = torch.as_tensor(0., device=device) + pred_smplx_pose.sum() * 0 losses['loss_smpl_pose_lhand'] = torch.as_tensor(0., device=device) + pred_smplx_pose.sum() * 0 losses['loss_smpl_pose_rhand'] = torch.as_tensor(0., device=device) + pred_smplx_pose.sum() * 0 losses['loss_smpl_pose_jaw'] = torch.as_tensor(0., device=device) + pred_smplx_pose.sum() * 0 return losses loss_smpl_pose = \ F.l1_loss( pred_smplx_pose, targets_smpl_pose, reduction='none' ) # pdb.set_trace() loss_smpl_pose = loss_smpl_pose * conf loss_smpl_pose = loss_smpl_pose.sum([-1,-2]) # loss_smpl_pose[:,0] = loss_smpl_pose[:,0]*5 if face_hand_kpt: losses = { 'loss_smpl_pose_root': loss_smpl_pose[:, 0].sum() / num_boxes, 'loss_smpl_pose_body': loss_smpl_pose[:, 1:22].sum() / num_boxes, 'loss_smpl_pose_lhand': loss_smpl_pose[:, 22:37].sum() / num_boxes, 'loss_smpl_pose_rhand': loss_smpl_pose[:, 37:52].sum() / num_boxes, 'loss_smpl_pose_jaw': loss_smpl_pose[:, 52].sum() / num_boxes, } else: losses = { 'loss_smpl_pose_root': loss_smpl_pose[:, 0].sum() / num_boxes, 'loss_smpl_pose_body': loss_smpl_pose[:, 1:22].sum() / num_boxes, 'loss_smpl_pose_lhand': 0 * loss_smpl_pose[:, 22:37].sum() / num_boxes, 'loss_smpl_pose_rhand': 0 * loss_smpl_pose[:, 37:52].sum() / num_boxes, 'loss_smpl_pose_jaw': loss_smpl_pose[:, 52].sum() / num_boxes, } # losses = {'loss_smpl_pose': loss_smpl_pose.sum() / num_boxes} return losses def loss_smpl_beta(self, outputs, targets, indices, idx, num_boxes, data_batch, face_hand_kpt=False): indices = indices[0] device = outputs['pred_logits'].device # import pdb # pdb.set_trace() pred_smpl_betas = outputs['pred_smpl_beta'][idx] targets_smpl_betas = torch.cat( [t[i] for t, (_, i) in zip(data_batch['smplx_shape'], indices)], dim=0) # import pdb # pdb.set_trace() valid_num=0 for indice in indices[0]: valid_num+=len(indice) losses = {} if valid_num == 0: losses['loss_smpl_beta'] = torch.as_tensor(0., device=device) + pred_smpl_betas.sum() * 0 return losses conf = torch.cat([t[i] for t, (_, i) in zip(data_batch['smplx_shape_valid'], indices)], dim=0) # valid_pos = conf > 0 if conf.sum() == 0: return { 'loss_smpl_beta': torch.as_tensor(0., device=device) + pred_smpl_betas.sum() * 0 } loss_smpl_betas = \ F.l1_loss( pred_smpl_betas, targets_smpl_betas, reduction='none' ) # pdb.set_trace() loss_smpl_betas = loss_smpl_betas.sum(-1) * conf losses = {'loss_smpl_beta': loss_smpl_betas.sum() / num_boxes} return losses def loss_smpl_expr(self, outputs, targets, indices, idx, num_boxes, data_batch, face_hand_kpt=False): indices = indices[0] device = outputs['pred_logits'].device pred_smpl_expr = outputs['pred_smpl_expr'][idx] # import pdb # pdb.set_trace() targets_smpl_expr = torch.cat([t[i] for t, (_, i) in zip(data_batch['smplx_expr'], indices)], dim=0) valid_num=0 for indice in indices[0]: valid_num+=len(indice) losses = {} if valid_num == 0: losses['loss_smpl_expr'] = torch.as_tensor(0., device=device) + pred_smpl_expr.sum() * 0 return losses conf = torch.cat([t[i] for t, (_, i) in zip(data_batch['smplx_expr_valid'], indices)], dim=0) # valid_pos = conf > 0 if conf.sum() == 0: return { 'loss_smpl_expr': torch.as_tensor(0., device=device) + pred_smpl_expr.sum() * 0 } loss_smpl_expr = \ F.l1_loss( pred_smpl_expr, targets_smpl_expr, reduction='none' ) # pdb.set_trace() loss_smpl_expr = loss_smpl_expr.sum(-1) * conf if face_hand_kpt: losses = {'loss_smpl_expr': loss_smpl_expr.sum() / (conf.sum() + 1e-6)} else: losses = {'loss_smpl_expr': 0*loss_smpl_expr.sum() / (conf.sum() + 1e-6) } return losses def loss_smpl_kp3d(self, outputs, targets, indices, idx, num_boxes, data_batch, has_keypoints3d=None, face_hand_kpt=False): # supervision for keypoints3d wo/ ra device = outputs['pred_logits'].device indices = indices[0] valid_num=0 for indice in indices[0]: valid_num+=len(indice) pred_smpl_kp3d = outputs['pred_smpl_kp3d'][idx].float() # meta_info['joint_valid'] * meta_info['is_3D'][:, None, None]) targets_smpl_kp3d = torch.cat( [t[i] for t, (_, i) in zip(data_batch['joint_cam'], indices)], dim=0) losses = {} if valid_num == 0: losses['loss_smpl_body_kp3d'] = torch.as_tensor(0., device=device) + pred_smpl_kp3d.sum() * 0 losses['loss_smpl_lhand_kp3d'] = torch.as_tensor(0., device=device) + pred_smpl_kp3d.sum() * 0 losses['loss_smpl_rhand_kp3d'] = torch.as_tensor(0., device=device) + pred_smpl_kp3d.sum() * 0 losses['loss_smpl_face_kp3d'] = torch.as_tensor(0., device=device) + pred_smpl_kp3d.sum() * 0 return losses targets_kp3d_conf = targets_smpl_kp3d[:,:,3:].clone() targets_smpl_kp3d = targets_smpl_kp3d[:,:,:3] targets_is_3d = torch.cat([ t[None, None].repeat(len(i), 1, 1) for t, (_, i) in zip(data_batch['is_3D'], indices) ], dim=0) targets_kp3d_conf = (targets_kp3d_conf * targets_is_3d).repeat(1, 1, 3) pelvis_idx = get_keypoint_idx('pelvis', self.convention) targets_pelvis = targets_smpl_kp3d[..., pelvis_idx, :] pred_pelvis = pred_smpl_kp3d[..., pelvis_idx, :] targets_smpl_kp3d = targets_smpl_kp3d - targets_pelvis[:, None, :] pred_smpl_kp3d = pred_smpl_kp3d - pred_pelvis[:, None, :] losses = {} body_idx = smpl_x.joint_part['body'] face_idx = smpl_x.joint_part['face'] lhand_idx = smpl_x.joint_part['lhand'] rhand_idx = smpl_x.joint_part['rhand'] # currently, only mpi_inf_3dhp and h36m have 3d keypoints # both datasets have right_hip_extra and left_hip_extra loss_smpl_kp3d = F.l1_loss(pred_smpl_kp3d, targets_smpl_kp3d, reduction='none') # If has_keypoints3d is not None, then computes the losses on the # instances that have ground-truth keypoints3d. # But the zero confidence keypoints will be included in mean. # Otherwise, only compute the keypoints3d # which have positive confidence. # has_keypoints3d is None when the key has_keypoints3d # is not in the datasets valid_pos = targets_kp3d_conf > 0 if targets_kp3d_conf[valid_pos].numel() == 0: return { 'loss_smpl_body_kp3d': torch.as_tensor(0., device=device) + pred_smpl_kp3d.sum() * 0, 'loss_smpl_lhand_kp3d': torch.as_tensor(0., device=device) + pred_smpl_kp3d.sum() * 0, 'loss_smpl_rhand_kp3d': torch.as_tensor(0., device=device) + pred_smpl_kp3d.sum() * 0, 'loss_smpl_face_kp3d': torch.as_tensor(0., device=device) + pred_smpl_kp3d.sum() * 0, } loss_smpl_kp3d = loss_smpl_kp3d * targets_kp3d_conf if face_hand_kpt: losses['loss_smpl_body_kp3d'] = torch.sum(loss_smpl_kp3d[:, body_idx, :]) / num_boxes losses['loss_smpl_lhand_kp3d'] = torch.sum(loss_smpl_kp3d[:, lhand_idx, :]) / num_boxes losses['loss_smpl_rhand_kp3d'] = torch.sum(loss_smpl_kp3d[:, rhand_idx, :]) / num_boxes losses['loss_smpl_face_kp3d'] = torch.sum(loss_smpl_kp3d[:, face_idx, :]) / num_boxes else: losses['loss_smpl_body_kp3d'] = torch.sum(loss_smpl_kp3d[:, body_idx, :]) / num_boxes losses['loss_smpl_lhand_kp3d'] = 0*torch.sum(loss_smpl_kp3d[:, lhand_idx, :]) / num_boxes losses['loss_smpl_rhand_kp3d'] = 0*torch.sum(loss_smpl_kp3d[:, rhand_idx, :]) /num_boxes losses['loss_smpl_face_kp3d'] = 0*torch.sum(loss_smpl_kp3d[:, face_idx, :]) / num_boxes return losses def loss_smpl_kp3d_ra(self, outputs, targets, indices, idx, num_boxes, data_batch, has_keypoints3d=None, face_hand_kpt=False): # supervision for keypoints3d w/ ra device = outputs['pred_logits'].device indices = indices[0] valid_num=0 for indice in indices[0]: valid_num+=len(indice) pred_smpl_kp3d = outputs['pred_smpl_kp3d'][idx].float() # meta_info['joint_valid'] * meta_info['is_3D'][:, None, None]) targets_smpl_kp3d = torch.cat([ t[i] for t, (_, i) in zip(data_batch['smplx_joint_cam'], indices) ], dim=0) losses = {} if valid_num == 0: losses['loss_smpl_rhand_kp3d_ra'] = torch.as_tensor(0., device=device) + pred_smpl_kp3d.sum() * 0 losses['loss_smpl_body_kp3d_ra'] = torch.as_tensor(0., device=device) + pred_smpl_kp3d.sum() * 0 losses['loss_smpl_face_kp3d_ra'] = torch.as_tensor(0., device=device) + pred_smpl_kp3d.sum() * 0 losses['loss_smpl_lhand_kp3d_ra'] = torch.as_tensor(0., device=device) + pred_smpl_kp3d.sum() * 0 return losses targets_kp3d_conf = targets_smpl_kp3d[:,:,3:].clone() targets_smpl_kp3d = targets_smpl_kp3d[:,:,:3] targets_is_3d = torch.cat([ t[None, None].repeat(len(i), 1, 1) for t, (_, i) in zip(data_batch['is_3D'], indices) ], dim=0) targets_kp3d_conf = (targets_kp3d_conf * targets_is_3d).repeat(1, 1, 3) targets_smpl_kp3d = targets_smpl_kp3d[..., :3].float() pelvis_idx = get_keypoint_idx('pelvis', self.convention) targets_pelvis = targets_smpl_kp3d[..., pelvis_idx, :] pred_pelvis = pred_smpl_kp3d[..., pelvis_idx, :] targets_smpl_kp3d = targets_smpl_kp3d - targets_pelvis[:, None, :] pred_smpl_kp3d = pred_smpl_kp3d - pred_pelvis[:, None, :] # calculate body, face and hand loss separately: losses = {} body_idx = smpl_x.joint_part['body'] face_idx = smpl_x.joint_part['face'] lhand_idx = smpl_x.joint_part['lhand'] rhand_idx = smpl_x.joint_part['rhand'] loss_smpl_body_kp3d = F.l1_loss(pred_smpl_kp3d[:, body_idx, :], targets_smpl_kp3d[:, body_idx, :], reduction='none') loss_smpl_body_kp3d = torch.sum( loss_smpl_body_kp3d * targets_kp3d_conf[:, body_idx, :]) losses['loss_smpl_body_kp3d_ra'] = loss_smpl_body_kp3d / num_boxes # if face_hand_kpt: face_cam = pred_smpl_kp3d[:, face_idx, :] neck_cam = pred_smpl_kp3d[:, smpl_x.neck_idx, None, :] face_cam = face_cam - neck_cam loss_smpl_face_kp3d = F.l1_loss(face_cam, targets_smpl_kp3d[:, face_idx, :], reduction='none') loss_smpl_face_kp3d = torch.sum( loss_smpl_face_kp3d * targets_kp3d_conf[:, face_idx, :]) if face_hand_kpt: losses['loss_smpl_face_kp3d_ra'] = (loss_smpl_face_kp3d / num_boxes) else: losses['loss_smpl_face_kp3d_ra'] = 0*(loss_smpl_face_kp3d / num_boxes) lhand_cam = pred_smpl_kp3d[:, lhand_idx, :] lwrist_cam = pred_smpl_kp3d[:, smpl_x.lwrist_idx, None, :] lhand_cam = lhand_cam - lwrist_cam loss_smpl_lhand_kp3d = F.l1_loss(lhand_cam, targets_smpl_kp3d[:, lhand_idx, :], reduction='none') loss_smpl_lhand_kp3d = torch.sum( loss_smpl_lhand_kp3d * targets_kp3d_conf[:, lhand_idx, :]) if face_hand_kpt: losses['loss_smpl_lhand_kp3d_ra'] = (loss_smpl_lhand_kp3d / num_boxes) else: losses['loss_smpl_lhand_kp3d_ra'] = 0*(loss_smpl_lhand_kp3d /num_boxes) rhand_cam = pred_smpl_kp3d[:, rhand_idx, :] rwrist_cam = pred_smpl_kp3d[:, smpl_x.rwrist_idx, None, :] rhand_cam = rhand_cam - rwrist_cam loss_smpl_rhand_kp3d = F.l1_loss(rhand_cam, targets_smpl_kp3d[:, rhand_idx, :], reduction='none') loss_smpl_rhand_kp3d = torch.sum( loss_smpl_rhand_kp3d * targets_kp3d_conf[:, rhand_idx, :]) if face_hand_kpt: losses['loss_smpl_rhand_kp3d_ra'] = (loss_smpl_rhand_kp3d / num_boxes) else: losses['loss_smpl_rhand_kp3d_ra'] = 0*(loss_smpl_rhand_kp3d / num_boxes) return losses def loss_smpl_kp2d(self, outputs, targets, indices, idx, num_boxes, data_batch, focal_length=5000., has_keypoints2d=None, face_hand_kpt=False): """Compute loss for 2d keypoints.""" device = outputs['pred_logits'].device indices = indices[0] valid_num=0 for indice in indices[0]: valid_num+=len(indice) # pdb.set_trace() pred_smpl_kp3d = outputs['pred_smpl_kp3d'][idx].float()#.detach() # pred_smpl_kp3d = outputs['pred_smpl_kp3d'][idx].float() # pelvis_idx = get_keypoint_idx('pelvis', self.convention) # pred_pelvis = pred_smpl_kp3d[..., pelvis_idx, :] # pred_smpl_kp3d = pred_smpl_kp3d - pred_pelvis[:, None, :] +1e-7 pred_cam = outputs['pred_smpl_cam'][idx].float() targets_kp2d = torch.cat([t[i] for t, (_, i) in zip(data_batch['joint_img'], indices)], dim=0) keypoints2d_conf = targets_kp2d[:,:,2:].clone() targets_kp2d = targets_kp2d[:,:,:2] target_lhand_boxes_conf = torch.cat( [t[i] for t, (_, i) in zip(data_batch['lhand_bbox_valid'], indices)], dim=0) lhand_num_boxes = target_lhand_boxes_conf.sum() target_rhand_boxes_conf = torch.cat( [t[i] for t, (_, i) in zip(data_batch['rhand_bbox_valid'], indices)], dim=0) rhand_num_boxes = target_rhand_boxes_conf.sum() target_face_boxes_conf = torch.cat( [t[i] for t, (_, i) in zip(data_batch['face_bbox_valid'], indices)], dim=0) face_num_boxes = target_face_boxes_conf.sum() # t_pose = torch.cat([t[i] for t, (_, i) in zip(data_batch['smplx_pose'], indices)], dim=0) # t_shape = torch.cat([t[i] for t, (_, i) in zip(data_batch['smplx_shape'], indices)], dim=0) # t_expr = torch.cat([t[i] for t, (_, i) in zip(data_batch['smplx_expr'], indices)], dim=0) keypoints2d_conf = keypoints2d_conf.repeat(1, 1, 2) targets_kp2d = targets_kp2d[:, :, :2].float() targets_kp2d[:,:,0] = targets_kp2d[:,:,0]/cfg.output_hm_shape[2] targets_kp2d[:,:,1] = targets_kp2d[:,:,1]/cfg.output_hm_shape[1] # targets_kp2d = targets_kp2d*2-1 img_wh = torch.cat([data_batch['img_shape'][i][None] for i in idx[0]], dim=0).flip(-1) # pred_smpl_kp2d = weak_perspective_projection(pred_smpl_kp3d, scale=pred_cam[:, 0], translation=pred_cam[:, 1:3]) # If kp2ds is normalized to [-1, 1], the center should be the center of the image; # if normalized to 0-1, it should be at the top left corner (0, 0)? pred_smpl_kp2d = project_points_new( points_3d=pred_smpl_kp3d, pred_cam=pred_cam, focal_length=focal_length, camera_center=img_wh/2 ) pred_smpl_kp2d = pred_smpl_kp2d / img_wh[:, None] vis=False # if 'vis' in cfg: # vis=cfg['vis'] # vis = True if vis: import mmcv import cv2 import numpy as np from detrsmpl.core.visualization.visualize_keypoints2d import visualize_kp2d from detrsmpl.core.visualization.visualize_smpl import visualize_smpl_hmr,render_smpl from detrsmpl.models.body_models.builder import build_body_model from pytorch3d.io import save_obj from detrsmpl.core.visualization.visualize_keypoints3d import visualize_kp3d img = mmcv.imdenormalize( img=(data_batch['img'][0].cpu().numpy()).transpose(1, 2, 0), mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_bgr=True).astype(np.uint8) cv2.imwrite('test.png', img) device = outputs['pred_smpl_kp3d'].device body_model = dict( type='smplx', keypoint_src='smplx', num_expression_coeffs=10, num_betas=10, keypoint_dst='smplx_137', model_path='data/body_models/smplx', use_pca=False, use_face_contour=True) bm = build_body_model(body_model).to(device) pred_smpl_body_pose = rotmat_to_aa(outputs['pred_smpl_pose'][idx]) pred_smpl_lhand_pose = rotmat_to_aa(outputs['pred_smpl_lhand_pose'][idx]) pred_smpl_rhand_pose = rotmat_to_aa(outputs['pred_smpl_rhand_pose'][idx]) pred_smpl_jaw_pose = rotmat_to_aa(outputs['pred_smpl_jaw_pose'][idx]) pred_smpl_shape = outputs['pred_smpl_beta'][idx] pred_output = bm( betas=pred_smpl_shape.reshape(-1, 10), body_pose=pred_smpl_body_pose[:,1:].reshape(-1, 21*3), global_orient=pred_smpl_body_pose[:,:1].reshape(-1, 3), left_hand_pose=pred_smpl_lhand_pose.reshape(-1, 15*3), right_hand_pose=pred_smpl_rhand_pose.reshape(-1, 15*3), leye_pose=torch.zeros_like(pred_smpl_jaw_pose).reshape(-1, 3), reye_pose=torch.zeros_like(pred_smpl_jaw_pose).reshape(-1, 3), expression=torch.zeros_like(pred_smpl_shape).reshape(-1, 10), jaw_pose=pred_smpl_jaw_pose.reshape(-1, 3)) verts = pred_output['vertices'] # for i_obj,v in enumerate(verts): # save_obj('./figs/pred_smpl_%d.obj'%i_obj,verts = v,faces=torch.tensor([])) pred_cam = outputs['pred_smpl_cam'][idx] targets_smpl_pose = data_batch['smplx_pose'][0] targets_shape = data_batch['smplx_shape'][0] gt_kp3d = data_batch['joint_cam'][0] gt_kp2d = data_batch['joint_img'][0] gt_body_boxes = torch.cat( [t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0) # gt kp3d pred_smpl_kp3d = outputs['pred_smpl_kp3d'][idx].float() visualize_kp3d(gt_kp3d.detach().cpu().numpy(), output_path='./figs/gt3d', data_source='smplx_137') # visualize_kp3d(pred_smpl_kp3d.detach().cpu().numpy(), # output_path='./figs/pred3d', # data_source='smplx_137') # gt kp2d img =(data_batch['img'][0].permute(1,2,0)*255).int().cpu().numpy() gt_2d= gt_kp2d.detach().cpu().numpy()[...,:2]*data_batch['img_shape'].cpu().numpy()[0,None,None,::-1] gt_2d[...,0] = gt_2d[...,0]/12 gt_2d[...,1] = gt_2d[...,1]/16 import mmcv batch_id = 0 gt_bbox = (box_ops.box_cxcywh_to_xyxy(targets[batch_id]['boxes']).reshape(-1,2).detach().cpu().numpy()*data_batch['img_shape'].cpu().numpy()[batch_id, ::-1]).reshape(-1,4) gt_bbox_lhand = (box_ops.box_cxcywh_to_xyxy(targets[batch_id]['lhand_boxes']).reshape(-1,2).detach().cpu().numpy()*data_batch['img_shape'].cpu().numpy()[batch_id, ::-1]).reshape(-1,4) gt_bbox_rhand = (box_ops.box_cxcywh_to_xyxy(targets[batch_id]['rhand_boxes']).reshape(-1,2).detach().cpu().numpy()*data_batch['img_shape'].cpu().numpy()[batch_id, ::-1]).reshape(-1,4) gt_bbox_face = (box_ops.box_cxcywh_to_xyxy(targets[batch_id]['face_boxes']).reshape(-1,2).detach().cpu().numpy()*data_batch['img_shape'].cpu().numpy()[batch_id, ::-1]).reshape(-1,4) gt_bbox = np.concatenate([gt_bbox,gt_bbox_face,gt_bbox_rhand,gt_bbox_lhand],axis=0) # gt_bbox = (box_ops.box_cxcywh_to_xyxy(gt_body_boxes).reshape(-1,2,2).detach().cpu().numpy()*data_batch['img_shape'].cpu().numpy()[0, ::-1][None,None,:]).reshape(-1,4) img = mmcv.imshow_bboxes(img.copy(), gt_bbox, show=False) gt_2d = data_batch['joint_img'][0][:,:,:2].cpu().numpy()*data_batch['img_shape'].cpu().numpy()[0,None,None,::-1]# *data_batch['joint_img'][0][:,:,2:].cpu().numpy() gt_2d[...,0] = gt_2d[...,0]/12 gt_2d[...,1] = gt_2d[...,1]/16 # data_batch['joint_img'] # gt_kp2d = gt_2d[0][keypoints2d_conf[0]!=0] visualize_kp2d( (gt_2d).reshape(-1,2)[None], output_path='./figs/gt2d', image_array=img.copy()[None], # data_source='smplx_137', disable_limbs = True, overwrite=True) img =(data_batch['img'][0].permute(1,2,0)*255).int().cpu().numpy() # pred_smpl_kp2d = project_points_new( # points_3d=outputs['pred_smpl_kp3d'][:,:2].reshape(-1,137,3), # pred_cam=pred_cam, # focal_length=focal_length, # camera_center=img_wh/2 # ) img_shape = data_batch['img_shape'][0] # pred_kp2d = pred_kp2d.cpu().detach().numpy()*img_shape.cpu().numpy()[None,None ::-1] # pred_bbox_all = [] # for i in idx[0]: # pred_bbox_body = (box_ops.box_cxcywh_to_xyxy(outputs['pred_boxes'][0,i]).reshape(2,2).detach().cpu().numpy()*data_batch['img_shape'].cpu().numpy()[0, ::-1]).reshape(1,4) # pred_bbox_lhand = (box_ops.box_cxcywh_to_xyxy(outputs['pred_lhand_boxes'][0,i]).reshape(2,2).detach().cpu().numpy()*data_batch['img_shape'].cpu().numpy()[0, ::-1]).reshape(1,4) # pred_bbox_rhand = (box_ops.box_cxcywh_to_xyxy(outputs['pred_rhand_boxes'][0,i]).reshape(2,2).detach().cpu().numpy()*data_batch['img_shape'].cpu().numpy()[0, ::-1]).reshape(1,4) # pred_bbox_face = (box_ops.box_cxcywh_to_xyxy(outputs['pred_face_boxes'][0,i]).reshape(2,2).detach().cpu().numpy()*data_batch['img_shape'].cpu().numpy()[0, ::-1]).reshape(1,4) # pred_bbox = np.concatenate([pred_bbox_body,pred_bbox_face,pred_bbox_rhand,pred_bbox_lhand],axis=0) # pred_bbox_all.append(pred_bbox) # src_body_boxes = outputs['pred_boxes'][idx] # pred_bbox_all = np.concatenate(pred_bbox_all,axis=0) pred_bbox_body = (box_ops.box_cxcywh_to_xyxy(outputs['pred_boxes'][idx]).reshape(-1,2).detach().cpu().numpy()*data_batch['img_shape'].cpu().numpy()[1, ::-1]).reshape(-1,4) pred_bbox_lhand = (box_ops.box_cxcywh_to_xyxy(outputs['pred_lhand_boxes'][idx]).reshape(-1,2).detach().cpu().numpy()*data_batch['img_shape'].cpu().numpy()[1, ::-1]).reshape(-1,4) pred_bbox_rhand = (box_ops.box_cxcywh_to_xyxy(outputs['pred_rhand_boxes'][idx]).reshape(-1,2).detach().cpu().numpy()*data_batch['img_shape'].cpu().numpy()[1, ::-1]).reshape(-1,4) pred_bbox_face = (box_ops.box_cxcywh_to_xyxy(outputs['pred_face_boxes'][idx]).reshape(-1,2).detach().cpu().numpy()*data_batch['img_shape'].cpu().numpy()[1, ::-1]).reshape(-1,4) pred_bbox = np.concatenate([pred_bbox_body,pred_bbox_face,pred_bbox_rhand,pred_bbox_lhand],axis=0) # pred_bbox_body = (box_ops.box_cxcywh_to_xyxy(src_body_boxes).reshape(-1,2,2).detach().cpu().numpy()*data_batch['img_shape'].cpu().numpy()[0, ::-1][None,None,:]).reshape(-1,4) # import ipdb;ipdb.set_trace() img = mmcv.imshow_bboxes(img.copy(), pred_bbox, show=False) # cv2.imwrite('test.png',img) visualize_kp2d( (pred_smpl_kp2d*img_wh[:, None])[None].detach().cpu().numpy(), output_path='./figs/pred2d', image_array=img.copy()[None], data_source='smplx_137', overwrite=True) # visualize_kp2d( # (pred_smpl_kp2d*img_wh[:, None])[None].detach().cpu().numpy(), # output_path='./figs/pred2d', # image_array=img.copy()[None], # data_source='smplx_137', # overwrite=True) vis_smpl=True if vis_smpl: gt_output = bm( betas=targets_shape.reshape(-1, 10), body_pose=targets_smpl_pose[:,3:66].reshape(-1, 21*3), global_orient=targets_smpl_pose[:,:3].reshape(-1, 3), left_hand_pose=targets_smpl_pose[:,66:111].reshape(-1, 15*3), right_hand_pose=targets_smpl_pose[:,111:156].reshape(-1, 15*3), leye_pose=torch.zeros_like(targets_smpl_pose[:,:3]).reshape(-1, 3), reye_pose=torch.zeros_like(targets_smpl_pose[:,:3]).reshape(-1, 3), expression=torch.zeros_like(targets_shape).reshape(-1, 10), jaw_pose=targets_smpl_pose[:,156:].reshape(-1, 3)) verts = gt_output['vertices'] for i_obj,v in enumerate(verts): save_obj('./figs/gt_smpl_%d.obj'%i_obj,verts = v,faces=torch.tensor([])) import ipdb;ipdb.set_trace() losses = {} if valid_num == 0: losses['loss_smpl_body_kp2d'] = torch.as_tensor(0., device=device) + pred_smpl_kp2d.sum()*0 losses['loss_smpl_lhand_kp2d'] = torch.as_tensor(0., device=device) + pred_smpl_kp2d.sum()*0 losses['loss_smpl_rhand_kp2d'] = torch.as_tensor(0., device=device) + pred_smpl_kp2d.sum()*0 losses['loss_smpl_face_kp2d'] = torch.as_tensor(0., device=device) + pred_smpl_kp2d.sum()*0 return losses body_idx = smpl_x.joint_part['body'] face_idx = smpl_x.joint_part['face'] lhand_idx = smpl_x.joint_part['lhand'] rhand_idx = smpl_x.joint_part['rhand'] loss_smpl_kp2d = F.l1_loss(pred_smpl_kp2d, targets_kp2d, reduction='none') # If has_keypoints2d is not None, then computes the losses on the # instances that have ground-truth keypoints2d. # But the zero confidence keypoints will be included in mean. # Otherwise, only compute the keypoints2d # which have positive confidence. # has_keypoints2d is None when the key has_keypoints2d # is not in the datasets # import pdb; pdb.set_trace() valid_pos = keypoints2d_conf > 0 if keypoints2d_conf[valid_pos].numel() == 0: return { 'loss_smpl_body_kp2d': torch.as_tensor(0., device=device) + loss_smpl_kp2d.sum()*0, 'loss_smpl_lhand_kp2d': torch.as_tensor(0., device=device) + loss_smpl_kp2d.sum()*0, 'loss_smpl_rhand_kp2d': torch.as_tensor(0., device=device) + loss_smpl_kp2d.sum()*0, 'loss_smpl_face_kp2d': torch.as_tensor(0., device=device) + loss_smpl_kp2d.sum()*0, } loss_smpl_kp2d = loss_smpl_kp2d * keypoints2d_conf # loss /= keypoints2d_conf[valid_pos].numel() if face_hand_kpt: losses['loss_smpl_body_kp2d'] = torch.sum(loss_smpl_kp2d[:, body_idx, :]) / num_boxes if lhand_num_boxes>0: losses['loss_smpl_lhand_kp2d'] = torch.sum(loss_smpl_kp2d[:, lhand_idx, :]) / lhand_num_boxes else: losses['loss_smpl_lhand_kp2d'] =torch.as_tensor(0., device=device) + loss_smpl_kp2d.sum()*0 if rhand_num_boxes>0: losses['loss_smpl_rhand_kp2d'] = torch.sum(loss_smpl_kp2d[:, rhand_idx, :]) / rhand_num_boxes else: losses['loss_smpl_rhand_kp2d'] = torch.as_tensor(0., device=device) + loss_smpl_kp2d.sum()*0 if face_num_boxes>0: losses['loss_smpl_face_kp2d'] = torch.sum(loss_smpl_kp2d[:, face_idx, :]) / face_num_boxes else: losses['loss_smpl_face_kp2d'] = torch.as_tensor(0., device=device) + loss_smpl_kp2d.sum()*0 else: losses['loss_smpl_body_kp2d'] = torch.sum(loss_smpl_kp2d[:, body_idx, :]) / num_boxes losses['loss_smpl_lhand_kp2d'] = 0*torch.sum(loss_smpl_kp2d[:, lhand_idx, :]) / (keypoints2d_conf[:, lhand_idx].sum() + 1e-6) losses['loss_smpl_rhand_kp2d'] = 0*torch.sum(loss_smpl_kp2d[:, rhand_idx, :]) / (keypoints2d_conf[:, rhand_idx].sum() + 1e-6) losses['loss_smpl_face_kp2d'] = 0*torch.sum(loss_smpl_kp2d[:, face_idx, :]) / (keypoints2d_conf[:, face_idx].sum() + 1e-6) return losses def loss_smpl_kp2d_ba(self, outputs, targets, indices, idx, num_boxes, data_batch, focal_length=5000., has_keypoints2d=None, face_hand_kpt=False): """Compute loss for 2d keypoints.""" device = outputs['pred_logits'].device indices = indices[0] # pdb.set_trace() pred_smpl_kp3d = outputs['pred_smpl_kp3d'][idx].float()#.detach() pred_cam = outputs['pred_smpl_cam'][idx].float() # pdb.set_trace() # max_img_res = orig_img_res.max(-1)[0] # torch.cat([ torch.Tensor([orig_img_res[0]]*9), torch.Tensor([orig_img_res[1]]*9)], 0) # torch.cat([orig_img_res[i][None].repeat(num,1) for i, num in enumerate(instance_num)], 0) # orig_img_res = torch.Tensor([t['orig_size'] for t, (_, i) in zip(targets, indices)]).type_as(pred_smpl_kp3d) # orig_img_res = torch.Tensor([target['orig_size'] for target in targets]).type_as(pred_smpl_kp3d) # max_img_res = torch.cat([torch.full_like(src, i) for i, (src, _) in zip(max_img_res, indices)]).type_as(pred_smpl_kp3d) valid_num=0 for indice in indices[0]: valid_num+=len(indice) targets_kp2d = torch.cat( [t[i] for t, (_, i) in zip(data_batch['joint_img'], indices)], dim=0) losses = {} keypoints2d_conf = targets_kp2d[:,:,2:].clone() targets_kp2d = targets_kp2d[:,:,:2] keypoints2d_conf = keypoints2d_conf.repeat(1, 1, 2) targets_kp2d = targets_kp2d[:, :, :2].float() targets_kp2d[:, :, 0] = targets_kp2d[:, :, 0] / cfg.output_hm_shape[2] targets_kp2d[:, :, 1] = targets_kp2d[:, :, 1] / cfg.output_hm_shape[1] # targets_kp2d = targets_kp2d * 2 - 1 img_wh = torch.cat([data_batch['img_shape'][i][None] for i in idx[0]], dim=0).flip(-1) pred_smpl_kp2d = project_points_new( points_3d=pred_smpl_kp3d, pred_cam=pred_cam, focal_length=focal_length, camera_center=img_wh/2 ) pred_smpl_kp2d = pred_smpl_kp2d / img_wh[:, None] if valid_num == 0: losses['loss_smpl_body_kp2d_ba'] = torch.as_tensor(0., device=device) + pred_smpl_kp2d.sum()*0 losses['loss_smpl_lhand_kp2d_ba'] = torch.as_tensor(0., device=device) + pred_smpl_kp2d.sum()*0 losses['loss_smpl_rhand_kp2d_ba'] = torch.as_tensor(0., device=device) + pred_smpl_kp2d.sum()*0 losses['loss_smpl_face_kp2d_ba'] = torch.as_tensor(0., device=device) + pred_smpl_kp2d.sum()*0 return losses # rhand bbox rhand_bbox_valid = torch.cat( [t[i] for t, (_, i) in zip(data_batch['rhand_bbox_valid'], indices) ], dim=0) rhand_bbox_gt = torch.cat( [t['rhand_boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0) rhand_bbox_gt = (box_ops.box_cxcywh_to_xyxy(rhand_bbox_gt). reshape(-1,2,2)*img_wh[:, None]).reshape(-1, 4) num_rhand_bbox = rhand_bbox_valid.sum() # lhand bbox lhand_bbox_valid = torch.cat([ t[i] for t, (_, i) in zip(data_batch['lhand_bbox_valid'], indices)], dim=0) lhand_bbox_gt = torch.cat( [t['lhand_boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0) lhand_bbox_gt = (box_ops.box_cxcywh_to_xyxy(lhand_bbox_gt). reshape(-1,2,2)*img_wh[:, None]).reshape(-1, 4) num_lhand_bbox = lhand_bbox_valid.sum() # face bbox face_bbox_valid = torch.cat( [t[i] for t, (_, i) in zip(data_batch['face_bbox_valid'], indices)], dim=0) face_bbox_gt = torch.cat( [t['face_boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0) face_bbox_gt = (box_ops.box_cxcywh_to_xyxy(face_bbox_gt). reshape(-1,2,2)*img_wh[:, None]).reshape(-1, 4) num_face_bbox = face_bbox_valid.sum() img_shape = torch.cat( [t[None].repeat(len(i), 1) for t, (_, i) in zip(data_batch['img_shape'], indices)], dim=0) # joint_proj = (joint_proj / 2 + 0.5) # joint_proj[:, :, 0] = joint_proj[:, :, 0] * img_shape[:, 1:] # joint_proj[:, :, 1] = joint_proj[:, :, 1] * img_shape[:, :1] if not (lhand_bbox_valid + rhand_bbox_valid + face_bbox_valid == 0).all(): for part_name, bbox in ( ('lhand', lhand_bbox_gt), ('rhand', rhand_bbox_gt), ('face', face_bbox_gt)): x = targets_kp2d[:, smpl_x.joint_part[part_name], 0] y = targets_kp2d[:, smpl_x.joint_part[part_name], 1] # trunc = joint_trunc[:, smpl_x.joint_part[part_name], 0] trunc = keypoints2d_conf[:, smpl_x.joint_part[part_name], 0].clone() # x in [0, 1]? bbox in [0, 1]. x -= (bbox[:, None, 0] / img_shape[:, 1:]) # x x *= (img_shape[:, 1:] / (bbox[:, None, 2] - bbox[:, None, 0] + 1e-6)) y -= (bbox[:, None, 1] / img_shape[:, :1]) y *= (img_shape[:, :1] / (bbox[:, None, 3] - bbox[:, None, 1] + 1e-6)) # transformed to 0-1 bbox space trunc *= ((x >= 0) * (x <= 1) * (y >= 0) * (y <= 1)) coord = torch.stack((x, y), 2) targets_kp2d = torch.cat( (targets_kp2d[:, :smpl_x.joint_part[part_name][0], :], coord, targets_kp2d[:, smpl_x.joint_part[part_name][-1] + 1:, :]), 1) x_pred = pred_smpl_kp2d[:, smpl_x.joint_part[part_name], 0] y_pred = pred_smpl_kp2d[:, smpl_x.joint_part[part_name], 1] # bbox: xyxy img_shape: hw x_pred -= (bbox[:, None, 0] / img_shape[:, 1:]) x_pred *= (img_shape[:, 1:] / (bbox[:, None, 2] - bbox[:, None, 0] + 1e-6)) y_pred -= (bbox[:, None, 1] / img_shape[:, :1]) y_pred *= (img_shape[:, :1] / (bbox[:, None, 3] - bbox[:, None, 1] + 1e-6)) coord_pred = torch.stack((x_pred, y_pred), 2) trans = [] for bid in range(coord_pred.shape[0]): mask = trunc[bid] == 1 if torch.sum(mask) == 0: trans.append(torch.zeros((2)).float().cuda()) else: trans.append( (-coord_pred[bid, mask, :2] + targets_kp2d[:, smpl_x.joint_part[part_name], :][bid, mask, :2]).mean(0)) trans = torch.stack(trans)[:, None, :] coord_pred = coord_pred + trans # global translation alignment pred_smpl_kp2d = torch.cat( (pred_smpl_kp2d[:, :smpl_x.joint_part[part_name][0], :], coord_pred, pred_smpl_kp2d[:, smpl_x.joint_part[part_name][-1] + 1:, :]), 1) vis = False if vis: import mmcv import cv2 import numpy as np from detrsmpl.core.visualization.visualize_keypoints2d import visualize_kp2d from detrsmpl.core.visualization.visualize_smpl import visualize_smpl_hmr,render_smpl from detrsmpl.models.body_models.builder import build_body_model from pytorch3d.io import save_obj from detrsmpl.core.visualization.visualize_keypoints3d import visualize_kp3d img = mmcv.imdenormalize( img=(data_batch['img'][0].cpu().numpy()).transpose(1, 2, 0), mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_bgr=True).astype(np.uint8).copy() device = outputs['pred_smpl_kp3d'].device gt_2d = (coord) img = mmcv.imshow_bboxes(img,bbox[0,None].int().cpu().numpy(),show=False) gt_2d[:,:,0] /= (img_shape[:, 1:] / (bbox[:, None, 2] - bbox[:, None, 0])) gt_2d[:,:,1] /= (img_shape[:, :1] / (bbox[:, None, 3] - bbox[:, None, 1])) gt_2d_ori = gt_2d.clone() gt_2d_ori[:,:,0] += (bbox[:, None, 0] / img_shape[:, 1:]) gt_2d_ori[:,:,1] += (bbox[:, None, 1] / img_shape[:, :1]) gt_2d = (gt_2d*img_wh[:, None]).cpu().detach().numpy() gt_2d_ori = (gt_2d_ori*img_wh[:, None]).cpu().detach().numpy() # visualize keypoints after translation to bbox and to gt pred_2d = (coord_pred).clone() pred_2d[:,:,0] /= (img_shape[:, 1:] / (bbox[:, None, 2] - bbox[:, None, 0])) pred_2d[:,:,1] /= (img_shape[:, :1] / (bbox[:, None, 3] - bbox[:, None, 1])) # visualize keypoints begore translation to bbox and to gt pred_2d_ori = (coord_pred-trans).clone() pred_2d_ori[:,:,0] /= (img_shape[:, 1:] / (bbox[:, None, 2] - bbox[:, None, 0])) pred_2d_ori[:,:,1] /= (img_shape[:, :1] / (bbox[:, None, 3] - bbox[:, None, 1])) pred_2d_ori[:,:,0] += (bbox[:, None, 0] / img_shape[:, 1:]) pred_2d_ori[:,:,1] += (bbox[:, None, 1] / img_shape[:, :1]) pred_2d = (pred_2d*img_wh[:, None]).cpu().detach().numpy() pred_2d_ori = (pred_2d_ori*img_wh[:, None]).cpu().detach().numpy() visualize_kp2d( gt_2d[0].reshape(-1,2)[None], output_path='./figs/gt2d%s'%part_name, image_array=img.copy()[None], # data_source='smplx_137', disable_limbs = True, overwrite=True) visualize_kp2d( gt_2d_ori[0].reshape(-1,2)[None], output_path='./figs/gt2d%s_ori'%part_name, image_array=img.copy()[None], # data_source='smplx_137', disable_limbs = True, overwrite=True) visualize_kp2d( pred_2d[0].reshape(-1,2)[None], output_path='./figs/pred2d%s'%part_name, image_array=img.copy()[None], # data_source='smplx_137', disable_limbs = True, overwrite=True) visualize_kp2d( pred_2d_ori[0].reshape(-1,2)[None], output_path='./figs/pred2d%s_ori'%part_name, image_array=img.copy()[None], # data_source='smplx_137', disable_limbs = True, overwrite=True) loss_smpl_kp2d_ba = F.l1_loss(pred_smpl_kp2d, targets_kp2d[:, :, :2], reduction='none') valid_pos = keypoints2d_conf > 0 losses = {} if keypoints2d_conf[valid_pos].numel() == 0: return { 'loss_smpl_body_kp2d_ba': torch.as_tensor(0., device=device) + loss_smpl_kp2d_ba.sum()*0, 'loss_smpl_lhand_kp2d_ba': torch.as_tensor(0., device=device) + loss_smpl_kp2d_ba.sum()*0, 'loss_smpl_rhand_kp2d_ba': torch.as_tensor(0., device=device) + loss_smpl_kp2d_ba.sum()*0, 'loss_smpl_face_kp2d_ba': torch.as_tensor(0., device=device) + loss_smpl_kp2d_ba.sum()*0, } # loss /= targets_kp3d_conf[valid_pos].numel() # 要改 loss_smpl_kp2d_ba = loss_smpl_kp2d_ba * keypoints2d_conf losses['loss_smpl_body_kp2d_ba'] = torch.sum(loss_smpl_kp2d_ba[:, smpl_x.joint_part['body'], :]) / num_boxes if face_hand_kpt: if num_lhand_bbox>0: losses['loss_smpl_lhand_kp2d_ba'] = torch.sum(loss_smpl_kp2d_ba[:, smpl_x.joint_part['lhand'], :]) / num_lhand_bbox else: losses['loss_smpl_lhand_kp2d_ba'] = torch.as_tensor(0., device=device) + loss_smpl_kp2d_ba.sum()*0 if num_rhand_bbox>0: losses['loss_smpl_rhand_kp2d_ba'] = torch.sum(loss_smpl_kp2d_ba[:, smpl_x.joint_part['rhand'], :]) / num_rhand_bbox else: losses['loss_smpl_rhand_kp2d_ba'] = torch.as_tensor(0., device=device) + loss_smpl_kp2d_ba.sum()*0 if num_face_bbox>0: losses['loss_smpl_face_kp2d_ba'] = torch.sum(loss_smpl_kp2d_ba[:, smpl_x.joint_part['face'], :]) / num_face_bbox else: losses['loss_smpl_face_kp2d_ba'] = torch.as_tensor(0., device=device) + loss_smpl_kp2d_ba.sum()*0 else: losses['loss_smpl_lhand_kp2d_ba'] = 0*torch.sum(loss_smpl_kp2d_ba[:, smpl_x.joint_part['lhand'], :]) / num_lhand_bbox losses['loss_smpl_rhand_kp2d_ba'] = 0*torch.sum(loss_smpl_kp2d_ba[:, smpl_x.joint_part['rhand'], :]) / num_rhand_bbox losses['loss_smpl_face_kp2d_ba'] = 0*torch.sum(loss_smpl_kp2d_ba[:, smpl_x.joint_part['face'], :]) / num_face_bbox return losses def loss_boxes(self, outputs, targets, indices, idx, num_boxes, data_batch, face_hand_box=False): """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4] The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size.""" indices = indices[0] device = outputs['pred_logits'].device assert 'pred_boxes' in outputs # assert 'pred_lhand_boxes' in outputs # assert 'pred_rhand_boxes' in outputs # assert 'pred_face_boxes' in outputs src_body_boxes = outputs['pred_boxes'][idx] target_body_boxes = torch.cat( [t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0) target_body_boxes_conf = torch.cat( [t[i] for t, (_, i) in zip(data_batch['body_bbox_valid'], indices)], dim=0) valid_num=0 for indice in indices[0]: valid_num+=len(indice) loss_body_bbox = F.l1_loss(src_body_boxes, target_body_boxes, reduction='none') loss_body_bbox = loss_body_bbox * target_body_boxes_conf[:,None] losses = {} losses['loss_body_bbox'] = loss_body_bbox.sum() / num_boxes loss_body_giou = 1 - torch.diag( box_ops.generalized_box_iou( box_ops.box_cxcywh_to_xyxy(src_body_boxes), box_ops.box_cxcywh_to_xyxy(target_body_boxes))) loss_body_giou = loss_body_giou * target_body_boxes_conf losses['loss_body_giou'] = loss_body_giou.sum() / num_boxes if 'pred_lhand_boxes' in outputs and face_hand_box: src_lhand_boxes = outputs['pred_lhand_boxes'][idx] target_lhand_boxes = torch.cat( [t['lhand_boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0) target_lhand_boxes_conf = torch.cat( [t[i] for t, (_, i) in zip(data_batch['lhand_bbox_valid'], indices)], dim=0) # print(target_lhand_boxes_conf) loss_lhand_bbox = F.l1_loss(src_lhand_boxes, target_lhand_boxes, reduction='none') loss_lhand_bbox = loss_lhand_bbox * target_lhand_boxes_conf[:,None] losses['loss_lhand_bbox'] = loss_lhand_bbox.sum() / num_boxes loss_lhand_giou = 1 - torch.diag( box_ops.generalized_box_iou( box_ops.box_cxcywh_to_xyxy(src_lhand_boxes), box_ops.box_cxcywh_to_xyxy(target_lhand_boxes))) loss_lhand_giou = loss_lhand_giou * target_lhand_boxes_conf losses['loss_lhand_giou'] = loss_lhand_giou.sum() / num_boxes # import mmcv # import cv2 # img = (data_batch['img'][0]*255).permute(1,2,0).int().detach().cpu().numpy() # pred_bbox = (box_ops.box_cxcywh_to_xyxy(src_lhand_boxes[0]).reshape(2,2).detach().cpu().numpy()*data_batch['img_shape'].cpu().numpy()[0, ::-1]).reshape(1,4) # pred_bbox = (box_ops.box_cxcywh_to_xyxy(src_lhand_boxes[0]).reshape(2,2).detach().cpu().numpy()*data_batch['img_shape'].cpu().numpy()[0, ::-1]).reshape(1,4) # img = mmcv.imshow_bboxes(img.copy(), pred_bbox, show=False) # cv2.imwrite('test.png',img) if 'pred_rhand_boxes' in outputs and face_hand_box: src_rhand_boxes = outputs['pred_rhand_boxes'][idx] target_rhand_boxes = torch.cat( [t['rhand_boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0) target_rhand_boxes_conf = torch.cat( [t[i] for t, (_, i) in zip(data_batch['rhand_bbox_valid'], indices)], dim=0) loss_rhand_bbox = F.l1_loss(src_rhand_boxes, target_rhand_boxes, reduction='none') loss_rhand_bbox = loss_rhand_bbox * target_rhand_boxes_conf[:,None] losses['loss_rhand_bbox'] = loss_rhand_bbox.sum() / num_boxes loss_rhand_giou = 1 - torch.diag( box_ops.generalized_box_iou( box_ops.box_cxcywh_to_xyxy(src_rhand_boxes), box_ops.box_cxcywh_to_xyxy(target_rhand_boxes))) loss_rhand_giou = loss_rhand_giou * target_rhand_boxes_conf losses['loss_rhand_giou'] = loss_rhand_giou.sum() / num_boxes if 'pred_face_boxes' in outputs and face_hand_box: src_face_boxes = outputs['pred_face_boxes'][idx] target_face_boxes = torch.cat( [t['face_boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0) target_face_boxes_conf = torch.cat( [t[i] for t, (_, i) in zip(data_batch['face_bbox_valid'], indices)], dim=0) loss_face_bbox = F.l1_loss(src_face_boxes, target_face_boxes, reduction='none') loss_face_bbox = loss_face_bbox * target_face_boxes_conf[:,None] losses['loss_face_bbox'] = loss_face_bbox.sum() / num_boxes loss_face_giou = 1 - torch.diag( box_ops.generalized_box_iou( box_ops.box_cxcywh_to_xyxy(src_face_boxes), box_ops.box_cxcywh_to_xyxy(target_face_boxes))) loss_face_giou = loss_face_giou * target_face_boxes_conf losses['loss_face_giou'] = loss_face_giou.sum() / num_boxes if valid_num == 0: losses = {} if face_hand_box: losses = { 'loss_body_bbox': loss_body_bbox.sum() * 0, 'loss_body_giou': loss_body_bbox.sum() * 0, 'loss_lhand_bbox': loss_lhand_bbox.sum() * 0, 'loss_lhand_giou': loss_lhand_bbox.sum() * 0, 'loss_rhand_bbox': loss_rhand_bbox.sum() * 0, 'loss_rhand_giou': loss_rhand_bbox.sum() * 0, 'loss_face_bbox': loss_face_bbox.sum() * 0, 'loss_face_giou': loss_face_bbox.sum() * 0, } else: losses = { 'loss_body_bbox': loss_body_bbox.sum() * 0, 'loss_body_giou': loss_body_bbox.sum() * 0, 'loss_lhand_bbox': loss_body_bbox.sum() * 0, 'loss_lhand_giou': loss_body_bbox.sum() * 0, 'loss_rhand_bbox': loss_body_bbox.sum() * 0, 'loss_rhand_giou': loss_body_bbox.sum() * 0, 'loss_face_bbox': loss_body_bbox.sum() * 0, 'loss_face_giou': loss_body_bbox.sum() * 0, } return losses return losses def loss_dn_boxes(self, outputs, targets, indices, idx, num_boxes, data_batch): """ Input: - src_boxes: bs, num_dn, 4 - tgt_boxes: bs, num_dn, 4 """ indices = indices[0] num_tgt = outputs['num_tgt'] src_boxes = outputs['dn_bbox_pred'] tgt_boxes = outputs['dn_bbox_input'] valid_num=0 for indice in indices[0]: valid_num+=len(indice) if valid_num == 0: device = outputs['pred_logits'].device losses = { 'dn_loss_bbox': src_boxes.sum()*0, 'dn_loss_giou': src_boxes.sum()*0, } return losses if 'num_tgt' not in outputs: device = outputs['pred_logits'].device losses = { 'dn_loss_bbox': src_boxes.sum()*0, 'dn_loss_giou': src_boxes.sum()*0, } return losses if 'num_tgt' not in outputs: device = outputs['pred_logits'].device losses = { 'dn_loss_bbox': src_boxes.sum()*0, 'dn_loss_giou': src_boxes.sum()*0, } return losses return self.tgt_loss_boxes(src_boxes, tgt_boxes, num_tgt) def loss_dn_labels(self, outputs, targets, indices, idx, num_boxes, data_batch): """ Input: - src_logits: bs, num_dn, num_classes - tgt_labels: bs, num_dn """ indices = indices[0] if 'num_tgt' not in outputs: device = outputs['pred_logits'].device losses = { 'dn_loss_ce': outputs['pred_logits'].sum()*0, } return losses valid_num = 0 for indice in indices[0]: valid_num+=len(indice) if valid_num == 0: device = outputs['pred_logits'].device losses = { 'dn_loss_ce': outputs['pred_logits'].sum()*0, } return losses num_tgt = outputs['num_tgt'] src_logits = outputs['dn_class_pred'] # bs, num_dn, text_len tgt_labels = outputs['dn_class_input'] return self.tgt_loss_labels(src_logits, tgt_labels, num_tgt) @torch.no_grad() def loss_matching_cost(self, outputs, targets, indices, idx, num_boxes, data_batch): """ Input: - src_logits: bs, num_dn, num_classes - tgt_labels: bs, num_dn """ cost_mean_dict = indices[1] losses = {'set_{}'.format(k): v for k, v in cost_mean_dict.items()} return losses def _get_src_permutation_idx(self, indices): # permute predictions following indices batch_idx = torch.cat( [torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) src_idx = torch.cat([src for (src, _) in indices]) return batch_idx, src_idx def _get_tgt_permutation_idx(self, indices): # permute targets following indices batch_idx = torch.cat( [torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) tgt_idx = torch.cat([tgt for (_, tgt) in indices]) return batch_idx, tgt_idx def get_loss(self, loss, outputs, targets, data_batch, indices, num_boxes, **kwargs): loss_map = { 'smpl_pose': self.loss_smpl_pose, 'smpl_beta': self.loss_smpl_beta, 'smpl_expr': self.loss_smpl_expr, 'smpl_kp2d': self.loss_smpl_kp2d, 'smpl_kp2d_ba': self.loss_smpl_kp2d_ba, 'smpl_kp3d_ra': self.loss_smpl_kp3d_ra, 'smpl_kp3d': self.loss_smpl_kp3d, 'labels': self.loss_labels, 'cardinality': self.loss_cardinality, 'boxes': self.loss_boxes, 'dn_label': self.loss_dn_labels, 'dn_bbox': self.loss_dn_boxes, 'matching': self.loss_matching_cost, } idx = self._get_src_permutation_idx(indices[0]) # pdb.set_trace() assert loss in loss_map, f'do you really want to compute {loss} loss?' return loss_map[loss](outputs, targets, indices, idx, num_boxes, data_batch, **kwargs) def prep_for_dn2(self, mask_dict): known_bboxs = mask_dict['known_bboxs'] known_labels = mask_dict['known_labels'] output_known_coord = mask_dict['output_known_coord'] output_known_class = mask_dict['output_known_class'] num_tgt = mask_dict['pad_size'] return known_labels, known_bboxs, output_known_class, output_known_coord, num_tgt ## SMPL losses def forward(self, outputs, targets, data_batch, return_indices=False): """ This performs the loss computation. Parameters: outputs: dict of tensors, see the output specification of the model for the format targets: list of dicts, such that len(targets) == batch_size. The expected keys in each dict depends on the losses applied, see each loss' doc return_indices: used for vis. if True, the layer0-5 indices will be returned as well. """ # import pdb; pdb.set_trace() outputs_without_aux = { k: v for k, v in outputs.items() if k != 'aux_outputs' } device = next(iter(outputs.values())).device # Compute the average number of target boxes accross all nodes, for normalization purposes num_boxes = sum(len(t['boxes']) for t in targets) num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=device) if is_dist_avail_and_initialized(): torch.distributed.all_reduce(num_boxes) num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item() # loss for final layer # pdb.set_trace() indices = self.matcher(outputs_without_aux, targets) if return_indices: indices0_copy = indices indices_list = [] losses = {} smpl_loss = ['smpl_pose', 'smpl_beta', 'smpl_expr', 'smpl_kp2d', 'smpl_kp2d_ba', 'smpl_kp3d', 'smpl_kp3d_ra'] # import pdb; pdb.set_trace() for loss in self.losses: # print(loss) # print(self.get_loss(loss, outputs, targets, indices, num_boxes)) kwargs = {} if loss == 'keypoints' or loss in smpl_loss: kwargs.update({'face_hand_kpt': True}) if loss == 'boxes': kwargs.update({'face_hand_box': True}) losses.update( self.get_loss( loss, outputs, targets, data_batch, indices, num_boxes, **kwargs )) # In case of auxiliary losses, we repeat this process with the output of each intermediate layer. if 'aux_outputs' in outputs: for idx, aux_outputs in enumerate(outputs['aux_outputs']): indices = self.matcher(aux_outputs, targets) if return_indices: indices_list.append(indices) for loss in self.losses: kwargs = {} if loss == 'boxes': kwargs.update({'face_hand_box': False}) if idx >= self.num_box_decoder_layers: kwargs.update({'face_hand_box': True}) if loss == 'masks': continue if loss == 'keypoints': if idx < self.num_box_decoder_layers: continue elif idx < self.num_hand_face_decoder_layers: kwargs.update({'face_hand_kpt': False}) else: kwargs.update({'face_hand_kpt': True}) if loss in smpl_loss: if idx < self.num_box_decoder_layers: continue elif idx < self.num_hand_face_decoder_layers: kwargs.update({'face_hand_kpt': False}) else: kwargs.update({'face_hand_kpt': True}) if loss == 'labels': # Logging is enabled only for the last layer kwargs = {'log': False} # if loss == 'smpl_expr' and idx < self.num_box_decoder_layers: # continue # import pdb;pdb.set_trace() l_dict = self.get_loss(loss, aux_outputs, targets, data_batch, indices, num_boxes, **kwargs) l_dict = {k + f'_{idx}': v for k, v in l_dict.items()} losses.update(l_dict) # interm_outputs loss if 'interm_outputs' in outputs: interm_outputs = outputs['interm_outputs'] indices = self.matcher(interm_outputs, targets) if return_indices: indices_list.append(indices) for loss in self.losses: if loss in ['dn_bbox', 'dn_label', 'keypoints']: continue if loss in [ 'smpl_pose', 'smpl_beta', 'smpl_kp2d_ba', 'smpl_kp2d', 'smpl_kp3d_ra', 'smpl_kp3d', 'smpl_expr' ]: continue kwargs = {} if loss == 'labels': kwargs = {'log': False} l_dict = self.get_loss(loss, interm_outputs, targets, data_batch, indices, num_boxes, **kwargs) l_dict = {k + f'_interm': v for k, v in l_dict.items()} losses.update(l_dict) # aux_init loss if 'query_expand' in outputs: interm_outputs = outputs['query_expand'] indices = self.matcher(interm_outputs, targets) if return_indices: indices_list.append(indices) for loss in self.losses: if loss in ['dn_bbox', 'dn_label']: continue kwargs = {} if loss == 'labels': kwargs = {'log': False} l_dict = self.get_loss(loss, interm_outputs, targets, data_batch, indices, num_boxes, **kwargs) l_dict = {k + f'_query_expand': v for k, v in l_dict.items()} losses.update(l_dict) if return_indices: indices_list.append(indices0_copy) return losses, indices_list return losses def tgt_loss_boxes( self, src_boxes, tgt_boxes, num_tgt, ): """ Input: - src_boxes: bs, num_dn, 4 - tgt_boxes: bs, num_dn, 4 """ loss_bbox = F.l1_loss(src_boxes, tgt_boxes, reduction='none') losses = {} losses['dn_loss_bbox'] = loss_bbox.sum() / num_tgt loss_giou = 1 - torch.diag( box_ops.generalized_box_iou( box_ops.box_cxcywh_to_xyxy(src_boxes.flatten(0, 1)), box_ops.box_cxcywh_to_xyxy(tgt_boxes.flatten(0, 1)))) losses['dn_loss_giou'] = loss_giou.sum() / num_tgt return losses def tgt_loss_labels(self, src_logits: Tensor, tgt_labels: Tensor, num_tgt: int, log: bool = True): """ Input: - src_logits: bs, num_dn, num_classes - tgt_labels: bs, num_dn """ target_classes_onehot = torch.zeros([ src_logits.shape[0], src_logits.shape[1], src_logits.shape[2] + 1 ], dtype=src_logits.dtype, layout=src_logits.layout, device=src_logits.device) target_classes_onehot.scatter_(2, tgt_labels.unsqueeze(-1), 1) target_classes_onehot = target_classes_onehot[:, :, :-1] loss_ce = sigmoid_focal_loss(src_logits, target_classes_onehot, num_tgt, alpha=self.focal_alpha, gamma=2) * src_logits.shape[1] losses = {'dn_loss_ce': loss_ce} return losses