import copy import pdb import os import math from typing import List import torch import torch.nn.functional as F from torch import nn from torch import Tensor from util import box_ops from util.keypoint_ops import keypoint_xyzxyz_to_xyxyzz from util.misc import (NestedTensor, nested_tensor_from_tensor_list, accuracy, get_world_size, interpolate, is_dist_avail_and_initialized, inverse_sigmoid) from .backbones import build_backbone from .matcher import build_matcher from .transformer import build_transformer from .utils import PoseProjector, sigmoid_focal_loss, MLP from .postprocesses import PostProcess_SMPLX, PostProcess_aios from .postprocesses import PostProcess_SMPLX_Multi as PostProcess_SMPLX from .postprocesses import PostProcess_SMPLX_Multi_Box from .postprocesses import PostProcess_SMPLX_Multi_Infer, PostProcess_SMPLX_Multi_Infer_Box from .criterion_smplx import SetCriterion, SetCriterion_Box from ..registry import MODULE_BUILD_FUNCS from detrsmpl.core.conventions.keypoints_mapping import convert_kps from detrsmpl.models.body_models.builder import build_body_model from util.human_models import smpl_x from detrsmpl.core.conventions.keypoints_mapping import get_keypoint_idxs_by_part import numpy as np from detrsmpl.utils.geometry import (rot6d_to_rotmat) from detrsmpl.utils.transforms import rotmat_to_aa import cv2 from config.config import cfg class AiOSSMPLX(nn.Module): def __init__( self, backbone, transformer, num_classes, num_queries, aux_loss=False, iter_update=True, query_dim=4, random_refpoints_xy=False, fix_refpoints_hw=-1, num_feature_levels=1, nheads=8, two_stage_type='no', dec_pred_class_embed_share=False, dec_pred_bbox_embed_share=False, dec_pred_pose_embed_share=False, two_stage_class_embed_share=True, two_stage_bbox_embed_share=True, dn_number=100, dn_box_noise_scale=0.4, dn_label_noise_ratio=0.5, dn_batch_gt_fuse=False, dn_labelbook_size=100, dn_attn_mask_type_list=['group2group'], cls_no_bias=False, num_group=100, num_body_points=17, num_hand_points=10, num_face_points=10, num_box_decoder_layers=2, num_hand_face_decoder_layers=4, body_model=dict( type='smplx', keypoint_src='smplx', num_expression_coeffs=10, keypoint_dst='smplx_137', model_path='data/body_models/smplx', use_pca=False, use_face_contour=True), train=True, inference=False, focal_length=[5000., 5000.], camera_3d_size=2.5 ): super().__init__() self.num_queries = num_queries self.transformer = transformer self.num_classes = num_classes self.hidden_dim = hidden_dim = transformer.d_model self.num_feature_levels = num_feature_levels self.nheads = nheads self.label_enc = nn.Embedding(dn_labelbook_size + 1, hidden_dim) self.num_body_points = num_body_points self.num_hand_points = num_hand_points self.num_face_points = num_face_points self.num_whole_body_points = num_body_points + 2*num_hand_points + num_face_points self.num_box_decoder_layers = num_box_decoder_layers self.num_hand_face_decoder_layers = num_hand_face_decoder_layers self.focal_length = focal_length self.camera_3d_size=camera_3d_size self.inference = inference if train: self.smpl_convention = 'smplx' else: self.smpl_convention = 'h36m' # setting query dim self.query_dim = query_dim assert query_dim == 4 self.random_refpoints_xy = random_refpoints_xy # False self.fix_refpoints_hw = fix_refpoints_hw # -1 # for dn training self.dn_number = dn_number self.dn_box_noise_scale = dn_box_noise_scale self.dn_label_noise_ratio = dn_label_noise_ratio self.dn_batch_gt_fuse = dn_batch_gt_fuse self.dn_labelbook_size = dn_labelbook_size self.dn_attn_mask_type_list = dn_attn_mask_type_list assert all([ i in ['match2dn', 'dn2dn', 'group2group'] for i in dn_attn_mask_type_list ]) assert not dn_batch_gt_fuse # build human body # if train: # self.body_model = build_body_model(body_model) if inference: body_model=dict( type='smplx', keypoint_src='smplx', num_expression_coeffs=10, num_betas=10, keypoint_dst='smplx', model_path='data/body_models/smplx', use_pca=False, use_face_contour=True) self.body_model = build_body_model(body_model) for param in self.body_model.parameters(): param.requires_grad = False # prepare input projection layers if num_feature_levels > 1: num_backbone_outs = len(backbone.num_channels) # 3 input_proj_list = [] for _ in range(num_backbone_outs): in_channels = backbone.num_channels[_] input_proj_list.append( nn.Sequential( nn.Conv2d(in_channels, hidden_dim, kernel_size=1), nn.GroupNorm(32, hidden_dim), )) for _ in range(num_feature_levels - num_backbone_outs): input_proj_list.append( nn.Sequential( nn.Conv2d(in_channels, hidden_dim, kernel_size=3, stride=2, padding=1), nn.GroupNorm(32, hidden_dim), )) in_channels = hidden_dim self.input_proj = nn.ModuleList(input_proj_list) else: assert two_stage_type == 'no', 'two_stage_type should be no if num_feature_levels=1 !!!' self.input_proj = nn.ModuleList([ nn.Sequential( nn.Conv2d(backbone.num_channels[-1], hidden_dim, kernel_size=1), nn.GroupNorm(32, hidden_dim), ) ]) self.backbone = backbone self.aux_loss = aux_loss self.box_pred_damping = box_pred_damping = None self.iter_update = iter_update assert iter_update, 'Why not iter_update?' # prepare pred layers self.dec_pred_class_embed_share = dec_pred_class_embed_share # false self.dec_pred_bbox_embed_share = dec_pred_bbox_embed_share # false # 1.1 prepare class & box embed _class_embed = nn.Linear(hidden_dim, num_classes, bias=(not cls_no_bias)) if not cls_no_bias: prior_prob = 0.01 bias_value = -math.log((1 - prior_prob) / prior_prob) _class_embed.bias.data = torch.ones(self.num_classes) * bias_value # 1.2 box embed layer list if dec_pred_class_embed_share: class_embed_layerlist = [ _class_embed for i in range(transformer.num_decoder_layers) ] else: class_embed_layerlist = [ copy.deepcopy(_class_embed) for i in range(transformer.num_decoder_layers) ] ########################################################################### # body bbox + l/r hand box + face box ########################################################################### # 1.1 body bbox embed _bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3) nn.init.constant_(_bbox_embed.layers[-1].weight.data, 0) nn.init.constant_(_bbox_embed.layers[-1].bias.data, 0) # 1.2 body bbox embed layer list self.num_group = num_group if dec_pred_bbox_embed_share: box_body_embed_layerlist = [ _bbox_embed for i in range(transformer.num_decoder_layers) ] else: box_body_embed_layerlist = [ copy.deepcopy(_bbox_embed) for i in range(transformer.num_decoder_layers) ] # 2.1 lhand bbox embed _bbox_hand_embed = MLP(hidden_dim, hidden_dim, 2, 3) # TODO: the out shape should be 2 not 4 nn.init.constant_(_bbox_hand_embed.layers[-1].weight.data, 0) nn.init.constant_(_bbox_hand_embed.layers[-1].bias.data, 0) _bbox_hand_hw_embed = MLP(hidden_dim, hidden_dim, 2, 3) nn.init.constant_(_bbox_hand_hw_embed.layers[-1].weight.data, 0) nn.init.constant_(_bbox_hand_hw_embed.layers[-1].bias.data, 0) # 2.2 lhand bbox embed layer list if dec_pred_pose_embed_share: box_hand_embed_layerlist = \ [_bbox_hand_embed for i in range(transformer.num_decoder_layers - num_box_decoder_layers+1)] else: box_hand_embed_layerlist = [ copy.deepcopy(_bbox_hand_embed) for i in range(transformer.num_decoder_layers - num_box_decoder_layers + 1) ] if dec_pred_pose_embed_share: box_hand_hw_embed_layerlist = [ _bbox_hand_hw_embed for i in range( transformer.num_decoder_layers - num_box_decoder_layers) ] else: box_hand_hw_embed_layerlist = [ copy.deepcopy(_bbox_hand_hw_embed) for i in range(transformer.num_decoder_layers - num_box_decoder_layers) ] # 4.1 face bbox embed _bbox_face_embed = MLP(hidden_dim, hidden_dim, 2, 3) nn.init.constant_(_bbox_face_embed.layers[-1].weight.data, 0) nn.init.constant_(_bbox_face_embed.layers[-1].bias.data, 0) _bbox_face_hw_embed = MLP(hidden_dim, hidden_dim, 2, 3) nn.init.constant_(_bbox_face_hw_embed.layers[-1].weight.data, 0) nn.init.constant_(_bbox_face_hw_embed.layers[-1].bias.data, 0) # 4.2 face bbox embed layer list if dec_pred_pose_embed_share: box_face_embed_layerlist = [ _bbox_face_embed for i in range( transformer.num_decoder_layers - num_box_decoder_layers + 1) ] else: box_face_embed_layerlist = [ copy.deepcopy(_bbox_face_embed) for i in range(transformer.num_decoder_layers - num_box_decoder_layers + 1) ] if dec_pred_pose_embed_share: box_face_hw_embed_layerlist = [ _bbox_face_hw_embed for i in range( transformer.num_decoder_layers - num_box_decoder_layers)] else: box_face_hw_embed_layerlist = [ copy.deepcopy(_bbox_face_hw_embed) for i in range(transformer.num_decoder_layers - num_box_decoder_layers) ] ########################################################################### # body kp2d + l/r hand kp2d + face kp2d ########################################################################### ######## body ####### # 1.1 body kp2d embed _pose_embed = MLP(hidden_dim, hidden_dim, 2, 3) nn.init.constant_(_pose_embed.layers[-1].weight.data, 0) nn.init.constant_(_pose_embed.layers[-1].bias.data, 0) # 1.2 body kp2d embed layer list if num_body_points == 17: if dec_pred_pose_embed_share: pose_embed_layerlist = \ [_pose_embed for i in range(transformer.num_decoder_layers - num_box_decoder_layers+1)] else: pose_embed_layerlist = [ copy.deepcopy(_pose_embed) for i in range(transformer.num_decoder_layers - num_box_decoder_layers + 1) ] else: if dec_pred_pose_embed_share: pose_embed_layerlist = [ _pose_embed for i in range(transformer.num_decoder_layers - num_box_decoder_layers) ] else: pose_embed_layerlist = [ copy.deepcopy(_pose_embed) for i in range(transformer.num_decoder_layers - num_box_decoder_layers) ] # 1.3 body kp bbox embed _pose_hw_embed = MLP(hidden_dim, hidden_dim, 2, 3) # 1.4 body kp bbox embed layer list pose_hw_embed_layerlist = [ _pose_hw_embed for i in range(transformer.num_decoder_layers - num_box_decoder_layers) ] ######## lhand ####### # 2.1 lhand kp2d embed _pose_hand_embed = MLP(hidden_dim, hidden_dim, 2, 3) nn.init.constant_(_pose_hand_embed.layers[-1].weight.data, 0) nn.init.constant_(_pose_hand_embed.layers[-1].bias.data, 0) # 2.2 lhand kp2d embed layer list if dec_pred_pose_embed_share: pose_hand_embed_layerlist = \ [_pose_hand_embed for i in range(transformer.num_decoder_layers - num_hand_face_decoder_layers+1)] else: pose_hand_embed_layerlist = [ copy.deepcopy(_pose_hand_embed) for i in range(transformer.num_decoder_layers - num_hand_face_decoder_layers + 1) ] # 2.3 lhand kp bbox embed _pose_hand_hw_embed = MLP(hidden_dim, hidden_dim, 2, 3) # 2.4 lhand kp bbox embed layer list pose_hand_hw_embed_layerlist = [ _pose_hand_hw_embed for i in range(transformer.num_decoder_layers - num_hand_face_decoder_layers) ] ######## face ####### # 4.1 face kp2d embed _pose_face_embed = MLP(hidden_dim, hidden_dim, 2, 3) nn.init.constant_(_pose_face_embed.layers[-1].weight.data, 0) nn.init.constant_(_pose_face_embed.layers[-1].bias.data, 0) # 4.2 face kp2d embed layer list if dec_pred_pose_embed_share: pose_face_embed_layerlist = \ [_pose_face_embed for i in range(transformer.num_decoder_layers - num_hand_face_decoder_layers+1)] else: pose_face_embed_layerlist = [ copy.deepcopy(_pose_face_embed) for i in range(transformer.num_decoder_layers - num_hand_face_decoder_layers + 1) ] # 4.3 face kp bbox embed _pose_face_hw_embed = MLP(hidden_dim, hidden_dim, 2, 3) # 4.4 face kp bbox embed layer list pose_face_hw_embed_layerlist = [ _pose_face_hw_embed for i in range(transformer.num_decoder_layers - num_hand_face_decoder_layers) ] ########################################################################### # smpl pose + betas + kp2d + kp3d + cam ########################################################################### # 1. smpl pose embed if body_model['type'].upper()=='SMPL': self.body_model_joint_num = 24 elif body_model['type'].upper()=='SMPLX': self.body_model_joint_num = 22 else: raise ValueError( f'Only supports SMPL or SMPLX, but get {body_model.type}') #TODO: _smpl_pose_embed = MLP(hidden_dim * (self.num_body_points + 4), hidden_dim, self.body_model_joint_num * 6, 3) nn.init.constant_(_smpl_pose_embed.layers[-1].weight.data, 0) nn.init.constant_(_smpl_pose_embed.layers[-1].bias.data, 0) if dec_pred_bbox_embed_share: smpl_pose_embed_layerlist = [ _smpl_pose_embed for i in range(transformer.num_decoder_layers - num_box_decoder_layers) ] else: smpl_pose_embed_layerlist = [ copy.deepcopy(_smpl_pose_embed) for i in range(transformer.num_decoder_layers - num_box_decoder_layers) ] # 2. smpl betas embed _smpl_beta_embed = MLP(hidden_dim * (self.num_body_points + 4), hidden_dim, 10, 3) nn.init.constant_(_smpl_beta_embed.layers[-1].weight.data, 0) nn.init.constant_(_smpl_beta_embed.layers[-1].bias.data, 0) if dec_pred_bbox_embed_share: smpl_beta_embed_layerlist = [ _smpl_beta_embed for i in range(transformer.num_decoder_layers - num_box_decoder_layers) ] else: smpl_beta_embed_layerlist = [ copy.deepcopy(_smpl_beta_embed) for i in range(transformer.num_decoder_layers - num_box_decoder_layers) ] # 3. smpl cam embed _cam_embed = MLP(hidden_dim * (self.num_body_points + 4), hidden_dim, 3, 3) nn.init.constant_(_cam_embed.layers[-1].weight.data, 0) nn.init.constant_(_cam_embed.layers[-1].bias.data, 0) if dec_pred_bbox_embed_share: cam_embed_layerlist = [ _cam_embed for i in range(transformer.num_decoder_layers - num_box_decoder_layers) ] else: cam_embed_layerlist = [ copy.deepcopy(_cam_embed) for i in range(transformer.num_decoder_layers - num_box_decoder_layers) ] ########################################################################### # smplx body pose + hand pose + expression + betas + kp2d + kp3d + cam ########################################################################### # 1. smplx body pose embed # _smplx_pose_embed = MLP(hidden_dim * (self.num_body_points + 1), # hidden_dim, 23 * 6, 3) # nn.init.constant_(_smplx_pose_embed.layers[-1].weight.data, 0) # nn.init.constant_(_smplx_pose_embed.layers[-1].bias.data, 0) # if dec_pred_bbox_embed_share: # smplx_pose_embed_layerlist = [ # _smplx_pose_embed # for i in range(transformer.num_decoder_layers - # num_box_decoder_layers + 1) # ] # else: # smplx_pose_embed_layerlist = [ # copy.deepcopy(_smplx_pose_embed) # for i in range(transformer.num_decoder_layers - # num_box_decoder_layers + 1) # ] # 2. smplx hand pose embed _smplx_hand_pose_embed_layer_2_3 = \ MLP(hidden_dim, hidden_dim, 15 * 6, 3) nn.init.constant_(_smplx_hand_pose_embed_layer_2_3.layers[-1].weight.data, 0) nn.init.constant_(_smplx_hand_pose_embed_layer_2_3.layers[-1].bias.data, 0) _smplx_hand_pose_embed_layer_4_5 = \ MLP(hidden_dim * (self.num_hand_points + 3), hidden_dim, 15 * 6, 3) nn.init.constant_(_smplx_hand_pose_embed_layer_4_5.layers[-1].weight.data, 0) nn.init.constant_(_smplx_hand_pose_embed_layer_4_5.layers[-1].bias.data, 0) if dec_pred_bbox_embed_share: smplx_hand_pose_embed_layerlist = [ _smplx_hand_pose_embed_layer_2_3 if i<2 else _smplx_hand_pose_embed_layer_4_5 for i in range(transformer.num_decoder_layers - num_box_decoder_layers) ] else: smplx_hand_pose_embed_layerlist = [ copy.deepcopy(_smplx_hand_pose_embed_layer_2_3) if i<2 else copy.deepcopy(_smplx_hand_pose_embed_layer_4_5) for i in range(transformer.num_decoder_layers - num_box_decoder_layers) ] # 3. smplx face expression _smplx_expression_embed_layer_2_3 = \ MLP(hidden_dim, hidden_dim, 10, 3) nn.init.constant_(_smplx_expression_embed_layer_2_3.layers[-1].weight.data, 0) nn.init.constant_(_smplx_expression_embed_layer_2_3.layers[-1].bias.data, 0) _smplx_expression_embed_layer_4_5 = \ MLP(hidden_dim * (self.num_hand_points + 2), hidden_dim, 10, 3) nn.init.constant_(_smplx_expression_embed_layer_4_5.layers[-1].weight.data, 0) nn.init.constant_(_smplx_expression_embed_layer_4_5.layers[-1].bias.data, 0) if dec_pred_bbox_embed_share: smplx_expression_embed_layerlist = [ _smplx_expression_embed_layer_2_3 if i<2 else _smplx_expression_embed_layer_4_5 for i in range(transformer.num_decoder_layers - num_box_decoder_layers) ] else: smplx_expression_embed_layerlist = [ copy.deepcopy(_smplx_expression_embed_layer_2_3) if i<2 else copy.deepcopy(_smplx_expression_embed_layer_4_5) for i in range(transformer.num_decoder_layers - num_box_decoder_layers) ] # _smplx_expression_embed = MLP(hidden_dim * (self.num_face_points + 2), # hidden_dim, 10, 3) # nn.init.constant_(_smplx_expression_embed.layers[-1].weight.data, 0) # nn.init.constant_(_smplx_expression_embed.layers[-1].bias.data, 0) # if dec_pred_bbox_embed_share: # smplx_expression_embed_layerlist = [ # _smplx_expression_embed # for i in range(transformer.num_decoder_layers - # num_hand_face_decoder_layers) # ] # else: # smplx_expression_embed_layerlist = [ # copy.deepcopy(_smplx_expression_embed) # for i in range(transformer.num_decoder_layers - # num_hand_face_decoder_layers) # ] # 4. smplx jaw pose embed _smplx_jaw_embed_2_3 = MLP(hidden_dim * 1, hidden_dim, 6, 3) nn.init.constant_(_smplx_jaw_embed_2_3.layers[-1].weight.data, 0) nn.init.constant_(_smplx_jaw_embed_2_3.layers[-1].bias.data, 0) _smplx_jaw_embed_4_5 = MLP(hidden_dim * (self.num_face_points + 2), hidden_dim, 6, 3) nn.init.constant_(_smplx_jaw_embed_4_5.layers[-1].weight.data, 0) nn.init.constant_(_smplx_jaw_embed_4_5.layers[-1].bias.data, 0) if dec_pred_bbox_embed_share: smplx_jaw_embed_layerlist = [ _smplx_jaw_embed_2_3 if i<2 else _smplx_jaw_embed_4_5 for i in range( transformer.num_decoder_layers - num_box_decoder_layers) ] else: smplx_jaw_embed_layerlist = [ copy.deepcopy(_smplx_jaw_embed_2_3) if i<2 else copy.deepcopy(_smplx_jaw_embed_4_5) for i in range( transformer.num_decoder_layers - num_box_decoder_layers) ] ############### self.bbox_embed = nn.ModuleList(box_body_embed_layerlist) self.class_embed = nn.ModuleList(class_embed_layerlist) self.pose_embed = nn.ModuleList(pose_embed_layerlist) self.pose_hw_embed = nn.ModuleList(pose_hw_embed_layerlist) self.transformer.decoder.bbox_embed = self.bbox_embed self.transformer.decoder.pose_embed = self.pose_embed self.transformer.decoder.pose_hw_embed = self.pose_hw_embed self.transformer.decoder.class_embed = self.class_embed # smpl self.smpl_pose_embed = nn.ModuleList(smpl_pose_embed_layerlist) self.smpl_beta_embed = nn.ModuleList(smpl_beta_embed_layerlist) self.smpl_cam_embed = nn.ModuleList(cam_embed_layerlist) # self.smpl_cam_f_embed = nn.ModuleList(f_embed_layerlist) # self.transformer.decoder.smpl_pose_embed = self.smpl_pose_embed # self.transformer.decoder.smpl_beta_embed = self.smpl_beta_embed # self.transformer.decoder.smpl_cam_embed = self.smpl_cam_embed # smplx lhand kp self.bbox_hand_embed = nn.ModuleList(box_hand_embed_layerlist) self.bbox_hand_hw_embed = nn.ModuleList(box_hand_hw_embed_layerlist) self.pose_hand_embed = nn.ModuleList(pose_hand_embed_layerlist) self.pose_hand_hw_embed = nn.ModuleList(pose_hand_hw_embed_layerlist) self.transformer.decoder.bbox_hand_embed = self.bbox_hand_embed self.transformer.decoder.bbox_hand_hw_embed = self.bbox_hand_hw_embed self.transformer.decoder.pose_hand_embed = self.pose_hand_embed self.transformer.decoder.pose_hand_hw_embed = self.pose_hand_hw_embed # smplx face kp self.bbox_face_embed = nn.ModuleList(box_face_embed_layerlist) self.bbox_face_hw_embed = nn.ModuleList(box_face_hw_embed_layerlist) self.pose_face_embed = nn.ModuleList(pose_face_embed_layerlist) self.pose_face_hw_embed = nn.ModuleList(pose_face_hw_embed_layerlist) self.transformer.decoder.bbox_face_embed = self.bbox_face_embed self.transformer.decoder.bbox_face_hw_embed = self.bbox_face_hw_embed self.transformer.decoder.pose_face_embed = self.pose_face_embed self.transformer.decoder.pose_face_hw_embed = self.pose_face_hw_embed # smplx self.smpl_hand_pose_embed = nn.ModuleList(smplx_hand_pose_embed_layerlist) # self.smplx_rhand_pose_embed = nn.ModuleList(smplx_rhand_pose_embed_layerlist) self.smpl_expr_embed = nn.ModuleList(smplx_expression_embed_layerlist) self.smpl_jaw_embed = nn.ModuleList(smplx_jaw_embed_layerlist) # self.transformer.decoder.smplx_hand_pose_embed = self.smplx_hand_pose_embed # self.transformer.decoder.smplx_rhand_pose_embed = self.smplx_rhand_pose_embed # self.transformer.decoder.num_whole_bosmpl_expr_embeddy_points = self.smplx_expression_embed # self.transformer.decoder.smpl_jaw_embed = self.smplx_jaw_embed ######### self.transformer.decoder.num_hand_face_decoder_layers = num_hand_face_decoder_layers self.transformer.decoder.num_box_decoder_layers = num_box_decoder_layers self.transformer.decoder.num_body_points = num_body_points self.transformer.decoder.num_hand_points = num_hand_points self.transformer.decoder.num_face_points = num_face_points # two stage self.two_stage_type = two_stage_type assert two_stage_type in [ 'no', 'standard' ], 'unknown param {} of two_stage_type'.format(two_stage_type) if two_stage_type != 'no': if two_stage_bbox_embed_share: assert dec_pred_class_embed_share and dec_pred_bbox_embed_share self.transformer.enc_out_bbox_embed = _bbox_embed else: self.transformer.enc_out_bbox_embed = copy.deepcopy( _bbox_embed) if two_stage_class_embed_share: assert dec_pred_class_embed_share and dec_pred_bbox_embed_share self.transformer.enc_out_class_embed = _class_embed else: self.transformer.enc_out_class_embed = copy.deepcopy( _class_embed) self.refpoint_embed = None self._reset_parameters() def get_camera_trans(self, cam_param, input_body_shape): # camera translation t_xy = cam_param[:, :2] gamma = torch.sigmoid(cam_param[:, 2]) # apply sigmoid to make it positive k_value = torch.FloatTensor( [ math.sqrt( self.focal_length[0] * self.focal_length[1] * self.camera_3d_size * self.camera_3d_size / (input_body_shape[0] * input_body_shape[1]) ) ] ).cuda().view(-1) t_z = k_value * gamma cam_trans = torch.cat((t_xy, t_z[:, None]), 1) return cam_trans def _reset_parameters(self): # init input_proj for proj in self.input_proj: nn.init.xavier_uniform_(proj[0].weight, gain=1) nn.init.constant_(proj[0].bias, 0) def prepare_for_dn2(self, targets): if not self.training: device = targets[0]['boxes'].device bs = len(targets) num_points = self.num_body_points + 4 attn_mask2 = torch.zeros( bs, self.nheads, self.num_group * num_points, self.num_group * num_points, device=device, dtype=torch.bool) group_bbox_kpt = num_points group_nobbox_kpt = self.num_body_points kpt_index = [ x for x in range(self.num_group * num_points) if x % num_points in [ 0, self.num_body_points+1, self.num_body_points+2, self.num_body_points+3 ] ] for matchj in range(self.num_group * num_points): sj = (matchj // group_bbox_kpt) * group_bbox_kpt ej = (matchj // group_bbox_kpt + 1)*group_bbox_kpt if sj > 0: attn_mask2[:, :, matchj, :sj] = True if ej < self.num_group * num_points: attn_mask2[:, :, matchj, ej:] = True for match_x in range(self.num_group * num_points): if match_x % group_bbox_kpt in [0, self.num_body_points+1, self.num_body_points+2, self.num_body_points+3]: attn_mask2[:,:,match_x,kpt_index]=False num_points = self.num_whole_body_points + 4 attn_mask3 = torch.zeros( bs, self.nheads, self.num_group * (num_points), self.num_group * (num_points), device=device, dtype=torch.bool) group_bbox_kpt = (num_points) # group_nobbox_kpt = self.num_body_points kpt_index = [ x for x in range(self.num_group * (num_points)) if x % (num_points) in [0, 1+self.num_body_points, 2+self.num_body_points+self.num_hand_points, 3+self.num_body_points+self.num_hand_points*2 ] ] for matchj in range(self.num_group * num_points): sj = (matchj // group_bbox_kpt) * group_bbox_kpt ej = (matchj // group_bbox_kpt + 1)*group_bbox_kpt if sj > 0: attn_mask3[:, :, matchj, :sj] = True if ej < self.num_group * num_points: attn_mask3[:, :, matchj, ej:] = True for match_x in range(self.num_group * num_points): if match_x % group_bbox_kpt in [ 0, 1 + self.num_body_points, 2 + self.num_body_points + self.num_hand_points, 3 + self.num_body_points + self.num_hand_points * 2]: attn_mask3[:, :, match_x, kpt_index] = False # num_points = self.num_whole_body_points + 4 # device = targets[0]['boxes'].device # bs = len(targets) # attn_mask_infere = torch.zeros( # bs, # self.nheads, # self.num_group * num_points, # self.num_group * num_points, # device=device, # dtype=torch.bool) # group_bbox_kpt = num_points # group_nobbox_kpt = self.num_body_points # kpt_index = [ # x for x in range(self.num_group * num_points) # if x % num_points == 0 # ] # for matchj in range(self.num_group * num_points): # sj = (matchj // group_bbox_kpt) * group_bbox_kpt # ej = (matchj // group_bbox_kpt + 1) * group_bbox_kpt # if sj > 0: # attn_mask_infere[:, :, matchj, :sj] = True # if ej < self.num_group * num_points: # attn_mask_infere[:, :, matchj, ej:] = True # for match_x in range(self.num_group * num_points): # if match_x % group_bbox_kpt == 0: # attn_mask_infere[:, :, match_x, kpt_index] = False # attn_mask_infere = attn_mask_infere.flatten(0, 1) attn_mask2 = attn_mask2.flatten(0, 1) attn_mask3 = attn_mask3.flatten(0, 1) return None, None, None, attn_mask2, attn_mask3, None # targets, dn_scalar, noise_scale = dn_args device = targets[0]['boxes'].device bs = len(targets) dn_number = self.dn_number # 100 dn_box_noise_scale = self.dn_box_noise_scale # 0.4 dn_label_noise_ratio = self.dn_label_noise_ratio # 0.5 # gather gt boxes and labels gt_boxes = [t['boxes'] for t in targets] gt_labels = [t['labels'] for t in targets] gt_keypoints = [t['keypoints'] for t in targets] # repeat them def get_indices_for_repeat(now_num, target_num, device='cuda'): """ Input: - now_num: int - target_num: int Output: - indices: tensor[target_num] """ out_indice = [] base_indice = torch.arange(now_num).to(device) multiplier = target_num // now_num out_indice.append(base_indice.repeat(multiplier)) residue = target_num % now_num out_indice.append(base_indice[torch.randint(0, now_num, (residue, ), device=device)]) return torch.cat(out_indice) if self.dn_batch_gt_fuse: raise NotImplementedError gt_boxes_bsall = torch.cat(gt_boxes) # num_boxes, 4 gt_labels_bsall = torch.cat(gt_labels) num_gt_bsall = gt_boxes_bsall.shape[0] if num_gt_bsall > 0: indices = get_indices_for_repeat(num_gt_bsall, dn_number, device) gt_boxes_expand = gt_boxes_bsall[indices][None].repeat( bs, 1, 1) # bs, num_dn, 4 gt_labels_expand = gt_labels_bsall[indices][None].repeat( bs, 1) # bs, num_dn else: # all negative samples when no gt boxes gt_boxes_expand = torch.rand(bs, dn_number, 4, device=device) gt_labels_expand = torch.ones( bs, dn_number, dtype=torch.int64, device=device) * int( self.num_classes) else: gt_boxes_expand = [] gt_labels_expand = [] gt_keypoints_expand = [] # here for idx, (gt_boxes_i, gt_labels_i, gt_keypoint_i) in enumerate( zip(gt_boxes, gt_labels, gt_keypoints)): # idx -> batch id num_gt_i = gt_boxes_i.shape[0] # instance num if num_gt_i > 0: indices = get_indices_for_repeat(num_gt_i, dn_number, device) gt_boxes_expand_i = gt_boxes_i[indices] # num_dn, 4 gt_labels_expand_i = gt_labels_i[indices] # add smpl gt_keypoints_expand_i = gt_keypoint_i[indices] else: # all negative samples when no gt boxes gt_boxes_expand_i = torch.rand(dn_number, 4, device=device) gt_labels_expand_i = torch.ones( dn_number, dtype=torch.int64, device=device) * int( self.num_classes) gt_keypoints_expand_i = torch.rand(dn_number, self.num_body_points * 3, device=device) gt_boxes_expand.append(gt_boxes_expand_i) # add smpl gt_labels_expand.append(gt_labels_expand_i) gt_keypoints_expand.append(gt_keypoints_expand_i) gt_boxes_expand = torch.stack(gt_boxes_expand) gt_labels_expand = torch.stack(gt_labels_expand) gt_keypoints_expand = torch.stack(gt_keypoints_expand) knwon_boxes_expand = gt_boxes_expand.clone() knwon_labels_expand = gt_labels_expand.clone() # add noise if dn_label_noise_ratio > 0: prob = torch.rand_like(knwon_labels_expand.float()) chosen_indice = prob < dn_label_noise_ratio new_label = torch.randint_like( knwon_labels_expand[chosen_indice], 0, self.dn_labelbook_size) # randomly put a new one here knwon_labels_expand[chosen_indice] = new_label if dn_box_noise_scale > 0: diff = torch.zeros_like(knwon_boxes_expand) diff[..., :2] = knwon_boxes_expand[..., 2:] / 2 diff[..., 2:] = knwon_boxes_expand[..., 2:] knwon_boxes_expand += torch.mul( (torch.rand_like(knwon_boxes_expand) * 2 - 1.0), diff) * dn_box_noise_scale knwon_boxes_expand = knwon_boxes_expand.clamp(min=0.0, max=1.0) input_query_label = self.label_enc(knwon_labels_expand) input_query_bbox = inverse_sigmoid(knwon_boxes_expand) # prepare mask if 'group2group' in self.dn_attn_mask_type_list: attn_mask = torch.zeros(bs, self.nheads, dn_number + self.num_queries, dn_number + self.num_queries, device=device, dtype=torch.bool) attn_mask[:, :, dn_number:, :dn_number] = True for idx, (gt_boxes_i, gt_labels_i) in enumerate(zip(gt_boxes, gt_labels)): # for batch num_gt_i = gt_boxes_i.shape[0] if num_gt_i == 0: continue for matchi in range(dn_number): si = (matchi // num_gt_i) * num_gt_i ei = (matchi // num_gt_i + 1) * num_gt_i if si > 0: attn_mask[idx, :, matchi, :si] = True if ei < dn_number: attn_mask[idx, :, matchi, ei:dn_number] = True attn_mask = attn_mask.flatten(0, 1) if 'group2group' in self.dn_attn_mask_type_list: # self.num_body_points = self.num_body_points +3 num_points = self.num_body_points + 4 attn_mask2 = torch.zeros( bs, self.nheads, dn_number + self.num_group * num_points, dn_number + self.num_group * num_points, device=device, dtype=torch.bool) attn_mask2[:, :, dn_number:, :dn_number] = True group_bbox_kpt = num_points # group_nobbox_kpt = self.num_body_points kpt_index = [x for x in range(self.num_group * num_points) if x % num_points in [ 0, self.num_body_points+1, self.num_body_points+2, self.num_body_points+3]] for matchj in range(self.num_group * num_points): sj = (matchj // group_bbox_kpt) * group_bbox_kpt ej = (matchj // group_bbox_kpt + 1)*group_bbox_kpt if sj > 0: attn_mask2[:, :, dn_number:, dn_number:][:, :, matchj, :sj] = True if ej < self.num_group * num_points: attn_mask2[:, :, dn_number:, dn_number:][:, :, matchj, ej:] = True for match_x in range(self.num_group * num_points): if match_x % group_bbox_kpt in [0, self.num_body_points+1, self.num_body_points+2, self.num_body_points+3]: attn_mask2[:, :, dn_number:, dn_number:][:,:,match_x,kpt_index]=False for idx, (gt_boxes_i, gt_labels_i) in enumerate(zip(gt_boxes, gt_labels)): num_gt_i = gt_boxes_i.shape[0] if num_gt_i == 0: continue for matchi in range(dn_number): si = (matchi // num_gt_i) * num_gt_i ei = (matchi // num_gt_i + 1) * num_gt_i if si > 0: attn_mask2[idx, :, matchi, :si] = True if ei < dn_number: attn_mask2[idx, :, matchi, ei:dn_number] = True attn_mask2 = attn_mask2.flatten(0, 1) if 'group2group' in self.dn_attn_mask_type_list: # self.num_body_points = self.num_body_points +3 num_points = self.num_whole_body_points + 4 attn_mask3 = torch.zeros( bs, self.nheads, dn_number + self.num_group * (num_points), dn_number + self.num_group * (num_points), device=device, dtype=torch.bool) attn_mask3[:, :, dn_number:, :dn_number] = True group_bbox_kpt = (num_points) # group_nobbox_kpt = self.num_body_points kpt_index = [ x for x in range(self.num_group * (num_points)) if x % (num_points) in [0, 1+self.num_body_points, 2+self.num_body_points+self.num_hand_points, 3+self.num_body_points+self.num_hand_points*2 ] ] for matchj in range(self.num_group * num_points): sj = (matchj // group_bbox_kpt) * group_bbox_kpt ej = (matchj // group_bbox_kpt + 1)*group_bbox_kpt if sj > 0: attn_mask3[:, :, dn_number:, dn_number:][:, :, matchj, :sj] = True if ej < self.num_group * num_points: attn_mask3[:, :, dn_number:, dn_number:][:, :, matchj, ej:] = True for match_x in range(self.num_group * num_points): if match_x % group_bbox_kpt in [0, 1 + self.num_body_points, 2 + self.num_body_points + self.num_hand_points, 3 + self.num_body_points + self.num_hand_points * 2]: attn_mask3[:, :, dn_number:, dn_number:][:,:,match_x,kpt_index]=False for idx, (gt_boxes_i, gt_labels_i) in enumerate(zip(gt_boxes, gt_labels)): num_gt_i = gt_boxes_i.shape[0] if num_gt_i == 0: continue for matchi in range(dn_number): si = (matchi // num_gt_i) * num_gt_i ei = (matchi // num_gt_i + 1) * num_gt_i if si > 0: attn_mask3[idx, :, matchi, :si] = True if ei < dn_number: attn_mask3[idx, :, matchi, ei:dn_number] = True attn_mask3 = attn_mask3.flatten(0, 1) mask_dict = { 'pad_size': dn_number, 'known_bboxs': gt_boxes_expand, 'known_labels': gt_labels_expand, 'known_keypoints': gt_keypoints_expand } return input_query_label, input_query_bbox, attn_mask, attn_mask2, attn_mask3, mask_dict def dn_post_process2(self, outputs_class, outputs_coord, outputs_body_keypoints_list, mask_dict): if mask_dict and mask_dict['pad_size'] > 0: output_known_class = [ outputs_class_i[:, :mask_dict['pad_size'], :] for outputs_class_i in outputs_class ] output_known_coord = [ outputs_coord_i[:, :mask_dict['pad_size'], :] for outputs_coord_i in outputs_coord ] outputs_class = [ outputs_class_i[:, mask_dict['pad_size']:, :] for outputs_class_i in outputs_class ] outputs_coord = [ outputs_coord_i[:, mask_dict['pad_size']:, :] for outputs_coord_i in outputs_coord ] outputs_keypoint = outputs_body_keypoints_list mask_dict.update({ 'output_known_coord': output_known_coord, 'output_known_class': output_known_class }) return outputs_class, outputs_coord, outputs_keypoint def forward(self, data_batch: NestedTensor, targets: List = None): """The forward expects a NestedTensor, which consists of: - samples.tensor: batched images, of shape [batch_size x 3 x H x W] - samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels It returns a dict with the following elements: - "pred_logits": the classification logits (including no-object) for all queries. Shape= [batch_size x num_queries x num_classes] - "pred_boxes": The normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These values are normalized in [0, 1], relative to the size of each individual image (disregarding possible padding). See PostProcess for information on how to retrieve the unnormalized bounding box. - "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of dictionnaries containing the two above keys for each decoder layer. """ if isinstance(data_batch, dict): samples, targets = self.prepare_targets(data_batch) # import pdb; pdb.set_trace() elif isinstance(data_batch, (list, torch.Tensor)): samples = nested_tensor_from_tensor_list(data_batch) else: samples = data_batch # print(samples.data['img'].shape) # exit() features, poss = self.backbone(samples) srcs = [] masks = [] for l, feat in enumerate(features): # len(features=3) src, mask = feat.decompose() srcs.append(self.input_proj[l](src)) masks.append(mask) assert mask is not None if self.num_feature_levels > len(srcs): _len_srcs = len(srcs) for l in range(_len_srcs, self.num_feature_levels): if l == _len_srcs: src = self.input_proj[l](features[-1].tensors) else: src = self.input_proj[l](srcs[-1]) m = samples.mask mask = F.interpolate(m[None].float(), size=src.shape[-2:]).to(torch.bool)[0] pos_l = self.backbone[1](NestedTensor(src, mask)).to(src.dtype) srcs.append(src) masks.append(mask) poss.append(pos_l) if self.dn_number > 0 or targets is not None: input_query_label, input_query_bbox, attn_mask,attn_mask2, attn_mask3, mask_dict =\ self.prepare_for_dn2(targets) else: assert targets is None input_query_bbox = input_query_label = attn_mask = attn_mask2 = attn_mask3 = mask_dict = None hs, reference, hs_enc, ref_enc, init_box_proposal = self.transformer( srcs, masks, input_query_bbox, poss, input_query_label, attn_mask, attn_mask2, attn_mask3) # update human boxes effective_dn_number = self.dn_number if self.training else 0 outputs_body_bbox_list = [] outputs_class = [] for dec_lid, (layer_ref_sig, layer_body_bbox_embed, layer_cls_embed, layer_hs) in enumerate( zip(reference[:-1], self.bbox_embed, self.class_embed, hs)): if dec_lid < self.num_box_decoder_layers: # human det layer_delta_unsig = layer_body_bbox_embed(layer_hs) layer_body_box_outputs_unsig = \ layer_delta_unsig + inverse_sigmoid(layer_ref_sig) layer_body_box_outputs_unsig = layer_body_box_outputs_unsig.sigmoid() layer_cls = layer_cls_embed(layer_hs) # import mmcv # import cv2 # img = (data_batch['img'][0]*255).permute(1,2,0).int().detach().cpu().numpy() # bbox = (box_ops.box_cxcywh_to_xyxy(layer_body_box_outputs_unsig[0][0]).reshape(2,2).detach().cpu().numpy()*data_batch['img_shape'].cpu().numpy()[0, ::-1]).reshape(1,4) # img = mmcv.imshow_bboxes(img.copy(), bbox, show=False) # cv2.imwrite('test.png',img) outputs_body_bbox_list.append(layer_body_box_outputs_unsig) outputs_class.append(layer_cls) elif dec_lid < self.num_box_decoder_layers + 2: bs = layer_ref_sig.shape[0] # dn body bbox layer_hs_body_bbox_dn = layer_hs[:, :effective_dn_number, :] # dn content query reference_before_sigmoid_body_bbox_dn = layer_ref_sig[:, :effective_dn_number, :] # dn position query layer_body_box_delta_unsig_dn = layer_body_bbox_embed(layer_hs_body_bbox_dn) layer_body_box_outputs_unsig_dn = layer_body_box_delta_unsig_dn + inverse_sigmoid( reference_before_sigmoid_body_bbox_dn) layer_body_box_outputs_unsig_dn = layer_body_box_outputs_unsig_dn.sigmoid() # norm body bbox layer_hs_body_bbox_norm = layer_hs[:, effective_dn_number:, :][ :, 0::(self.num_body_points + 4), :] # norm content query reference_before_sigmoid_body_bbox_norm = layer_ref_sig[:, effective_dn_number:, :][ :, 0::(self.num_body_points+ 4), :] # norm position query layer_body_box_delta_unsig_norm = layer_body_bbox_embed(layer_hs_body_bbox_norm) layer_body_box_outputs_unsig_norm = layer_body_box_delta_unsig_norm + inverse_sigmoid( reference_before_sigmoid_body_bbox_norm) layer_body_box_outputs_unsig_norm = layer_body_box_outputs_unsig_norm.sigmoid() layer_body_box_outputs_unsig = torch.cat( (layer_body_box_outputs_unsig_dn, layer_body_box_outputs_unsig_norm), dim=1) # classfication layer_cls_dn = layer_cls_embed(layer_hs_body_bbox_dn) layer_cls_norm = layer_cls_embed(layer_hs_body_bbox_norm) layer_cls = torch.cat((layer_cls_dn, layer_cls_norm), dim=1) outputs_class.append(layer_cls) outputs_body_bbox_list.append(layer_body_box_outputs_unsig) else: bs = layer_ref_sig.shape[0] # dn body bbox layer_hs_body_bbox_dn = layer_hs[:, :effective_dn_number, :] # dn content query reference_before_sigmoid_body_bbox_dn = layer_ref_sig[:, :effective_dn_number, :] # dn position query layer_body_box_delta_unsig_dn = layer_body_bbox_embed(layer_hs_body_bbox_dn) layer_body_box_outputs_unsig_dn = layer_body_box_delta_unsig_dn + inverse_sigmoid( reference_before_sigmoid_body_bbox_dn) layer_body_box_outputs_unsig_dn = layer_body_box_outputs_unsig_dn.sigmoid() # norm body bbox layer_hs_body_bbox_norm = layer_hs[:, effective_dn_number:, :][ :, 0::(self.num_whole_body_points + 4), :] # norm content query reference_before_sigmoid_body_bbox_norm = layer_ref_sig[:,effective_dn_number:, :][ :, 0::(self.num_whole_body_points + 4), :] # norm position query layer_body_box_delta_unsig_norm = layer_body_bbox_embed(layer_hs_body_bbox_norm) layer_body_box_outputs_unsig_norm = layer_body_box_delta_unsig_norm + inverse_sigmoid( reference_before_sigmoid_body_bbox_norm) layer_body_box_outputs_unsig_norm = layer_body_box_outputs_unsig_norm.sigmoid() layer_body_box_outputs_unsig = torch.cat( (layer_body_box_outputs_unsig_dn, layer_body_box_outputs_unsig_norm), dim=1) # classfication layer_cls_dn = layer_cls_embed(layer_hs_body_bbox_dn) layer_cls_norm = layer_cls_embed(layer_hs_body_bbox_norm) layer_cls = torch.cat((layer_cls_dn, layer_cls_norm), dim=1) outputs_class.append(layer_cls) outputs_body_bbox_list.append(layer_body_box_outputs_unsig) # 找query q_index = torch.topk(layer_cls_norm.max(-1)[0], 100, dim=1)[1] q_value = torch.topk(layer_cls_norm.max(-1)[0], 100, dim=1)[0] # update hand and face boxes outputs_lhand_bbox_list = [] outputs_rhand_bbox_list = [] outputs_face_bbox_list = [] # update keypoints boxes outputs_body_keypoints_list = [] outputs_body_keypoints_hw = [] outputs_lhand_keypoints_list = [] outputs_lhand_keypoints_hw = [] outputs_rhand_keypoints_list = [] outputs_rhand_keypoints_hw = [] outputs_face_keypoints_list = [] outputs_face_keypoints_hw = [] outputs_smpl_pose_list = [] outputs_smpl_lhand_pose_list = [] outputs_smpl_rhand_pose_list = [] outputs_smpl_expr_list = [] outputs_smpl_jaw_pose_list = [] outputs_smpl_beta_list = [] outputs_smpl_cam_list = [] # outputs_smpl_cam_f_list = [] outputs_smpl_kp2d_list = [] outputs_smpl_kp3d_list = [] outputs_smpl_verts_list = [] body_kpt_index = [ x for x in range(self.num_group * (self.num_body_points + 4)) if x % (self.num_body_points + 4) in range(1,self.num_body_points+1) ] body_kpt_index_2 = [ x for x in range(self.num_group * (self.num_whole_body_points + 4)) if (x % (self.num_whole_body_points + 4) in range(1,self.num_body_points+1)) ] lhand_bbox_index = [ x for x in range(self.num_group * (self.num_body_points + 4)) if x % (self.num_body_points + 4) != 1 ] lhand_kpt_index = [ x for x in range(self.num_group * (self.num_whole_body_points + 4)) if (x % (self.num_whole_body_points + 4) in range( self.num_body_points+2, self.num_body_points+self.num_hand_points+2))] rhand_bbox_index = [ x for x in range(self.num_group * (self.num_body_points + 4)) if x % (self.num_body_points + 4) != 2 ] rhand_kpt_index = [ x for x in range(self.num_group * (self.num_whole_body_points + 4)) if (x % (self.num_whole_body_points + 4) in range( self.num_body_points+self.num_hand_points+3, self.num_body_points+self.num_hand_points*2+3)) ] face_bbox_index = [ x for x in range(self.num_group * (self.num_body_points + 4)) if x % (self.num_body_points + 4) != 3 ] face_kpt_index = [ x for x in range(self.num_group * (self.num_whole_body_points + 4)) if (x % (self.num_whole_body_points + 4) in range( self.num_body_points+self.num_hand_points*2+4, self.num_body_points+self.num_hand_points*2+self.num_face_points+4)) ] # smpl pose # body box, kps, lhand box body_index = list(range(0,self.num_body_points+2)) # rhand box and face box body_index.extend( [self.num_body_points + self.num_hand_points + 2, self.num_body_points + 2 * self.num_hand_points + 3] ) smpl_pose_index = [ x for x in range(self.num_group * (self.num_whole_body_points + 4)) if (x % (self.num_whole_body_points + 4) in body_index) ] # smpl lhand lhand_index = list(range(self.num_body_points+1, self.num_body_points+self.num_hand_points+3)) # body box lhand_index.insert(0, 0) smpl_lhand_pose_index = [ x for x in range(self.num_group * (self.num_whole_body_points + 4)) if (x % (self.num_whole_body_points + 4) in lhand_index)] # smpl rhand rhand_index = list(range(self.num_body_points + self.num_hand_points + 2, self.num_body_points + self.num_hand_points * 2 +3)) rhand_index.insert(0,self.num_body_points+1) rhand_index.insert(0,0) smpl_rhand_pose_index = [ x for x in range(self.num_group * (self.num_whole_body_points + 4)) if (x % (self.num_whole_body_points + 4) in rhand_index)] # smpl face face_index = list(range(self.num_body_points + self.num_hand_points * 2 + 3, self.num_body_points + self.num_hand_points * 2 + self.num_face_points + 4)) face_index.insert(0,0) smpl_face_pose_index = [ x for x in range(self.num_group * (self.num_whole_body_points + 4)) if (x % (self.num_whole_body_points + 4) in face_index)] for dec_lid, (layer_ref_sig, layer_hs) in enumerate(zip(reference[:-1], hs)): if dec_lid < self.num_box_decoder_layers: assert isinstance(layer_hs, torch.Tensor) bs = layer_hs.shape[0] layer_body_kps_res = layer_hs.new_zeros( (bs, self.num_queries, self.num_body_points * 3)) # [-, 900, 42] outputs_body_keypoints_list.append(layer_body_kps_res) # lhand layer_lhand_bbox_res = layer_hs.new_zeros( (bs, self.num_queries, 4)) # [-, 900, 42] outputs_lhand_bbox_list.append(layer_lhand_bbox_res) layer_lhand_kps_res = layer_hs.new_zeros( (bs, self.num_queries, self.num_hand_points * 3)) # [-, 900, 42] outputs_lhand_keypoints_list.append(layer_lhand_kps_res) # rhand layer_rhand_bbox_res = layer_hs.new_zeros( (bs, self.num_queries, 4)) # [-, 900, 42] outputs_rhand_bbox_list.append(layer_rhand_bbox_res) layer_rhand_kps_res = layer_hs.new_zeros( (bs, self.num_queries, self.num_hand_points * 3)) # [-, 900, 42] outputs_rhand_keypoints_list.append(layer_rhand_kps_res) # face layer_face_bbox_res = layer_hs.new_zeros( (bs, self.num_queries, 4)) # [-, 900, 42] outputs_face_bbox_list.append(layer_face_bbox_res) layer_face_kps_res = layer_hs.new_zeros( (bs, self.num_queries, self.num_face_points * 3)) # [-, 900, 42] outputs_face_keypoints_list.append(layer_face_kps_res) # smpl or smplx smpl_pose = layer_hs.new_zeros((bs, self.num_queries, self.body_model_joint_num * 3)) smpl_rhand_pose = layer_hs.new_zeros( (bs, self.num_queries, 15 * 3)) smpl_lhand_pose = layer_hs.new_zeros( (bs, self.num_queries, 15 * 3)) smpl_expr = layer_hs.new_zeros((bs, self.num_queries, 10)) smpl_jaw_pose = layer_hs.new_zeros((bs, self.num_queries, 6)) smpl_beta = layer_hs.new_zeros((bs, self.num_queries, 10)) smpl_cam = layer_hs.new_zeros((bs, self.num_queries, 3)) # smpl_cam_f = layer_hs.new_zeros((bs, self.num_queries, 1)) # smpl_kp2d = layer_hs.new_zeros((bs, self.num_queries, self.num_body_points,3)) smpl_kp3d = layer_hs.new_zeros( (bs, self.num_queries, self.num_body_points, 4)) outputs_smpl_pose_list.append(smpl_pose) outputs_smpl_rhand_pose_list.append(smpl_rhand_pose) outputs_smpl_lhand_pose_list.append(smpl_lhand_pose) outputs_smpl_expr_list.append(smpl_expr) outputs_smpl_jaw_pose_list.append(smpl_jaw_pose) outputs_smpl_beta_list.append(smpl_beta) outputs_smpl_cam_list.append(smpl_cam) # outputs_smpl_cam_f_list.append(smpl_cam_f) # outputs_smpl_kp2d_list.append(smpl_kp2d) outputs_smpl_kp3d_list.append(smpl_kp3d) elif dec_lid < self.num_box_decoder_layers +2: bs = layer_ref_sig.shape[0] layer_hs_body_kpt = \ layer_hs[:, effective_dn_number:, :].index_select( 1, torch.tensor(body_kpt_index, device=layer_hs.device)) # body kp2d delta_body_kp_xy_unsig = \ self.pose_embed[dec_lid - self.num_box_decoder_layers](layer_hs_body_kpt) layer_ref_sig_body_kpt = \ layer_ref_sig[:,effective_dn_number:, :].index_select(1,torch.tensor(body_kpt_index,device=layer_hs.device)) layer_outputs_unsig_body_keypoints = delta_body_kp_xy_unsig + inverse_sigmoid( layer_ref_sig_body_kpt[..., :2]) vis_xy_unsig = torch.ones_like( layer_outputs_unsig_body_keypoints, device=layer_outputs_unsig_body_keypoints.device) xyv = torch.cat((layer_outputs_unsig_body_keypoints, vis_xy_unsig[:, :, 0].unsqueeze(-1)), dim=-1) xyv = xyv.sigmoid() # from detrsmpl.core.visualization.visualize_keypoints2d import visualize_kp2d # img =(data_batch['img'][0].permute(1,2,0)*255).int().cpu().numpy() # gt_kp2d = xyv[0,:17] # coco_kps = gt_kp2d[:,:2].reshape(17,2).detach().cpu().numpy() * data_batch['img_shape'].cpu().numpy()[0,None,None,::-1] # visualize_kp2d( # coco_kps, # output_path='.', # image_array=img.copy()[None], # data_source='coco', # overwrite=True) layer_res = xyv.reshape( (bs, self.num_group, self.num_body_points, 3)).flatten(2, 3) layer_hw = layer_ref_sig_body_kpt[..., 2:].reshape( bs, self.num_group, self.num_body_points, 2).flatten(2, 3) layer_res = keypoint_xyzxyz_to_xyxyzz(layer_res) outputs_body_keypoints_list.append(layer_res) outputs_body_keypoints_hw.append(layer_hw) # lhand bbox layer_hs_lhand_bbox = \ layer_hs[:, effective_dn_number:, :][:, (self.num_body_points + 1)::(self.num_body_points + 4), :] delta_lhand_bbox_xy_unsig = self.bbox_hand_embed[dec_lid - self.num_box_decoder_layers](layer_hs_lhand_bbox) layer_ref_sig_lhand_bbox = \ layer_ref_sig[:,effective_dn_number:, :][ :, (self.num_body_points + 1)::(self.num_body_points + 4), :].clone() layer_ref_unsig_lhand_bbox = inverse_sigmoid(layer_ref_sig_lhand_bbox) delta_lhand_bbox_hw_unsig = self.bbox_hand_hw_embed[ dec_lid-self.num_box_decoder_layers](layer_hs_lhand_bbox) layer_ref_unsig_lhand_bbox[..., :2] +=delta_lhand_bbox_xy_unsig[..., :2] layer_ref_unsig_lhand_bbox[..., 2:] +=delta_lhand_bbox_hw_unsig layer_ref_sig_lhand_bbox = layer_ref_unsig_lhand_bbox.sigmoid() outputs_lhand_bbox_list.append(layer_ref_sig_lhand_bbox) layer_lhand_kps_res = layer_hs.new_zeros( (bs, self.num_queries, self.num_hand_points * 3)) # [-, 900, 42] outputs_lhand_keypoints_list.append(layer_lhand_kps_res) # rhand bbox layer_hs_rhand_bbox = \ layer_hs[:, effective_dn_number:, :][ :, (self.num_body_points + 2)::(self.num_body_points + 4), :] delta_rhand_bbox_xy_unsig = self.bbox_hand_embed[ dec_lid - self.num_box_decoder_layers](layer_hs_rhand_bbox) layer_ref_sig_rhand_bbox = \ layer_ref_sig[:,effective_dn_number:, :][ :, (self.num_body_points + 2)::(self.num_body_points + 4), :].clone() layer_ref_unsig_rhand_bbox = inverse_sigmoid(layer_ref_sig_rhand_bbox) delta_rhand_bbox_hw_unsig = self.bbox_hand_hw_embed[ dec_lid-self.num_box_decoder_layers](layer_hs_rhand_bbox) layer_ref_unsig_rhand_bbox[..., :2] +=delta_rhand_bbox_xy_unsig[..., :2] layer_ref_unsig_rhand_bbox[..., 2:] +=delta_rhand_bbox_hw_unsig layer_ref_sig_rhand_bbox = layer_ref_unsig_rhand_bbox.sigmoid() outputs_rhand_bbox_list.append(layer_ref_sig_rhand_bbox) # rhand kps layer_rhand_kps_res = layer_hs.new_zeros( (bs, self.num_queries, self.num_hand_points * 3)) # [-, 900, 42] outputs_rhand_keypoints_list.append(layer_rhand_kps_res) # face bbox layer_hs_face_bbox = \ layer_hs[:, effective_dn_number:, :][ :, (self.num_body_points + 3)::(self.num_body_points + 4), :] delta_face_bbox_xy_unsig = self.bbox_face_embed[ dec_lid - self.num_box_decoder_layers](layer_hs_face_bbox) layer_ref_sig_face_bbox = \ layer_ref_sig[:,effective_dn_number:, :][ :, (self.num_body_points + 3)::(self.num_body_points + 4), :].clone() layer_ref_unsig_face_bbox = inverse_sigmoid(layer_ref_sig_face_bbox) delta_face_bbox_hw_unsig = self.bbox_face_hw_embed[ dec_lid-self.num_box_decoder_layers](layer_hs_face_bbox) layer_ref_unsig_face_bbox[..., :2] +=delta_face_bbox_xy_unsig[..., :2] layer_ref_unsig_face_bbox[..., 2:] +=delta_face_bbox_hw_unsig layer_ref_sig_face_bbox = layer_ref_unsig_face_bbox.sigmoid() outputs_face_bbox_list.append(layer_ref_sig_face_bbox) # face kps layer_face_kps_res = layer_hs.new_zeros( (bs, self.num_queries, self.num_face_points * 3)) # [-, 900, 42] outputs_face_keypoints_list.append(layer_face_kps_res) # smpl or smplx bs, _, feat_dim = layer_hs.shape smpl_feats = layer_hs[:, effective_dn_number:, :].reshape( bs, -1, feat_dim * (self.num_body_points + 4)) smpl_lhand_pose_feats = layer_hs[:, effective_dn_number:, :][ :, (self.num_body_points + 1):: (self.num_body_points + 4), :].reshape( bs, -1, feat_dim) smpl_rhand_pose_feats = layer_hs[:, effective_dn_number:, :][ :, (self.num_body_points + 2):: (self.num_body_points + 4), :].reshape( bs, -1, feat_dim) smpl_face_pose_feats = layer_hs[:, effective_dn_number:, :][ :, (self.num_body_points + 3):: (self.num_body_points + 4), :].reshape( bs, -1, feat_dim) smpl_pose = self.smpl_pose_embed[ dec_lid - self.num_box_decoder_layers](smpl_feats) smpl_pose = rot6d_to_rotmat(smpl_pose.reshape(-1, 6)).reshape( bs, self.num_group, self.body_model_joint_num, 3, 3) smpl_lhand_pose = self.smpl_hand_pose_embed[ dec_lid - self.num_box_decoder_layers](smpl_lhand_pose_feats) smpl_lhand_pose = rot6d_to_rotmat(smpl_lhand_pose.reshape( -1, 6)).reshape(bs, self.num_group, 15, 3, 3) smpl_rhand_pose = self.smpl_hand_pose_embed[ dec_lid - self.num_box_decoder_layers](smpl_rhand_pose_feats) smpl_rhand_pose = rot6d_to_rotmat(smpl_rhand_pose.reshape( -1, 6)).reshape(bs, self.num_group, 15, 3, 3) smpl_jaw_pose = self.smpl_jaw_embed[ dec_lid - self.num_box_decoder_layers](smpl_face_pose_feats) smpl_jaw_pose = rot6d_to_rotmat(smpl_jaw_pose.reshape(-1, 6)).reshape( bs, self.num_group, 1, 3, 3) smpl_beta = self.smpl_beta_embed[ dec_lid - self.num_box_decoder_layers](smpl_feats) smpl_cam = self.smpl_cam_embed[ dec_lid - self.num_box_decoder_layers](smpl_feats) # smpl_cam_f = self.smpl_cam_f_embed[ # dec_lid - self.num_box_decoder_layers](smpl_feats) # zero # smpl_lhand_pose = layer_hs.new_zeros(bs, self.num_group, 15, 3, 3) # smpl_rhand_pose = layer_hs.new_zeros(bs, self.num_group, 15, 3, 3) # smpl_expr = layer_hs.new_zeros(bs, self.num_group, 10) smpl_expr = self.smpl_expr_embed[ dec_lid - self.num_box_decoder_layers](smpl_face_pose_feats) # smpl_jaw_pose = layer_hs.new_zeros(bs, self.num_group, 3) leye_pose = torch.zeros_like(smpl_jaw_pose) reye_pose = torch.zeros_like(smpl_jaw_pose) if self.body_model is not None: smpl_pose_ = rotmat_to_aa(smpl_pose) # smpl_lhand_pose_ = rotmat_to_aa(smpl_lhand_pose) # smpl_rhand_pose_ = rotmat_to_aa(smpl_rhand_pose) smpl_lhand_pose_ = layer_hs.new_zeros(bs, self.num_group, 15, 3) smpl_rhand_pose_ = layer_hs.new_zeros(bs, self.num_group, 15, 3) smpl_jaw_pose_ = rotmat_to_aa(smpl_jaw_pose) leye_pose_ = rotmat_to_aa(leye_pose) reye_pose_ = rotmat_to_aa(reye_pose) pred_output = self.body_model( betas=smpl_beta.reshape(-1, 10), body_pose=smpl_pose_[:, :, 1:].reshape(-1, 21 * 3), global_orient=smpl_pose_[:, :, 0].reshape( -1, 3).unsqueeze(1), left_hand_pose=smpl_lhand_pose_.reshape(-1, 15 * 3), right_hand_pose=smpl_rhand_pose_.reshape(-1, 15 * 3), leye_pose=leye_pose_, reye_pose=reye_pose_, jaw_pose=smpl_jaw_pose_.reshape(-1, 3), # expression=smpl_expr.reshape(-1, 10), expression=layer_hs.new_zeros(bs, self.num_group, 10).reshape(-1, 10) ) smpl_kp3d = pred_output['joints'].reshape( bs, self.num_group, -1, 3) smpl_verts = pred_output['vertices'].reshape( bs, self.num_group, -1, 3) # pred_vertices = pred_output['vertices'].reshape(bs, -1, 6890, 3) outputs_smpl_pose_list.append(smpl_pose) outputs_smpl_rhand_pose_list.append(smpl_rhand_pose) outputs_smpl_lhand_pose_list.append(smpl_lhand_pose) outputs_smpl_expr_list.append(smpl_expr) outputs_smpl_jaw_pose_list.append(smpl_jaw_pose) outputs_smpl_beta_list.append(smpl_beta) outputs_smpl_cam_list.append(smpl_cam) # outputs_smpl_cam_f_list.append(smpl_cam_f) outputs_smpl_kp3d_list.append(smpl_kp3d) else: bs = layer_ref_sig.shape[0] layer_hs_body_kpt = \ layer_hs[:, effective_dn_number:, :].index_select( 1, torch.tensor(body_kpt_index_2, device=layer_hs.device)) # body kp2d delta_body_kp_xy_unsig = \ self.pose_embed[ dec_lid - self.num_box_decoder_layers](layer_hs_body_kpt) layer_ref_sig_body_kpt = \ layer_ref_sig[:,effective_dn_number:, :].index_select( 1,torch.tensor(body_kpt_index_2,device=layer_hs.device)) layer_outputs_unsig_body_keypoints = \ delta_body_kp_xy_unsig + inverse_sigmoid( layer_ref_sig_body_kpt[..., :2]) vis_xy_unsig = torch.ones_like( layer_outputs_unsig_body_keypoints, device=layer_outputs_unsig_body_keypoints.device) xyv = torch.cat((layer_outputs_unsig_body_keypoints, vis_xy_unsig[:, :, 0].unsqueeze(-1)), dim=-1) xyv = xyv.sigmoid() layer_res = xyv.reshape( (bs, self.num_group, self.num_body_points, 3)).flatten(2, 3) layer_hw = layer_ref_sig_body_kpt[..., 2:].reshape( bs, self.num_group, self.num_body_points, 2).flatten(2, 3) layer_res = keypoint_xyzxyz_to_xyxyzz(layer_res) outputs_body_keypoints_list.append(layer_res) outputs_body_keypoints_hw.append(layer_hw) # lhand bbox layer_hs_lhand_bbox = \ layer_hs[:, effective_dn_number:, :][ :, (self.num_body_points + 1)::(self.num_whole_body_points + 4), :] delta_lhand_bbox_xy_unsig = self.bbox_hand_embed[ dec_lid - self.num_box_decoder_layers](layer_hs_lhand_bbox) layer_ref_sig_lhand_bbox = \ layer_ref_sig[:,effective_dn_number:, :][ :, (self.num_body_points + 1)::(self.num_whole_body_points + 4), :].clone() layer_ref_unsig_lhand_bbox = inverse_sigmoid(layer_ref_sig_lhand_bbox) delta_lhand_bbox_hw_unsig = self.bbox_hand_hw_embed[ dec_lid-self.num_box_decoder_layers](layer_hs_lhand_bbox) layer_ref_unsig_lhand_bbox[..., :2] +=delta_lhand_bbox_xy_unsig[..., :2] layer_ref_unsig_lhand_bbox[..., 2:] +=delta_lhand_bbox_hw_unsig layer_ref_sig_lhand_bbox = layer_ref_unsig_lhand_bbox.sigmoid() outputs_lhand_bbox_list.append(layer_ref_sig_lhand_bbox) # lhand kps layer_hs_lhand_kps_res = \ layer_hs[:, effective_dn_number:, :].index_select( 1, torch.tensor(lhand_kpt_index, device=layer_hs.device)) delta_lhand_kp_xy_unsig = \ self.pose_hand_embed[ dec_lid - self.num_hand_face_decoder_layers](layer_hs_lhand_kps_res) layer_ref_sig_lhand_kpt = \ layer_ref_sig[:,effective_dn_number:, :].index_select( 1,torch.tensor(lhand_kpt_index,device=layer_hs.device)) layer_outputs_unsig_lhand_keypoints = delta_lhand_kp_xy_unsig + inverse_sigmoid( layer_ref_sig_lhand_kpt[..., :2]) lhand_vis_xy_unsig = torch.ones_like( layer_outputs_unsig_lhand_keypoints, device=layer_outputs_unsig_lhand_keypoints.device) lhand_xyv = torch.cat((layer_outputs_unsig_lhand_keypoints, lhand_vis_xy_unsig[:, :, 0].unsqueeze(-1)), dim=-1) lhand_xyv = lhand_xyv.sigmoid() layer_lhand_kps_res = lhand_xyv.reshape( (bs, self.num_group, self.num_hand_points, 3)).flatten(2, 3) layer_lhand_hw = layer_ref_sig_lhand_kpt[..., 2:].reshape( bs, self.num_group, self.num_hand_points, 2).flatten(2, 3) layer_lhand_kps_res = keypoint_xyzxyz_to_xyxyzz(layer_lhand_kps_res) outputs_lhand_keypoints_list.append(layer_lhand_kps_res) outputs_lhand_keypoints_hw.append(layer_lhand_hw) # rhand bbox layer_hs_rhand_bbox = \ layer_hs[:, effective_dn_number:, :][ :, (self.num_body_points + self.num_hand_points + 2)::(self.num_whole_body_points + 4), :] delta_rhand_bbox_xy_unsig = self.bbox_hand_embed[ dec_lid - self.num_box_decoder_layers](layer_hs_rhand_bbox) layer_ref_sig_rhand_bbox = \ layer_ref_sig[:,effective_dn_number:, :][ :, (self.num_body_points + self.num_hand_points + 2)::(self.num_whole_body_points + 4), :].clone() layer_ref_unsig_rhand_bbox = inverse_sigmoid(layer_ref_sig_rhand_bbox) delta_rhand_bbox_hw_unsig = self.bbox_hand_hw_embed[ dec_lid-self.num_box_decoder_layers](layer_hs_rhand_bbox) layer_ref_unsig_rhand_bbox[..., :2] +=delta_rhand_bbox_xy_unsig[..., :2] layer_ref_unsig_rhand_bbox[..., 2:] +=delta_rhand_bbox_hw_unsig layer_ref_sig_rhand_bbox = layer_ref_unsig_rhand_bbox.sigmoid() outputs_rhand_bbox_list.append(layer_ref_sig_rhand_bbox) # rhand kps layer_hs_rhand_kps_res = \ layer_hs[:, effective_dn_number:, :].index_select( 1, torch.tensor(rhand_kpt_index, device=layer_hs.device)) delta_rhand_kp_xy_unsig = \ self.pose_hand_embed[ dec_lid - self.num_hand_face_decoder_layers](layer_hs_rhand_kps_res) layer_ref_sig_rhand_kpt = \ layer_ref_sig[:,effective_dn_number:, :].index_select( 1,torch.tensor(rhand_kpt_index,device=layer_hs.device)) layer_outputs_unsig_rhand_keypoints = delta_rhand_kp_xy_unsig + inverse_sigmoid( layer_ref_sig_rhand_kpt[..., :2]) rhand_vis_xy_unsig = torch.ones_like( layer_outputs_unsig_rhand_keypoints, device=layer_outputs_unsig_rhand_keypoints.device) rhand_xyv = torch.cat((layer_outputs_unsig_rhand_keypoints, rhand_vis_xy_unsig[:, :, 0].unsqueeze(-1)), dim=-1) rhand_xyv = rhand_xyv.sigmoid() layer_rhand_kps_res = rhand_xyv.reshape( (bs, self.num_group, self.num_hand_points, 3)).flatten(2, 3) layer_rhand_hw = layer_ref_sig_rhand_kpt[..., 2:].reshape( bs, self.num_group, self.num_hand_points, 2).flatten(2, 3) layer_rhand_kps_res = keypoint_xyzxyz_to_xyxyzz(layer_rhand_kps_res) outputs_rhand_keypoints_list.append(layer_rhand_kps_res) outputs_rhand_keypoints_hw.append(layer_rhand_hw) # face bbox layer_hs_face_bbox = \ layer_hs[:, effective_dn_number:, :][ :, (self.num_body_points + 2 * self.num_hand_points + 3)::(self.num_whole_body_points + 4), :] delta_face_bbox_xy_unsig = self.bbox_face_embed[dec_lid - self.num_box_decoder_layers](layer_hs_face_bbox) layer_ref_sig_face_bbox = \ layer_ref_sig[:,effective_dn_number:, :][ :, (self.num_body_points + 2 * self.num_hand_points + 3)::(self.num_whole_body_points + 4), :].clone() layer_ref_unsig_face_bbox = inverse_sigmoid(layer_ref_sig_face_bbox) delta_face_bbox_hw_unsig = self.bbox_face_hw_embed[ dec_lid-self.num_box_decoder_layers](layer_hs_face_bbox) layer_ref_unsig_face_bbox[..., :2] +=delta_face_bbox_xy_unsig[..., :2] layer_ref_unsig_face_bbox[..., 2:] +=delta_face_bbox_hw_unsig layer_ref_sig_face_bbox = layer_ref_unsig_face_bbox.sigmoid() outputs_face_bbox_list.append(layer_ref_sig_face_bbox) # face kps layer_hs_face_kps_res = \ layer_hs[:, effective_dn_number:, :].index_select( 1, torch.tensor(face_kpt_index, device=layer_hs.device)) delta_face_kp_xy_unsig = \ self.pose_face_embed[ dec_lid - self.num_hand_face_decoder_layers](layer_hs_face_kps_res) layer_ref_sig_face_kpt = \ layer_ref_sig[:,effective_dn_number:, :].index_select( 1,torch.tensor(face_kpt_index,device=layer_hs.device)) layer_outputs_unsig_face_keypoints = delta_face_kp_xy_unsig + inverse_sigmoid( layer_ref_sig_face_kpt[..., :2]) face_vis_xy_unsig = torch.ones_like( layer_outputs_unsig_face_keypoints, device=layer_outputs_unsig_face_keypoints.device) face_xyv = torch.cat((layer_outputs_unsig_face_keypoints, face_vis_xy_unsig[:, :, 0].unsqueeze(-1)), dim=-1) face_xyv = face_xyv.sigmoid() layer_face_kps_res = face_xyv.reshape( (bs, self.num_group, self.num_face_points, 3)).flatten(2, 3) layer_face_hw = layer_ref_sig_face_kpt[..., 2:].reshape( bs, self.num_group, self.num_face_points, 2).flatten(2, 3) layer_face_kps_res = keypoint_xyzxyz_to_xyxyzz(layer_face_kps_res) outputs_face_keypoints_list.append(layer_face_kps_res) outputs_face_keypoints_hw.append(layer_face_hw) # pdb.set_trace() bs, _, feat_dim = layer_hs.shape smpl_body_pose_feats = layer_hs[:, effective_dn_number:, :].index_select( 1, torch.tensor(smpl_pose_index, device=layer_hs.device) ).reshape(bs, -1, feat_dim * (self.num_body_points + 4)) smpl_lhand_pose_feats = layer_hs[:, effective_dn_number:, :].index_select( 1, torch.tensor(smpl_lhand_pose_index, device=layer_hs.device) ).reshape(bs, -1, feat_dim * (self.num_hand_points + 3)) smpl_rhand_pose_feats = layer_hs[:, effective_dn_number:, :].index_select( 1, torch.tensor(smpl_rhand_pose_index, device=layer_hs.device) ).reshape(bs, -1, feat_dim * (self.num_hand_points + 3)) smpl_face_pose_feats = layer_hs[:, effective_dn_number:, :].index_select( 1, torch.tensor(smpl_face_pose_index, device=layer_hs.device) ).reshape(bs, -1, feat_dim * (self.num_face_points + 2)) smpl_pose = self.smpl_pose_embed[ dec_lid - self.num_box_decoder_layers](smpl_body_pose_feats) smpl_pose = rot6d_to_rotmat(smpl_pose.reshape(-1, 6)).reshape( bs, self.num_group, self.body_model_joint_num, 3, 3) smpl_lhand_pose = self.smpl_hand_pose_embed[ dec_lid - self.num_box_decoder_layers](smpl_lhand_pose_feats) smpl_lhand_pose = rot6d_to_rotmat(smpl_lhand_pose.reshape( -1, 6)).reshape(bs, self.num_group, 15, 3, 3) smpl_rhand_pose = self.smpl_hand_pose_embed[ dec_lid - self.num_box_decoder_layers](smpl_rhand_pose_feats) smpl_rhand_pose = rot6d_to_rotmat(smpl_rhand_pose.reshape( -1, 6)).reshape(bs, self.num_group, 15, 3, 3) smpl_expr = self.smpl_expr_embed[ dec_lid - self.num_box_decoder_layers](smpl_face_pose_feats) smpl_jaw_pose = self.smpl_jaw_embed[ dec_lid - self.num_box_decoder_layers](smpl_face_pose_feats) smpl_jaw_pose = rot6d_to_rotmat(smpl_jaw_pose.reshape(-1, 6)).reshape( bs, self.num_group, 1, 3, 3) smpl_beta = self.smpl_beta_embed[ dec_lid - self.num_box_decoder_layers](smpl_body_pose_feats) smpl_cam = self.smpl_cam_embed[ dec_lid - self.num_box_decoder_layers](smpl_body_pose_feats) # smpl_cam_f = self.smpl_cam_f_embed[ # dec_lid - self.num_box_decoder_layers](smpl_body_pose_feats) num_samples = smpl_beta.reshape(-1, 10).shape[0] device = smpl_beta.device leye_pose = torch.zeros_like(smpl_jaw_pose) reye_pose = torch.zeros_like(smpl_jaw_pose) if self.body_model is not None: # print(smpl_pose) # exit() smpl_pose_ = rotmat_to_aa(smpl_pose) smpl_lhand_pose_ = rotmat_to_aa(smpl_lhand_pose) smpl_rhand_pose_ = rotmat_to_aa(smpl_rhand_pose) smpl_jaw_pose_ = rotmat_to_aa(smpl_jaw_pose) leye_pose_ = rotmat_to_aa(leye_pose) reye_pose_ = rotmat_to_aa(reye_pose) pred_output = self.body_model( betas=smpl_beta.reshape(-1, 10), body_pose=smpl_pose_[:, :, 1:].reshape(-1, 21 * 3), global_orient=smpl_pose_[:, :, 0].reshape( -1, 3).unsqueeze(1), left_hand_pose=smpl_lhand_pose_.reshape(-1, 15 * 3), right_hand_pose=smpl_rhand_pose_.reshape(-1, 15 * 3), leye_pose=leye_pose_, reye_pose=reye_pose_, jaw_pose=smpl_jaw_pose_.reshape(-1, 3), expression=smpl_expr.reshape(-1, 10), # expression=layer_hs.new_zeros(bs, self.num_group, 10).reshape(-1, 10), ) smpl_kp3d = pred_output['joints'].reshape( bs, self.num_group, -1, 3) smpl_verts = pred_output['vertices'].reshape( bs, self.num_group, -1, 3) # pred_vertices = pred_output['vertices'].reshape(bs, -1, 6890, 3) # from detrsmpl.core.visualization.visualize_keypoints3d import visualize_kp3d # visualize_kp3d(smpl_kp3d[0,:100].detach().cpu().numpy(), # output_path='./figs/pred3d', # data_source='smplx_137') # import numpy as np # from pytorch3d.io import save_obj # save_obj( # '1.obj', # torch.tensor(pred_output['vertices'][0]), # torch.tensor(self.body_model.faces.astype(np.float))) # exit() outputs_smpl_pose_list.append(smpl_pose) outputs_smpl_rhand_pose_list.append(smpl_rhand_pose) outputs_smpl_lhand_pose_list.append(smpl_lhand_pose) outputs_smpl_expr_list.append(smpl_expr) outputs_smpl_jaw_pose_list.append(smpl_jaw_pose) outputs_smpl_beta_list.append(smpl_beta) outputs_smpl_cam_list.append(smpl_cam) # outputs_smpl_cam_f_list.append(smpl_cam_f) outputs_smpl_kp3d_list.append(smpl_kp3d) if not self.training: outputs_smpl_verts_list.append(smpl_verts) dn_mask_dict = mask_dict if self.dn_number > 0 and dn_mask_dict is not None: outputs_class, outputs_body_bbox_list, outputs_body_keypoints_list = self.dn_post_process2( outputs_class, outputs_body_bbox_list, outputs_body_keypoints_list, dn_mask_dict) dn_class_input = dn_mask_dict['known_labels'] dn_bbox_input = dn_mask_dict['known_bboxs'] dn_class_pred = dn_mask_dict['output_known_class'] dn_bbox_pred = dn_mask_dict['output_known_coord'] for idx, (_out_class, _out_bbox, _out_keypoint) in enumerate( zip(outputs_class, outputs_body_bbox_list, outputs_body_keypoints_list)): assert _out_class.shape[1] == _out_bbox.shape[ 1] == _out_keypoint.shape[1] out = { 'pred_logits': outputs_class[-1], 'pred_boxes': outputs_body_bbox_list[-1], 'pred_lhand_boxes': outputs_lhand_bbox_list[-1], 'pred_rhand_boxes': outputs_rhand_bbox_list[-1], 'pred_face_boxes': outputs_face_bbox_list[-1], 'pred_keypoints': outputs_body_keypoints_list[-1], 'pred_lhand_keypoints': outputs_lhand_keypoints_list[-1], 'pred_rhand_keypoints': outputs_rhand_keypoints_list[-1], 'pred_face_keypoints': outputs_face_keypoints_list[-1], 'pred_smpl_pose': outputs_smpl_pose_list[-1], 'pred_smpl_rhand_pose': outputs_smpl_rhand_pose_list[-1], 'pred_smpl_lhand_pose': outputs_smpl_lhand_pose_list[-1], 'pred_smpl_jaw_pose': outputs_smpl_jaw_pose_list[-1], 'pred_smpl_expr': outputs_smpl_expr_list[-1], 'pred_smpl_beta': outputs_smpl_beta_list[-1], # [B, 100, 10] 'pred_smpl_cam': outputs_smpl_cam_list[-1], # 'pred_smpl_cam_f': outputs_smpl_cam_f_list[-1], 'pred_smpl_kp3d': outputs_smpl_kp3d_list[-1] } if not self.training: full_pose = torch.cat((outputs_smpl_pose_list[-1], outputs_smpl_lhand_pose_list[-1], outputs_smpl_rhand_pose_list[-1], outputs_smpl_jaw_pose_list[-1]),dim=2) bs,num_q,_,_,_ = full_pose.shape full_pose = rotmat_to_aa(full_pose).reshape(bs,num_q,53*3) out = { 'pred_logits': outputs_class[-1], 'pred_boxes': outputs_body_bbox_list[-1], 'pred_lhand_boxes': outputs_lhand_bbox_list[-1], 'pred_rhand_boxes': outputs_rhand_bbox_list[-1], 'pred_face_boxes': outputs_face_bbox_list[-1], 'pred_keypoints': outputs_body_keypoints_list[-1], 'pred_lhand_keypoints': outputs_lhand_keypoints_list[-1], 'pred_rhand_keypoints': outputs_rhand_keypoints_list[-1], 'pred_face_keypoints': outputs_face_keypoints_list[-1], 'pred_smpl_pose': outputs_smpl_pose_list[-1], 'pred_smpl_rhand_pose': outputs_smpl_rhand_pose_list[-1], 'pred_smpl_lhand_pose': outputs_smpl_lhand_pose_list[-1], 'pred_smpl_jaw_pose': outputs_smpl_jaw_pose_list[-1], 'pred_smpl_expr': outputs_smpl_expr_list[-1], 'pred_smpl_beta': outputs_smpl_beta_list[-1], # [B, 100, 10] 'pred_smpl_cam': outputs_smpl_cam_list[-1], # 'pred_smpl_cam_f': outputs_smpl_cam_f_list[-1], 'pred_smpl_kp3d': outputs_smpl_kp3d_list[-1], 'pred_smpl_verts': outputs_smpl_verts_list[-1], 'pred_smpl_fullpose': full_pose } if self.dn_number > 0 and dn_mask_dict is not None: out.update({ 'dn_class_input': dn_class_input, 'dn_bbox_input': dn_bbox_input, 'dn_class_pred': dn_class_pred[-1], 'dn_bbox_pred': dn_bbox_pred[-1], 'num_tgt': dn_mask_dict['pad_size'] }) if self.aux_loss: out['aux_outputs'] = \ self._set_aux_loss( outputs_class, outputs_body_bbox_list, outputs_lhand_bbox_list, outputs_rhand_bbox_list, outputs_face_bbox_list, outputs_body_keypoints_list, outputs_lhand_keypoints_list, outputs_rhand_keypoints_list, outputs_face_keypoints_list, outputs_smpl_pose_list, outputs_smpl_rhand_pose_list, outputs_smpl_lhand_pose_list, outputs_smpl_jaw_pose_list, outputs_smpl_expr_list, outputs_smpl_beta_list, outputs_smpl_cam_list, # outputs_smpl_cam_f_list, outputs_smpl_kp3d_list ) # with key pred_logits, pred_bbox, pred_keypoints if self.dn_number > 0 and dn_mask_dict is not None: assert len(dn_class_pred[:-1]) == len( dn_bbox_pred[:-1]) == len(out['aux_outputs']) for aux_out, dn_class_pred_i, dn_bbox_pred_i in zip( out['aux_outputs'], dn_class_pred, dn_bbox_pred): aux_out.update({ 'dn_class_input': dn_class_input, 'dn_bbox_input': dn_bbox_input, 'dn_class_pred': dn_class_pred_i, 'dn_bbox_pred': dn_bbox_pred_i, 'num_tgt': dn_mask_dict['pad_size'] }) # for encoder output if hs_enc is not None: interm_coord = ref_enc[-1] interm_class = self.transformer.enc_out_class_embed(hs_enc[-1]) interm_pose = torch.zeros_like(outputs_body_keypoints_list[0]) out['interm_outputs'] = { 'pred_logits': interm_class, 'pred_boxes': interm_coord, 'pred_keypoints': interm_pose } return out, targets, data_batch @torch.jit.unused def _set_aux_loss(self, outputs_class, outputs_body_coord, outputs_lhand_coord, outputs_rhand_coord, outputs_face_coord, outputs_body_keypoints, outputs_lhand_keypoints, outputs_rhand_keypoints, outputs_face_keypoints, outputs_smpl_pose, outputs_smpl_rhand_pose, outputs_smpl_lhand_pose, outputs_smpl_jaw_pose, outputs_smpl_expr, outputs_smpl_beta, outputs_smpl_cam, # outputs_smpl_cam_f, outputs_smpl_kp3d): return [{ 'pred_logits': a, 'pred_boxes': b, 'pred_lhand_boxes': c, 'pred_rhand_boxes': d, 'pred_face_boxes': e, 'pred_keypoints': f, 'pred_lhand_keypoints': g, 'pred_rhand_keypoints': h, 'pred_face_keypoints': i, 'pred_smpl_pose': j, 'pred_smpl_rhand_pose': k, 'pred_smpl_lhand_pose': l, 'pred_smpl_jaw_pose': m, 'pred_smpl_expr': n, 'pred_smpl_beta': o, 'pred_smpl_cam': p, # 'pred_smpl_cam_f': q, 'pred_smpl_kp3d': q } for a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q in zip( outputs_class[:-1], outputs_body_coord[:-1], outputs_lhand_coord[:-1], outputs_rhand_coord[:-1], outputs_face_coord[:-1], outputs_body_keypoints[:-1], outputs_lhand_keypoints[:-1], outputs_rhand_keypoints[:-1], outputs_face_keypoints[:-1], outputs_smpl_pose[:-1], outputs_smpl_rhand_pose[:-1], outputs_smpl_lhand_pose[:-1], outputs_smpl_jaw_pose[:-1], outputs_smpl_expr[:-1], outputs_smpl_beta[:-1], outputs_smpl_cam[:-1], outputs_smpl_kp3d[:-1])] def prepare_targets(self, data_batch): data_batch_coco = [] instance_dict = {} img_list = data_batch['img'].float() # input_img_h, input_img_w = data_batch['image_metas'][0]['batch_input_shape'] batch_size, _, input_img_h, input_img_w = img_list.shape device = img_list.device masks = torch.ones((batch_size, input_img_h, input_img_w), dtype=torch.bool, device=device) # cv2.imread(data_batch['img_metas'][img_id]['image_path']).shape for img_id in range(batch_size): img_h, img_w = data_batch['img_shape'][img_id] masks[img_id, :img_h, :img_w] = 0 if not self.inference: instance_body_bbox = torch.cat([data_batch['body_bbox_center'][img_id],\ data_batch['body_bbox_size'][img_id]],dim=-1) instance_face_bbox = torch.cat([data_batch['face_bbox_center'][img_id],\ data_batch['face_bbox_size'][img_id]],dim=-1) instance_lhand_bbox = torch.cat([data_batch['lhand_bbox_center'][img_id],\ data_batch['lhand_bbox_size'][img_id]],dim=-1) instance_rhand_bbox = torch.cat([data_batch['rhand_bbox_center'][img_id],\ data_batch['rhand_bbox_size'][img_id]],dim=-1) instance_kp2d = data_batch['joint_img'][img_id].clone().float() instance_kp2d_mask = data_batch['joint_trunc'][img_id].clone().float() instance_kp2d[:,:,2:] = instance_kp2d_mask body_kp2d, _ = convert_kps(instance_kp2d, 'smplx_137', 'coco', approximate=True) lhand_kp2d, _ = convert_kps(instance_kp2d, 'smplx_137', 'smplx_lhand', approximate=True) rhand_kp2d, _ = convert_kps(instance_kp2d, 'smplx_137', 'smplx_rhand', approximate=True) face_kp2d, _ = convert_kps(instance_kp2d, 'smplx_137', 'smplx_face', approximate=True) # from util.vis_utils import show_bbox # show_bbox(img_list[img_id],instance_kp2d.cpu().numpy(),data_batch['bbox_xywh'][img_id].cpu().numpy) body_kp2d[:,:,0] = body_kp2d[:,:,0]/cfg.output_hm_shape[2] body_kp2d[:,:,1] = body_kp2d[:,:,1]/cfg.output_hm_shape[1] body_kp2d = torch.cat([body_kp2d[:,:,:2].flatten(1),body_kp2d[:,:,2]],dim=-1) lhand_kp2d[:,:,0] = lhand_kp2d[:,:,0]/cfg.output_hm_shape[2] lhand_kp2d[:,:,1] = lhand_kp2d[:,:,1]/cfg.output_hm_shape[1] lhand_kp2d = torch.cat([lhand_kp2d[:,:,:2].flatten(1),lhand_kp2d[:,:,2]],dim=-1) rhand_kp2d[:,:,0] = rhand_kp2d[:,:,0]/cfg.output_hm_shape[2] rhand_kp2d[:,:,1] = rhand_kp2d[:,:,1]/cfg.output_hm_shape[1] rhand_kp2d = torch.cat([rhand_kp2d[:,:,:2].flatten(1),rhand_kp2d[:,:,2]],dim=-1) face_kp2d[:,:,0] = face_kp2d[:,:,0]/cfg.output_hm_shape[2] face_kp2d[:,:,1] = face_kp2d[:,:,1]/cfg.output_hm_shape[1] face_kp2d = torch.cat([face_kp2d[:,:,:2].flatten(1),face_kp2d[:,:,2]],dim=-1) instance_dict = {} instance_dict['boxes'] = instance_body_bbox.float() instance_dict['face_boxes'] = instance_face_bbox.float() instance_dict['lhand_boxes'] = instance_lhand_bbox.float() instance_dict['rhand_boxes'] = instance_rhand_bbox.float() instance_dict['keypoints'] = body_kp2d.float() instance_dict['lhand_keypoints'] = lhand_kp2d.float() instance_dict['rhand_keypoints'] = rhand_kp2d.float() instance_dict['face_keypoints'] = face_kp2d.float() # instance_dict['orig_size'] = data_batch['ori_shape'][img_id] instance_dict['size'] = data_batch['img_shape'][img_id] # after augmentation instance_dict['area'] = instance_body_bbox[:, 2] * instance_body_bbox[:, 3] instance_dict['lhand_area'] = instance_lhand_bbox[:, 2] * instance_lhand_bbox[:, 3] instance_dict['rhand_area'] = instance_rhand_bbox[:, 2] * instance_rhand_bbox[:, 3] instance_dict['face_area'] = instance_face_bbox[:, 2] * instance_face_bbox[:, 3] instance_dict['labels'] = torch.ones(instance_body_bbox.shape[0], dtype=torch.long, device=device) data_batch_coco.append(instance_dict) # body_bbox = data_batch['body_bbox'][img_id].clone().float().reshape(-1, 4) # lhand_bbox = data_batch['lhand_bbox'][img_id].clone().float().reshape(-1, 4) # rhand_bbox = data_batch['rhand_bbox'][img_id].clone().float().reshape(-1, 4) # face_bbox = data_batch['face_bbox'][img_id].clone().float().reshape(-1, 4) # vis = False # if vis: # import mmcv # body_bbox[:, 0] *= img_w # body_bbox[:, 1] *= img_h # body_bbox[:, 2] *= img_w # body_bbox[:, 3] *= img_h # img = (data_batch['img'][img_id]*255).int().permute(1,2,0).cpu().detach().numpy() # img = mmcv.imshow_bboxes(img.copy(), face_bbox.cpu().numpy(), show=False) # cv2.imwrite('test.png', img) # instance_kp2d[:,:,0] = instance_kp2d[:,:,0]/cfg.output_hm_shape[2]*img_w # instance_kp2d[:,:,1] = instance_kp2d[:,:,1]/cfg.output_hm_shape[1]*img_h # from detrsmpl.core.visualization.visualize_keypoints2d import visualize_kp2d # img = (data_batch['img'][img_id]*255).int().permute(1,2,0).cpu().detach().numpy() # img1 = visualize_kp2d(instance_kp2d.cpu().detach().numpy(),image_array=img[None].copy(),return_array=True) # cv2.imwrite('test.png',img1[0]) # lhand_kp2d[:,:,0] = lhand_kp2d[:,:,0]/cfg.output_hm_shape[2]*img_w # lhand_kp2d[:,:,1] = lhand_kp2d[:,:,1]/cfg.output_hm_shape[1]*img_h # lhand_kp2d = convert_kps(lhand_kp2d, 'smplx_lhand', 'smplx', approximate=True)[0] else: instance_body_bbox = torch.cat([data_batch['body_bbox_center'][img_id],\ data_batch['body_bbox_size'][img_id]],dim=-1) instance_dict = {} # instance_dict['orig_size'] = data_batch['ori_shape'][img_id] instance_dict['size'] = data_batch['img_shape'][img_id] # after augmentation instance_dict['boxes'] = instance_body_bbox.float() data_batch_coco.append(instance_dict) input_img = NestedTensor(img_list, masks) return input_img, data_batch_coco def keypoints_to_scaled_bbox_bfh( self, keypoints, occ=None, body_scale=1.0, fh_scale=1.0, convention='smplx'): '''Obtain scaled bbox in xyxy format given keypoints Args: keypoints (np.ndarray): Keypoints scale (float): Bounding Box scale Returns: bbox_xyxy (np.ndarray): Bounding box in xyxy format ''' bboxs = [] # supported kps.shape: (1, n, k) or (n, k), k = 2 or 3 if keypoints.ndim == 3: keypoints = keypoints[0] if keypoints.shape[-1] != 2: keypoints = keypoints[:, :2] for body_part in ['body', 'head', 'left_hand', 'right_hand']: if body_part == 'body': scale = body_scale kps = keypoints else: scale = fh_scale kp_id = get_keypoint_idxs_by_part(body_part, convention=convention) kps = keypoints[kp_id] if not occ is None: occ_p = occ[kp_id] if np.sum(occ_p) / len(kp_id) >= 0.1: conf = 0 # print(f'{body_part} occluded, occlusion: {np.sum(occ_p) / len(kp_id)}, skip') else: # print(f'{body_part} good, {np.sum(self_occ_p + occ_p) / len(kp_id)}') conf = 1 else: conf = 1 if body_part == 'body': conf = 1 xmin, ymin = np.amin(kps, axis=0) xmax, ymax = np.amax(kps, axis=0) width = (xmax - xmin) * scale height = (ymax - ymin) * scale x_center = 0.5 * (xmax + xmin) y_center = 0.5 * (ymax + ymin) xmin = x_center - 0.5 * width xmax = x_center + 0.5 * width ymin = y_center - 0.5 * height ymax = y_center + 0.5 * height bbox = np.stack([xmin, ymin, xmax, ymax, conf], axis=0).astype(np.float32) bboxs.append(bbox) return bboxs @MODULE_BUILD_FUNCS.registe_with_name(module_name='aios_smplx') def build_aios_smplx(args, cfg): # pdb.set_trace() num_classes = args.num_classes # 2 device = torch.device(args.device) backbone = build_backbone(args) transformer = build_transformer(args) dn_labelbook_size = args.dn_labelbook_size dec_pred_class_embed_share = args.dec_pred_class_embed_share dec_pred_bbox_embed_share = args.dec_pred_bbox_embed_share if args.eval: body_model = args.body_model_test train = False else: body_model = args.body_model_train train = True model = AiOSSMPLX( backbone, transformer, num_classes=num_classes, # 2 num_queries=args.num_queries, # 900 aux_loss=True, iter_update=True, query_dim=4, random_refpoints_xy=args.random_refpoints_xy, # False fix_refpoints_hw=args.fix_refpoints_hw, # -1 num_feature_levels=args.num_feature_levels, # 4 nheads=args.nheads, # 8 dec_pred_class_embed_share=dec_pred_class_embed_share, # false dec_pred_bbox_embed_share=dec_pred_bbox_embed_share, # False # two stage two_stage_type=args.two_stage_type, # box_share two_stage_bbox_embed_share=args.two_stage_bbox_embed_share, # False two_stage_class_embed_share=args.two_stage_class_embed_share, # False dn_number=args.dn_number if args.use_dn else 0, # 100 dn_box_noise_scale=args.dn_box_noise_scale, # 0.4 dn_label_noise_ratio=args.dn_label_noise_ratio, # 0.5 dn_batch_gt_fuse=args.dn_batch_gt_fuse, # false dn_attn_mask_type_list=args.dn_attn_mask_type_list, dn_labelbook_size=dn_labelbook_size, # 100 cls_no_bias=args.cls_no_bias, # False num_group=args.num_group, # 100 num_body_points=args.num_body_points, # 17 num_hand_points=args.num_hand_points, # 17 num_face_points=args.num_face_points, # 17 num_box_decoder_layers=args.num_box_decoder_layers, # 2 num_hand_face_decoder_layers=args.num_hand_face_decoder_layers, # smpl_convention=convention body_model=body_model, train=train, inference=args.inference) matcher = build_matcher(args) # prepare weight dict weight_dict = { 'loss_ce': args.cls_loss_coef, # 2 # bbox 'loss_body_bbox': args.body_bbox_loss_coef, # 5 'loss_rhand_bbox': args.rhand_bbox_loss_coef, # 5 'loss_lhand_bbox': args.lhand_bbox_loss_coef, # 5 'loss_face_bbox': args.face_bbox_loss_coef, # 5 # bbox giou 'loss_body_giou': args.body_giou_loss_coef, # 2 'loss_rhand_giou': args.rhand_giou_loss_coef, # 2 'loss_lhand_giou': args.lhand_giou_loss_coef, # 2 'loss_face_giou': args.face_giou_loss_coef, # 2 # 2d kp 'loss_keypoints': args.keypoints_loss_coef, # 10 'loss_rhand_keypoints': args.rhand_keypoints_loss_coef, # 10 'loss_lhand_keypoints': args.lhand_keypoints_loss_coef, # 10 'loss_face_keypoints': args.face_keypoints_loss_coef, # 10 # 2d kp oks 'loss_oks': args.oks_loss_coef, # 4 'loss_rhand_oks': args.rhand_oks_loss_coef, # 4 'loss_lhand_oks': args.lhand_oks_loss_coef, # 4 'loss_face_oks': args.face_oks_loss_coef, # 4 # smpl param 'loss_smpl_pose_root': args.smpl_pose_loss_root_coef, # 0 'loss_smpl_pose_body': args.smpl_pose_loss_body_coef, # 0 'loss_smpl_pose_lhand': args.smpl_pose_loss_lhand_coef, # 0 'loss_smpl_pose_rhand': args.smpl_pose_loss_rhand_coef, # 0 'loss_smpl_pose_jaw': args.smpl_pose_loss_jaw_coef, # 0 'loss_smpl_beta': args.smpl_beta_loss_coef, # 0 'loss_smpl_expr': args.smpl_expr_loss_coef, # smpl kp3d ra 'loss_smpl_body_kp3d_ra': args.smpl_body_kp3d_ra_loss_coef, # 0 'loss_smpl_lhand_kp3d_ra': args.smpl_lhand_kp3d_ra_loss_coef, # 0 'loss_smpl_rhand_kp3d_ra': args.smpl_rhand_kp3d_ra_loss_coef, # 0 'loss_smpl_face_kp3d_ra': args.smpl_face_kp3d_ra_loss_coef, # 0 # smpl kp3d 'loss_smpl_body_kp3d': args.smpl_body_kp3d_loss_coef, # 0 'loss_smpl_face_kp3d': args.smpl_face_kp3d_loss_coef, # 0 'loss_smpl_lhand_kp3d': args.smpl_lhand_kp3d_loss_coef, # 0 'loss_smpl_rhand_kp3d': args.smpl_rhand_kp3d_loss_coef, # 0 # smpl kp2d 'loss_smpl_body_kp2d': args.smpl_body_kp2d_loss_coef, # 0 'loss_smpl_lhand_kp2d': args.smpl_lhand_kp2d_loss_coef, # 0 'loss_smpl_rhand_kp2d': args.smpl_rhand_kp2d_loss_coef, # 0 'loss_smpl_face_kp2d': args.smpl_face_kp2d_loss_coef, # 0 # smpl kp2d ba 'loss_smpl_body_kp2d_ba': args.smpl_body_kp2d_ba_loss_coef, 'loss_smpl_face_kp2d_ba': args.smpl_face_kp2d_ba_loss_coef, 'loss_smpl_lhand_kp2d_ba': args.smpl_lhand_kp2d_ba_loss_coef, 'loss_smpl_rhand_kp2d_ba': args.smpl_rhand_kp2d_ba_loss_coef, } clean_weight_dict_wo_dn = copy.deepcopy(weight_dict) if args.use_dn: weight_dict.update({ 'dn_loss_ce': args.dn_label_coef, # 0.3 'dn_loss_bbox': args.bbox_loss_coef * args.dn_bbox_coef, # 5 * 0.5 'dn_loss_giou': args.giou_loss_coef * args.dn_bbox_coef, # 2 * 0.5 }) clean_weight_dict = copy.deepcopy(weight_dict) if args.aux_loss: aux_weight_dict = {} for i in range(args.dec_layers - 1): # from 0 t 4 # ??? for k, v in clean_weight_dict.items(): if i < args.num_box_decoder_layers and ('keypoints' in k or 'oks' in k): continue if i < args.num_box_decoder_layers and k in [ 'loss_rhand_bbox', 'loss_lhand_bbox', 'loss_face_bbox', 'loss_rhand_giou', 'loss_lhand_giou', 'loss_face_giou']: continue if i < args.num_hand_face_decoder_layers and k in [ 'loss_rhand_keypoints', 'loss_lhand_keypoints', 'loss_face_keypoints', 'loss_rhand_oks', 'loss_lhand_oks', 'loss_face_oks']: continue if i < args.num_box_decoder_layers and 'smpl' in k: continue aux_weight_dict.update({k + f'_{i}': v}) weight_dict.update(aux_weight_dict) if args.two_stage_type != 'no': interm_weight_dict = {} try: no_interm_box_loss = args.no_interm_box_loss except: no_interm_box_loss = False _coeff_weight_dict = { 'loss_ce': 1.0, # bbox 'loss_body_bbox': 1.0 if not no_interm_box_loss else 0.0, 'loss_rhand_bbox': 1.0 if not no_interm_box_loss else 0.0, 'loss_lhand_bbox': 1.0 if not no_interm_box_loss else 0.0, 'loss_face_bbox': 1.0 if not no_interm_box_loss else 0.0, # bbox giou 'loss_body_giou': 1.0 if not no_interm_box_loss else 0.0, 'loss_rhand_giou': 1.0 if not no_interm_box_loss else 0.0, 'loss_lhand_giou': 1.0 if not no_interm_box_loss else 0.0, 'loss_face_giou': 1.0 if not no_interm_box_loss else 0.0, # 2d kp 'loss_keypoints': 1.0 if not no_interm_box_loss else 0.0, 'loss_rhand_keypoints': 1.0 if not no_interm_box_loss else 0.0, 'loss_lhand_keypoints': 1.0 if not no_interm_box_loss else 0.0, 'loss_face_keypoints': 1.0 if not no_interm_box_loss else 0.0, # 2d oks 'loss_oks': 1.0 if not no_interm_box_loss else 0.0, 'loss_rhand_oks': 1.0 if not no_interm_box_loss else 0.0, 'loss_lhand_oks': 1.0 if not no_interm_box_loss else 0.0, 'loss_face_oks': 1.0 if not no_interm_box_loss else 0.0, # smpl param 'loss_smpl_pose_root': 1.0 if not no_interm_box_loss else 0.0, 'loss_smpl_pose_body': 1.0 if not no_interm_box_loss else 0.0, 'loss_smpl_pose_lhand': 1.0 if not no_interm_box_loss else 0.0, 'loss_smpl_pose_rhand': 1.0 if not no_interm_box_loss else 0.0, 'loss_smpl_pose_jaw': 1.0 if not no_interm_box_loss else 0.0, 'loss_smpl_beta': 1.0 if not no_interm_box_loss else 0.0, 'loss_smpl_expr': 1.0 if not no_interm_box_loss else 0.0, # smpl kp3d ra 'loss_smpl_body_kp3d_ra': 1.0 if not no_interm_box_loss else 0.0, 'loss_smpl_lhand_kp3d_ra': 1.0 if not no_interm_box_loss else 0.0, 'loss_smpl_rhand_kp3d_ra': 1.0 if not no_interm_box_loss else 0.0, 'loss_smpl_face_kp3d_ra': 1.0 if not no_interm_box_loss else 0.0, # smpl kp3d 'loss_smpl_body_kp3d': 1.0 if not no_interm_box_loss else 0.0, 'loss_smpl_face_kp3d': 1.0 if not no_interm_box_loss else 0.0, 'loss_smpl_lhand_kp3d': 1.0 if not no_interm_box_loss else 0.0, 'loss_smpl_rhand_kp3d': 1.0 if not no_interm_box_loss else 0.0, # smpl kp2d 'loss_smpl_body_kp2d': 1.0 if not no_interm_box_loss else 0.0, 'loss_smpl_lhand_kp2d': 1.0 if not no_interm_box_loss else 0.0, 'loss_smpl_rhand_kp2d': 1.0 if not no_interm_box_loss else 0.0, 'loss_smpl_face_kp2d': 1.0 if not no_interm_box_loss else 0.0, # smpl kp2d ba 'loss_smpl_body_kp2d_ba': 1.0 if not no_interm_box_loss else 0.0, 'loss_smpl_lhand_kp2d_ba': 1.0 if not no_interm_box_loss else 0.0, 'loss_smpl_rhand_kp2d_ba': 1.0 if not no_interm_box_loss else 0.0, 'loss_smpl_face_kp2d_ba': 1.0 if not no_interm_box_loss else 0.0, } try: interm_loss_coef = args.interm_loss_coef # 1 except: interm_loss_coef = 1.0 interm_weight_dict.update({ k + f'_interm': v * interm_loss_coef * _coeff_weight_dict[k] for k, v in clean_weight_dict_wo_dn.items() if 'keypoints' not in k }) weight_dict.update(interm_weight_dict) interm_weight_dict.update({ k + f'_query_expand': v * interm_loss_coef * _coeff_weight_dict[k] for k, v in clean_weight_dict_wo_dn.items() }) # ??? weight_dict.update(interm_weight_dict) losses = cfg.losses if args.dn_number > 0: losses += ['dn_label', 'dn_bbox'] losses += ['matching'] criterion = SetCriterion( num_classes, matcher=matcher, weight_dict=weight_dict, focal_alpha=args.focal_alpha, losses=losses, num_box_decoder_layers=args.num_box_decoder_layers, num_hand_face_decoder_layers=args.num_hand_face_decoder_layers, num_body_points=args.num_body_points, num_hand_points=args.num_hand_points, num_face_points=args.num_face_points, ) criterion.to(device) if args.inference: postprocessors = { 'bbox': PostProcess_SMPLX_Multi_Infer( num_select=args.num_select, nms_iou_threshold=args.nms_iou_threshold, num_body_points=args.num_body_points), } else: postprocessors = { 'bbox': PostProcess_SMPLX( num_select=args.num_select, nms_iou_threshold=args.nms_iou_threshold, num_body_points=args.num_body_points), } postprocessors_aios = { 'bbox': PostProcess_aios(num_select=args.num_select, nms_iou_threshold=args.nms_iou_threshold, num_body_points=args.num_body_points), } # criterion_smpl=build_architecture(cfg['smpl_loss']) return model, criterion, postprocessors, postprocessors_aios class AiOSSMPLX_Box(nn.Module): def __init__( self, backbone, transformer, num_classes, num_queries, aux_loss=False, iter_update=True, query_dim=4, random_refpoints_xy=False, fix_refpoints_hw=-1, num_feature_levels=1, nheads=8, two_stage_type='no', dec_pred_class_embed_share=False, dec_pred_bbox_embed_share=False, dec_pred_pose_embed_share=False, two_stage_class_embed_share=True, two_stage_bbox_embed_share=True, dn_number=100, dn_box_noise_scale=0.4, dn_label_noise_ratio=0.5, dn_batch_gt_fuse=False, dn_labelbook_size=100, dn_attn_mask_type_list=['group2group'], cls_no_bias=False, num_group=100, num_body_points=0, num_hand_points=0, num_face_points=0, num_box_decoder_layers=2, num_hand_face_decoder_layers=4, body_model=dict( type='smplx', keypoint_src='smplx', num_expression_coeffs=10, keypoint_dst='smplx_137', model_path='data/body_models/smplx', use_pca=False, use_face_contour=True), train=True, inference=False, focal_length=[5000., 5000.], camera_3d_size=2.5 ): super().__init__() self.num_queries = num_queries self.transformer = transformer self.num_classes = num_classes self.hidden_dim = hidden_dim = transformer.d_model self.num_feature_levels = num_feature_levels self.nheads = nheads self.label_enc = nn.Embedding(dn_labelbook_size + 1, hidden_dim) self.num_body_points = num_body_points self.num_hand_points = num_hand_points self.num_face_points = num_face_points self.num_whole_body_points = num_body_points + 2*num_hand_points + num_face_points self.num_box_decoder_layers = num_box_decoder_layers self.num_hand_face_decoder_layers = num_hand_face_decoder_layers self.focal_length = focal_length self.camera_3d_size=camera_3d_size self.inference = inference if train: self.smpl_convention = 'smplx' else: self.smpl_convention = 'h36m' # setting query dim self.query_dim = query_dim assert query_dim == 4 self.random_refpoints_xy = random_refpoints_xy # False self.fix_refpoints_hw = fix_refpoints_hw # -1 # for dn training self.dn_number = dn_number self.dn_box_noise_scale = dn_box_noise_scale self.dn_label_noise_ratio = dn_label_noise_ratio self.dn_batch_gt_fuse = dn_batch_gt_fuse self.dn_labelbook_size = dn_labelbook_size self.dn_attn_mask_type_list = dn_attn_mask_type_list assert all([ i in ['match2dn', 'dn2dn', 'group2group'] for i in dn_attn_mask_type_list ]) assert not dn_batch_gt_fuse # build human body # if train: # self.body_model = build_body_model(body_model) if inference: body_model=dict( type='smplx', keypoint_src='smplx', num_expression_coeffs=10, num_betas=10, keypoint_dst='smplx', model_path='data/body_models/smplx', use_pca=False, use_face_contour=True) self.body_model = build_body_model(body_model) for param in self.body_model.parameters(): param.requires_grad = False # prepare input projection layers if num_feature_levels > 1: num_backbone_outs = len(backbone.num_channels) # 3 input_proj_list = [] for _ in range(num_backbone_outs): in_channels = backbone.num_channels[_] input_proj_list.append( nn.Sequential( nn.Conv2d(in_channels, hidden_dim, kernel_size=1), nn.GroupNorm(32, hidden_dim), )) for _ in range(num_feature_levels - num_backbone_outs): input_proj_list.append( nn.Sequential( nn.Conv2d(in_channels, hidden_dim, kernel_size=3, stride=2, padding=1), nn.GroupNorm(32, hidden_dim), )) in_channels = hidden_dim self.input_proj = nn.ModuleList(input_proj_list) else: assert two_stage_type == 'no', 'two_stage_type should be no if num_feature_levels=1 !!!' self.input_proj = nn.ModuleList([ nn.Sequential( nn.Conv2d(backbone.num_channels[-1], hidden_dim, kernel_size=1), nn.GroupNorm(32, hidden_dim), ) ]) self.backbone = backbone self.aux_loss = aux_loss self.box_pred_damping = box_pred_damping = None self.iter_update = iter_update assert iter_update, 'Why not iter_update?' # prepare pred layers self.dec_pred_class_embed_share = dec_pred_class_embed_share # false self.dec_pred_bbox_embed_share = dec_pred_bbox_embed_share # false # 1.1 prepare class & box embed _class_embed = nn.Linear(hidden_dim, num_classes, bias=(not cls_no_bias)) if not cls_no_bias: prior_prob = 0.01 bias_value = -math.log((1 - prior_prob) / prior_prob) _class_embed.bias.data = torch.ones(self.num_classes) * bias_value # 1.2 box embed layer list if dec_pred_class_embed_share: class_embed_layerlist = [ _class_embed for i in range(transformer.num_decoder_layers) ] else: class_embed_layerlist = [ copy.deepcopy(_class_embed) for i in range(transformer.num_decoder_layers) ] ########################################################################### # body bbox + l/r hand box + face box ########################################################################### # 1.1 body bbox embed _bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3) nn.init.constant_(_bbox_embed.layers[-1].weight.data, 0) nn.init.constant_(_bbox_embed.layers[-1].bias.data, 0) # 1.2 body bbox embed layer list self.num_group = num_group if dec_pred_bbox_embed_share: box_body_embed_layerlist = [ _bbox_embed for i in range(transformer.num_decoder_layers) ] else: box_body_embed_layerlist = [ copy.deepcopy(_bbox_embed) for i in range(transformer.num_decoder_layers) ] # 2.1 lhand bbox embed _bbox_hand_embed = MLP(hidden_dim, hidden_dim, 2, 3) # TODO: the out shape should be 2 not 4 nn.init.constant_(_bbox_hand_embed.layers[-1].weight.data, 0) nn.init.constant_(_bbox_hand_embed.layers[-1].bias.data, 0) _bbox_hand_hw_embed = MLP(hidden_dim, hidden_dim, 2, 3) nn.init.constant_(_bbox_hand_hw_embed.layers[-1].weight.data, 0) nn.init.constant_(_bbox_hand_hw_embed.layers[-1].bias.data, 0) # 2.2 lhand bbox embed layer list if dec_pred_pose_embed_share: box_hand_embed_layerlist = \ [_bbox_hand_embed for i in range(transformer.num_decoder_layers - num_box_decoder_layers+1)] else: box_hand_embed_layerlist = [ copy.deepcopy(_bbox_hand_embed) for i in range(transformer.num_decoder_layers - num_box_decoder_layers + 1) ] if dec_pred_pose_embed_share: box_hand_hw_embed_layerlist = [ _bbox_hand_hw_embed for i in range( transformer.num_decoder_layers - num_box_decoder_layers) ] else: box_hand_hw_embed_layerlist = [ copy.deepcopy(_bbox_hand_hw_embed) for i in range(transformer.num_decoder_layers - num_box_decoder_layers) ] # 4.1 face bbox embed _bbox_face_embed = MLP(hidden_dim, hidden_dim, 2, 3) nn.init.constant_(_bbox_face_embed.layers[-1].weight.data, 0) nn.init.constant_(_bbox_face_embed.layers[-1].bias.data, 0) _bbox_face_hw_embed = MLP(hidden_dim, hidden_dim, 2, 3) nn.init.constant_(_bbox_face_hw_embed.layers[-1].weight.data, 0) nn.init.constant_(_bbox_face_hw_embed.layers[-1].bias.data, 0) # 4.2 face bbox embed layer list if dec_pred_pose_embed_share: box_face_embed_layerlist = [ _bbox_face_embed for i in range( transformer.num_decoder_layers - num_box_decoder_layers + 1) ] else: box_face_embed_layerlist = [ copy.deepcopy(_bbox_face_embed) for i in range(transformer.num_decoder_layers - num_box_decoder_layers + 1) ] if dec_pred_pose_embed_share: box_face_hw_embed_layerlist = [ _bbox_face_hw_embed for i in range( transformer.num_decoder_layers - num_box_decoder_layers)] else: box_face_hw_embed_layerlist = [ copy.deepcopy(_bbox_face_hw_embed) for i in range(transformer.num_decoder_layers - num_box_decoder_layers) ] # 1. smpl pose embed if body_model['type'].upper()=='SMPL': self.body_model_joint_num = 24 elif body_model['type'].upper()=='SMPLX': self.body_model_joint_num = 22 else: raise ValueError( f'Only supports SMPL or SMPLX, but get {body_model.type}') #TODO: _smpl_pose_embed = MLP(hidden_dim * 4, hidden_dim, self.body_model_joint_num * 6, 3) nn.init.constant_(_smpl_pose_embed.layers[-1].weight.data, 0) nn.init.constant_(_smpl_pose_embed.layers[-1].bias.data, 0) if dec_pred_bbox_embed_share: smpl_pose_embed_layerlist = [ _smpl_pose_embed for i in range(transformer.num_decoder_layers - num_box_decoder_layers) ] else: smpl_pose_embed_layerlist = [ copy.deepcopy(_smpl_pose_embed) for i in range(transformer.num_decoder_layers - num_box_decoder_layers) ] # 2. smpl betas embed _smpl_beta_embed = MLP(hidden_dim * 4, hidden_dim, 10, 3) nn.init.constant_(_smpl_beta_embed.layers[-1].weight.data, 0) nn.init.constant_(_smpl_beta_embed.layers[-1].bias.data, 0) if dec_pred_bbox_embed_share: smpl_beta_embed_layerlist = [ _smpl_beta_embed for i in range(transformer.num_decoder_layers - num_box_decoder_layers) ] else: smpl_beta_embed_layerlist = [ copy.deepcopy(_smpl_beta_embed) for i in range(transformer.num_decoder_layers - num_box_decoder_layers) ] # 3. smpl cam embed _cam_embed = MLP(hidden_dim * 4, hidden_dim, 3, 3) nn.init.constant_(_cam_embed.layers[-1].weight.data, 0) nn.init.constant_(_cam_embed.layers[-1].bias.data, 0) if dec_pred_bbox_embed_share: cam_embed_layerlist = [ _cam_embed for i in range(transformer.num_decoder_layers - num_box_decoder_layers) ] else: cam_embed_layerlist = [ copy.deepcopy(_cam_embed) for i in range(transformer.num_decoder_layers - num_box_decoder_layers) ] ########################################################################### # smplx body pose + hand pose + expression + betas + kp2d + kp3d + cam ########################################################################### # 2. smplx hand pose embed _smplx_hand_pose_embed_layer_2_3 = \ MLP(hidden_dim * 2, hidden_dim, 15 * 6, 3) nn.init.constant_(_smplx_hand_pose_embed_layer_2_3.layers[-1].weight.data, 0) nn.init.constant_(_smplx_hand_pose_embed_layer_2_3.layers[-1].bias.data, 0) _smplx_hand_pose_embed_layer_4_5 = \ MLP(hidden_dim * 2, hidden_dim, 15 * 6, 3) nn.init.constant_(_smplx_hand_pose_embed_layer_4_5.layers[-1].weight.data, 0) nn.init.constant_(_smplx_hand_pose_embed_layer_4_5.layers[-1].bias.data, 0) if dec_pred_bbox_embed_share: smplx_hand_pose_embed_layerlist = [ _smplx_hand_pose_embed_layer_2_3 if i<2 else _smplx_hand_pose_embed_layer_4_5 for i in range(transformer.num_decoder_layers - num_box_decoder_layers) ] else: smplx_hand_pose_embed_layerlist = [ copy.deepcopy(_smplx_hand_pose_embed_layer_2_3) if i<2 else copy.deepcopy(_smplx_hand_pose_embed_layer_4_5) for i in range(transformer.num_decoder_layers - num_box_decoder_layers) ] # 3. smplx face expression _smplx_expression_embed_layer_2_3 = \ MLP(hidden_dim*2, hidden_dim, 10, 3) nn.init.constant_(_smplx_expression_embed_layer_2_3.layers[-1].weight.data, 0) nn.init.constant_(_smplx_expression_embed_layer_2_3.layers[-1].bias.data, 0) _smplx_expression_embed_layer_4_5 = \ MLP(hidden_dim * 2, hidden_dim, 10, 3) nn.init.constant_(_smplx_expression_embed_layer_4_5.layers[-1].weight.data, 0) nn.init.constant_(_smplx_expression_embed_layer_4_5.layers[-1].bias.data, 0) if dec_pred_bbox_embed_share: smplx_expression_embed_layerlist = [ _smplx_expression_embed_layer_2_3 if i<2 else _smplx_expression_embed_layer_4_5 for i in range(transformer.num_decoder_layers - num_box_decoder_layers) ] else: smplx_expression_embed_layerlist = [ copy.deepcopy(_smplx_expression_embed_layer_2_3) if i<2 else copy.deepcopy(_smplx_expression_embed_layer_4_5) for i in range(transformer.num_decoder_layers - num_box_decoder_layers) ] # 4. smplx jaw pose embed _smplx_jaw_embed_2_3 = MLP(hidden_dim * 2, hidden_dim, 6, 3) nn.init.constant_(_smplx_jaw_embed_2_3.layers[-1].weight.data, 0) nn.init.constant_(_smplx_jaw_embed_2_3.layers[-1].bias.data, 0) _smplx_jaw_embed_4_5 = MLP(hidden_dim * 2, hidden_dim, 6, 3) nn.init.constant_(_smplx_jaw_embed_4_5.layers[-1].weight.data, 0) nn.init.constant_(_smplx_jaw_embed_4_5.layers[-1].bias.data, 0) if dec_pred_bbox_embed_share: smplx_jaw_embed_layerlist = [ _smplx_jaw_embed_2_3 if i<2 else _smplx_jaw_embed_4_5 for i in range( transformer.num_decoder_layers - num_box_decoder_layers) ] else: smplx_jaw_embed_layerlist = [ copy.deepcopy(_smplx_jaw_embed_2_3) if i<2 else copy.deepcopy(_smplx_jaw_embed_4_5) for i in range( transformer.num_decoder_layers - num_box_decoder_layers) ] self.bbox_embed = nn.ModuleList(box_body_embed_layerlist) self.class_embed = nn.ModuleList(class_embed_layerlist) self.transformer.decoder.bbox_embed = self.bbox_embed self.transformer.decoder.class_embed = self.class_embed # smpl self.smpl_pose_embed = nn.ModuleList(smpl_pose_embed_layerlist) self.smpl_beta_embed = nn.ModuleList(smpl_beta_embed_layerlist) self.smpl_cam_embed = nn.ModuleList(cam_embed_layerlist) # smplx lhand kp self.bbox_hand_embed = nn.ModuleList(box_hand_embed_layerlist) self.bbox_hand_hw_embed = nn.ModuleList(box_hand_hw_embed_layerlist) self.transformer.decoder.bbox_hand_embed = self.bbox_hand_embed self.transformer.decoder.bbox_hand_hw_embed = self.bbox_hand_hw_embed # smplx face kp self.bbox_face_embed = nn.ModuleList(box_face_embed_layerlist) self.bbox_face_hw_embed = nn.ModuleList(box_face_hw_embed_layerlist) self.transformer.decoder.bbox_face_embed = self.bbox_face_embed self.transformer.decoder.bbox_face_hw_embed = self.bbox_face_hw_embed # smplx self.smpl_hand_pose_embed = nn.ModuleList(smplx_hand_pose_embed_layerlist) self.smpl_expr_embed = nn.ModuleList(smplx_expression_embed_layerlist) self.smpl_jaw_embed = nn.ModuleList(smplx_jaw_embed_layerlist) self.transformer.decoder.num_hand_face_decoder_layers = num_hand_face_decoder_layers self.transformer.decoder.num_box_decoder_layers = num_box_decoder_layers self.transformer.decoder.num_body_points = num_body_points self.transformer.decoder.num_hand_points = num_hand_points self.transformer.decoder.num_face_points = num_face_points # two stage self.two_stage_type = two_stage_type assert two_stage_type in [ 'no', 'standard' ], 'unknown param {} of two_stage_type'.format(two_stage_type) if two_stage_type != 'no': if two_stage_bbox_embed_share: assert dec_pred_class_embed_share and dec_pred_bbox_embed_share self.transformer.enc_out_bbox_embed = _bbox_embed else: self.transformer.enc_out_bbox_embed = copy.deepcopy( _bbox_embed) if two_stage_class_embed_share: assert dec_pred_class_embed_share and dec_pred_bbox_embed_share self.transformer.enc_out_class_embed = _class_embed else: self.transformer.enc_out_class_embed = copy.deepcopy( _class_embed) self.refpoint_embed = None self._reset_parameters() def get_camera_trans(self, cam_param, input_body_shape): # camera translation t_xy = cam_param[:, :2] gamma = torch.sigmoid(cam_param[:, 2]) # apply sigmoid to make it positive k_value = torch.FloatTensor( [ math.sqrt( self.focal_length[0] * self.focal_length[1] * self.camera_3d_size * self.camera_3d_size / (input_body_shape[0] * input_body_shape[1]) ) ] ).cuda().view(-1) t_z = k_value * gamma cam_trans = torch.cat((t_xy, t_z[:, None]), 1) return cam_trans def _reset_parameters(self): # init input_proj for proj in self.input_proj: nn.init.xavier_uniform_(proj[0].weight, gain=1) nn.init.constant_(proj[0].bias, 0) def prepare_for_dn2(self, targets): if not self.training: device = targets[0]['boxes'].device bs = len(targets) num_points = 4 attn_mask2 = torch.zeros( bs, self.nheads, self.num_group * 4, self.num_group * 4, device=device, dtype=torch.bool) group_bbox_kpt = 4 # body bbox index kpt_index = [x for x in range(self.num_group * 4) if x % 4 in [0]] for matchj in range(self.num_group * 4): sj = (matchj // group_bbox_kpt) * group_bbox_kpt ej = (matchj // group_bbox_kpt + 1)*group_bbox_kpt # for each instance, they should associate with their query (body hand face) if sj > 0: attn_mask2[:, :, matchj, :sj] = True if ej < self.num_group * 4: attn_mask2[:, :, matchj, ej:] = True for match_x in range(self.num_group * 4): if match_x % group_bbox_kpt in [0, 1, 2, 3]: # each query (hand face body) should associate with all body query attn_mask2[:,:,match_x, kpt_index]=False num_points = 4 attn_mask3 = torch.zeros( bs, self.nheads, self.num_group * 4, self.num_group * 4, device=device, dtype=torch.bool) group_bbox_kpt = 4 kpt_index = [x for x in range(self.num_group * 4) if x % 4 in [0]] for matchj in range(self.num_group * 4): sj = (matchj // group_bbox_kpt) * group_bbox_kpt ej = (matchj // group_bbox_kpt + 1)*group_bbox_kpt # for each instance, they should associate with their query (body hand face) if sj > 0: attn_mask3[:, :, matchj, :sj] = True if ej < self.num_group * 4: attn_mask3[:, :, matchj, ej:] = True for match_x in range(self.num_group * 4): if match_x % group_bbox_kpt in [0, 1, 2, 3]: # each query (hand face body) should associate with all body query attn_mask3[:, :, match_x, kpt_index] = False attn_mask2 = attn_mask2.flatten(0, 1) attn_mask3 = attn_mask3.flatten(0, 1) return None, None, None, attn_mask2, attn_mask3, None # targets, dn_scalar, noise_scale = dn_args device = targets[0]['boxes'].device bs = len(targets) dn_number = self.dn_number # 100 dn_box_noise_scale = self.dn_box_noise_scale # 0.4 dn_label_noise_ratio = self.dn_label_noise_ratio # 0.5 # gather gt boxes and labels gt_boxes = [t['boxes'] for t in targets] gt_labels = [t['labels'] for t in targets] gt_keypoints = [t['keypoints'] for t in targets] # repeat them def get_indices_for_repeat(now_num, target_num, device='cuda'): """ Input: - now_num: int - target_num: int Output: - indices: tensor[target_num] """ out_indice = [] base_indice = torch.arange(now_num).to(device) multiplier = target_num // now_num out_indice.append(base_indice.repeat(multiplier)) residue = target_num % now_num out_indice.append(base_indice[torch.randint(0, now_num, (residue, ), device=device)]) return torch.cat(out_indice) if self.dn_batch_gt_fuse: raise NotImplementedError gt_boxes_bsall = torch.cat(gt_boxes) # num_boxes, 4 gt_labels_bsall = torch.cat(gt_labels) num_gt_bsall = gt_boxes_bsall.shape[0] if num_gt_bsall > 0: indices = get_indices_for_repeat(num_gt_bsall, dn_number, device) gt_boxes_expand = gt_boxes_bsall[indices][None].repeat( bs, 1, 1) # bs, num_dn, 4 gt_labels_expand = gt_labels_bsall[indices][None].repeat( bs, 1) # bs, num_dn else: # all negative samples when no gt boxes gt_boxes_expand = torch.rand(bs, dn_number, 4, device=device) gt_labels_expand = torch.ones( bs, dn_number, dtype=torch.int64, device=device) * int( self.num_classes) else: gt_boxes_expand = [] gt_labels_expand = [] gt_keypoints_expand = [] # here for idx, (gt_boxes_i, gt_labels_i, gt_keypoint_i) in enumerate( zip(gt_boxes, gt_labels, gt_keypoints)): # idx -> batch id num_gt_i = gt_boxes_i.shape[0] # instance num if num_gt_i > 0: indices = get_indices_for_repeat(num_gt_i, dn_number, device) gt_boxes_expand_i = gt_boxes_i[indices] # num_dn, 4 gt_labels_expand_i = gt_labels_i[indices] # add smpl gt_keypoints_expand_i = gt_keypoint_i[indices] else: # all negative samples when no gt boxes gt_boxes_expand_i = torch.rand(dn_number, 4, device=device) gt_labels_expand_i = torch.ones( dn_number, dtype=torch.int64, device=device) * int( self.num_classes) gt_keypoints_expand_i = torch.rand(dn_number, self.num_body_points * 3, device=device) gt_boxes_expand.append(gt_boxes_expand_i) # add smpl gt_labels_expand.append(gt_labels_expand_i) gt_keypoints_expand.append(gt_keypoints_expand_i) gt_boxes_expand = torch.stack(gt_boxes_expand) gt_labels_expand = torch.stack(gt_labels_expand) gt_keypoints_expand = torch.stack(gt_keypoints_expand) knwon_boxes_expand = gt_boxes_expand.clone() knwon_labels_expand = gt_labels_expand.clone() # add noise if dn_label_noise_ratio > 0: prob = torch.rand_like(knwon_labels_expand.float()) chosen_indice = prob < dn_label_noise_ratio new_label = torch.randint_like( knwon_labels_expand[chosen_indice], 0, self.dn_labelbook_size) # randomly put a new one here knwon_labels_expand[chosen_indice] = new_label if dn_box_noise_scale > 0: diff = torch.zeros_like(knwon_boxes_expand) diff[..., :2] = knwon_boxes_expand[..., 2:] / 2 diff[..., 2:] = knwon_boxes_expand[..., 2:] knwon_boxes_expand += torch.mul( (torch.rand_like(knwon_boxes_expand) * 2 - 1.0), diff) * dn_box_noise_scale knwon_boxes_expand = knwon_boxes_expand.clamp(min=0.0, max=1.0) input_query_label = self.label_enc(knwon_labels_expand) input_query_bbox = inverse_sigmoid(knwon_boxes_expand) # prepare mask if 'group2group' in self.dn_attn_mask_type_list: attn_mask = torch.zeros(bs, self.nheads, dn_number + self.num_queries, dn_number + self.num_queries, device=device, dtype=torch.bool) attn_mask[:, :, dn_number:, :dn_number] = True for idx, (gt_boxes_i, gt_labels_i) in enumerate( zip(gt_boxes, gt_labels)): # for batch num_gt_i = gt_boxes_i.shape[0] if num_gt_i == 0: continue for matchi in range(dn_number): si = (matchi // num_gt_i) * num_gt_i ei = (matchi // num_gt_i + 1) * num_gt_i if si > 0: attn_mask[idx, :, matchi, :si] = True if ei < dn_number: attn_mask[idx, :, matchi, ei:dn_number] = True attn_mask = attn_mask.flatten(0, 1) if 'group2group' in self.dn_attn_mask_type_list: # self.num_body_points = self.num_body_points +3 num_points = 4 attn_mask2 = torch.zeros( bs, self.nheads, dn_number + self.num_group * 4, dn_number + self.num_group * 4, device=device, dtype=torch.bool) attn_mask2[:, :, dn_number:, :dn_number] = True group_bbox_kpt = 4 for matchj in range(self.num_group * 4): sj = (matchj // group_bbox_kpt) * group_bbox_kpt ej = (matchj // group_bbox_kpt + 1)*group_bbox_kpt # for each instance, they should associate their body, hand, and face bbox if sj > 0: attn_mask2[:, :, dn_number:, dn_number:][:, :, matchj, :sj] = True if ej < self.num_group * 4: attn_mask2[:, :, dn_number:, dn_number:][:, :, matchj, ej:] = True # body bbox index kpt_index = [x for x in range(self.num_group * 4) if x % 4 in [0]] for match_x in range(self.num_group * 4): if match_x % group_bbox_kpt in [0, 1, 2, 3]: # for each instance, they should associate their each query with # other instances' body query attn_mask2[:, :, dn_number:, dn_number:][:, :, match_x, kpt_index]=False for idx, (gt_boxes_i, gt_labels_i) in enumerate(zip(gt_boxes, gt_labels)): num_gt_i = gt_boxes_i.shape[0] if num_gt_i == 0: continue for matchi in range(dn_number): si = (matchi // num_gt_i) * num_gt_i ei = (matchi // num_gt_i + 1) * num_gt_i if si > 0: attn_mask2[idx, :, matchi, :si] = True if ei < dn_number: attn_mask2[idx, :, matchi, ei:dn_number] = True attn_mask2 = attn_mask2.flatten(0, 1) if 'group2group' in self.dn_attn_mask_type_list: num_points = 4 attn_mask3 = torch.zeros( bs, self.nheads, dn_number + self.num_group * 4, dn_number + self.num_group * 4, device=device, dtype=torch.bool) attn_mask3[:, :, dn_number:, :dn_number] = True group_bbox_kpt = 4 for matchj in range(self.num_group * 4): sj = (matchj // group_bbox_kpt) * group_bbox_kpt ej = (matchj // group_bbox_kpt + 1)*group_bbox_kpt # for each instance, they should associate their body, hand, and face bbox if sj > 0: attn_mask3[:, :, dn_number:, dn_number:][:, :, matchj, :sj] = True if ej < self.num_group * 4: attn_mask3[:, :, dn_number:, dn_number:][:, :, matchj, ej:] = True kpt_index = [x for x in range(self.num_group * 4) if x % 4 in [0]] for match_x in range(self.num_group * 4): if match_x % group_bbox_kpt in [0, 1, 2, 3]: # for each instance, they should associate their each query with # other instances' body query attn_mask3[:, :, dn_number:, dn_number:][:, :, match_x, kpt_index]=False for idx, (gt_boxes_i, gt_labels_i) in enumerate(zip(gt_boxes, gt_labels)): num_gt_i = gt_boxes_i.shape[0] if num_gt_i == 0: continue for matchi in range(dn_number): si = (matchi // num_gt_i) * num_gt_i ei = (matchi // num_gt_i + 1) * num_gt_i if si > 0: attn_mask3[idx, :, matchi, :si] = True if ei < dn_number: attn_mask3[idx, :, matchi, ei:dn_number] = True attn_mask3 = attn_mask3.flatten(0, 1) mask_dict = { 'pad_size': dn_number, 'known_bboxs': gt_boxes_expand, 'known_labels': gt_labels_expand, 'known_keypoints': gt_keypoints_expand } return input_query_label, input_query_bbox, attn_mask, attn_mask2, attn_mask3, mask_dict def dn_post_process2(self, outputs_class, outputs_coord, mask_dict): if mask_dict and mask_dict['pad_size'] > 0: output_known_class = [ outputs_class_i[:, :mask_dict['pad_size'], :] for outputs_class_i in outputs_class ] output_known_coord = [ outputs_coord_i[:, :mask_dict['pad_size'], :] for outputs_coord_i in outputs_coord ] outputs_class = [ outputs_class_i[:, mask_dict['pad_size']:, :] for outputs_class_i in outputs_class ] outputs_coord = [ outputs_coord_i[:, mask_dict['pad_size']:, :] for outputs_coord_i in outputs_coord ] mask_dict.update({ 'output_known_coord': output_known_coord, 'output_known_class': output_known_class }) return outputs_class, outputs_coord def forward(self, data_batch: NestedTensor, targets: List = None): """The forward expects a NestedTensor, which consists of: - samples.tensor: batched images, of shape [batch_size x 3 x H x W] - samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels It returns a dict with the following elements: - "pred_logits": the classification logits (including no-object) for all queries. Shape= [batch_size x num_queries x num_classes] - "pred_boxes": The normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These values are normalized in [0, 1], relative to the size of each individual image (disregarding possible padding). See PostProcess for information on how to retrieve the unnormalized bounding box. - "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of dictionnaries containing the two above keys for each decoder layer. """ if isinstance(data_batch, dict): samples, targets = self.prepare_targets(data_batch) # import pdb; pdb.set_trace() elif isinstance(data_batch, (list, torch.Tensor)): samples = nested_tensor_from_tensor_list(data_batch) else: samples = data_batch features, poss = self.backbone(samples) srcs = [] masks = [] for l, feat in enumerate(features): # len(features=3) src, mask = feat.decompose() srcs.append(self.input_proj[l](src)) masks.append(mask) assert mask is not None if self.num_feature_levels > len(srcs): _len_srcs = len(srcs) for l in range(_len_srcs, self.num_feature_levels): if l == _len_srcs: src = self.input_proj[l](features[-1].tensors) else: src = self.input_proj[l](srcs[-1]) m = samples.mask mask = F.interpolate(m[None].float(), size=src.shape[-2:]).to(torch.bool)[0] pos_l = self.backbone[1](NestedTensor(src, mask)).to(src.dtype) srcs.append(src) masks.append(mask) poss.append(pos_l) if self.dn_number > 0 or targets is not None: input_query_label, input_query_bbox, attn_mask,attn_mask2, attn_mask3, mask_dict =\ self.prepare_for_dn2(targets) else: assert targets is None input_query_bbox = input_query_label = attn_mask = attn_mask2 = attn_mask3 = mask_dict = None hs, reference, hs_enc, ref_enc, init_box_proposal = self.transformer( srcs, masks, input_query_bbox, poss, input_query_label, attn_mask, attn_mask2, attn_mask3) # update human boxes effective_dn_number = self.dn_number if self.training else 0 outputs_body_bbox_list = [] outputs_class = [] for dec_lid, (layer_ref_sig, layer_body_bbox_embed, layer_cls_embed, layer_hs) in enumerate( zip(reference[:-1], self.bbox_embed, self.class_embed, hs)): if dec_lid < self.num_box_decoder_layers: # human det layer_delta_unsig = layer_body_bbox_embed(layer_hs) layer_body_box_outputs_unsig = \ layer_delta_unsig + inverse_sigmoid(layer_ref_sig) layer_body_box_outputs_unsig = layer_body_box_outputs_unsig.sigmoid() layer_cls = layer_cls_embed(layer_hs) outputs_body_bbox_list.append(layer_body_box_outputs_unsig) outputs_class.append(layer_cls) elif dec_lid < self.num_box_decoder_layers + 2: bs = layer_ref_sig.shape[0] # dn body bbox layer_hs_body_bbox_dn = layer_hs[:, :effective_dn_number, :] # dn content query reference_before_sigmoid_body_bbox_dn = layer_ref_sig[:, :effective_dn_number, :] # dn position query layer_body_box_delta_unsig_dn = layer_body_bbox_embed(layer_hs_body_bbox_dn) layer_body_box_outputs_unsig_dn = layer_body_box_delta_unsig_dn + inverse_sigmoid( reference_before_sigmoid_body_bbox_dn) layer_body_box_outputs_unsig_dn = layer_body_box_outputs_unsig_dn.sigmoid() # norm body bbox layer_hs_body_bbox_norm = layer_hs[:, effective_dn_number:, :][ :, 0::(self.num_body_points + 4), :] # norm content query reference_before_sigmoid_body_bbox_norm = layer_ref_sig[:, effective_dn_number:, :][ :, 0::(self.num_body_points+ 4), :] # norm position query layer_body_box_delta_unsig_norm = layer_body_bbox_embed(layer_hs_body_bbox_norm) layer_body_box_outputs_unsig_norm = layer_body_box_delta_unsig_norm + inverse_sigmoid( reference_before_sigmoid_body_bbox_norm) layer_body_box_outputs_unsig_norm = layer_body_box_outputs_unsig_norm.sigmoid() layer_body_box_outputs_unsig = torch.cat( (layer_body_box_outputs_unsig_dn, layer_body_box_outputs_unsig_norm), dim=1) # classfication layer_cls_dn = layer_cls_embed(layer_hs_body_bbox_dn) layer_cls_norm = layer_cls_embed(layer_hs_body_bbox_norm) layer_cls = torch.cat((layer_cls_dn, layer_cls_norm), dim=1) outputs_class.append(layer_cls) outputs_body_bbox_list.append(layer_body_box_outputs_unsig) else: bs = layer_ref_sig.shape[0] # dn body bbox layer_hs_body_bbox_dn = layer_hs[:, :effective_dn_number, :] # dn content query reference_before_sigmoid_body_bbox_dn = layer_ref_sig[:, :effective_dn_number, :] # dn position query layer_body_box_delta_unsig_dn = layer_body_bbox_embed(layer_hs_body_bbox_dn) layer_body_box_outputs_unsig_dn = layer_body_box_delta_unsig_dn + inverse_sigmoid( reference_before_sigmoid_body_bbox_dn) layer_body_box_outputs_unsig_dn = layer_body_box_outputs_unsig_dn.sigmoid() # norm body bbox layer_hs_body_bbox_norm = layer_hs[:, effective_dn_number:, :][ :, 0::(self.num_whole_body_points + 4), :] # norm content query reference_before_sigmoid_body_bbox_norm = layer_ref_sig[:,effective_dn_number:, :][ :, 0::(self.num_whole_body_points + 4), :] # norm position query layer_body_box_delta_unsig_norm = layer_body_bbox_embed(layer_hs_body_bbox_norm) layer_body_box_outputs_unsig_norm = layer_body_box_delta_unsig_norm + inverse_sigmoid( reference_before_sigmoid_body_bbox_norm) layer_body_box_outputs_unsig_norm = layer_body_box_outputs_unsig_norm.sigmoid() layer_body_box_outputs_unsig = torch.cat( (layer_body_box_outputs_unsig_dn, layer_body_box_outputs_unsig_norm), dim=1) # classfication layer_cls_dn = layer_cls_embed(layer_hs_body_bbox_dn) layer_cls_norm = layer_cls_embed(layer_hs_body_bbox_norm) layer_cls = torch.cat((layer_cls_dn, layer_cls_norm), dim=1) outputs_class.append(layer_cls) outputs_body_bbox_list.append(layer_body_box_outputs_unsig) # update hand and face boxes outputs_lhand_bbox_list = [] outputs_rhand_bbox_list = [] outputs_face_bbox_list = [] # update keypoints boxes outputs_body_keypoints_list = [] outputs_body_keypoints_hw = [] outputs_lhand_keypoints_list = [] outputs_lhand_keypoints_hw = [] outputs_rhand_keypoints_list = [] outputs_rhand_keypoints_hw = [] outputs_face_keypoints_list = [] outputs_face_keypoints_hw = [] outputs_smpl_pose_list = [] outputs_smpl_lhand_pose_list = [] outputs_smpl_rhand_pose_list = [] outputs_smpl_expr_list = [] outputs_smpl_jaw_pose_list = [] outputs_smpl_beta_list = [] outputs_smpl_cam_list = [] outputs_smpl_kp2d_list = [] outputs_smpl_kp3d_list = [] outputs_smpl_verts_list = [] # smpl pose # body box, kps, lhand box body_index = [0, 1, 2, 3] smpl_pose_index = [ x for x in range(self.num_group * 4) if (x % 4 in body_index)] # smpl lhand lhand_index = [0, 1] smpl_lhand_pose_index = [ x for x in range(self.num_group * 4) if (x % 4 in lhand_index)] # smpl rhand rhand_index = [0, 2] smpl_rhand_pose_index = [ x for x in range(self.num_group * 4) if (x % 4 in rhand_index)] # smpl face face_index = [0, 3] smpl_face_pose_index = [ x for x in range(self.num_group * 4) if (x % 4 in face_index)] for dec_lid, (layer_ref_sig, layer_hs) in enumerate(zip(reference[:-1], hs)): if dec_lid < self.num_box_decoder_layers: assert isinstance(layer_hs, torch.Tensor) bs = layer_hs.shape[0] layer_body_kps_res = layer_hs.new_zeros( (bs, self.num_queries, self.num_body_points * 3)) # [-, 900, 42] outputs_body_keypoints_list.append(layer_body_kps_res) # lhand layer_lhand_bbox_res = layer_hs.new_zeros( (bs, self.num_queries, 4)) # [-, 900, 42] outputs_lhand_bbox_list.append(layer_lhand_bbox_res) layer_lhand_kps_res = layer_hs.new_zeros( (bs, self.num_queries, self.num_hand_points * 3)) # [-, 900, 42] outputs_lhand_keypoints_list.append(layer_lhand_kps_res) # rhand layer_rhand_bbox_res = layer_hs.new_zeros( (bs, self.num_queries, 4)) # [-, 900, 42] outputs_rhand_bbox_list.append(layer_rhand_bbox_res) layer_rhand_kps_res = layer_hs.new_zeros( (bs, self.num_queries, self.num_hand_points * 3)) # [-, 900, 42] outputs_rhand_keypoints_list.append(layer_rhand_kps_res) # face layer_face_bbox_res = layer_hs.new_zeros( (bs, self.num_queries, 4)) # [-, 900, 42] outputs_face_bbox_list.append(layer_face_bbox_res) layer_face_kps_res = layer_hs.new_zeros( (bs, self.num_queries, self.num_face_points * 3)) # [-, 900, 42] outputs_face_keypoints_list.append(layer_face_kps_res) # smpl or smplx smpl_pose = layer_hs.new_zeros((bs, self.num_queries, self.body_model_joint_num * 3)) smpl_rhand_pose = layer_hs.new_zeros( (bs, self.num_queries, 15 * 3)) smpl_lhand_pose = layer_hs.new_zeros( (bs, self.num_queries, 15 * 3)) smpl_expr = layer_hs.new_zeros((bs, self.num_queries, 10)) smpl_jaw_pose = layer_hs.new_zeros((bs, self.num_queries, 6)) smpl_beta = layer_hs.new_zeros((bs, self.num_queries, 10)) smpl_cam = layer_hs.new_zeros((bs, self.num_queries, 3)) # smpl_kp2d = layer_hs.new_zeros((bs, self.num_queries, self.num_body_points,3)) smpl_kp3d = layer_hs.new_zeros( (bs, self.num_queries, self.num_body_points, 4)) outputs_smpl_pose_list.append(smpl_pose) outputs_smpl_rhand_pose_list.append(smpl_rhand_pose) outputs_smpl_lhand_pose_list.append(smpl_lhand_pose) outputs_smpl_expr_list.append(smpl_expr) outputs_smpl_jaw_pose_list.append(smpl_jaw_pose) outputs_smpl_beta_list.append(smpl_beta) outputs_smpl_cam_list.append(smpl_cam) # outputs_smpl_kp2d_list.append(smpl_kp2d) outputs_smpl_kp3d_list.append(smpl_kp3d) elif dec_lid < self.num_box_decoder_layers +2: bs = layer_ref_sig.shape[0] # lhand bbox layer_hs_lhand_bbox = \ layer_hs[:, effective_dn_number:, :][:, 1::4, :] delta_lhand_bbox_xy_unsig = self.bbox_hand_embed[dec_lid - self.num_box_decoder_layers](layer_hs_lhand_bbox) layer_ref_sig_lhand_bbox = \ layer_ref_sig[:,effective_dn_number:, :][:, 1::4, :].clone() layer_ref_unsig_lhand_bbox = inverse_sigmoid(layer_ref_sig_lhand_bbox) delta_lhand_bbox_hw_unsig = self.bbox_hand_hw_embed[ dec_lid-self.num_box_decoder_layers](layer_hs_lhand_bbox) layer_ref_unsig_lhand_bbox[..., :2] +=delta_lhand_bbox_xy_unsig[..., :2] layer_ref_unsig_lhand_bbox[..., 2:] +=delta_lhand_bbox_hw_unsig layer_ref_sig_lhand_bbox = layer_ref_unsig_lhand_bbox.sigmoid() outputs_lhand_bbox_list.append(layer_ref_sig_lhand_bbox) # rhand bbox layer_hs_rhand_bbox = \ layer_hs[:, effective_dn_number:, :][:, 2::4, :] delta_rhand_bbox_xy_unsig = self.bbox_hand_embed[ dec_lid - self.num_box_decoder_layers](layer_hs_rhand_bbox) layer_ref_sig_rhand_bbox = \ layer_ref_sig[:,effective_dn_number:, :][:, 2::4, :].clone() layer_ref_unsig_rhand_bbox = inverse_sigmoid(layer_ref_sig_rhand_bbox) delta_rhand_bbox_hw_unsig = self.bbox_hand_hw_embed[ dec_lid-self.num_box_decoder_layers](layer_hs_rhand_bbox) layer_ref_unsig_rhand_bbox[..., :2] +=delta_rhand_bbox_xy_unsig[..., :2] layer_ref_unsig_rhand_bbox[..., 2:] +=delta_rhand_bbox_hw_unsig layer_ref_sig_rhand_bbox = layer_ref_unsig_rhand_bbox.sigmoid() outputs_rhand_bbox_list.append(layer_ref_sig_rhand_bbox) # face bbox layer_hs_face_bbox = \ layer_hs[:, effective_dn_number:, :][:, 3::4, :] delta_face_bbox_xy_unsig = self.bbox_face_embed[ dec_lid - self.num_box_decoder_layers](layer_hs_face_bbox) layer_ref_sig_face_bbox = \ layer_ref_sig[:,effective_dn_number:, :][:, 3::4, :].clone() layer_ref_unsig_face_bbox = inverse_sigmoid(layer_ref_sig_face_bbox) delta_face_bbox_hw_unsig = self.bbox_face_hw_embed[ dec_lid-self.num_box_decoder_layers](layer_hs_face_bbox) layer_ref_unsig_face_bbox[..., :2] +=delta_face_bbox_xy_unsig[..., :2] layer_ref_unsig_face_bbox[..., 2:] +=delta_face_bbox_hw_unsig layer_ref_sig_face_bbox = layer_ref_unsig_face_bbox.sigmoid() outputs_face_bbox_list.append(layer_ref_sig_face_bbox) # smpl or smplx bs, _, feat_dim = layer_hs.shape smpl_feats = layer_hs[:, effective_dn_number:, :].index_select( 1, torch.tensor(smpl_pose_index, device=layer_hs.device) ).reshape(bs, -1, feat_dim * 4) smpl_lhand_pose_feats = \ layer_hs[:, effective_dn_number:, :].index_select( 1, torch.tensor(smpl_lhand_pose_index, device=layer_hs.device) ).reshape(bs, -1, feat_dim * 2) smpl_rhand_pose_feats = layer_hs[:, effective_dn_number:, :].index_select( 1, torch.tensor(smpl_rhand_pose_index, device=layer_hs.device) ).reshape(bs, -1, feat_dim * 2) smpl_face_pose_feats = layer_hs[:, effective_dn_number:, :].index_select( 1, torch.tensor(smpl_face_pose_index, device=layer_hs.device) ).reshape(bs, -1, feat_dim * 2) smpl_pose = self.smpl_pose_embed[ dec_lid - self.num_box_decoder_layers](smpl_feats) smpl_pose = rot6d_to_rotmat(smpl_pose.reshape(-1, 6)).reshape( bs, self.num_group, self.body_model_joint_num, 3, 3) smpl_lhand_pose = self.smpl_hand_pose_embed[ dec_lid - self.num_box_decoder_layers](smpl_lhand_pose_feats) smpl_lhand_pose = rot6d_to_rotmat(smpl_lhand_pose.reshape( -1, 6)).reshape(bs, self.num_group, 15, 3, 3) smpl_rhand_pose = self.smpl_hand_pose_embed[ dec_lid - self.num_box_decoder_layers](smpl_rhand_pose_feats) smpl_rhand_pose = rot6d_to_rotmat(smpl_rhand_pose.reshape( -1, 6)).reshape(bs, self.num_group, 15, 3, 3) smpl_jaw_pose = self.smpl_jaw_embed[ dec_lid - self.num_box_decoder_layers](smpl_face_pose_feats) smpl_jaw_pose = rot6d_to_rotmat(smpl_jaw_pose.reshape(-1, 6)).reshape( bs, self.num_group, 1, 3, 3) smpl_beta = self.smpl_beta_embed[ dec_lid - self.num_box_decoder_layers](smpl_feats) smpl_cam = self.smpl_cam_embed[ dec_lid - self.num_box_decoder_layers](smpl_feats) smpl_expr = self.smpl_expr_embed[ dec_lid - self.num_box_decoder_layers](smpl_face_pose_feats) # smpl_jaw_pose = layer_hs.new_zeros(bs, self.num_group, 3) leye_pose = torch.zeros_like(smpl_jaw_pose) reye_pose = torch.zeros_like(smpl_jaw_pose) if self.body_model is not None: smpl_pose_ = rotmat_to_aa(smpl_pose) # smpl_lhand_pose_ = rotmat_to_aa(smpl_lhand_pose) # smpl_rhand_pose_ = rotmat_to_aa(smpl_rhand_pose) smpl_lhand_pose_ = layer_hs.new_zeros(bs, self.num_group, 15, 3) smpl_rhand_pose_ = layer_hs.new_zeros(bs, self.num_group, 15, 3) smpl_jaw_pose_ = rotmat_to_aa(smpl_jaw_pose) leye_pose_ = rotmat_to_aa(leye_pose) reye_pose_ = rotmat_to_aa(reye_pose) pred_output = self.body_model( betas=smpl_beta.reshape(-1, 10), body_pose=smpl_pose_[:, :, 1:].reshape(-1, 21 * 3), global_orient=smpl_pose_[:, :, 0].reshape( -1, 3).unsqueeze(1), left_hand_pose=smpl_lhand_pose_.reshape(-1, 15 * 3), right_hand_pose=smpl_rhand_pose_.reshape(-1, 15 * 3), leye_pose=leye_pose_, reye_pose=reye_pose_, jaw_pose=smpl_jaw_pose_.reshape(-1, 3), # expression=smpl_expr.reshape(-1, 10), expression=layer_hs.new_zeros(bs, self.num_group, 10).reshape(-1, 10) ) smpl_kp3d = pred_output['joints'].reshape( bs, self.num_group, -1, 3) smpl_verts = pred_output['vertices'].reshape( bs, self.num_group, -1, 3) # pred_vertices = pred_output['vertices'].reshape(bs, -1, 6890, 3) outputs_smpl_pose_list.append(smpl_pose) outputs_smpl_rhand_pose_list.append(smpl_rhand_pose) outputs_smpl_lhand_pose_list.append(smpl_lhand_pose) outputs_smpl_expr_list.append(smpl_expr) outputs_smpl_jaw_pose_list.append(smpl_jaw_pose) outputs_smpl_beta_list.append(smpl_beta) outputs_smpl_cam_list.append(smpl_cam) outputs_smpl_kp3d_list.append(smpl_kp3d) else: bs = layer_ref_sig.shape[0] # lhand bbox layer_hs_lhand_bbox = \ layer_hs[:, effective_dn_number:, :][:, 1::4, :] delta_lhand_bbox_xy_unsig = self.bbox_hand_embed[ dec_lid - self.num_box_decoder_layers](layer_hs_lhand_bbox) layer_ref_sig_lhand_bbox = \ layer_ref_sig[:,effective_dn_number:, :][:, 1::4, :].clone() layer_ref_unsig_lhand_bbox = inverse_sigmoid(layer_ref_sig_lhand_bbox) delta_lhand_bbox_hw_unsig = self.bbox_hand_hw_embed[ dec_lid-self.num_box_decoder_layers](layer_hs_lhand_bbox) layer_ref_unsig_lhand_bbox[..., :2] +=delta_lhand_bbox_xy_unsig[..., :2] layer_ref_unsig_lhand_bbox[..., 2:] +=delta_lhand_bbox_hw_unsig layer_ref_sig_lhand_bbox = layer_ref_unsig_lhand_bbox.sigmoid() outputs_lhand_bbox_list.append(layer_ref_sig_lhand_bbox) # rhand bbox layer_hs_rhand_bbox = \ layer_hs[:, effective_dn_number:, :][:, 2::4, :] delta_rhand_bbox_xy_unsig = self.bbox_hand_embed[ dec_lid - self.num_box_decoder_layers](layer_hs_rhand_bbox) layer_ref_sig_rhand_bbox = \ layer_ref_sig[:,effective_dn_number:, :][:, 2::4, :].clone() layer_ref_unsig_rhand_bbox = inverse_sigmoid(layer_ref_sig_rhand_bbox) delta_rhand_bbox_hw_unsig = self.bbox_hand_hw_embed[ dec_lid-self.num_box_decoder_layers](layer_hs_rhand_bbox) layer_ref_unsig_rhand_bbox[..., :2] +=delta_rhand_bbox_xy_unsig[..., :2] layer_ref_unsig_rhand_bbox[..., 2:] +=delta_rhand_bbox_hw_unsig layer_ref_sig_rhand_bbox = layer_ref_unsig_rhand_bbox.sigmoid() outputs_rhand_bbox_list.append(layer_ref_sig_rhand_bbox) # face bbox layer_hs_face_bbox = \ layer_hs[:, effective_dn_number:, :][:, 3::4, :] delta_face_bbox_xy_unsig = \ self.bbox_face_embed[dec_lid - self.num_box_decoder_layers](layer_hs_face_bbox) layer_ref_sig_face_bbox = \ layer_ref_sig[:,effective_dn_number:, :][:, 3::4, :].clone() layer_ref_unsig_face_bbox = inverse_sigmoid(layer_ref_sig_face_bbox) delta_face_bbox_hw_unsig = self.bbox_face_hw_embed[ dec_lid-self.num_box_decoder_layers](layer_hs_face_bbox) layer_ref_unsig_face_bbox[..., :2] +=delta_face_bbox_xy_unsig[..., :2] layer_ref_unsig_face_bbox[..., 2:] +=delta_face_bbox_hw_unsig layer_ref_sig_face_bbox = layer_ref_unsig_face_bbox.sigmoid() outputs_face_bbox_list.append(layer_ref_sig_face_bbox) bs, _, feat_dim = layer_hs.shape smpl_body_pose_feats = layer_hs[:, effective_dn_number:, :].index_select( 1, torch.tensor(smpl_pose_index, device=layer_hs.device) ).reshape(bs, -1, feat_dim * 4) smpl_lhand_pose_feats = layer_hs[:, effective_dn_number:, :].index_select( 1, torch.tensor(smpl_lhand_pose_index, device=layer_hs.device) ).reshape(bs, -1, feat_dim * 2) smpl_rhand_pose_feats = layer_hs[:, effective_dn_number:, :].index_select( 1, torch.tensor(smpl_rhand_pose_index, device=layer_hs.device) ).reshape(bs, -1, feat_dim * 2) smpl_face_pose_feats = layer_hs[:, effective_dn_number:, :].index_select( 1, torch.tensor(smpl_face_pose_index, device=layer_hs.device) ).reshape(bs, -1, feat_dim * 2) smpl_pose = self.smpl_pose_embed[ dec_lid - self.num_box_decoder_layers](smpl_body_pose_feats) smpl_pose = rot6d_to_rotmat(smpl_pose.reshape(-1, 6)).reshape( bs, self.num_group, self.body_model_joint_num, 3, 3) smpl_lhand_pose = self.smpl_hand_pose_embed[ dec_lid - self.num_box_decoder_layers](smpl_lhand_pose_feats) smpl_lhand_pose = rot6d_to_rotmat(smpl_lhand_pose.reshape( -1, 6)).reshape(bs, self.num_group, 15, 3, 3) smpl_rhand_pose = self.smpl_hand_pose_embed[ dec_lid - self.num_box_decoder_layers](smpl_rhand_pose_feats) smpl_rhand_pose = rot6d_to_rotmat(smpl_rhand_pose.reshape( -1, 6)).reshape(bs, self.num_group, 15, 3, 3) smpl_expr = self.smpl_expr_embed[ dec_lid - self.num_box_decoder_layers](smpl_face_pose_feats) smpl_jaw_pose = self.smpl_jaw_embed[ dec_lid - self.num_box_decoder_layers](smpl_face_pose_feats) smpl_jaw_pose = rot6d_to_rotmat(smpl_jaw_pose.reshape(-1, 6)).reshape( bs, self.num_group, 1, 3, 3) smpl_beta = self.smpl_beta_embed[ dec_lid - self.num_box_decoder_layers](smpl_body_pose_feats) smpl_cam = self.smpl_cam_embed[ dec_lid - self.num_box_decoder_layers](smpl_body_pose_feats) num_samples = smpl_beta.reshape(-1, 10).shape[0] device = smpl_beta.device leye_pose = torch.zeros_like(smpl_jaw_pose) reye_pose = torch.zeros_like(smpl_jaw_pose) if self.body_model is not None: smpl_pose_ = rotmat_to_aa(smpl_pose) smpl_lhand_pose_ = rotmat_to_aa(smpl_lhand_pose) smpl_rhand_pose_ = rotmat_to_aa(smpl_rhand_pose) smpl_jaw_pose_ = rotmat_to_aa(smpl_jaw_pose) leye_pose_ = rotmat_to_aa(leye_pose) reye_pose_ = rotmat_to_aa(reye_pose) pred_output = self.body_model( betas=smpl_beta.reshape(-1, 10), body_pose=smpl_pose_[:, :, 1:].reshape(-1, 21 * 3), global_orient=smpl_pose_[:, :, 0].reshape( -1, 3).unsqueeze(1), left_hand_pose=smpl_lhand_pose_.reshape(-1, 15 * 3), right_hand_pose=smpl_rhand_pose_.reshape(-1, 15 * 3), leye_pose=leye_pose_, reye_pose=reye_pose_, jaw_pose=smpl_jaw_pose_.reshape(-1, 3), expression=smpl_expr.reshape(-1, 10), # expression=layer_hs.new_zeros(bs, self.num_group, 10).reshape(-1, 10), ) smpl_kp3d = pred_output['joints'].reshape( bs, self.num_group, -1, 3) smpl_verts = pred_output['vertices'].reshape( bs, self.num_group, -1, 3) outputs_smpl_pose_list.append(smpl_pose) outputs_smpl_rhand_pose_list.append(smpl_rhand_pose) outputs_smpl_lhand_pose_list.append(smpl_lhand_pose) outputs_smpl_expr_list.append(smpl_expr) outputs_smpl_jaw_pose_list.append(smpl_jaw_pose) outputs_smpl_beta_list.append(smpl_beta) outputs_smpl_cam_list.append(smpl_cam) outputs_smpl_kp3d_list.append(smpl_kp3d) if not self.training: outputs_smpl_verts_list.append(smpl_verts) dn_mask_dict = mask_dict if self.dn_number > 0 and dn_mask_dict is not None: outputs_class, outputs_body_bbox_list = self.dn_post_process2( outputs_class, outputs_body_bbox_list, dn_mask_dict) dn_class_input = dn_mask_dict['known_labels'] dn_bbox_input = dn_mask_dict['known_bboxs'] dn_class_pred = dn_mask_dict['output_known_class'] dn_bbox_pred = dn_mask_dict['output_known_coord'] for idx, (_out_class, _out_bbox) in enumerate(zip(outputs_class, outputs_body_bbox_list)): assert _out_class.shape[1] == _out_bbox.shape[1] out = { 'pred_logits': outputs_class[-1], 'pred_boxes': outputs_body_bbox_list[-1], 'pred_lhand_boxes': outputs_lhand_bbox_list[-1], 'pred_rhand_boxes': outputs_rhand_bbox_list[-1], 'pred_face_boxes': outputs_face_bbox_list[-1], 'pred_smpl_pose': outputs_smpl_pose_list[-1], 'pred_smpl_rhand_pose': outputs_smpl_rhand_pose_list[-1], 'pred_smpl_lhand_pose': outputs_smpl_lhand_pose_list[-1], 'pred_smpl_jaw_pose': outputs_smpl_jaw_pose_list[-1], 'pred_smpl_expr': outputs_smpl_expr_list[-1], 'pred_smpl_beta': outputs_smpl_beta_list[-1], # [B, 100, 10] 'pred_smpl_cam': outputs_smpl_cam_list[-1], 'pred_smpl_kp3d': outputs_smpl_kp3d_list[-1] } if not self.training: full_pose = torch.cat((outputs_smpl_pose_list[-1], outputs_smpl_lhand_pose_list[-1], outputs_smpl_rhand_pose_list[-1], outputs_smpl_jaw_pose_list[-1]),dim=2) bs,num_q,_,_,_ = full_pose.shape full_pose = rotmat_to_aa(full_pose).reshape(bs,num_q,53*3) out = { 'pred_logits': outputs_class[-1], 'pred_boxes': outputs_body_bbox_list[-1], 'pred_lhand_boxes': outputs_lhand_bbox_list[-1], 'pred_rhand_boxes': outputs_rhand_bbox_list[-1], 'pred_face_boxes': outputs_face_bbox_list[-1], 'pred_smpl_pose': outputs_smpl_pose_list[-1], 'pred_smpl_rhand_pose': outputs_smpl_rhand_pose_list[-1], 'pred_smpl_lhand_pose': outputs_smpl_lhand_pose_list[-1], 'pred_smpl_jaw_pose': outputs_smpl_jaw_pose_list[-1], 'pred_smpl_expr': outputs_smpl_expr_list[-1], 'pred_smpl_beta': outputs_smpl_beta_list[-1], # [B, 100, 10] 'pred_smpl_cam': outputs_smpl_cam_list[-1], 'pred_smpl_kp3d': outputs_smpl_kp3d_list[-1], 'pred_smpl_verts': outputs_smpl_verts_list[-1], 'pred_smpl_fullpose': full_pose } if self.dn_number > 0 and dn_mask_dict is not None: out.update({ 'dn_class_input': dn_class_input, 'dn_bbox_input': dn_bbox_input, 'dn_class_pred': dn_class_pred[-1], 'dn_bbox_pred': dn_bbox_pred[-1], 'num_tgt': dn_mask_dict['pad_size'] }) if self.aux_loss: out['aux_outputs'] = \ self._set_aux_loss( outputs_class, outputs_body_bbox_list, outputs_lhand_bbox_list, outputs_rhand_bbox_list, outputs_face_bbox_list, outputs_smpl_pose_list, outputs_smpl_rhand_pose_list, outputs_smpl_lhand_pose_list, outputs_smpl_jaw_pose_list, outputs_smpl_expr_list, outputs_smpl_beta_list, outputs_smpl_cam_list, outputs_smpl_kp3d_list ) # with key pred_logits, pred_bbox, pred_keypoints if self.dn_number > 0 and dn_mask_dict is not None: assert len(dn_class_pred[:-1]) == len( dn_bbox_pred[:-1]) == len(out['aux_outputs']) for aux_out, dn_class_pred_i, dn_bbox_pred_i in zip( out['aux_outputs'], dn_class_pred, dn_bbox_pred): aux_out.update({ 'dn_class_input': dn_class_input, 'dn_bbox_input': dn_bbox_input, 'dn_class_pred': dn_class_pred_i, 'dn_bbox_pred': dn_bbox_pred_i, 'num_tgt': dn_mask_dict['pad_size'] }) # for encoder output if hs_enc is not None: interm_coord = ref_enc[-1] interm_class = self.transformer.enc_out_class_embed(hs_enc[-1]) interm_pose = torch.zeros_like(outputs_body_keypoints_list[0]) out['interm_outputs'] = { 'pred_logits': interm_class, 'pred_boxes': interm_coord, 'pred_keypoints': interm_pose } return out, targets, data_batch @torch.jit.unused def _set_aux_loss(self, outputs_class, outputs_body_coord, outputs_lhand_coord, outputs_rhand_coord, outputs_face_coord, outputs_smpl_pose, outputs_smpl_rhand_pose, outputs_smpl_lhand_pose, outputs_smpl_jaw_pose, outputs_smpl_expr, outputs_smpl_beta, outputs_smpl_cam, outputs_smpl_kp3d): return [{ 'pred_logits': a, 'pred_boxes': b, 'pred_lhand_boxes': c, 'pred_rhand_boxes': d, 'pred_face_boxes': e, 'pred_smpl_pose': j, 'pred_smpl_rhand_pose': k, 'pred_smpl_lhand_pose': l, 'pred_smpl_jaw_pose': m, 'pred_smpl_expr': n, 'pred_smpl_beta': o, 'pred_smpl_cam': p, 'pred_smpl_kp3d': q } for a, b, c, d, e, j, k, l, m, n, o, p, q in zip( outputs_class[:-1], outputs_body_coord[:-1], outputs_lhand_coord[:-1], outputs_rhand_coord[:-1], outputs_face_coord[:-1], outputs_smpl_pose[:-1], outputs_smpl_rhand_pose[:-1], outputs_smpl_lhand_pose[:-1], outputs_smpl_jaw_pose[:-1], outputs_smpl_expr[:-1], outputs_smpl_beta[:-1], outputs_smpl_cam[:-1], outputs_smpl_kp3d[:-1])] def prepare_targets(self, data_batch): data_batch_coco = [] instance_dict = {} img_list = data_batch['img'].float() # input_img_h, input_img_w = data_batch['image_metas'][0]['batch_input_shape'] batch_size, _, input_img_h, input_img_w = img_list.shape device = img_list.device masks = torch.ones((batch_size, input_img_h, input_img_w), dtype=torch.bool, device=device) if self.num_body_points == 17: ed_convention = 'coco' elif self.num_body_points == 14: ed_convention = 'crowdpose' # cv2.imread(data_batch['img_metas'][img_id]['image_path']).shape for img_id in range(batch_size): img_h, img_w = data_batch['img_shape'][img_id] masks[img_id, :img_h, :img_w] = 0 if not self.inference: instance_body_bbox = torch.cat([data_batch['body_bbox_center'][img_id],\ data_batch['body_bbox_size'][img_id]],dim=-1) instance_face_bbox = torch.cat([data_batch['face_bbox_center'][img_id],\ data_batch['face_bbox_size'][img_id]],dim=-1) instance_lhand_bbox = torch.cat([data_batch['lhand_bbox_center'][img_id],\ data_batch['lhand_bbox_size'][img_id]],dim=-1) instance_rhand_bbox = torch.cat([data_batch['rhand_bbox_center'][img_id],\ data_batch['rhand_bbox_size'][img_id]],dim=-1) instance_kp2d = data_batch['joint_img'][img_id].clone().float() instance_kp2d_mask = data_batch['joint_trunc'][img_id].clone().float() instance_kp2d[:,:,2:] = instance_kp2d_mask body_kp2d, _ = convert_kps(instance_kp2d, 'smplx_137', 'coco', approximate=True) lhand_kp2d, _ = convert_kps(instance_kp2d, 'smplx_137', 'smplx_lhand', approximate=True) rhand_kp2d, _ = convert_kps(instance_kp2d, 'smplx_137', 'smplx_rhand', approximate=True) face_kp2d, _ = convert_kps(instance_kp2d, 'smplx_137', 'smplx_face', approximate=True) # from util.vis_utils import show_bbox # show_bbox(img_list[img_id],instance_kp2d.cpu().numpy(),data_batch['bbox_xywh'][img_id].cpu().numpy) body_kp2d[:,:,0] = body_kp2d[:,:,0]/cfg.output_hm_shape[2] body_kp2d[:,:,1] = body_kp2d[:,:,1]/cfg.output_hm_shape[1] body_kp2d = torch.cat([body_kp2d[:,:,:2].flatten(1),body_kp2d[:,:,2]],dim=-1) lhand_kp2d[:,:,0] = lhand_kp2d[:,:,0]/cfg.output_hm_shape[2] lhand_kp2d[:,:,1] = lhand_kp2d[:,:,1]/cfg.output_hm_shape[1] lhand_kp2d = torch.cat([lhand_kp2d[:,:,:2].flatten(1),lhand_kp2d[:,:,2]],dim=-1) rhand_kp2d[:,:,0] = rhand_kp2d[:,:,0]/cfg.output_hm_shape[2] rhand_kp2d[:,:,1] = rhand_kp2d[:,:,1]/cfg.output_hm_shape[1] rhand_kp2d = torch.cat([rhand_kp2d[:,:,:2].flatten(1),rhand_kp2d[:,:,2]],dim=-1) face_kp2d[:,:,0] = face_kp2d[:,:,0]/cfg.output_hm_shape[2] face_kp2d[:,:,1] = face_kp2d[:,:,1]/cfg.output_hm_shape[1] face_kp2d = torch.cat([face_kp2d[:,:,:2].flatten(1),face_kp2d[:,:,2]],dim=-1) instance_dict = {} instance_dict['boxes'] = instance_body_bbox.float() instance_dict['face_boxes'] = instance_face_bbox.float() instance_dict['lhand_boxes'] = instance_lhand_bbox.float() instance_dict['rhand_boxes'] = instance_rhand_bbox.float() instance_dict['keypoints'] = body_kp2d.float() instance_dict['lhand_keypoints'] = lhand_kp2d.float() instance_dict['rhand_keypoints'] = rhand_kp2d.float() instance_dict['face_keypoints'] = face_kp2d.float() # instance_dict['orig_size'] = data_batch['ori_shape'][img_id] instance_dict['size'] = data_batch['img_shape'][img_id] # after augmentation instance_dict['area'] = instance_body_bbox[:, 2] * instance_body_bbox[:, 3] instance_dict['lhand_area'] = instance_lhand_bbox[:, 2] * instance_lhand_bbox[:, 3] instance_dict['rhand_area'] = instance_rhand_bbox[:, 2] * instance_rhand_bbox[:, 3] instance_dict['face_area'] = instance_face_bbox[:, 2] * instance_face_bbox[:, 3] instance_dict['labels'] = torch.ones(instance_body_bbox.shape[0], dtype=torch.long, device=device) data_batch_coco.append(instance_dict) else: instance_body_bbox = torch.cat([data_batch['body_bbox_center'][img_id],\ data_batch['body_bbox_size'][img_id]],dim=-1) instance_dict = {} # instance_dict['orig_size'] = data_batch['ori_shape'][img_id] instance_dict['size'] = data_batch['img_shape'][img_id] # after augmentation instance_dict['boxes'] = instance_body_bbox.float() data_batch_coco.append(instance_dict) input_img = NestedTensor(img_list, masks) return input_img, data_batch_coco def keypoints_to_scaled_bbox_bfh( self, keypoints, occ=None, body_scale=1.0, fh_scale=1.0, convention='smplx'): '''Obtain scaled bbox in xyxy format given keypoints Args: keypoints (np.ndarray): Keypoints scale (float): Bounding Box scale Returns: bbox_xyxy (np.ndarray): Bounding box in xyxy format ''' bboxs = [] # supported kps.shape: (1, n, k) or (n, k), k = 2 or 3 if keypoints.ndim == 3: keypoints = keypoints[0] if keypoints.shape[-1] != 2: keypoints = keypoints[:, :2] for body_part in ['body', 'head', 'left_hand', 'right_hand']: if body_part == 'body': scale = body_scale kps = keypoints else: scale = fh_scale kp_id = get_keypoint_idxs_by_part(body_part, convention=convention) kps = keypoints[kp_id] if not occ is None: occ_p = occ[kp_id] if np.sum(occ_p) / len(kp_id) >= 0.1: conf = 0 # print(f'{body_part} occluded, occlusion: {np.sum(occ_p) / len(kp_id)}, skip') else: # print(f'{body_part} good, {np.sum(self_occ_p + occ_p) / len(kp_id)}') conf = 1 else: conf = 1 if body_part == 'body': conf = 1 xmin, ymin = np.amin(kps, axis=0) xmax, ymax = np.amax(kps, axis=0) width = (xmax - xmin) * scale height = (ymax - ymin) * scale x_center = 0.5 * (xmax + xmin) y_center = 0.5 * (ymax + ymin) xmin = x_center - 0.5 * width xmax = x_center + 0.5 * width ymin = y_center - 0.5 * height ymax = y_center + 0.5 * height bbox = np.stack([xmin, ymin, xmax, ymax, conf], axis=0).astype(np.float32) bboxs.append(bbox) return bboxs @MODULE_BUILD_FUNCS.registe_with_name(module_name='aios_smplx_box') def build_aios_smplx_box(args, cfg): # pdb.set_trace() num_classes = args.num_classes # 2 device = torch.device(args.device) backbone = build_backbone(args) transformer = build_transformer(args) dn_labelbook_size = args.dn_labelbook_size dec_pred_class_embed_share = args.dec_pred_class_embed_share dec_pred_bbox_embed_share = args.dec_pred_bbox_embed_share if args.eval: body_model = args.body_model_test train = False else: body_model = args.body_model_train train = True model = AiOSSMPLX_Box( backbone, transformer, num_classes=num_classes, # 2 num_queries=args.num_queries, # 900 aux_loss=True, iter_update=True, query_dim=4, random_refpoints_xy=args.random_refpoints_xy, # False fix_refpoints_hw=args.fix_refpoints_hw, # -1 num_feature_levels=args.num_feature_levels, # 4 nheads=args.nheads, # 8 dec_pred_class_embed_share=dec_pred_class_embed_share, # false dec_pred_bbox_embed_share=dec_pred_bbox_embed_share, # False # two stage two_stage_type=args.two_stage_type, # box_share two_stage_bbox_embed_share=args.two_stage_bbox_embed_share, # False two_stage_class_embed_share=args.two_stage_class_embed_share, # False dn_number=args.dn_number if args.use_dn else 0, # 100 dn_box_noise_scale=args.dn_box_noise_scale, # 0.4 dn_label_noise_ratio=args.dn_label_noise_ratio, # 0.5 dn_batch_gt_fuse=args.dn_batch_gt_fuse, # false dn_attn_mask_type_list=args.dn_attn_mask_type_list, dn_labelbook_size=dn_labelbook_size, # 100 cls_no_bias=args.cls_no_bias, # False num_group=args.num_group, # 100 num_body_points=0, # 17 num_hand_points=0, # 17 num_face_points=0, # 17 num_box_decoder_layers=args.num_box_decoder_layers, # 2 num_hand_face_decoder_layers=args.num_hand_face_decoder_layers, # smpl_convention=convention body_model=body_model, train=train, inference=args.inference) matcher = build_matcher(args) # prepare weight dict weight_dict = { 'loss_ce': args.cls_loss_coef, # 2 # bbox 'loss_body_bbox': args.body_bbox_loss_coef, # 5 'loss_rhand_bbox': args.rhand_bbox_loss_coef, # 5 'loss_lhand_bbox': args.lhand_bbox_loss_coef, # 5 'loss_face_bbox': args.face_bbox_loss_coef, # 5 # bbox giou 'loss_body_giou': args.body_giou_loss_coef, # 2 'loss_rhand_giou': args.rhand_giou_loss_coef, # 2 'loss_lhand_giou': args.lhand_giou_loss_coef, # 2 'loss_face_giou': args.face_giou_loss_coef, # 2 # smpl param 'loss_smpl_pose_root': args.smpl_pose_loss_root_coef, # 0 'loss_smpl_pose_body': args.smpl_pose_loss_body_coef, # 0 'loss_smpl_pose_lhand': args.smpl_pose_loss_lhand_coef, # 0 'loss_smpl_pose_rhand': args.smpl_pose_loss_rhand_coef, # 0 'loss_smpl_pose_jaw': args.smpl_pose_loss_jaw_coef, # 0 'loss_smpl_beta': args.smpl_beta_loss_coef, # 0 'loss_smpl_expr': args.smpl_expr_loss_coef, # smpl kp3d ra 'loss_smpl_body_kp3d_ra': args.smpl_body_kp3d_ra_loss_coef, # 0 'loss_smpl_lhand_kp3d_ra': args.smpl_lhand_kp3d_ra_loss_coef, # 0 'loss_smpl_rhand_kp3d_ra': args.smpl_rhand_kp3d_ra_loss_coef, # 0 'loss_smpl_face_kp3d_ra': args.smpl_face_kp3d_ra_loss_coef, # 0 # smpl kp3d 'loss_smpl_body_kp3d': args.smpl_body_kp3d_loss_coef, # 0 'loss_smpl_face_kp3d': args.smpl_face_kp3d_loss_coef, # 0 'loss_smpl_lhand_kp3d': args.smpl_lhand_kp3d_loss_coef, # 0 'loss_smpl_rhand_kp3d': args.smpl_rhand_kp3d_loss_coef, # 0 # smpl kp2d 'loss_smpl_body_kp2d': args.smpl_body_kp2d_loss_coef, # 0 'loss_smpl_lhand_kp2d': args.smpl_lhand_kp2d_loss_coef, # 0 'loss_smpl_rhand_kp2d': args.smpl_rhand_kp2d_loss_coef, # 0 'loss_smpl_face_kp2d': args.smpl_face_kp2d_loss_coef, # 0 } clean_weight_dict_wo_dn = copy.deepcopy(weight_dict) if args.use_dn: weight_dict.update({ 'dn_loss_ce': args.dn_label_coef, # 0.3 'dn_loss_bbox': args.bbox_loss_coef * args.dn_bbox_coef, # 5 * 0.5 'dn_loss_giou': args.giou_loss_coef * args.dn_bbox_coef, # 2 * 0.5 }) clean_weight_dict = copy.deepcopy(weight_dict) if args.aux_loss: aux_weight_dict = {} for i in range(args.dec_layers - 1): # from 0 t 4 # ??? for k, v in clean_weight_dict.items(): if i < args.num_box_decoder_layers and ('keypoints' in k or 'oks' in k): continue if i < args.num_box_decoder_layers and k in [ 'loss_rhand_bbox', 'loss_lhand_bbox', 'loss_face_bbox', 'loss_rhand_giou', 'loss_lhand_giou', 'loss_face_giou']: continue if i < args.num_hand_face_decoder_layers and k in [ 'loss_rhand_keypoints', 'loss_lhand_keypoints', 'loss_face_keypoints', 'loss_rhand_oks', 'loss_lhand_oks', 'loss_face_oks']: continue if i < args.num_box_decoder_layers and 'smpl' in k: continue aux_weight_dict.update({k + f'_{i}': v}) weight_dict.update(aux_weight_dict) if args.two_stage_type != 'no': interm_weight_dict = {} try: no_interm_box_loss = args.no_interm_box_loss except: no_interm_box_loss = False _coeff_weight_dict = { 'loss_ce': 1.0, # bbox 'loss_body_bbox': 1.0 if not no_interm_box_loss else 0.0, 'loss_rhand_bbox': 1.0 if not no_interm_box_loss else 0.0, 'loss_lhand_bbox': 1.0 if not no_interm_box_loss else 0.0, 'loss_face_bbox': 1.0 if not no_interm_box_loss else 0.0, # bbox giou 'loss_body_giou': 1.0 if not no_interm_box_loss else 0.0, 'loss_rhand_giou': 1.0 if not no_interm_box_loss else 0.0, 'loss_lhand_giou': 1.0 if not no_interm_box_loss else 0.0, 'loss_face_giou': 1.0 if not no_interm_box_loss else 0.0, # smpl param 'loss_smpl_pose_root': 1.0 if not no_interm_box_loss else 0.0, 'loss_smpl_pose_body': 1.0 if not no_interm_box_loss else 0.0, 'loss_smpl_pose_lhand': 1.0 if not no_interm_box_loss else 0.0, 'loss_smpl_pose_rhand': 1.0 if not no_interm_box_loss else 0.0, 'loss_smpl_pose_jaw': 1.0 if not no_interm_box_loss else 0.0, 'loss_smpl_beta': 1.0 if not no_interm_box_loss else 0.0, 'loss_smpl_expr': 1.0 if not no_interm_box_loss else 0.0, # smpl kp3d ra 'loss_smpl_body_kp3d_ra': 1.0 if not no_interm_box_loss else 0.0, 'loss_smpl_lhand_kp3d_ra': 1.0 if not no_interm_box_loss else 0.0, 'loss_smpl_rhand_kp3d_ra': 1.0 if not no_interm_box_loss else 0.0, 'loss_smpl_face_kp3d_ra': 1.0 if not no_interm_box_loss else 0.0, # smpl kp3d 'loss_smpl_body_kp3d': 1.0 if not no_interm_box_loss else 0.0, 'loss_smpl_face_kp3d': 1.0 if not no_interm_box_loss else 0.0, 'loss_smpl_lhand_kp3d': 1.0 if not no_interm_box_loss else 0.0, 'loss_smpl_rhand_kp3d': 1.0 if not no_interm_box_loss else 0.0, # smpl kp2d 'loss_smpl_body_kp2d': 1.0 if not no_interm_box_loss else 0.0, 'loss_smpl_lhand_kp2d': 1.0 if not no_interm_box_loss else 0.0, 'loss_smpl_rhand_kp2d': 1.0 if not no_interm_box_loss else 0.0, 'loss_smpl_face_kp2d': 1.0 if not no_interm_box_loss else 0.0, } try: interm_loss_coef = args.interm_loss_coef # 1 except: interm_loss_coef = 1.0 interm_weight_dict.update({ k + f'_interm': v * interm_loss_coef * _coeff_weight_dict[k] for k, v in clean_weight_dict_wo_dn.items() if 'keypoints' not in k }) weight_dict.update(interm_weight_dict) interm_weight_dict.update({ k + f'_query_expand': v * interm_loss_coef * _coeff_weight_dict[k] for k, v in clean_weight_dict_wo_dn.items() }) # ??? weight_dict.update(interm_weight_dict) losses = cfg.losses if args.dn_number > 0: losses += ['dn_label', 'dn_bbox'] losses += ['matching'] criterion = SetCriterion_Box( num_classes, matcher=matcher, weight_dict=weight_dict, focal_alpha=args.focal_alpha, losses=losses, num_box_decoder_layers=args.num_box_decoder_layers, num_hand_face_decoder_layers=args.num_hand_face_decoder_layers, num_body_points=0, num_hand_points=0, num_face_points=0, ) criterion.to(device) if args.inference: postprocessors = { 'bbox': PostProcess_SMPLX_Multi_Infer_Box( num_select=args.num_select, nms_iou_threshold=args.nms_iou_threshold, num_body_points=0), } else: postprocessors = { 'bbox': PostProcess_SMPLX_Multi_Box( num_select=args.num_select, nms_iou_threshold=args.nms_iou_threshold, num_body_points=0), } postprocessors_aios = { 'bbox': PostProcess_aios(num_select=args.num_select, nms_iou_threshold=args.nms_iou_threshold, num_body_points=0), } return model, criterion, postprocessors, postprocessors_aios