import math, random import copy import os from typing import Optional, List, Union import warnings from util.misc import inverse_sigmoid import torch import torch.nn.functional as F from torch import nn, Tensor from .transformer_deformable import DeformableTransformerEncoderLayer, DeformableTransformerDecoderLayer from .utils import gen_encoder_output_proposals, sigmoid_focal_loss, MLP, _get_activation_fn, gen_sineembed_for_position from .ops.modules.ms_deform_attn import MSDeformAttn import pdb class Transformer(nn.Module): def __init__( self, d_model=256, nhead=8, num_queries=300, num_encoder_layers=6, num_decoder_layers=6, dim_feedforward=2048, dropout=0.0, activation='relu', normalize_before=False, return_intermediate_dec=False, query_dim=4, num_patterns=0, modulate_hw_attn=False, # for deformable encoder deformable_encoder=False, deformable_decoder=False, num_feature_levels=1, enc_n_points=4, dec_n_points=4, # init query learnable_tgt_init=False, random_refpoints_xy=False, # two stage two_stage_type='no', two_stage_learn_wh=False, two_stage_keep_all_tokens=False, # evo of #anchors dec_layer_number=None, rm_self_attn_layers=None, # for detach rm_detach=None, decoder_sa_type='sa', module_seq=['sa', 'ca', 'ffn'], # for pose embed_init_tgt=False, num_body_points=17, num_hand_points=10, num_face_points=10, num_box_decoder_layers=2, num_hand_face_decoder_layers=4, num_group=100): super().__init__() # pdb.set_trace() self.num_feature_levels = num_feature_levels # 4 self.num_encoder_layers = num_encoder_layers # 6 self.num_decoder_layers = num_decoder_layers # 6 self.deformable_encoder = deformable_encoder self.deformable_decoder = deformable_decoder self.two_stage_keep_all_tokens = two_stage_keep_all_tokens # False self.num_queries = num_queries # 900 self.random_refpoints_xy = random_refpoints_xy # False assert query_dim == 4 if num_feature_levels > 1: assert deformable_encoder, 'only support deformable_encoder for num_feature_levels > 1' self.decoder_sa_type = decoder_sa_type # sa assert decoder_sa_type in ['sa', 'ca_label', 'ca_content'] # choose encoder layer type if deformable_encoder: encoder_layer = DeformableTransformerEncoderLayer( d_model, dim_feedforward, dropout, activation, num_feature_levels, nhead, enc_n_points) else: raise NotImplementedError encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, activation, normalize_before) encoder_norm = nn.LayerNorm(d_model) if normalize_before else None self.encoder = TransformerEncoder( encoder_layer, num_encoder_layers, encoder_norm, d_model=d_model, num_queries=num_queries, deformable_encoder=deformable_encoder, two_stage_type=two_stage_type) # choose decoder layer type if deformable_decoder: decoder_layer = DeformableTransformerDecoderLayer( d_model, dim_feedforward, dropout, activation, num_feature_levels, nhead, dec_n_points, decoder_sa_type=decoder_sa_type, module_seq=module_seq) else: raise NotImplementedError decoder_layer = TransformerDecoderLayer( d_model, nhead, dim_feedforward, dropout, activation, normalize_before, num_feature_levels=num_feature_levels) decoder_norm = nn.LayerNorm(d_model) self.decoder = TransformerDecoder( decoder_layer, num_decoder_layers, decoder_norm, return_intermediate=return_intermediate_dec, d_model=d_model, query_dim=query_dim, modulate_hw_attn=modulate_hw_attn, num_feature_levels=num_feature_levels, deformable_decoder=deformable_decoder, dec_layer_number=dec_layer_number, num_body_points=num_body_points, num_hand_points=num_hand_points, num_face_points=num_face_points, num_box_decoder_layers=num_box_decoder_layers, num_hand_face_decoder_layers=num_hand_face_decoder_layers, num_group=num_group, num_dn=num_group, ) self.d_model = d_model self.nhead = nhead # 8 self.dec_layers = num_decoder_layers # 6 self.num_queries = num_queries # useful for single stage model only self.num_patterns = num_patterns # 0 if not isinstance(num_patterns, int): Warning('num_patterns should be int but {}'.format( type(num_patterns))) self.num_patterns = 0 if self.num_patterns > 0: assert two_stage_type == 'no' self.patterns = nn.Embedding(self.num_patterns, d_model) if num_feature_levels > 1: if self.num_encoder_layers > 0: self.level_embed = nn.Parameter( torch.Tensor(num_feature_levels, d_model)) else: self.level_embed = None self.learnable_tgt_init = learnable_tgt_init # true assert learnable_tgt_init, 'why not learnable_tgt_init' self.embed_init_tgt = embed_init_tgt # false if (two_stage_type != 'no' and embed_init_tgt) or (two_stage_type == 'no'): self.tgt_embed = nn.Embedding(self.num_queries, d_model) nn.init.normal_(self.tgt_embed.weight.data) else: self.tgt_embed = None # for two stage self.two_stage_type = two_stage_type self.two_stage_learn_wh = two_stage_learn_wh assert two_stage_type in [ 'no', 'standard', 'early', 'combine', 'enceachlayer', 'enclayer1' ], 'unknown param {} of two_stage_type'.format(two_stage_type) if two_stage_type in [ 'standard', 'combine', 'enceachlayer', 'enclayer1' ]: # anchor selection at the output of encoder self.enc_output = nn.Linear(d_model, d_model) self.enc_output_norm = nn.LayerNorm(d_model) if two_stage_learn_wh: # import pdb; pdb.set_trace() self.two_stage_wh_embedding = nn.Embedding(1, 2) else: self.two_stage_wh_embedding = None if two_stage_type in ['early', 'combine']: # anchor selection at the output of backbone self.enc_output_backbone = nn.Linear(d_model, d_model) self.enc_output_norm_backbone = nn.LayerNorm(d_model) if two_stage_type == 'no': self.init_ref_points(num_queries) # init self.refpoint_embed self.enc_out_class_embed = None self.enc_out_bbox_embed = None self.enc_out_pose_embed = None # evolution of anchors self.dec_layer_number = dec_layer_number if dec_layer_number is not None: if self.two_stage_type != 'no' or num_patterns == 0: assert dec_layer_number[ 0] == num_queries, f'dec_layer_number[0]({dec_layer_number[0]}) != num_queries({num_queries})' else: assert dec_layer_number[ 0] == num_queries * num_patterns, f'dec_layer_number[0]({dec_layer_number[0]}) != num_queries({num_queries}) * num_patterns({num_patterns})' self._reset_parameters() self.rm_self_attn_layers = rm_self_attn_layers if rm_self_attn_layers is not None: # assert len(rm_self_attn_layers) == num_decoder_layers print('Removing the self-attn in {} decoder layers'.format( rm_self_attn_layers)) for lid, dec_layer in enumerate(self.decoder.layers): if lid in rm_self_attn_layers: dec_layer.rm_self_attn_modules() self.rm_detach = rm_detach if self.rm_detach: assert isinstance(rm_detach, list) assert any([i in ['enc_ref', 'enc_tgt', 'dec'] for i in rm_detach]) self.decoder.rm_detach = rm_detach def _reset_parameters(self): for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) for m in self.modules(): if isinstance(m, MSDeformAttn): m._reset_parameters() if self.num_feature_levels > 1 and self.level_embed is not None: nn.init.normal_(self.level_embed) if self.two_stage_learn_wh: nn.init.constant_(self.two_stage_wh_embedding.weight, math.log(0.05 / (1 - 0.05))) def get_valid_ratio(self, mask): _, H, W = mask.shape valid_H = torch.sum(~mask[:, :, 0], 1) valid_W = torch.sum(~mask[:, 0, :], 1) valid_ratio_h = valid_H.float() / H valid_ratio_w = valid_W.float() / W valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1) return valid_ratio def init_ref_points(self, use_num_queries): self.refpoint_embed = nn.Embedding(use_num_queries, 4) if self.random_refpoints_xy: # import pdb; pdb.set_trace() self.refpoint_embed.weight.data[:, :2].uniform_(0, 1) self.refpoint_embed.weight.data[:, :2] = inverse_sigmoid( self.refpoint_embed.weight.data[:, :2]) self.refpoint_embed.weight.data[:, :2].requires_grad = False # srcs: features; refpoint_embed: def forward(self, srcs, masks, refpoint_embed, pos_embeds, tgt, attn_mask=None, attn_mask2=None, attn_mask3=None): # pdb.set_trace() # prepare input for encoder src_flatten = [] mask_flatten = [] lvl_pos_embed_flatten = [] spatial_shapes = [] for lvl, (src, mask, pos_embed) in enumerate( zip(srcs, masks, pos_embeds)): # for feature level bs, c, h, w = src.shape spatial_shape = (h, w) spatial_shapes.append(spatial_shape) src = src.flatten(2).transpose(1, 2) # bs, hw, c mask = mask.flatten(1) # bs, hw pos_embed = pos_embed.flatten(2).transpose(1, 2) # bs, hw, c if self.num_feature_levels > 1 and self.level_embed is not None: lvl_pos_embed = pos_embed + self.level_embed[lvl].view( 1, 1, -1) # level_embed[lvl]: [256] else: lvl_pos_embed = pos_embed lvl_pos_embed_flatten.append(lvl_pos_embed) src_flatten.append(src) mask_flatten.append(mask) src_flatten = torch.cat(src_flatten, 1) # bs, \sum{hxw}, c mask_flatten = torch.cat(mask_flatten, 1) # bs, \sum{hxw} lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) # bs, \sum{hxw}, c spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device) level_start_index = torch.cat((spatial_shapes.new_zeros( (1, )), spatial_shapes.prod(1).cumsum(0)[:-1])) valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1) # two stage if self.two_stage_type in ['early', 'combine']: output_memory, output_proposals = gen_encoder_output_proposals( src_flatten, mask_flatten, spatial_shapes) output_memory = self.enc_output_norm_backbone( self.enc_output_backbone(output_memory)) # gather boxes topk = self.num_queries enc_outputs_class = self.encoder.class_embed[0](output_memory) enc_topk_proposals = torch.topk(enc_outputs_class.max(-1)[0], topk, dim=1)[1] # bs, nq enc_refpoint_embed = torch.gather( output_proposals, 1, enc_topk_proposals.unsqueeze(-1).repeat(1, 1, 4)) src_flatten = output_memory else: enc_topk_proposals = enc_refpoint_embed = None ######################################################### # Begin Encoder ######################################################### memory, enc_intermediate_output, enc_intermediate_refpoints = self.encoder( src_flatten, pos=lvl_pos_embed_flatten, level_start_index=level_start_index, spatial_shapes=spatial_shapes, valid_ratios=valid_ratios, key_padding_mask=mask_flatten, ref_token_index=enc_topk_proposals, # bs, nq ref_token_coord=enc_refpoint_embed, # bs, nq, 4 ) ######################################################### # End Encoder # - memory: bs, \sum{hw}, c # - mask_flatten: bs, \sum{hw} # - lvl_pos_embed_flatten: bs, \sum{hw}, c # - enc_intermediate_output: None or (nenc+1, bs, nq, c) or (nenc, bs, nq, c) # - enc_intermediate_refpoints: None or (nenc+1, bs, nq, c) or (nenc, bs, nq, c) ######################################################### if self.two_stage_type in [ 'standard', 'combine', 'enceachlayer', 'enclayer1' ]: if self.two_stage_learn_wh: # import pdb; pdb.set_trace() input_hw = self.two_stage_wh_embedding.weight[0] else: input_hw = None output_memory, output_proposals = gen_encoder_output_proposals( memory, mask_flatten, spatial_shapes, input_hw) output_memory = self.enc_output_norm( self.enc_output(output_memory)) enc_outputs_class_unselected = self.enc_out_class_embed( output_memory) # [11531, 2] for swin enc_outputs_coord_unselected = self.enc_out_bbox_embed( output_memory ) + output_proposals # (bs, \sum{hw}, 4) unsigmoid topk = self.num_queries topk_proposals = torch.topk( enc_outputs_class_unselected.max(-1)[0], topk, dim=1)[1] # bs, nq coarse human query selection # gather boxes refpoint_embed_undetach = torch.gather( enc_outputs_coord_unselected, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4)) # unsigmoid refpoint_embed_ = refpoint_embed_undetach.detach() init_box_proposal = torch.gather( output_proposals, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4)).sigmoid() # sigmoid # gather tgt tgt_undetach = torch.gather( output_memory, 1, topk_proposals.unsqueeze(-1).repeat( 1, 1, self.d_model)) # selected content query if self.embed_init_tgt: tgt_ = self.tgt_embed.weight[:, None, :].repeat( 1, bs, 1).transpose(0, 1) # nq, bs, d_model else: tgt_ = tgt_undetach.detach() if refpoint_embed is not None: # import pdb; pdb.set_trace() refpoint_embed = torch.cat([refpoint_embed, refpoint_embed_], dim=1) # [1000, 4] tgt = torch.cat([tgt, tgt_], dim=1) else: refpoint_embed, tgt = refpoint_embed_, tgt_ elif self.two_stage_type == 'early': refpoint_embed_undetach = self.enc_out_bbox_embed( enc_intermediate_output[-1] ) + enc_refpoint_embed # unsigmoid, (bs, nq, 4) refpoint_embed = refpoint_embed_undetach.detach() # tgt_undetach = enc_intermediate_output[-1] # bs, nq, d_model tgt = tgt_undetach.detach() elif self.two_stage_type == 'no': tgt_ = self.tgt_embed.weight[:, None, :].repeat(1, bs, 1).transpose( 0, 1) # nq, bs, d_model refpoint_embed_ = self.refpoint_embed.weight[:, None, :].repeat( 1, bs, 1).transpose(0, 1) # nq, bs, 4 if refpoint_embed is not None: # import pdb; pdb.set_trace() refpoint_embed = torch.cat([refpoint_embed, refpoint_embed_], dim=1) tgt = torch.cat([tgt, tgt_], dim=1) else: refpoint_embed, tgt = refpoint_embed_, tgt_ # pat embed if self.num_patterns > 0: tgt_embed = tgt.repeat(1, self.num_patterns, 1) refpoint_embed = refpoint_embed.repeat(1, self.num_patterns, 1) tgt_pat = self.patterns.weight[None, :, :].repeat_interleave( self.num_queries, 1) # 1, n_q*n_pat, d_model tgt = tgt_embed + tgt_pat init_box_proposal = refpoint_embed_.sigmoid() else: raise NotImplementedError('unknown two_stage_type {}'.format( self.two_stage_type)) ######################################################### # Begin Decoder ######################################################### hs, references = self.decoder( tgt=tgt.transpose(0, 1), memory=memory.transpose(0, 1), memory_key_padding_mask=mask_flatten, pos=lvl_pos_embed_flatten.transpose(0, 1), refpoints_unsigmoid=refpoint_embed.transpose(0, 1), level_start_index=level_start_index, spatial_shapes=spatial_shapes, valid_ratios=valid_ratios, tgt_mask=attn_mask, tgt_mask2=attn_mask2, tgt_mask3=attn_mask3) ######################################################### # End Decoder # hs: n_dec, bs, nq, d_model # references: n_dec+1, bs, nq, query_dim ######################################################### ######################################################### # Begin postprocess ######################################################### if self.two_stage_type == 'standard': if self.two_stage_keep_all_tokens: hs_enc = output_memory.unsqueeze(0) ref_enc = enc_outputs_coord_unselected.unsqueeze(0) init_box_proposal = output_proposals # import pdb; pdb.set_trace() else: hs_enc = tgt_undetach.unsqueeze(0) ref_enc = refpoint_embed_undetach.sigmoid().unsqueeze(0) elif self.two_stage_type in ['combine', 'early']: hs_enc = enc_intermediate_output hs_enc = torch.cat((hs_enc, tgt_undetach.unsqueeze(0)), dim=0) # nenc+1, bs, nq, c n_layer_hs_enc = hs_enc.shape[0] assert n_layer_hs_enc == self.num_encoder_layers + 1 ref_enc = enc_intermediate_refpoints ref_enc = torch.cat( (ref_enc, refpoint_embed_undetach.sigmoid().unsqueeze(0)), dim=0) # nenc+1, bs, nq, 4 elif self.two_stage_type in ['enceachlayer', 'enclayer1']: hs_enc = enc_intermediate_output hs_enc = torch.cat((hs_enc, tgt_undetach.unsqueeze(0)), dim=0) # nenc, bs, nq, c n_layer_hs_enc = hs_enc.shape[0] assert n_layer_hs_enc == self.num_encoder_layers ref_enc = enc_intermediate_refpoints ref_enc = torch.cat( (ref_enc, refpoint_embed_undetach.sigmoid().unsqueeze(0)), dim=0) # nenc, bs, nq, 4 else: hs_enc = ref_enc = None return hs, references, hs_enc, ref_enc, init_box_proposal class TransformerEncoder(nn.Module): def __init__( self, encoder_layer, num_layers, norm=None, d_model=256, num_queries=300, deformable_encoder=False, enc_layer_share=False, enc_layer_dropout_prob=None, two_stage_type='no', ): super().__init__() # pdb.set_trace() # prepare layers if num_layers > 0: # 6 self.layers = _get_clones( encoder_layer, num_layers, layer_share=enc_layer_share) # enc_layer_share false else: self.layers = [] del encoder_layer self.query_scale = None self.num_queries = num_queries # 900 self.deformable_encoder = deformable_encoder self.num_layers = num_layers # 6 self.norm = norm self.d_model = d_model self.enc_layer_dropout_prob = enc_layer_dropout_prob if enc_layer_dropout_prob is not None: assert isinstance(enc_layer_dropout_prob, list) assert len(enc_layer_dropout_prob) == num_layers for i in enc_layer_dropout_prob: assert 0.0 <= i <= 1.0 self.two_stage_type = two_stage_type if two_stage_type in ['enceachlayer', 'enclayer1']: _proj_layer = nn.Linear(d_model, d_model) _norm_layer = nn.LayerNorm(d_model) if two_stage_type == 'enclayer1': self.enc_norm = nn.ModuleList([_norm_layer]) self.enc_proj = nn.ModuleList([_proj_layer]) else: self.enc_norm = nn.ModuleList([ copy.deepcopy(_norm_layer) for i in range(num_layers - 1) ]) self.enc_proj = nn.ModuleList([ copy.deepcopy(_proj_layer) for i in range(num_layers - 1) ]) @staticmethod def get_reference_points(spatial_shapes, valid_ratios, device): reference_points_list = [] for lvl, (H_, W_) in enumerate(spatial_shapes): ref_y, ref_x = torch.meshgrid( torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device), torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device)) ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_) ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_) ref = torch.stack((ref_x, ref_y), -1) reference_points_list.append(ref) reference_points = torch.cat(reference_points_list, 1) reference_points = reference_points[:, :, None] * valid_ratios[:, None] return reference_points def forward(self, src: Tensor, pos: Tensor, spatial_shapes: Tensor, level_start_index: Tensor, valid_ratios: Tensor, key_padding_mask: Tensor, ref_token_index: Optional[Tensor] = None, ref_token_coord: Optional[Tensor] = None): """ Input: - src: [bs, sum(hi*wi), 256] - pos: pos embed for src. [bs, sum(hi*wi), 256] - spatial_shapes: h,w of each level [num_level, 2] - level_start_index: [num_level] start point of level in sum(hi*wi). - valid_ratios: [bs, num_level, 2] - key_padding_mask: [bs, sum(hi*wi)] - ref_token_index: bs, nq - ref_token_coord: bs, nq, 4 Intermedia: - reference_points: [bs, sum(hi*wi), num_level, 2] Outpus: - output: [bs, sum(hi*wi), 256] """ # pdb.set_trace() if self.two_stage_type in [ 'no', 'standard', 'enceachlayer', 'enclayer1' ]: assert ref_token_index is None output = src # preparation and reshape if self.num_layers > 0: if self.deformable_encoder: reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=src.device) # import pdb; pdb.set_trace() intermediate_output = [] intermediate_ref = [] if ref_token_index is not None: out_i = torch.gather( output, 1, ref_token_index.unsqueeze(-1).repeat(1, 1, self.d_model)) intermediate_output.append(out_i) intermediate_ref.append(ref_token_coord) # intermediate_coord = [] # main process for layer_id, layer in enumerate(self.layers): # main process dropflag = False if self.enc_layer_dropout_prob is not None: prob = random.random() if prob < self.enc_layer_dropout_prob[layer_id]: dropflag = True if not dropflag: if self.deformable_encoder: output = layer(src=output, pos=pos, reference_points=reference_points, spatial_shapes=spatial_shapes, level_start_index=level_start_index, key_padding_mask=key_padding_mask) else: output = layer( src=output.transpose(0, 1), pos=pos.transpose(0, 1), key_padding_mask=key_padding_mask).transpose(0, 1) if ((layer_id == 0 and self.two_stage_type in ['enceachlayer', 'enclayer1']) \ or (self.two_stage_type == 'enceachlayer')) \ and (layer_id != self.num_layers - 1): output_memory, output_proposals = gen_encoder_output_proposals( output, key_padding_mask, spatial_shapes) output_memory = self.enc_norm[layer_id]( self.enc_proj[layer_id](output_memory)) # gather boxes topk = self.num_queries enc_outputs_class = self.class_embed[layer_id](output_memory) ref_token_index = torch.topk(enc_outputs_class.max(-1)[0], topk, dim=1)[1] # bs, nq ref_token_coord = torch.gather( output_proposals, 1, ref_token_index.unsqueeze(-1).repeat(1, 1, 4)) output = output_memory # aux loss if (layer_id != self.num_layers - 1) and ref_token_index is not None: out_i = torch.gather( output, 1, ref_token_index.unsqueeze(-1).repeat(1, 1, self.d_model)) intermediate_output.append(out_i) intermediate_ref.append(ref_token_coord) if self.norm is not None: output = self.norm(output) if ref_token_index is not None: intermediate_output = torch.stack( intermediate_output) # n_enc/n_enc-1, bs, \sum{hw}, d_model intermediate_ref = torch.stack(intermediate_ref) else: intermediate_output = intermediate_ref = None return output, intermediate_output, intermediate_ref class TransformerDecoder(nn.Module): def __init__( self, decoder_layer, num_layers, norm=None, return_intermediate=False, d_model=256, query_dim=4, modulate_hw_attn=False, num_feature_levels=1, deformable_decoder=False, dec_layer_number=None, # number of queries each layer in decoder dec_layer_share=False, dec_layer_dropout_prob=None, num_box_decoder_layers=2, num_hand_face_decoder_layers=4, num_body_points=17, num_hand_points=10, num_face_points=10, num_dn=100, num_group=100): super().__init__() # pdb.set_trace() if num_layers > 0: self.layers = _get_clones(decoder_layer, num_layers, layer_share=dec_layer_share) else: self.layers = [] self.num_layers = num_layers self.norm = norm self.return_intermediate = return_intermediate # True assert return_intermediate, 'support return_intermediate only' self.query_dim = query_dim # 4 assert query_dim in [ 2, 4 ], 'query_dim should be 2/4 but {}'.format(query_dim) self.num_feature_levels = num_feature_levels # 4 self.ref_point_head = MLP(query_dim // 2 * d_model, d_model, d_model, 2) # 4//2 * 256, 256, 256, 2 if not deformable_decoder: self.query_pos_sine_scale = MLP(d_model, d_model, d_model, 2) else: self.query_pos_sine_scale = None self.num_body_points = num_body_points self.num_hand_points = num_hand_points self.num_face_points = num_face_points self.query_scale = None # aios kp self.bbox_embed = None self.class_embed = None self.pose_embed = None self.pose_hw_embed = None # smpl # self.smpl_pose_embed = None # self.smpl_beta_embed = None # self.smpl_cam_embed = None # smplx # smplx hand kp self.bbox_hand_embed = None self.bbox_hand_hw_embed = None self.pose_hand_embed = None self.pose_hand_hw_embed = None # smplx face kp self.bbox_face_embed = None self.bbox_face_hw_embed = None self.pose_face_embed = None self.pose_face_hw_embed = None # self.smplx_lhand_pose_embed = None # self.smplx_rhand_pose_embed = None # self.smplx_expression_embed = None # self.smplx_jaw_embed = None self.num_box_decoder_layers = num_box_decoder_layers # 2 self.num_hand_face_decoder_layers = num_hand_face_decoder_layers self.d_model = d_model self.modulate_hw_attn = modulate_hw_attn self.deformable_decoder = deformable_decoder if not deformable_decoder and modulate_hw_attn: self.ref_anchor_head = MLP(d_model, d_model, 2, 2) else: self.ref_anchor_head = None self.box_pred_damping = None self.dec_layer_number = dec_layer_number if dec_layer_number is not None: assert isinstance(dec_layer_number, list) assert len(dec_layer_number) == num_layers # assert dec_layer_number[0] == self.dec_layer_dropout_prob = dec_layer_dropout_prob if dec_layer_dropout_prob is not None: raise NotImplementedError assert isinstance(dec_layer_dropout_prob, list) assert len(dec_layer_dropout_prob) == num_layers for i in dec_layer_dropout_prob: assert 0.0 <= i <= 1.0 self.num_group = num_group self.rm_detach = None self.num_dn = num_dn # self.hw_body_kps = nn.Embedding(self.num_body_points, 2) self.hw = nn.Embedding(self.num_body_points, 2) self.keypoint_embed = nn.Embedding(self.num_body_points, d_model) self.body_kpt_index_1 = [ x for x in range(self.num_group*(self.num_body_points+4)) if x%(self.num_body_points+4) not in [0, (1 + self.num_body_points), (2 + self.num_body_points), (3 + self.num_body_points)]] self.whole_body_points = \ self.num_body_points + self.num_hand_points *2 + self.num_face_points self.body_kpt_index_2 = [ x for x in range(self.num_group * (self.whole_body_points + 4)) if (x % (self.whole_body_points + 4) in range(1,self.num_body_points+1)) ] # [0-99]: dn bbox; # [0,1]: body box; # [1, 18]: body kps; # [18, 19]: lhand box # [19, 29]: lhand kps # [29, 30]: rhand box # [30, 40]: rhand kps # [40, 41]: face bbox # [41, 51]: face kps self.lhand_kpt_index = [ x for x in range(self.num_group * (self.whole_body_points + 4)) if (x % (self.whole_body_points + 4) in range( self.num_body_points+2, self.num_body_points+self.num_hand_points+2))] self.rhand_kpt_index = [ x for x in range(self.num_group * (self.whole_body_points + 4)) if (x % (self.whole_body_points + 4) in range( self.num_body_points+self.num_hand_points+3, self.num_body_points+self.num_hand_points*2+3)) ] self.face_kpt_index = [ x for x in range(self.num_group * (self.whole_body_points + 4)) if (x % (self.whole_body_points + 4) in range( self.num_body_points+self.num_hand_points*2+4, self.num_body_points+self.num_hand_points*2+self.num_face_points+4)) ] self.lhand_box_embed = nn.Embedding(1, d_model) self.rhand_box_embed = nn.Embedding(1, d_model) self.face_box_embed = nn.Embedding(1, d_model) self.hw_lhand_bbox = nn.Embedding(1, 2) self.hw_rhand_bbox = nn.Embedding(1, 2) self.hw_face_bbox = nn.Embedding(1, 2) self.hw_lhand_kps = nn.Embedding(self.num_hand_points, 2) self.hw_rhand_kps = nn.Embedding(self.num_hand_points, 2) self.hw_face_kps = nn.Embedding(self.num_face_points, 2) self.lhand_keypoint_embed = nn.Embedding(self.num_hand_points, d_model) self.rhand_keypoint_embed = nn.Embedding(self.num_hand_points, d_model) self.face_keypoint_embed = nn.Embedding(self.num_face_points, d_model) def forward( self, tgt, memory, tgt_mask: Optional[Tensor] = None, tgt_mask2: Optional[Tensor] = None, tgt_mask3: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, refpoints_unsigmoid: Optional[Tensor] = None, # num_queries, bs, 2 # for memory level_start_index: Optional[Tensor] = None, # num_levels spatial_shapes: Optional[Tensor] = None, # bs, num_levels, 2 valid_ratios: Optional[Tensor] = None, ): output = tgt intermediate = [] reference_points = refpoints_unsigmoid.sigmoid() ref_points = [reference_points] effect_num_dn = self.num_dn if self.training else 0 inter_select_number = self.num_group for layer_id, layer in enumerate(self.layers): if self.deformable_decoder: if reference_points.shape[-1] == 4: reference_points_input = reference_points[:, :, None] \ * torch.cat([valid_ratios, valid_ratios], -1)[None, :] # nq, bs, nlevel, 4 else: assert reference_points.shape[-1] == 2 reference_points_input = reference_points[:, :, None] * valid_ratios[ None, :] query_sine_embed = gen_sineembed_for_position( reference_points_input[:, :, 0, :] ) # convert the position query from bbox to sine/cosin embend else: query_sine_embed = gen_sineembed_for_position( reference_points) # nq, bs, 256*2 reference_points_input = None raw_query_pos = self.ref_point_head( query_sine_embed) # nq, bs, 256 pos_scale = self.query_scale( output) if self.query_scale is not None else 1 # ? query_pos = pos_scale * raw_query_pos if not self.deformable_decoder: query_sine_embed = query_sine_embed[ ..., :self.d_model] * self.query_pos_sine_scale(output) # modulated HW attentions if not self.deformable_decoder and self.modulate_hw_attn: refHW_cond = self.ref_anchor_head( output).sigmoid() # nq, bs, 2 query_sine_embed[..., self.d_model // 2:] *= ( refHW_cond[..., 0] / reference_points[..., 2]).unsqueeze(-1) query_sine_embed[..., :self.d_model // 2] *= (refHW_cond[..., 1] / reference_points[..., 3]).unsqueeze(-1) dropflag = False if self.dec_layer_dropout_prob is not None: prob = random.random() if prob < self.dec_layer_dropout_prob[layer_id]: dropflag = True if not dropflag: output = layer( tgt=output, tgt_query_pos=query_pos, tgt_query_sine_embed=query_sine_embed, tgt_key_padding_mask=tgt_key_padding_mask, tgt_reference_points=reference_points_input, memory=memory, # encoder output, also known as content query of encoder memory_key_padding_mask=memory_key_padding_mask, memory_level_start_index=level_start_index, memory_spatial_shapes=spatial_shapes, memory_pos=pos, # position query of enconder self_attn_mask=tgt_mask, cross_attn_mask=memory_mask) intermediate.append(self.norm(output)) # human update if layer_id < self.num_box_decoder_layers: # reference_points: [100*(17+20*2+72) 4, 4] reference_before_sigmoid = inverse_sigmoid(reference_points) delta_unsig = self.bbox_embed[layer_id]( output) # delta_x, delta_y, delta_w, delta_h outputs_unsig = delta_unsig + reference_before_sigmoid new_reference_points = outputs_unsig.sigmoid( ) # update the positional query by adding the offset delta_unsig # kp query expansion if layer_id == self.num_box_decoder_layers - 1: dn_output = output[:effect_num_dn] # [100,-,256] dn_new_reference_points = new_reference_points[: effect_num_dn] # [100, -, 4] class_unselected = self.class_embed[layer_id](output)[ effect_num_dn:] # [900, -, 2] topk_proposals = torch.topk(class_unselected.max(-1)[0], inter_select_number, dim=0)[1] # 100 # selected position: select 100 query new_reference_points_for_body_box = torch.gather( new_reference_points[effect_num_dn:], 0, topk_proposals.unsqueeze(-1).repeat( 1, 1, 4)) # selected position query # selected output features new_output_for_body_box = torch.gather( output[effect_num_dn:], 0, topk_proposals.unsqueeze(-1).repeat( 1, 1, self.d_model)) # selected content query bs = new_output_for_body_box.shape[1] # selected content query + keypoint position query, with shape [100, -, 4] # expand per-human query to per-keypoint query new_output_for_body_keypoint = new_output_for_body_box[:, None, :, :] \ + self.keypoint_embed.weight[None, :, None, :] # keypoint content query if self.num_body_points == 17: delta_xy = self.pose_embed[-1](new_output_for_body_keypoint)[ ..., :2] else: delta_xy = self.pose_embed[0](new_output_for_body_keypoint)[ ..., :2] body_keypoint_xy = (inverse_sigmoid( new_reference_points_for_body_box[..., :2][:, None]) + delta_xy).sigmoid() # [100, 14, -, 2] num_queries, _, bs, _ = body_keypoint_xy.shape body_keypoint_wh_weight = self.hw.weight.unsqueeze(0).unsqueeze( -2).repeat(num_queries, 1, bs, 1).sigmoid() body_keypoint_wh = body_keypoint_wh_weight * new_reference_points_for_body_box[ ..., 2:][:, None] new_reference_points_for_keypoint = torch.cat( (body_keypoint_xy, body_keypoint_wh), dim=-1) # for lhand bbox new_output_for_lhand_box = new_output_for_body_box[:, None, :, :] \ + self.lhand_box_embed.weight[None, :, None, :] delta_lhand_box_xy = self.bbox_hand_embed[-1](new_output_for_lhand_box)[..., :2] lhand_bbox_xy = (inverse_sigmoid( new_reference_points_for_body_box[..., :2][:, None]) + delta_lhand_box_xy).sigmoid() # [100, 14, -, 2] num_queries, _, bs, _ = lhand_bbox_xy.shape lhand_bbox_wh_weight = self.hw_lhand_bbox.weight.unsqueeze(0).unsqueeze( -2).repeat(num_queries, 1, bs, 1).sigmoid() lhand_bbox_wh = lhand_bbox_wh_weight * new_reference_points_for_body_box[ ..., 2:][:, None] new_reference_points_for_lhand_bbox = torch.cat( (lhand_bbox_xy, lhand_bbox_wh), dim=-1) # for rhand bbox new_output_for_rhand_box = new_output_for_body_box[:, None, :, :] \ + self.rhand_box_embed.weight[None, :, None, :] delta_rhand_box_xy = self.bbox_hand_embed[-1](new_output_for_rhand_box)[..., :2] rhand_bbox_xy = (inverse_sigmoid( new_reference_points_for_body_box[..., :2][:, None]) + delta_rhand_box_xy).sigmoid() # [100, 14, -, 2] num_queries, _, bs, _ = rhand_bbox_xy.shape rhand_bbox_wh_weight = self.hw_rhand_bbox.weight.unsqueeze(0).unsqueeze( -2).repeat(num_queries, 1, bs, 1).sigmoid() rhand_bbox_wh = rhand_bbox_wh_weight * new_reference_points_for_body_box[ ..., 2:][:, None] new_reference_points_for_rhand_bbox = torch.cat( (rhand_bbox_xy, rhand_bbox_wh), dim=-1) # for face bbox new_output_for_face_box = new_output_for_body_box[:, None, :, :] \ + self.face_box_embed.weight[None, :, None, :] delta_face_box_xy = self.bbox_face_embed[-1](new_output_for_face_box)[..., :2] face_bbox_xy = (inverse_sigmoid( new_reference_points_for_body_box[..., :2][:, None]) + delta_face_box_xy).sigmoid() # [100, 14, -, 2] num_queries, _, bs, _ = face_bbox_xy.shape face_bbox_wh_weight = self.hw_face_bbox.weight.unsqueeze(0).unsqueeze( -2).repeat(num_queries, 1, bs, 1).sigmoid() face_bbox_wh = face_bbox_wh_weight * new_reference_points_for_body_box[ ..., 2:][:, None] new_reference_points_for_face_box = torch.cat( (face_bbox_xy, face_bbox_wh), dim=-1) output = torch.cat( (new_output_for_body_box.unsqueeze(1), new_output_for_body_keypoint, new_output_for_lhand_box, new_output_for_rhand_box, new_output_for_face_box), dim=1).flatten(0, 1) new_reference_points = torch.cat( (new_reference_points_for_body_box.unsqueeze(1), new_reference_points_for_keypoint, new_reference_points_for_lhand_bbox, new_reference_points_for_rhand_bbox, new_reference_points_for_face_box), dim=1).flatten(0,1) new_reference_points = torch.cat((dn_new_reference_points, new_reference_points),dim=0) output = torch.cat((dn_output, output), dim=0) tgt_mask = tgt_mask2 # human-to-keypoints, human2face, human2hand update # 2 if layer_id >= self.num_box_decoder_layers and layer_id < self.num_box_decoder_layers +2: reference_before_sigmoid = inverse_sigmoid(reference_points) reference_before_sigmoid_body_bbox_dn = \ reference_before_sigmoid[:effect_num_dn] reference_before_sigmoid_bbox_body_norm = \ reference_before_sigmoid[effect_num_dn:][0::(self.num_body_points+4)] output_bbox_body_dn=output[:effect_num_dn] output_bbox_body_norm = output[effect_num_dn:][ 0::(self.num_body_points+4)] delta_unsig_bbox_body_dn = self.bbox_embed[ layer_id](output_bbox_body_dn) delta_unsig_bbox_body_norm = self.bbox_embed[ layer_id](output_bbox_body_norm) outputs_unsig_body_bbox_dn = delta_unsig_bbox_body_dn + reference_before_sigmoid_body_bbox_dn outputs_unsig_body_bbox_norm = delta_unsig_bbox_body_norm + reference_before_sigmoid_bbox_body_norm new_reference_points_for_body_box_dn = outputs_unsig_body_bbox_dn.sigmoid() new_reference_points_for_body_box_norm = outputs_unsig_body_bbox_norm.sigmoid() # body kps output_body_kpt=output[effect_num_dn:].index_select( 0,torch.tensor(self.body_kpt_index_1,device=output.device)) # select kp center content query delta_xy_body_unsig = self.pose_embed[ layer_id-self.num_box_decoder_layers](output_body_kpt) # offset of kp bbox center outputs_body_kp_unsig = \ reference_before_sigmoid[effect_num_dn:].index_select( 0, torch.tensor(self.body_kpt_index_1, device=output.device)).clone() # select kp position query delta_hw_body_kp_unsig = self.pose_hw_embed[ layer_id-self.num_box_decoder_layers](output_body_kpt) outputs_body_kp_unsig[..., :2] += delta_xy_body_unsig[..., :2] outputs_body_kp_unsig[..., 2:] += delta_hw_body_kp_unsig new_reference_points_for_body_keypoint = outputs_body_kp_unsig.sigmoid() bs=new_reference_points_for_body_box_norm.shape[1] # lhand box output_lhand_bbox_query = output[effect_num_dn:][ (self.num_body_points + 1)::(self.num_body_points+4)] delta_xy_lhand_bbox_unsig = self.bbox_hand_embed[ layer_id-self.num_box_decoder_layers](output_lhand_bbox_query) outputs_lhand_bbox_unsig = \ reference_before_sigmoid[effect_num_dn:][ (self.num_body_points + 1)::(self.num_body_points+4)].clone() delta_hw_lhand_bbox_unsig = self.bbox_hand_hw_embed[ layer_id-self.num_box_decoder_layers](output_lhand_bbox_query) outputs_lhand_bbox_unsig[..., :2] +=delta_xy_lhand_bbox_unsig[..., :2] outputs_lhand_bbox_unsig[..., 2:] +=delta_hw_lhand_bbox_unsig new_reference_points_for_lhand_box_norm = outputs_lhand_bbox_unsig.sigmoid() # rhand box output_rhand_bbox_query = output[effect_num_dn:][ (self.num_body_points + 2)::(self.num_body_points+4)] delta_xy_rhand_bbox_unsig = self.bbox_hand_embed[ layer_id-self.num_box_decoder_layers](output_rhand_bbox_query) outputs_rhand_bbox_unsig = \ reference_before_sigmoid[effect_num_dn:][ (self.num_body_points + 2)::(self.num_body_points+4)].clone() delta_hw_rhand_bbox_unsig = self.bbox_hand_hw_embed[ layer_id-self.num_box_decoder_layers](output_rhand_bbox_query) outputs_rhand_bbox_unsig[..., :2] +=delta_xy_rhand_bbox_unsig[..., :2] outputs_rhand_bbox_unsig[..., 2:] +=delta_hw_rhand_bbox_unsig new_reference_points_for_rhand_box_norm = outputs_rhand_bbox_unsig.sigmoid() # face box output_face_bbox_query = output[effect_num_dn:][ (self.num_body_points + 3)::(self.num_body_points+4)] delta_xy_face_bbox_unsig = self.bbox_face_embed[ layer_id-self.num_box_decoder_layers](output_face_bbox_query) outputs_face_bbox_unsig = \ reference_before_sigmoid[effect_num_dn:][ (self.num_body_points + 3)::(self.num_body_points+4)].clone() delta_hw_face_bbox_unsig = self.bbox_face_hw_embed[ layer_id-self.num_box_decoder_layers](output_face_bbox_query) outputs_face_bbox_unsig[..., :2] +=delta_xy_face_bbox_unsig[..., :2] outputs_face_bbox_unsig[..., 2:] +=delta_hw_face_bbox_unsig new_reference_points_for_face_box_norm = outputs_face_bbox_unsig.sigmoid() new_reference_points_norm = torch.cat( (new_reference_points_for_body_box_norm.unsqueeze(1), new_reference_points_for_body_keypoint.view(-1,self.num_body_points,bs,4), new_reference_points_for_lhand_box_norm.unsqueeze(1), new_reference_points_for_rhand_box_norm.unsqueeze(1), new_reference_points_for_face_box_norm.unsqueeze(1)), dim=1).flatten(0,1) new_reference_points = torch.cat(( new_reference_points_for_body_box_dn, new_reference_points_norm), dim=0) # hand, bbox query expansion if layer_id == self.num_hand_face_decoder_layers - 1: dn_body_output = output[:effect_num_dn] dn_reference_points_body = new_reference_points[:effect_num_dn] # body bbox new_reference_points_for_body_box = \ new_reference_points[effect_num_dn:][0::(self.num_body_points + 4)] new_output_for_body_box = output[effect_num_dn:][0:: (self.num_body_points + 4)] # body kp bbox new_output_body_for_body_keypoint = \ output[effect_num_dn:].index_select( 0,torch.tensor(self.body_kpt_index_1,device=output.device)).clone() new_output_body_for_body_keypoint = new_output_body_for_body_keypoint.view( self.num_group, self.num_body_points, bs, self.d_model) new_reference_points_for_body_keypoint = new_reference_points[effect_num_dn:].index_select( 0,torch.tensor(self.body_kpt_index_1,device=output.device)).clone() new_reference_points_for_body_keypoint = \ new_reference_points_for_body_keypoint.view(self.num_group, self.num_body_points, bs, 4) new_reference_points_body = \ torch.cat((new_reference_points_for_body_box.unsqueeze(1), new_reference_points_for_body_keypoint), dim=1) new_body_output = torch.cat((new_output_for_body_box.unsqueeze(1), new_output_body_for_body_keypoint), dim=1) # lhand bbox content query and position query new_reference_points_for_lhand_box = \ new_reference_points[effect_num_dn:][ (self.num_body_points + 1)::(self.num_body_points + 4)] new_output_for_lhand_box = output[effect_num_dn:][ (self.num_body_points + 1)::(self.num_body_points + 4)] # lhand query expansion new_output_for_lhand_keypoint = new_output_for_lhand_box[:, None, :, :] \ + self.lhand_keypoint_embed.weight[None, :, None, :] # use the expanded lhand kp query to regress # the center displacement relatived to lhand bbox delta_lhand_kp_xy = self.pose_hand_embed[-1](new_output_for_lhand_keypoint)[..., :2] # get absoulte bbox center for each lhand kps bbox lhand_keypoint_xy = ( inverse_sigmoid(new_reference_points_for_lhand_box[..., :2][:, None]) + delta_lhand_kp_xy).sigmoid() num_queries,_,bs,_=lhand_keypoint_xy.shape lhand_keypoint_wh_weight = \ self.hw_lhand_kps.weight.unsqueeze(0).unsqueeze(-2).repeat(num_queries,1,bs,1).sigmoid() lhand_keypoint_wh = lhand_keypoint_wh_weight * new_reference_points_for_lhand_box[..., 2:][:, None] new_reference_points_for_lhand_keypoint = torch.cat((lhand_keypoint_xy, lhand_keypoint_wh), dim=-1) new_reference_points_lhand = \ torch.cat((new_reference_points_for_lhand_box.unsqueeze(1), new_reference_points_for_lhand_keypoint), dim=1) new_lhand_output = torch.cat((new_output_for_lhand_box.unsqueeze(1), new_output_for_lhand_keypoint), dim=1) # rhand new_reference_points_for_rhand_box = \ new_reference_points[effect_num_dn:][ (self.num_body_points + 2)::(self.num_body_points + 4)] new_output_for_rhand_box = output[effect_num_dn:][ (self.num_body_points + 2)::(self.num_body_points + 4)] new_output_for_rhand_keypoint = new_output_for_rhand_box[:, None, :, :] \ + self.rhand_keypoint_embed.weight[None, :, None, :] delta_rhand_kp_xy = self.pose_hand_embed[-1](new_output_for_rhand_keypoint) rhand_keypoint_xy = ( inverse_sigmoid(new_reference_points_for_rhand_box[..., :2][:, None]) + delta_rhand_kp_xy).sigmoid() num_queries,_,bs,_=rhand_keypoint_xy.shape rhand_keypoint_wh_weight = \ self.hw_rhand_kps.weight.unsqueeze(0).unsqueeze(-2).repeat(num_queries,1,bs,1).sigmoid() rhand_keypoint_wh = rhand_keypoint_wh_weight * new_reference_points_for_rhand_box[..., 2:][:, None] new_reference_points_for_rhand_keypoint = torch.cat((rhand_keypoint_xy, rhand_keypoint_wh), dim=-1) new_reference_points_rhand = \ torch.cat((new_reference_points_for_rhand_box.unsqueeze(1), new_reference_points_for_rhand_keypoint), dim=1) new_rhand_output = torch.cat((new_output_for_rhand_box.unsqueeze(1), new_output_for_rhand_keypoint), dim=1) # face new_reference_points_for_face_box = \ new_reference_points[effect_num_dn:][ (self.num_body_points + 3)::(self.num_body_points + 4)] new_output_for_face_box = output[effect_num_dn:][ (self.num_body_points + 3)::(self.num_body_points + 4)] new_output_for_face_keypoint = new_output_for_face_box[:, None, :, :] \ + self.face_keypoint_embed.weight[None, :, None, :] delta_face_kp_xy = self.pose_face_embed[-1](new_output_for_face_keypoint)[..., :2] face_keypoint_xy = ( inverse_sigmoid(new_reference_points_for_face_box[..., :2][:, None]) + delta_face_kp_xy).sigmoid() num_queries,_,bs,_= face_keypoint_xy.shape face_keypoint_wh_weight = \ self.hw_face_kps.weight.unsqueeze(0).unsqueeze(-2).repeat(num_queries,1,bs,1).sigmoid() face_keypoint_wh = face_keypoint_wh_weight * new_reference_points_for_face_box[..., 2:][:, None] new_reference_points_for_face_keypoint = torch.cat((face_keypoint_xy, face_keypoint_wh), dim=-1) new_reference_points_face = torch.cat( (new_reference_points_for_face_box.unsqueeze(1), new_reference_points_for_face_keypoint), dim=1) new_face_output = torch.cat( (new_output_for_face_box.unsqueeze(1), new_output_for_face_keypoint), dim=1) # new_reference_points = torch.cat( # (dn_reference_points_body.unsqueeze(1), # new_reference_points_body, # new_reference_points_lhand, # new_reference_points_rhand, # new_reference_points_face), dim=1).flatten(0,1) new_reference_points = torch.cat( (new_reference_points_body, new_reference_points_lhand, new_reference_points_rhand, new_reference_points_face), dim=1).flatten(0,1) # new_reference_points = torch.cat((dn_reference_points_body,new_reference_points),dim=0) new_reference_points = torch.cat( (dn_reference_points_body, new_reference_points), dim=0 ) output = torch.cat( (new_body_output, new_lhand_output, new_rhand_output, new_face_output), dim=1).flatten(0, 1) output = torch.cat( (dn_body_output, output), dim=0 ) tgt_mask = tgt_mask3 if layer_id >= self.num_hand_face_decoder_layers: reference_before_sigmoid = inverse_sigmoid(reference_points) # body box reference_before_sigmoid_body_bbox_dn = \ reference_before_sigmoid[:effect_num_dn] reference_before_sigmoid_bbox_body_norm = \ reference_before_sigmoid[effect_num_dn:][ 0::(self.num_body_points+2*self.num_hand_points+self.num_face_points+4)] output_bbox_body_dn=output[:effect_num_dn] output_bbox_body_norm = output[effect_num_dn:][ 0::(self.num_body_points+2*self.num_hand_points+self.num_face_points+4)] delta_unsig_bbox_body_dn = self.bbox_embed[ layer_id](output_bbox_body_dn) delta_unsig_bbox_body_norm = self.bbox_embed[ layer_id](output_bbox_body_norm) outputs_unsig_body_bbox_dn = \ delta_unsig_bbox_body_dn + reference_before_sigmoid_body_bbox_dn outputs_unsig_body_bbox_norm = \ delta_unsig_bbox_body_norm + reference_before_sigmoid_bbox_body_norm new_reference_points_for_body_box_dn = outputs_unsig_body_bbox_dn.sigmoid() new_reference_points_for_body_box_norm = outputs_unsig_body_bbox_norm.sigmoid() # body kps output_body_kpt=output[effect_num_dn:].index_select( 0,torch.tensor(self.body_kpt_index_2,device=output.device)) # select kp center content query delta_xy_body_unsig = self.pose_embed[ layer_id-self.num_box_decoder_layers](output_body_kpt) # offset of kp bbox center outputs_body_kp_unsig = \ reference_before_sigmoid[effect_num_dn:].index_select( 0, torch.tensor(self.body_kpt_index_2, device=output.device)).clone() # select kp position query delta_hw_body_kp_unsig = self.pose_hw_embed[ layer_id-self.num_box_decoder_layers](output_body_kpt) outputs_body_kp_unsig[..., :2] += delta_xy_body_unsig[..., :2] outputs_body_kp_unsig[..., 2:] += delta_hw_body_kp_unsig new_reference_points_for_body_keypoint = outputs_body_kp_unsig.sigmoid() bs=new_reference_points_for_body_box_norm.shape[1] new_reference_points_for_body_keypoint = \ new_reference_points_for_body_keypoint.view(-1,self.num_body_points,bs,4) # lhand bbox output_lhand_bbox_query = output[effect_num_dn:][ (self.num_body_points + 1):: (self.num_body_points + 2 * self.num_hand_points + self.num_face_points + 4)] delta_xy_lhand_bbox_unsig = self.bbox_hand_embed[ layer_id-self.num_box_decoder_layers](output_lhand_bbox_query) outputs_lhand_bbox_unsig = \ reference_before_sigmoid[effect_num_dn:][ (self.num_body_points + 1):: (self.num_body_points+2*self.num_hand_points+self.num_face_points+4)].clone() delta_hw_lhand_bbox_unsig = self.bbox_hand_hw_embed[ layer_id-self.num_box_decoder_layers](output_lhand_bbox_query) outputs_lhand_bbox_unsig[..., :2] +=delta_xy_lhand_bbox_unsig[..., :2] outputs_lhand_bbox_unsig[..., 2:] +=delta_hw_lhand_bbox_unsig new_reference_points_for_lhand_box_norm = outputs_lhand_bbox_unsig.sigmoid() # output_bbox_lhand_norm = output[effect_num_dn:][ # (self.num_body_points + 1):: # (self.num_body_points + 2 * self.num_hand_points + self.num_face_points + 4)] # reference_before_sigmoid_bbox_lhand_norm = \ # reference_before_sigmoid[effect_num_dn:][ # (self.num_body_points + 1):: # (self.num_body_points+2*self.num_hand_points+self.num_face_points+4)] # delta_unsig_bbox_lhand_norm = self.bbox_hand_embed[ # layer_id-self.num_box_decoder_layers](output_bbox_lhand_norm) # outputs_unsig_lhand_bbox_norm = \ # delta_unsig_bbox_lhand_norm + reference_before_sigmoid_bbox_lhand_norm # new_reference_points_for_lhand_box_norm = outputs_unsig_lhand_bbox_norm.sigmoid() # lhand kps output_lhand_kpt_query=output[effect_num_dn:].index_select( 0,torch.tensor(self.lhand_kpt_index,device=output.device)) # select kp center content query delta_xy_lhand_kpt_unsig = self.pose_hand_embed[ layer_id-self.num_hand_face_decoder_layers](output_lhand_kpt_query) # offset of kp bbox center outputs_lhand_kp_unsig = \ reference_before_sigmoid[effect_num_dn:].index_select( 0, torch.tensor(self.lhand_kpt_index, device=output.device)).clone() # select kp position query delta_hw_lhand_kp_unsig = self.pose_hand_hw_embed[ layer_id-self.num_hand_face_decoder_layers](output_lhand_kpt_query) outputs_lhand_kp_unsig[..., :2] += delta_xy_lhand_kpt_unsig[..., :2] outputs_lhand_kp_unsig[..., 2:] += delta_hw_lhand_kp_unsig new_reference_points_for_lhand_keypoint = outputs_lhand_kp_unsig.sigmoid() bs=new_reference_points_for_lhand_box_norm.shape[1] new_reference_points_for_lhand_keypoint = \ new_reference_points_for_lhand_keypoint.view(-1,self.num_hand_points,bs,4) # rhand bbox output_rhand_bbox_query = output[effect_num_dn:][ (self.num_body_points + self.num_hand_points + 2):: (self.num_body_points+2*self.num_hand_points+self.num_face_points+4)] delta_xy_rhand_bbox_unsig = self.bbox_hand_embed[ layer_id-self.num_box_decoder_layers](output_rhand_bbox_query) outputs_rhand_bbox_unsig = \ reference_before_sigmoid[effect_num_dn:][ (self.num_body_points + self.num_hand_points + 2):: (self.num_body_points+2*self.num_hand_points+self.num_face_points+4)].clone() delta_hw_rhand_bbox_unsig = self.bbox_hand_hw_embed[ layer_id-self.num_box_decoder_layers](output_rhand_bbox_query) outputs_rhand_bbox_unsig[..., :2] +=delta_xy_rhand_bbox_unsig[..., :2] outputs_rhand_bbox_unsig[..., 2:] +=delta_hw_rhand_bbox_unsig new_reference_points_for_rhand_box_norm = outputs_rhand_bbox_unsig.sigmoid() # output_bbox_rhand_norm = output[effect_num_dn:][ # (self.num_body_points + self.num_hand_points + 2):: # (self.num_body_points+2*self.num_hand_points+self.num_face_points+4)] # reference_before_sigmoid_bbox_rhand_norm = \ # reference_before_sigmoid[effect_num_dn:][ # (self.num_body_points + self.num_hand_points + 2):: # (self.num_body_points+2*self.num_hand_points+self.num_face_points+4)] # delta_unsig_bbox_rhand_norm = self.bbox_hand_embed[ # layer_id-self.num_box_decoder_layers](output_bbox_rhand_norm) # outputs_unsig_rhand_bbox_norm = \ # delta_unsig_bbox_rhand_norm + reference_before_sigmoid_bbox_rhand_norm # new_reference_points_for_rhand_box_norm = outputs_unsig_rhand_bbox_norm.sigmoid() # rhand kps output_rhand_kpt_query=output[effect_num_dn:].index_select( 0,torch.tensor(self.rhand_kpt_index,device=output.device)) # select kp center content query delta_xy_rhand_kpt_unsig = self.pose_hand_embed[ layer_id-self.num_hand_face_decoder_layers](output_rhand_kpt_query) # offset of kp bbox center outputs_rhand_kp_unsig = \ reference_before_sigmoid[effect_num_dn:].index_select( 0, torch.tensor(self.rhand_kpt_index, device=output.device)).clone() # select kp position query delta_hw_rhand_kp_unsig = self.pose_hand_hw_embed[ layer_id-self.num_hand_face_decoder_layers](output_rhand_kpt_query) outputs_rhand_kp_unsig[..., :2] += delta_xy_rhand_kpt_unsig[..., :2] outputs_rhand_kp_unsig[..., 2:] += delta_hw_rhand_kp_unsig new_reference_points_for_rhand_keypoint = outputs_rhand_kp_unsig.sigmoid() bs=new_reference_points_for_rhand_box_norm.shape[1] new_reference_points_for_rhand_keypoint = \ new_reference_points_for_rhand_keypoint.view(-1,self.num_hand_points,bs,4) # face bbox output_face_bbox_query = output[effect_num_dn:][ (self.num_body_points + 2 * self.num_hand_points + 3):: (self.num_body_points+2*self.num_hand_points+self.num_face_points+4)] delta_xy_face_bbox_unsig = self.bbox_face_embed[ layer_id-self.num_box_decoder_layers](output_face_bbox_query) outputs_face_bbox_unsig = \ reference_before_sigmoid[effect_num_dn:][ (self.num_body_points + 2 * self.num_hand_points + 3):: (self.num_body_points+2*self.num_hand_points+self.num_face_points+4)].clone() delta_hw_face_bbox_unsig = self.bbox_face_hw_embed[ layer_id-self.num_box_decoder_layers](output_face_bbox_query) outputs_face_bbox_unsig[..., :2] +=delta_xy_face_bbox_unsig[..., :2] outputs_face_bbox_unsig[..., 2:] +=delta_hw_face_bbox_unsig new_reference_points_for_face_box_norm = outputs_face_bbox_unsig.sigmoid() # output_bbox_face_norm = output[effect_num_dn:][ # (self.num_body_points + 2 * self.num_hand_points + 3):: # (self.num_body_points+2*self.num_hand_points+self.num_face_points+4)] # reference_before_sigmoid_bbox_face_norm = \ # reference_before_sigmoid[effect_num_dn:][ # (self.num_body_points + 2 * self.num_hand_points + 3):: # (self.num_body_points+2*self.num_hand_points+self.num_face_points+4)] # delta_unsig_bbox_face_norm = self.bbox_face_embed[ # layer_id-self.num_box_decoder_layers](output_bbox_face_norm) # outputs_unsig_face_bbox_norm = \ # delta_unsig_bbox_face_norm + reference_before_sigmoid_bbox_face_norm # new_reference_points_for_face_box_norm = outputs_unsig_face_bbox_norm.sigmoid() # face kps output_face_kpt_query=output[effect_num_dn:].index_select( 0,torch.tensor(self.face_kpt_index,device=output.device)) # select kp center content query delta_xy_face_kpt_unsig = self.pose_face_embed[ layer_id-self.num_hand_face_decoder_layers](output_face_kpt_query) # offset of kp bbox center outputs_face_kp_unsig = \ reference_before_sigmoid[effect_num_dn:].index_select( 0, torch.tensor(self.face_kpt_index, device=output.device)).clone() # select kp position query delta_hw_face_kp_unsig = self.pose_face_hw_embed[ layer_id-self.num_hand_face_decoder_layers](output_face_kpt_query) outputs_face_kp_unsig[..., :2] += delta_xy_face_kpt_unsig[..., :2] outputs_face_kp_unsig[..., 2:] += delta_hw_face_kp_unsig new_reference_points_for_face_keypoint = outputs_face_kp_unsig.sigmoid() bs=new_reference_points_for_face_box_norm.shape[1] new_reference_points_for_face_keypoint = \ new_reference_points_for_face_keypoint.view(-1,self.num_face_points,bs,4) new_reference_points_norm = torch.cat( (new_reference_points_for_body_box_norm.unsqueeze(1), new_reference_points_for_body_keypoint, new_reference_points_for_lhand_box_norm.unsqueeze(1), new_reference_points_for_lhand_keypoint, new_reference_points_for_rhand_box_norm.unsqueeze(1), new_reference_points_for_rhand_keypoint, new_reference_points_for_face_box_norm.unsqueeze(1), new_reference_points_for_face_keypoint, ), dim=1).flatten(0,1) new_reference_points = torch.cat( (new_reference_points_for_body_box_dn, new_reference_points_norm), dim=0) if self.rm_detach and 'dec' in self.rm_detach: reference_points = new_reference_points else: reference_points = new_reference_points.detach() ref_points.append(new_reference_points) return [[itm_out.transpose(0, 1) for itm_out in intermediate], [itm_refpoint.transpose(0, 1) for itm_refpoint in ref_points]] def _get_clones(module, N, layer_share=False): if layer_share: return nn.ModuleList([module for i in range(N)]) else: return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) def build_transformer(args): if args.modelname == 'aios_smplx_box': return Transformer_Box( d_model=args.hidden_dim, dropout=args.dropout, nhead=args.nheads, num_queries=args.num_queries, dim_feedforward=args.dim_feedforward, num_encoder_layers=args.enc_layers, num_decoder_layers=args.dec_layers, normalize_before=args.pre_norm, return_intermediate_dec=True, query_dim=args.query_dim, activation=args.transformer_activation, num_patterns=args.num_patterns, modulate_hw_attn=True, deformable_encoder=True, deformable_decoder=True, num_feature_levels=args.num_feature_levels, enc_n_points=args.enc_n_points, dec_n_points=args.dec_n_points, learnable_tgt_init=True, random_refpoints_xy=args.random_refpoints_xy, two_stage_type=args.two_stage_type, two_stage_learn_wh=args.two_stage_learn_wh, two_stage_keep_all_tokens=args.two_stage_keep_all_tokens, dec_layer_number=args.dec_layer_number, rm_self_attn_layers=args.rm_self_attn_layers, rm_detach=args.rm_detach, decoder_sa_type=args.decoder_sa_type, module_seq=args.decoder_module_seq, embed_init_tgt=args.embed_init_tgt, num_body_points=args.num_body_points, num_hand_points=args.num_hand_points, num_face_points=args.num_face_points, num_box_decoder_layers=args.num_box_decoder_layers, num_hand_face_decoder_layers=args.num_hand_face_decoder_layers, num_group=args.num_group) elif args.modelname == 'aios_smplx': return Transformer( d_model=args.hidden_dim, dropout=args.dropout, nhead=args.nheads, num_queries=args.num_queries, dim_feedforward=args.dim_feedforward, num_encoder_layers=args.enc_layers, num_decoder_layers=args.dec_layers, normalize_before=args.pre_norm, return_intermediate_dec=True, query_dim=args.query_dim, activation=args.transformer_activation, num_patterns=args.num_patterns, modulate_hw_attn=True, deformable_encoder=True, deformable_decoder=True, num_feature_levels=args.num_feature_levels, enc_n_points=args.enc_n_points, dec_n_points=args.dec_n_points, learnable_tgt_init=True, random_refpoints_xy=args.random_refpoints_xy, two_stage_type=args.two_stage_type, two_stage_learn_wh=args.two_stage_learn_wh, two_stage_keep_all_tokens=args.two_stage_keep_all_tokens, dec_layer_number=args.dec_layer_number, rm_self_attn_layers=args.rm_self_attn_layers, rm_detach=args.rm_detach, decoder_sa_type=args.decoder_sa_type, module_seq=args.decoder_module_seq, embed_init_tgt=args.embed_init_tgt, num_body_points=args.num_body_points, num_hand_points=args.num_hand_points, num_face_points=args.num_face_points, num_box_decoder_layers=args.num_box_decoder_layers, num_hand_face_decoder_layers=args.num_hand_face_decoder_layers, num_group=args.num_group) else: raise ValueError('Wrong Transformer type') class TransformerDecoder_Box(nn.Module): def __init__( self, decoder_layer, num_layers, norm=None, return_intermediate=False, d_model=256, query_dim=4, modulate_hw_attn=False, num_feature_levels=1, deformable_decoder=False, dec_layer_number=None, # number of queries each layer in decoder dec_layer_share=False, dec_layer_dropout_prob=None, num_box_decoder_layers=2, num_hand_face_decoder_layers=4, num_body_points=0, num_hand_points=0, num_face_points=0, num_dn=100, num_group=100): super().__init__() # pdb.set_trace() if num_layers > 0: self.layers = _get_clones(decoder_layer, num_layers, layer_share=dec_layer_share) else: self.layers = [] self.num_layers = num_layers self.norm = norm self.return_intermediate = return_intermediate # True assert return_intermediate, 'support return_intermediate only' self.query_dim = query_dim # 4 assert query_dim in [ 2, 4 ], 'query_dim should be 2/4 but {}'.format(query_dim) self.num_feature_levels = num_feature_levels # 4 self.ref_point_head = MLP(query_dim // 2 * d_model, d_model, d_model, 2) # 4//2 * 256, 256, 256, 2 if not deformable_decoder: self.query_pos_sine_scale = MLP(d_model, d_model, d_model, 2) else: self.query_pos_sine_scale = None self.num_body_points = 0 self.num_hand_points = 0 self.num_face_points = 0 self.query_scale = None # aios kp self.bbox_embed = None self.class_embed = None self.bbox_hand_embed = None self.bbox_hand_hw_embed = None # smplx face kp self.bbox_face_embed = None self.bbox_face_hw_embed = None self.num_box_decoder_layers = num_box_decoder_layers # 2 self.num_hand_face_decoder_layers = num_hand_face_decoder_layers self.d_model = d_model self.modulate_hw_attn = modulate_hw_attn self.deformable_decoder = deformable_decoder if not deformable_decoder and modulate_hw_attn: self.ref_anchor_head = MLP(d_model, d_model, 2, 2) else: self.ref_anchor_head = None self.box_pred_damping = None self.dec_layer_number = dec_layer_number if dec_layer_number is not None: assert isinstance(dec_layer_number, list) assert len(dec_layer_number) == num_layers # assert dec_layer_number[0] == self.dec_layer_dropout_prob = dec_layer_dropout_prob if dec_layer_dropout_prob is not None: raise NotImplementedError assert isinstance(dec_layer_dropout_prob, list) assert len(dec_layer_dropout_prob) == num_layers for i in dec_layer_dropout_prob: assert 0.0 <= i <= 1.0 self.num_group = num_group self.rm_detach = None self.num_dn = num_dn # self.hw_body_kps = nn.Embedding(self.num_body_points, 2) # self.hw = nn.Embedding(self.num_body_points, 2) # self.keypoint_embed = nn.Embedding(self.num_body_points, d_model) # self.body_kpt_index_1 = [ # x for x in range(self.num_group*(self.num_body_points+4)) if x%(self.num_body_points+4) not in [0, (1 + self.num_body_points), (2 + self.num_body_points), (3 + self.num_body_points)]] # self.whole_body_points = \ # self.num_body_points + self.num_hand_points *2 + self.num_face_points # self.body_kpt_index_2 = [ # x for x in range(self.num_group * (self.whole_body_points + 4)) # if (x % (self.whole_body_points + 4) in range(1,self.num_body_points+1)) # ] # [0-99]: dn bbox; # [0,1]: body box; # [1, 18]: body kps; # [18, 19]: lhand box # [19, 29]: lhand kps # [29, 30]: rhand box # [30, 40]: rhand kps # [40, 41]: face bbox # [41, 51]: face kps # self.lhand_kpt_index = [ # x for x in range(self.num_group * (self.whole_body_points + 4)) # if (x % (self.whole_body_points + 4) in range( # self.num_body_points+2, self.num_body_points+self.num_hand_points+2))] # self.rhand_kpt_index = [ # x for x in range(self.num_group * (self.whole_body_points + 4)) # if (x % (self.whole_body_points + 4) in range( # self.num_body_points+self.num_hand_points+3, self.num_body_points+self.num_hand_points*2+3)) # ] # self.face_kpt_index = [ # x for x in range(self.num_group * (self.whole_body_points + 4)) # if (x % (self.whole_body_points + 4) in range( # self.num_body_points+self.num_hand_points*2+4, self.num_body_points+self.num_hand_points*2+self.num_face_points+4)) # ] self.lhand_box_embed = nn.Embedding(1, d_model) self.rhand_box_embed = nn.Embedding(1, d_model) self.face_box_embed = nn.Embedding(1, d_model) self.hw_lhand_bbox = nn.Embedding(1, 2) self.hw_rhand_bbox = nn.Embedding(1, 2) self.hw_face_bbox = nn.Embedding(1, 2) def forward( self, tgt, memory, tgt_mask: Optional[Tensor] = None, tgt_mask2: Optional[Tensor] = None, tgt_mask3: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, refpoints_unsigmoid: Optional[Tensor] = None, # num_queries, bs, 2 # for memory level_start_index: Optional[Tensor] = None, # num_levels spatial_shapes: Optional[Tensor] = None, # bs, num_levels, 2 valid_ratios: Optional[Tensor] = None, ): output = tgt intermediate = [] reference_points = refpoints_unsigmoid.sigmoid() ref_points = [reference_points] effect_num_dn = self.num_dn if self.training else 0 inter_select_number = self.num_group for layer_id, layer in enumerate(self.layers): if self.deformable_decoder: if reference_points.shape[-1] == 4: reference_points_input = reference_points[:, :, None] \ * torch.cat([valid_ratios, valid_ratios], -1)[None, :] # nq, bs, nlevel, 4 else: assert reference_points.shape[-1] == 2 reference_points_input = reference_points[:, :, None] * valid_ratios[ None, :] query_sine_embed = gen_sineembed_for_position( reference_points_input[:, :, 0, :] ) # convert the position query from bbox to sine/cosin embend else: query_sine_embed = gen_sineembed_for_position( reference_points) # nq, bs, 256*2 reference_points_input = None raw_query_pos = self.ref_point_head( query_sine_embed) # nq, bs, 256 pos_scale = self.query_scale( output) if self.query_scale is not None else 1 # ? query_pos = pos_scale * raw_query_pos if not self.deformable_decoder: query_sine_embed = query_sine_embed[ ..., :self.d_model] * self.query_pos_sine_scale(output) # modulated HW attentions if not self.deformable_decoder and self.modulate_hw_attn: refHW_cond = self.ref_anchor_head( output).sigmoid() # nq, bs, 2 query_sine_embed[..., self.d_model // 2:] *= ( refHW_cond[..., 0] / reference_points[..., 2]).unsqueeze(-1) query_sine_embed[..., :self.d_model // 2] *= (refHW_cond[..., 1] / reference_points[..., 3]).unsqueeze(-1) dropflag = False if self.dec_layer_dropout_prob is not None: prob = random.random() if prob < self.dec_layer_dropout_prob[layer_id]: dropflag = True if not dropflag: output = layer( tgt=output, tgt_query_pos=query_pos, tgt_query_sine_embed=query_sine_embed, tgt_key_padding_mask=tgt_key_padding_mask, tgt_reference_points=reference_points_input, memory=memory, # encoder output, also known as content query of encoder memory_key_padding_mask=memory_key_padding_mask, memory_level_start_index=level_start_index, memory_spatial_shapes=spatial_shapes, memory_pos=pos, # position query of enconder self_attn_mask=tgt_mask, cross_attn_mask=memory_mask) intermediate.append(self.norm(output)) # human update if layer_id < self.num_box_decoder_layers: # reference_points: [100*(17+20*2+72) 4, 4] reference_before_sigmoid = inverse_sigmoid(reference_points) delta_unsig = self.bbox_embed[layer_id]( output) # delta_x, delta_y, delta_w, delta_h outputs_unsig = delta_unsig + reference_before_sigmoid new_reference_points = outputs_unsig.sigmoid( ) # update the positional query by adding the offset delta_unsig # kp query expansion if layer_id == self.num_box_decoder_layers - 1: dn_output = output[:effect_num_dn] # [100,-,256] dn_new_reference_points = new_reference_points[:effect_num_dn] # [100, -, 4] class_unselected = self.class_embed[layer_id](output)[ effect_num_dn:] # [900, -, 2] topk_proposals = torch.topk(class_unselected.max(-1)[0], inter_select_number, dim=0)[1] # 100 # selected position: select 100 query new_reference_points_for_body_box = torch.gather( new_reference_points[effect_num_dn:], 0, topk_proposals.unsqueeze(-1).repeat( 1, 1, 4)) # selected position query # selected output features new_output_for_body_box = torch.gather( output[effect_num_dn:], 0, topk_proposals.unsqueeze(-1).repeat( 1, 1, self.d_model)) # selected content query bs = new_output_for_body_box.shape[1] # for lhand bbox new_output_for_lhand_box = new_output_for_body_box[:, None, :, :] \ + self.lhand_box_embed.weight[None, :, None, :] delta_lhand_box_xy = self.bbox_hand_embed[-1](new_output_for_lhand_box)[..., :2] lhand_bbox_xy = (inverse_sigmoid( new_reference_points_for_body_box[..., :2][:, None]) + delta_lhand_box_xy).sigmoid() # [100, 14, -, 2] num_queries, _, bs, _ = lhand_bbox_xy.shape lhand_bbox_wh_weight = self.hw_lhand_bbox.weight.unsqueeze(0).unsqueeze( -2).repeat(num_queries, 1, bs, 1).sigmoid() lhand_bbox_wh = lhand_bbox_wh_weight * new_reference_points_for_body_box[ ..., 2:][:, None] new_reference_points_for_lhand_bbox = torch.cat( (lhand_bbox_xy, lhand_bbox_wh), dim=-1) # for rhand bbox new_output_for_rhand_box = new_output_for_body_box[:, None, :, :] \ + self.rhand_box_embed.weight[None, :, None, :] delta_rhand_box_xy = self.bbox_hand_embed[-1](new_output_for_rhand_box)[..., :2] rhand_bbox_xy = (inverse_sigmoid( new_reference_points_for_body_box[..., :2][:, None]) + delta_rhand_box_xy).sigmoid() # [100, 14, -, 2] num_queries, _, bs, _ = rhand_bbox_xy.shape rhand_bbox_wh_weight = self.hw_rhand_bbox.weight.unsqueeze(0).unsqueeze( -2).repeat(num_queries, 1, bs, 1).sigmoid() rhand_bbox_wh = rhand_bbox_wh_weight * new_reference_points_for_body_box[ ..., 2:][:, None] new_reference_points_for_rhand_bbox = torch.cat( (rhand_bbox_xy, rhand_bbox_wh), dim=-1) # for face bbox new_output_for_face_box = new_output_for_body_box[:, None, :, :] \ + self.face_box_embed.weight[None, :, None, :] delta_face_box_xy = self.bbox_face_embed[-1](new_output_for_face_box)[..., :2] face_bbox_xy = (inverse_sigmoid( new_reference_points_for_body_box[..., :2][:, None]) + delta_face_box_xy).sigmoid() # [100, 14, -, 2] num_queries, _, bs, _ = face_bbox_xy.shape face_bbox_wh_weight = self.hw_face_bbox.weight.unsqueeze(0).unsqueeze( -2).repeat(num_queries, 1, bs, 1).sigmoid() face_bbox_wh = face_bbox_wh_weight * new_reference_points_for_body_box[ ..., 2:][:, None] new_reference_points_for_face_box = torch.cat( (face_bbox_xy, face_bbox_wh), dim=-1) output = torch.cat( (new_output_for_body_box.unsqueeze(1), new_output_for_lhand_box, new_output_for_rhand_box, new_output_for_face_box), dim=1).flatten(0, 1) new_reference_points = torch.cat( (new_reference_points_for_body_box.unsqueeze(1), new_reference_points_for_lhand_bbox, new_reference_points_for_rhand_bbox, new_reference_points_for_face_box), dim=1).flatten(0,1) new_reference_points = torch.cat((dn_new_reference_points, new_reference_points),dim=0) output = torch.cat((dn_output, output), dim=0) tgt_mask = tgt_mask2 # human-to-keypoints, human2face, human2hand update # 2 if layer_id >= self.num_box_decoder_layers: reference_before_sigmoid = inverse_sigmoid(reference_points) reference_before_sigmoid_body_bbox_dn = reference_before_sigmoid[:effect_num_dn] reference_before_sigmoid_bbox_body_norm = \ reference_before_sigmoid[effect_num_dn:][0::(self.num_body_points+4)] output_bbox_body_dn=output[:effect_num_dn] output_bbox_body_norm = output[effect_num_dn:][ 0::(self.num_body_points+4)] delta_unsig_bbox_body_dn = self.bbox_embed[ layer_id](output_bbox_body_dn) delta_unsig_bbox_body_norm = self.bbox_embed[ layer_id](output_bbox_body_norm) outputs_unsig_body_bbox_dn = delta_unsig_bbox_body_dn + reference_before_sigmoid_body_bbox_dn outputs_unsig_body_bbox_norm = delta_unsig_bbox_body_norm + reference_before_sigmoid_bbox_body_norm new_reference_points_for_body_box_dn = outputs_unsig_body_bbox_dn.sigmoid() new_reference_points_for_body_box_norm = outputs_unsig_body_bbox_norm.sigmoid() # lhand box output_lhand_bbox_query = output[effect_num_dn:][ (self.num_body_points + 1)::(self.num_body_points+4)] delta_xy_lhand_bbox_unsig = self.bbox_hand_embed[ layer_id-self.num_box_decoder_layers](output_lhand_bbox_query) outputs_lhand_bbox_unsig = \ reference_before_sigmoid[effect_num_dn:][ (self.num_body_points + 1)::(self.num_body_points+4)].clone() delta_hw_lhand_bbox_unsig = self.bbox_hand_hw_embed[ layer_id-self.num_box_decoder_layers](output_lhand_bbox_query) outputs_lhand_bbox_unsig[..., :2] +=delta_xy_lhand_bbox_unsig[..., :2] outputs_lhand_bbox_unsig[..., 2:] +=delta_hw_lhand_bbox_unsig new_reference_points_for_lhand_box_norm = outputs_lhand_bbox_unsig.sigmoid() # rhand box output_rhand_bbox_query = output[effect_num_dn:][ (self.num_body_points + 2)::(self.num_body_points+4)] delta_xy_rhand_bbox_unsig = self.bbox_hand_embed[ layer_id-self.num_box_decoder_layers](output_rhand_bbox_query) outputs_rhand_bbox_unsig = \ reference_before_sigmoid[effect_num_dn:][ (self.num_body_points + 2)::(self.num_body_points+4)].clone() delta_hw_rhand_bbox_unsig = self.bbox_hand_hw_embed[ layer_id-self.num_box_decoder_layers](output_rhand_bbox_query) outputs_rhand_bbox_unsig[..., :2] +=delta_xy_rhand_bbox_unsig[..., :2] outputs_rhand_bbox_unsig[..., 2:] +=delta_hw_rhand_bbox_unsig new_reference_points_for_rhand_box_norm = outputs_rhand_bbox_unsig.sigmoid() # face box output_face_bbox_query = output[effect_num_dn:][ (self.num_body_points + 3)::(self.num_body_points+4)] delta_xy_face_bbox_unsig = self.bbox_face_embed[ layer_id-self.num_box_decoder_layers](output_face_bbox_query) outputs_face_bbox_unsig = \ reference_before_sigmoid[effect_num_dn:][ (self.num_body_points + 3)::(self.num_body_points+4)].clone() delta_hw_face_bbox_unsig = self.bbox_face_hw_embed[ layer_id-self.num_box_decoder_layers](output_face_bbox_query) outputs_face_bbox_unsig[..., :2] +=delta_xy_face_bbox_unsig[..., :2] outputs_face_bbox_unsig[..., 2:] +=delta_hw_face_bbox_unsig new_reference_points_for_face_box_norm = outputs_face_bbox_unsig.sigmoid() new_reference_points_norm = torch.cat( (new_reference_points_for_body_box_norm.unsqueeze(1), new_reference_points_for_lhand_box_norm.unsqueeze(1), new_reference_points_for_rhand_box_norm.unsqueeze(1), new_reference_points_for_face_box_norm.unsqueeze(1)), dim=1).flatten(0,1) new_reference_points = torch.cat(( new_reference_points_for_body_box_dn, new_reference_points_norm), dim=0) if self.rm_detach and 'dec' in self.rm_detach: reference_points = new_reference_points else: reference_points = new_reference_points.detach() ref_points.append(new_reference_points) return [[itm_out.transpose(0, 1) for itm_out in intermediate], [itm_refpoint.transpose(0, 1) for itm_refpoint in ref_points]] class Transformer_Box(nn.Module): def __init__( self, d_model=256, nhead=8, num_queries=300, num_encoder_layers=6, num_decoder_layers=6, dim_feedforward=2048, dropout=0.0, activation='relu', normalize_before=False, return_intermediate_dec=False, query_dim=4, num_patterns=0, modulate_hw_attn=False, # for deformable encoder deformable_encoder=False, deformable_decoder=False, num_feature_levels=1, enc_n_points=4, dec_n_points=4, # init query learnable_tgt_init=False, random_refpoints_xy=False, # two stage two_stage_type='no', two_stage_learn_wh=False, two_stage_keep_all_tokens=False, # evo of #anchors dec_layer_number=None, rm_self_attn_layers=None, # for detach rm_detach=None, decoder_sa_type='sa', module_seq=['sa', 'ca', 'ffn'], # for pose embed_init_tgt=False, num_body_points=0, num_hand_points=0, num_face_points=0, num_box_decoder_layers=2, num_hand_face_decoder_layers=4, num_group=100): super().__init__() # pdb.set_trace() self.num_feature_levels = num_feature_levels # 4 self.num_encoder_layers = num_encoder_layers # 6 self.num_decoder_layers = num_decoder_layers # 6 self.deformable_encoder = deformable_encoder self.deformable_decoder = deformable_decoder self.two_stage_keep_all_tokens = two_stage_keep_all_tokens # False self.num_queries = num_queries # 900 self.random_refpoints_xy = random_refpoints_xy # False assert query_dim == 4 if num_feature_levels > 1: assert deformable_encoder, 'only support deformable_encoder for num_feature_levels > 1' self.decoder_sa_type = decoder_sa_type # sa assert decoder_sa_type in ['sa', 'ca_label', 'ca_content'] # choose encoder layer type if deformable_encoder: encoder_layer = DeformableTransformerEncoderLayer( d_model, dim_feedforward, dropout, activation, num_feature_levels, nhead, enc_n_points) else: raise NotImplementedError encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, activation, normalize_before) encoder_norm = nn.LayerNorm(d_model) if normalize_before else None self.encoder = TransformerEncoder( encoder_layer, num_encoder_layers, encoder_norm, d_model=d_model, num_queries=num_queries, deformable_encoder=deformable_encoder, two_stage_type=two_stage_type) # choose decoder layer type if deformable_decoder: decoder_layer = DeformableTransformerDecoderLayer( d_model, dim_feedforward, dropout, activation, num_feature_levels, nhead, dec_n_points, decoder_sa_type=decoder_sa_type, module_seq=module_seq) else: raise NotImplementedError decoder_layer = TransformerDecoderLayer( d_model, nhead, dim_feedforward, dropout, activation, normalize_before, num_feature_levels=num_feature_levels) decoder_norm = nn.LayerNorm(d_model) self.decoder = TransformerDecoder_Box( decoder_layer, num_decoder_layers, decoder_norm, return_intermediate=return_intermediate_dec, d_model=d_model, query_dim=query_dim, modulate_hw_attn=modulate_hw_attn, num_feature_levels=num_feature_levels, deformable_decoder=deformable_decoder, dec_layer_number=dec_layer_number, num_body_points=num_body_points, num_hand_points=num_hand_points, num_face_points=num_face_points, num_box_decoder_layers=num_box_decoder_layers, num_hand_face_decoder_layers=num_hand_face_decoder_layers, num_group=num_group, num_dn=num_group, ) self.d_model = d_model self.nhead = nhead # 8 self.dec_layers = num_decoder_layers # 6 self.num_queries = num_queries # useful for single stage model only self.num_patterns = num_patterns # 0 if not isinstance(num_patterns, int): Warning('num_patterns should be int but {}'.format( type(num_patterns))) self.num_patterns = 0 if self.num_patterns > 0: assert two_stage_type == 'no' self.patterns = nn.Embedding(self.num_patterns, d_model) if num_feature_levels > 1: if self.num_encoder_layers > 0: self.level_embed = nn.Parameter( torch.Tensor(num_feature_levels, d_model)) else: self.level_embed = None self.learnable_tgt_init = learnable_tgt_init # true assert learnable_tgt_init, 'why not learnable_tgt_init' self.embed_init_tgt = embed_init_tgt # false if (two_stage_type != 'no' and embed_init_tgt) or (two_stage_type == 'no'): self.tgt_embed = nn.Embedding(self.num_queries, d_model) nn.init.normal_(self.tgt_embed.weight.data) else: self.tgt_embed = None # for two stage self.two_stage_type = two_stage_type self.two_stage_learn_wh = two_stage_learn_wh assert two_stage_type in [ 'no', 'standard', 'early', 'combine', 'enceachlayer', 'enclayer1' ], 'unknown param {} of two_stage_type'.format(two_stage_type) if two_stage_type in [ 'standard', 'combine', 'enceachlayer', 'enclayer1' ]: # anchor selection at the output of encoder self.enc_output = nn.Linear(d_model, d_model) self.enc_output_norm = nn.LayerNorm(d_model) if two_stage_learn_wh: # import pdb; pdb.set_trace() self.two_stage_wh_embedding = nn.Embedding(1, 2) else: self.two_stage_wh_embedding = None if two_stage_type in ['early', 'combine']: # anchor selection at the output of backbone self.enc_output_backbone = nn.Linear(d_model, d_model) self.enc_output_norm_backbone = nn.LayerNorm(d_model) if two_stage_type == 'no': self.init_ref_points(num_queries) # init self.refpoint_embed self.enc_out_class_embed = None self.enc_out_bbox_embed = None self.enc_out_pose_embed = None # evolution of anchors self.dec_layer_number = dec_layer_number if dec_layer_number is not None: if self.two_stage_type != 'no' or num_patterns == 0: assert dec_layer_number[ 0] == num_queries, f'dec_layer_number[0]({dec_layer_number[0]}) != num_queries({num_queries})' else: assert dec_layer_number[ 0] == num_queries * num_patterns, f'dec_layer_number[0]({dec_layer_number[0]}) != num_queries({num_queries}) * num_patterns({num_patterns})' self._reset_parameters() self.rm_self_attn_layers = rm_self_attn_layers if rm_self_attn_layers is not None: # assert len(rm_self_attn_layers) == num_decoder_layers print('Removing the self-attn in {} decoder layers'.format( rm_self_attn_layers)) for lid, dec_layer in enumerate(self.decoder.layers): if lid in rm_self_attn_layers: dec_layer.rm_self_attn_modules() self.rm_detach = rm_detach if self.rm_detach: assert isinstance(rm_detach, list) assert any([i in ['enc_ref', 'enc_tgt', 'dec'] for i in rm_detach]) self.decoder.rm_detach = rm_detach def _reset_parameters(self): for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) for m in self.modules(): if isinstance(m, MSDeformAttn): m._reset_parameters() if self.num_feature_levels > 1 and self.level_embed is not None: nn.init.normal_(self.level_embed) if self.two_stage_learn_wh: nn.init.constant_(self.two_stage_wh_embedding.weight, math.log(0.05 / (1 - 0.05))) def get_valid_ratio(self, mask): _, H, W = mask.shape valid_H = torch.sum(~mask[:, :, 0], 1) valid_W = torch.sum(~mask[:, 0, :], 1) valid_ratio_h = valid_H.float() / H valid_ratio_w = valid_W.float() / W valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1) return valid_ratio def init_ref_points(self, use_num_queries): self.refpoint_embed = nn.Embedding(use_num_queries, 4) if self.random_refpoints_xy: # import pdb; pdb.set_trace() self.refpoint_embed.weight.data[:, :2].uniform_(0, 1) self.refpoint_embed.weight.data[:, :2] = inverse_sigmoid( self.refpoint_embed.weight.data[:, :2]) self.refpoint_embed.weight.data[:, :2].requires_grad = False # srcs: features; refpoint_embed: def forward(self, srcs, masks, refpoint_embed, pos_embeds, tgt, attn_mask=None, attn_mask2=None, attn_mask3=None): # pdb.set_trace() # prepare input for encoder src_flatten = [] mask_flatten = [] lvl_pos_embed_flatten = [] spatial_shapes = [] for lvl, (src, mask, pos_embed) in enumerate( zip(srcs, masks, pos_embeds)): # for feature level bs, c, h, w = src.shape spatial_shape = (h, w) spatial_shapes.append(spatial_shape) src = src.flatten(2).transpose(1, 2) # bs, hw, c mask = mask.flatten(1) # bs, hw pos_embed = pos_embed.flatten(2).transpose(1, 2) # bs, hw, c if self.num_feature_levels > 1 and self.level_embed is not None: lvl_pos_embed = pos_embed + self.level_embed[lvl].view( 1, 1, -1) # level_embed[lvl]: [256] else: lvl_pos_embed = pos_embed lvl_pos_embed_flatten.append(lvl_pos_embed) src_flatten.append(src) mask_flatten.append(mask) src_flatten = torch.cat(src_flatten, 1) # bs, \sum{hxw}, c mask_flatten = torch.cat(mask_flatten, 1) # bs, \sum{hxw} lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) # bs, \sum{hxw}, c spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device) level_start_index = torch.cat((spatial_shapes.new_zeros( (1, )), spatial_shapes.prod(1).cumsum(0)[:-1])) valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1) # two stage if self.two_stage_type in ['early', 'combine']: output_memory, output_proposals = gen_encoder_output_proposals( src_flatten, mask_flatten, spatial_shapes) output_memory = self.enc_output_norm_backbone( self.enc_output_backbone(output_memory)) # gather boxes topk = self.num_queries enc_outputs_class = self.encoder.class_embed[0](output_memory) enc_topk_proposals = torch.topk(enc_outputs_class.max(-1)[0], topk, dim=1)[1] # bs, nq enc_refpoint_embed = torch.gather( output_proposals, 1, enc_topk_proposals.unsqueeze(-1).repeat(1, 1, 4)) src_flatten = output_memory else: enc_topk_proposals = enc_refpoint_embed = None ######################################################### # Begin Encoder ######################################################### memory, enc_intermediate_output, enc_intermediate_refpoints = self.encoder( src_flatten, pos=lvl_pos_embed_flatten, level_start_index=level_start_index, spatial_shapes=spatial_shapes, valid_ratios=valid_ratios, key_padding_mask=mask_flatten, ref_token_index=enc_topk_proposals, # bs, nq ref_token_coord=enc_refpoint_embed, # bs, nq, 4 ) ######################################################### # End Encoder # - memory: bs, \sum{hw}, c # - mask_flatten: bs, \sum{hw} # - lvl_pos_embed_flatten: bs, \sum{hw}, c # - enc_intermediate_output: None or (nenc+1, bs, nq, c) or (nenc, bs, nq, c) # - enc_intermediate_refpoints: None or (nenc+1, bs, nq, c) or (nenc, bs, nq, c) ######################################################### if self.two_stage_type in [ 'standard', 'combine', 'enceachlayer', 'enclayer1' ]: if self.two_stage_learn_wh: # import pdb; pdb.set_trace() input_hw = self.two_stage_wh_embedding.weight[0] else: input_hw = None output_memory, output_proposals = gen_encoder_output_proposals( memory, mask_flatten, spatial_shapes, input_hw) output_memory = self.enc_output_norm( self.enc_output(output_memory)) enc_outputs_class_unselected = self.enc_out_class_embed( output_memory) # [11531, 2] for swin enc_outputs_coord_unselected = self.enc_out_bbox_embed( output_memory ) + output_proposals # (bs, \sum{hw}, 4) unsigmoid topk = self.num_queries topk_proposals = torch.topk( enc_outputs_class_unselected.max(-1)[0], topk, dim=1)[1] # bs, nq coarse human query selection # gather boxes refpoint_embed_undetach = torch.gather( enc_outputs_coord_unselected, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4)) # unsigmoid refpoint_embed_ = refpoint_embed_undetach.detach() init_box_proposal = torch.gather( output_proposals, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4)).sigmoid() # sigmoid # gather tgt tgt_undetach = torch.gather( output_memory, 1, topk_proposals.unsqueeze(-1).repeat( 1, 1, self.d_model)) # selected content query if self.embed_init_tgt: tgt_ = self.tgt_embed.weight[:, None, :].repeat( 1, bs, 1).transpose(0, 1) # nq, bs, d_model else: tgt_ = tgt_undetach.detach() if refpoint_embed is not None: # import pdb; pdb.set_trace() refpoint_embed = torch.cat([refpoint_embed, refpoint_embed_], dim=1) # [1000, 4] tgt = torch.cat([tgt, tgt_], dim=1) else: refpoint_embed, tgt = refpoint_embed_, tgt_ elif self.two_stage_type == 'early': refpoint_embed_undetach = self.enc_out_bbox_embed( enc_intermediate_output[-1] ) + enc_refpoint_embed # unsigmoid, (bs, nq, 4) refpoint_embed = refpoint_embed_undetach.detach() # tgt_undetach = enc_intermediate_output[-1] # bs, nq, d_model tgt = tgt_undetach.detach() elif self.two_stage_type == 'no': tgt_ = self.tgt_embed.weight[:, None, :].repeat(1, bs, 1).transpose( 0, 1) # nq, bs, d_model refpoint_embed_ = self.refpoint_embed.weight[:, None, :].repeat( 1, bs, 1).transpose(0, 1) # nq, bs, 4 if refpoint_embed is not None: # import pdb; pdb.set_trace() refpoint_embed = torch.cat([refpoint_embed, refpoint_embed_], dim=1) tgt = torch.cat([tgt, tgt_], dim=1) else: refpoint_embed, tgt = refpoint_embed_, tgt_ # pat embed if self.num_patterns > 0: tgt_embed = tgt.repeat(1, self.num_patterns, 1) refpoint_embed = refpoint_embed.repeat(1, self.num_patterns, 1) tgt_pat = self.patterns.weight[None, :, :].repeat_interleave( self.num_queries, 1) # 1, n_q*n_pat, d_model tgt = tgt_embed + tgt_pat init_box_proposal = refpoint_embed_.sigmoid() else: raise NotImplementedError('unknown two_stage_type {}'.format( self.two_stage_type)) ######################################################### # Begin Decoder ######################################################### hs, references = self.decoder( tgt=tgt.transpose(0, 1), memory=memory.transpose(0, 1), memory_key_padding_mask=mask_flatten, pos=lvl_pos_embed_flatten.transpose(0, 1), refpoints_unsigmoid=refpoint_embed.transpose(0, 1), level_start_index=level_start_index, spatial_shapes=spatial_shapes, valid_ratios=valid_ratios, tgt_mask=attn_mask, tgt_mask2=attn_mask2, tgt_mask3=attn_mask3) ######################################################### # End Decoder # hs: n_dec, bs, nq, d_model # references: n_dec+1, bs, nq, query_dim ######################################################### ######################################################### # Begin postprocess ######################################################### if self.two_stage_type == 'standard': if self.two_stage_keep_all_tokens: hs_enc = output_memory.unsqueeze(0) ref_enc = enc_outputs_coord_unselected.unsqueeze(0) init_box_proposal = output_proposals # import pdb; pdb.set_trace() else: hs_enc = tgt_undetach.unsqueeze(0) ref_enc = refpoint_embed_undetach.sigmoid().unsqueeze(0) elif self.two_stage_type in ['combine', 'early']: hs_enc = enc_intermediate_output hs_enc = torch.cat((hs_enc, tgt_undetach.unsqueeze(0)), dim=0) # nenc+1, bs, nq, c n_layer_hs_enc = hs_enc.shape[0] assert n_layer_hs_enc == self.num_encoder_layers + 1 ref_enc = enc_intermediate_refpoints ref_enc = torch.cat( (ref_enc, refpoint_embed_undetach.sigmoid().unsqueeze(0)), dim=0) # nenc+1, bs, nq, 4 elif self.two_stage_type in ['enceachlayer', 'enclayer1']: hs_enc = enc_intermediate_output hs_enc = torch.cat((hs_enc, tgt_undetach.unsqueeze(0)), dim=0) # nenc, bs, nq, c n_layer_hs_enc = hs_enc.shape[0] assert n_layer_hs_enc == self.num_encoder_layers ref_enc = enc_intermediate_refpoints ref_enc = torch.cat( (ref_enc, refpoint_embed_undetach.sigmoid().unsqueeze(0)), dim=0) # nenc, bs, nq, 4 else: hs_enc = ref_enc = None return hs, references, hs_enc, ref_enc, init_box_proposal