Spaces:
Running
on
L40S
Running
on
L40S
import math, random | |
import copy | |
import os | |
from typing import Optional, List, Union | |
import warnings | |
from util.misc import inverse_sigmoid | |
import torch | |
import torch.nn.functional as F | |
from torch import nn, Tensor | |
from .transformer_deformable import DeformableTransformerEncoderLayer, DeformableTransformerDecoderLayer | |
from .utils import gen_encoder_output_proposals, sigmoid_focal_loss, MLP, _get_activation_fn, gen_sineembed_for_position | |
from .ops.modules.ms_deform_attn import MSDeformAttn | |
import pdb | |
class Transformer(nn.Module): | |
def __init__( | |
self, | |
d_model=256, | |
nhead=8, | |
num_queries=300, | |
num_encoder_layers=6, | |
num_decoder_layers=6, | |
dim_feedforward=2048, | |
dropout=0.0, | |
activation='relu', | |
normalize_before=False, | |
return_intermediate_dec=False, | |
query_dim=4, | |
num_patterns=0, | |
modulate_hw_attn=False, | |
# for deformable encoder | |
deformable_encoder=False, | |
deformable_decoder=False, | |
num_feature_levels=1, | |
enc_n_points=4, | |
dec_n_points=4, | |
# init query | |
learnable_tgt_init=False, | |
random_refpoints_xy=False, | |
# two stage | |
two_stage_type='no', | |
two_stage_learn_wh=False, | |
two_stage_keep_all_tokens=False, | |
# evo of #anchors | |
dec_layer_number=None, | |
rm_self_attn_layers=None, | |
# for detach | |
rm_detach=None, | |
decoder_sa_type='sa', | |
module_seq=['sa', 'ca', 'ffn'], | |
# for pose | |
embed_init_tgt=False, | |
num_body_points=17, | |
num_hand_points=10, | |
num_face_points=10, | |
num_box_decoder_layers=2, | |
num_hand_face_decoder_layers=4, | |
num_group=100): | |
super().__init__() | |
# pdb.set_trace() | |
self.num_feature_levels = num_feature_levels # 4 | |
self.num_encoder_layers = num_encoder_layers # 6 | |
self.num_decoder_layers = num_decoder_layers # 6 | |
self.deformable_encoder = deformable_encoder | |
self.deformable_decoder = deformable_decoder | |
self.two_stage_keep_all_tokens = two_stage_keep_all_tokens # False | |
self.num_queries = num_queries # 900 | |
self.random_refpoints_xy = random_refpoints_xy # False | |
assert query_dim == 4 | |
if num_feature_levels > 1: | |
assert deformable_encoder, 'only support deformable_encoder for num_feature_levels > 1' | |
self.decoder_sa_type = decoder_sa_type # sa | |
assert decoder_sa_type in ['sa', 'ca_label', 'ca_content'] | |
# choose encoder layer type | |
if deformable_encoder: | |
encoder_layer = DeformableTransformerEncoderLayer( | |
d_model, dim_feedforward, dropout, activation, | |
num_feature_levels, nhead, enc_n_points) | |
else: | |
raise NotImplementedError | |
encoder_layer = TransformerEncoderLayer(d_model, nhead, | |
dim_feedforward, dropout, | |
activation, | |
normalize_before) | |
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None | |
self.encoder = TransformerEncoder( | |
encoder_layer, | |
num_encoder_layers, | |
encoder_norm, | |
d_model=d_model, | |
num_queries=num_queries, | |
deformable_encoder=deformable_encoder, | |
two_stage_type=two_stage_type) | |
# choose decoder layer type | |
if deformable_decoder: | |
decoder_layer = DeformableTransformerDecoderLayer( | |
d_model, | |
dim_feedforward, | |
dropout, | |
activation, | |
num_feature_levels, | |
nhead, | |
dec_n_points, | |
decoder_sa_type=decoder_sa_type, | |
module_seq=module_seq) | |
else: | |
raise NotImplementedError | |
decoder_layer = TransformerDecoderLayer( | |
d_model, | |
nhead, | |
dim_feedforward, | |
dropout, | |
activation, | |
normalize_before, | |
num_feature_levels=num_feature_levels) | |
decoder_norm = nn.LayerNorm(d_model) | |
self.decoder = TransformerDecoder( | |
decoder_layer, | |
num_decoder_layers, | |
decoder_norm, | |
return_intermediate=return_intermediate_dec, | |
d_model=d_model, | |
query_dim=query_dim, | |
modulate_hw_attn=modulate_hw_attn, | |
num_feature_levels=num_feature_levels, | |
deformable_decoder=deformable_decoder, | |
dec_layer_number=dec_layer_number, | |
num_body_points=num_body_points, | |
num_hand_points=num_hand_points, | |
num_face_points=num_face_points, | |
num_box_decoder_layers=num_box_decoder_layers, | |
num_hand_face_decoder_layers=num_hand_face_decoder_layers, | |
num_group=num_group, | |
num_dn=num_group, | |
) | |
self.d_model = d_model | |
self.nhead = nhead # 8 | |
self.dec_layers = num_decoder_layers # 6 | |
self.num_queries = num_queries # useful for single stage model only | |
self.num_patterns = num_patterns # 0 | |
if not isinstance(num_patterns, int): | |
Warning('num_patterns should be int but {}'.format( | |
type(num_patterns))) | |
self.num_patterns = 0 | |
if self.num_patterns > 0: | |
assert two_stage_type == 'no' | |
self.patterns = nn.Embedding(self.num_patterns, d_model) | |
if num_feature_levels > 1: | |
if self.num_encoder_layers > 0: | |
self.level_embed = nn.Parameter( | |
torch.Tensor(num_feature_levels, d_model)) | |
else: | |
self.level_embed = None | |
self.learnable_tgt_init = learnable_tgt_init # true | |
assert learnable_tgt_init, 'why not learnable_tgt_init' | |
self.embed_init_tgt = embed_init_tgt # false | |
if (two_stage_type != 'no' and embed_init_tgt) or (two_stage_type | |
== 'no'): | |
self.tgt_embed = nn.Embedding(self.num_queries, d_model) | |
nn.init.normal_(self.tgt_embed.weight.data) | |
else: | |
self.tgt_embed = None | |
# for two stage | |
self.two_stage_type = two_stage_type | |
self.two_stage_learn_wh = two_stage_learn_wh | |
assert two_stage_type in [ | |
'no', 'standard', 'early', 'combine', 'enceachlayer', 'enclayer1' | |
], 'unknown param {} of two_stage_type'.format(two_stage_type) | |
if two_stage_type in [ | |
'standard', 'combine', 'enceachlayer', 'enclayer1' | |
]: | |
# anchor selection at the output of encoder | |
self.enc_output = nn.Linear(d_model, d_model) | |
self.enc_output_norm = nn.LayerNorm(d_model) | |
if two_stage_learn_wh: | |
# import pdb; pdb.set_trace() | |
self.two_stage_wh_embedding = nn.Embedding(1, 2) | |
else: | |
self.two_stage_wh_embedding = None | |
if two_stage_type in ['early', 'combine']: | |
# anchor selection at the output of backbone | |
self.enc_output_backbone = nn.Linear(d_model, d_model) | |
self.enc_output_norm_backbone = nn.LayerNorm(d_model) | |
if two_stage_type == 'no': | |
self.init_ref_points(num_queries) # init self.refpoint_embed | |
self.enc_out_class_embed = None | |
self.enc_out_bbox_embed = None | |
self.enc_out_pose_embed = None | |
# evolution of anchors | |
self.dec_layer_number = dec_layer_number | |
if dec_layer_number is not None: | |
if self.two_stage_type != 'no' or num_patterns == 0: | |
assert dec_layer_number[ | |
0] == num_queries, f'dec_layer_number[0]({dec_layer_number[0]}) != num_queries({num_queries})' | |
else: | |
assert dec_layer_number[ | |
0] == num_queries * num_patterns, f'dec_layer_number[0]({dec_layer_number[0]}) != num_queries({num_queries}) * num_patterns({num_patterns})' | |
self._reset_parameters() | |
self.rm_self_attn_layers = rm_self_attn_layers | |
if rm_self_attn_layers is not None: | |
# assert len(rm_self_attn_layers) == num_decoder_layers | |
print('Removing the self-attn in {} decoder layers'.format( | |
rm_self_attn_layers)) | |
for lid, dec_layer in enumerate(self.decoder.layers): | |
if lid in rm_self_attn_layers: | |
dec_layer.rm_self_attn_modules() | |
self.rm_detach = rm_detach | |
if self.rm_detach: | |
assert isinstance(rm_detach, list) | |
assert any([i in ['enc_ref', 'enc_tgt', 'dec'] for i in rm_detach]) | |
self.decoder.rm_detach = rm_detach | |
def _reset_parameters(self): | |
for p in self.parameters(): | |
if p.dim() > 1: | |
nn.init.xavier_uniform_(p) | |
for m in self.modules(): | |
if isinstance(m, MSDeformAttn): | |
m._reset_parameters() | |
if self.num_feature_levels > 1 and self.level_embed is not None: | |
nn.init.normal_(self.level_embed) | |
if self.two_stage_learn_wh: | |
nn.init.constant_(self.two_stage_wh_embedding.weight, | |
math.log(0.05 / (1 - 0.05))) | |
def get_valid_ratio(self, mask): | |
_, H, W = mask.shape | |
valid_H = torch.sum(~mask[:, :, 0], 1) | |
valid_W = torch.sum(~mask[:, 0, :], 1) | |
valid_ratio_h = valid_H.float() / H | |
valid_ratio_w = valid_W.float() / W | |
valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1) | |
return valid_ratio | |
def init_ref_points(self, use_num_queries): | |
self.refpoint_embed = nn.Embedding(use_num_queries, 4) | |
if self.random_refpoints_xy: | |
# import pdb; pdb.set_trace() | |
self.refpoint_embed.weight.data[:, :2].uniform_(0, 1) | |
self.refpoint_embed.weight.data[:, :2] = inverse_sigmoid( | |
self.refpoint_embed.weight.data[:, :2]) | |
self.refpoint_embed.weight.data[:, :2].requires_grad = False | |
# srcs: features; refpoint_embed: | |
def forward(self, | |
srcs, | |
masks, | |
refpoint_embed, | |
pos_embeds, | |
tgt, | |
attn_mask=None, | |
attn_mask2=None, | |
attn_mask3=None): | |
# pdb.set_trace() | |
# prepare input for encoder | |
src_flatten = [] | |
mask_flatten = [] | |
lvl_pos_embed_flatten = [] | |
spatial_shapes = [] | |
for lvl, (src, mask, pos_embed) in enumerate( | |
zip(srcs, masks, pos_embeds)): # for feature level | |
bs, c, h, w = src.shape | |
spatial_shape = (h, w) | |
spatial_shapes.append(spatial_shape) | |
src = src.flatten(2).transpose(1, 2) # bs, hw, c | |
mask = mask.flatten(1) # bs, hw | |
pos_embed = pos_embed.flatten(2).transpose(1, 2) # bs, hw, c | |
if self.num_feature_levels > 1 and self.level_embed is not None: | |
lvl_pos_embed = pos_embed + self.level_embed[lvl].view( | |
1, 1, -1) # level_embed[lvl]: [256] | |
else: | |
lvl_pos_embed = pos_embed | |
lvl_pos_embed_flatten.append(lvl_pos_embed) | |
src_flatten.append(src) | |
mask_flatten.append(mask) | |
src_flatten = torch.cat(src_flatten, 1) # bs, \sum{hxw}, c | |
mask_flatten = torch.cat(mask_flatten, 1) # bs, \sum{hxw} | |
lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, | |
1) # bs, \sum{hxw}, c | |
spatial_shapes = torch.as_tensor(spatial_shapes, | |
dtype=torch.long, | |
device=src_flatten.device) | |
level_start_index = torch.cat((spatial_shapes.new_zeros( | |
(1, )), spatial_shapes.prod(1).cumsum(0)[:-1])) | |
valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1) | |
# two stage | |
if self.two_stage_type in ['early', 'combine']: | |
output_memory, output_proposals = gen_encoder_output_proposals( | |
src_flatten, mask_flatten, spatial_shapes) | |
output_memory = self.enc_output_norm_backbone( | |
self.enc_output_backbone(output_memory)) | |
# gather boxes | |
topk = self.num_queries | |
enc_outputs_class = self.encoder.class_embed[0](output_memory) | |
enc_topk_proposals = torch.topk(enc_outputs_class.max(-1)[0], | |
topk, | |
dim=1)[1] # bs, nq | |
enc_refpoint_embed = torch.gather( | |
output_proposals, 1, | |
enc_topk_proposals.unsqueeze(-1).repeat(1, 1, 4)) | |
src_flatten = output_memory | |
else: | |
enc_topk_proposals = enc_refpoint_embed = None | |
######################################################### | |
# Begin Encoder | |
######################################################### | |
memory, enc_intermediate_output, enc_intermediate_refpoints = self.encoder( | |
src_flatten, | |
pos=lvl_pos_embed_flatten, | |
level_start_index=level_start_index, | |
spatial_shapes=spatial_shapes, | |
valid_ratios=valid_ratios, | |
key_padding_mask=mask_flatten, | |
ref_token_index=enc_topk_proposals, # bs, nq | |
ref_token_coord=enc_refpoint_embed, # bs, nq, 4 | |
) | |
######################################################### | |
# End Encoder | |
# - memory: bs, \sum{hw}, c | |
# - mask_flatten: bs, \sum{hw} | |
# - lvl_pos_embed_flatten: bs, \sum{hw}, c | |
# - enc_intermediate_output: None or (nenc+1, bs, nq, c) or (nenc, bs, nq, c) | |
# - enc_intermediate_refpoints: None or (nenc+1, bs, nq, c) or (nenc, bs, nq, c) | |
######################################################### | |
if self.two_stage_type in [ | |
'standard', 'combine', 'enceachlayer', 'enclayer1' | |
]: | |
if self.two_stage_learn_wh: | |
# import pdb; pdb.set_trace() | |
input_hw = self.two_stage_wh_embedding.weight[0] | |
else: | |
input_hw = None | |
output_memory, output_proposals = gen_encoder_output_proposals( | |
memory, mask_flatten, spatial_shapes, input_hw) | |
output_memory = self.enc_output_norm( | |
self.enc_output(output_memory)) | |
enc_outputs_class_unselected = self.enc_out_class_embed( | |
output_memory) # [11531, 2] for swin | |
enc_outputs_coord_unselected = self.enc_out_bbox_embed( | |
output_memory | |
) + output_proposals # (bs, \sum{hw}, 4) unsigmoid | |
topk = self.num_queries | |
topk_proposals = torch.topk( | |
enc_outputs_class_unselected.max(-1)[0], topk, | |
dim=1)[1] # bs, nq coarse human query selection | |
# gather boxes | |
refpoint_embed_undetach = torch.gather( | |
enc_outputs_coord_unselected, 1, | |
topk_proposals.unsqueeze(-1).repeat(1, 1, 4)) # unsigmoid | |
refpoint_embed_ = refpoint_embed_undetach.detach() | |
init_box_proposal = torch.gather( | |
output_proposals, 1, | |
topk_proposals.unsqueeze(-1).repeat(1, 1, | |
4)).sigmoid() # sigmoid | |
# gather tgt | |
tgt_undetach = torch.gather( | |
output_memory, 1, | |
topk_proposals.unsqueeze(-1).repeat( | |
1, 1, self.d_model)) # selected content query | |
if self.embed_init_tgt: | |
tgt_ = self.tgt_embed.weight[:, None, :].repeat( | |
1, bs, 1).transpose(0, 1) # nq, bs, d_model | |
else: | |
tgt_ = tgt_undetach.detach() | |
if refpoint_embed is not None: | |
# import pdb; pdb.set_trace() | |
refpoint_embed = torch.cat([refpoint_embed, refpoint_embed_], | |
dim=1) # [1000, 4] | |
tgt = torch.cat([tgt, tgt_], dim=1) | |
else: | |
refpoint_embed, tgt = refpoint_embed_, tgt_ | |
elif self.two_stage_type == 'early': | |
refpoint_embed_undetach = self.enc_out_bbox_embed( | |
enc_intermediate_output[-1] | |
) + enc_refpoint_embed # unsigmoid, (bs, nq, 4) | |
refpoint_embed = refpoint_embed_undetach.detach() # | |
tgt_undetach = enc_intermediate_output[-1] # bs, nq, d_model | |
tgt = tgt_undetach.detach() | |
elif self.two_stage_type == 'no': | |
tgt_ = self.tgt_embed.weight[:, | |
None, :].repeat(1, bs, 1).transpose( | |
0, 1) # nq, bs, d_model | |
refpoint_embed_ = self.refpoint_embed.weight[:, None, :].repeat( | |
1, bs, 1).transpose(0, 1) # nq, bs, 4 | |
if refpoint_embed is not None: | |
# import pdb; pdb.set_trace() | |
refpoint_embed = torch.cat([refpoint_embed, refpoint_embed_], | |
dim=1) | |
tgt = torch.cat([tgt, tgt_], dim=1) | |
else: | |
refpoint_embed, tgt = refpoint_embed_, tgt_ | |
# pat embed | |
if self.num_patterns > 0: | |
tgt_embed = tgt.repeat(1, self.num_patterns, 1) | |
refpoint_embed = refpoint_embed.repeat(1, self.num_patterns, 1) | |
tgt_pat = self.patterns.weight[None, :, :].repeat_interleave( | |
self.num_queries, 1) # 1, n_q*n_pat, d_model | |
tgt = tgt_embed + tgt_pat | |
init_box_proposal = refpoint_embed_.sigmoid() | |
else: | |
raise NotImplementedError('unknown two_stage_type {}'.format( | |
self.two_stage_type)) | |
######################################################### | |
# Begin Decoder | |
######################################################### | |
hs, references = self.decoder( | |
tgt=tgt.transpose(0, 1), | |
memory=memory.transpose(0, 1), | |
memory_key_padding_mask=mask_flatten, | |
pos=lvl_pos_embed_flatten.transpose(0, 1), | |
refpoints_unsigmoid=refpoint_embed.transpose(0, 1), | |
level_start_index=level_start_index, | |
spatial_shapes=spatial_shapes, | |
valid_ratios=valid_ratios, | |
tgt_mask=attn_mask, | |
tgt_mask2=attn_mask2, | |
tgt_mask3=attn_mask3) | |
######################################################### | |
# End Decoder | |
# hs: n_dec, bs, nq, d_model | |
# references: n_dec+1, bs, nq, query_dim | |
######################################################### | |
######################################################### | |
# Begin postprocess | |
######################################################### | |
if self.two_stage_type == 'standard': | |
if self.two_stage_keep_all_tokens: | |
hs_enc = output_memory.unsqueeze(0) | |
ref_enc = enc_outputs_coord_unselected.unsqueeze(0) | |
init_box_proposal = output_proposals | |
# import pdb; pdb.set_trace() | |
else: | |
hs_enc = tgt_undetach.unsqueeze(0) | |
ref_enc = refpoint_embed_undetach.sigmoid().unsqueeze(0) | |
elif self.two_stage_type in ['combine', 'early']: | |
hs_enc = enc_intermediate_output | |
hs_enc = torch.cat((hs_enc, tgt_undetach.unsqueeze(0)), | |
dim=0) # nenc+1, bs, nq, c | |
n_layer_hs_enc = hs_enc.shape[0] | |
assert n_layer_hs_enc == self.num_encoder_layers + 1 | |
ref_enc = enc_intermediate_refpoints | |
ref_enc = torch.cat( | |
(ref_enc, refpoint_embed_undetach.sigmoid().unsqueeze(0)), | |
dim=0) # nenc+1, bs, nq, 4 | |
elif self.two_stage_type in ['enceachlayer', 'enclayer1']: | |
hs_enc = enc_intermediate_output | |
hs_enc = torch.cat((hs_enc, tgt_undetach.unsqueeze(0)), | |
dim=0) # nenc, bs, nq, c | |
n_layer_hs_enc = hs_enc.shape[0] | |
assert n_layer_hs_enc == self.num_encoder_layers | |
ref_enc = enc_intermediate_refpoints | |
ref_enc = torch.cat( | |
(ref_enc, refpoint_embed_undetach.sigmoid().unsqueeze(0)), | |
dim=0) # nenc, bs, nq, 4 | |
else: | |
hs_enc = ref_enc = None | |
return hs, references, hs_enc, ref_enc, init_box_proposal | |
class TransformerEncoder(nn.Module): | |
def __init__( | |
self, | |
encoder_layer, | |
num_layers, | |
norm=None, | |
d_model=256, | |
num_queries=300, | |
deformable_encoder=False, | |
enc_layer_share=False, | |
enc_layer_dropout_prob=None, | |
two_stage_type='no', | |
): | |
super().__init__() | |
# pdb.set_trace() | |
# prepare layers | |
if num_layers > 0: # 6 | |
self.layers = _get_clones( | |
encoder_layer, num_layers, | |
layer_share=enc_layer_share) # enc_layer_share false | |
else: | |
self.layers = [] | |
del encoder_layer | |
self.query_scale = None | |
self.num_queries = num_queries # 900 | |
self.deformable_encoder = deformable_encoder | |
self.num_layers = num_layers # 6 | |
self.norm = norm | |
self.d_model = d_model | |
self.enc_layer_dropout_prob = enc_layer_dropout_prob | |
if enc_layer_dropout_prob is not None: | |
assert isinstance(enc_layer_dropout_prob, list) | |
assert len(enc_layer_dropout_prob) == num_layers | |
for i in enc_layer_dropout_prob: | |
assert 0.0 <= i <= 1.0 | |
self.two_stage_type = two_stage_type | |
if two_stage_type in ['enceachlayer', 'enclayer1']: | |
_proj_layer = nn.Linear(d_model, d_model) | |
_norm_layer = nn.LayerNorm(d_model) | |
if two_stage_type == 'enclayer1': | |
self.enc_norm = nn.ModuleList([_norm_layer]) | |
self.enc_proj = nn.ModuleList([_proj_layer]) | |
else: | |
self.enc_norm = nn.ModuleList([ | |
copy.deepcopy(_norm_layer) for i in range(num_layers - 1) | |
]) | |
self.enc_proj = nn.ModuleList([ | |
copy.deepcopy(_proj_layer) for i in range(num_layers - 1) | |
]) | |
def get_reference_points(spatial_shapes, valid_ratios, device): | |
reference_points_list = [] | |
for lvl, (H_, W_) in enumerate(spatial_shapes): | |
ref_y, ref_x = torch.meshgrid( | |
torch.linspace(0.5, | |
H_ - 0.5, | |
H_, | |
dtype=torch.float32, | |
device=device), | |
torch.linspace(0.5, | |
W_ - 0.5, | |
W_, | |
dtype=torch.float32, | |
device=device)) | |
ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * | |
H_) | |
ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * | |
W_) | |
ref = torch.stack((ref_x, ref_y), -1) | |
reference_points_list.append(ref) | |
reference_points = torch.cat(reference_points_list, 1) | |
reference_points = reference_points[:, :, None] * valid_ratios[:, None] | |
return reference_points | |
def forward(self, | |
src: Tensor, | |
pos: Tensor, | |
spatial_shapes: Tensor, | |
level_start_index: Tensor, | |
valid_ratios: Tensor, | |
key_padding_mask: Tensor, | |
ref_token_index: Optional[Tensor] = None, | |
ref_token_coord: Optional[Tensor] = None): | |
""" | |
Input: | |
- src: [bs, sum(hi*wi), 256] | |
- pos: pos embed for src. [bs, sum(hi*wi), 256] | |
- spatial_shapes: h,w of each level [num_level, 2] | |
- level_start_index: [num_level] start point of level in sum(hi*wi). | |
- valid_ratios: [bs, num_level, 2] | |
- key_padding_mask: [bs, sum(hi*wi)] | |
- ref_token_index: bs, nq | |
- ref_token_coord: bs, nq, 4 | |
Intermedia: | |
- reference_points: [bs, sum(hi*wi), num_level, 2] | |
Outpus: | |
- output: [bs, sum(hi*wi), 256] | |
""" | |
# pdb.set_trace() | |
if self.two_stage_type in [ | |
'no', 'standard', 'enceachlayer', 'enclayer1' | |
]: | |
assert ref_token_index is None | |
output = src | |
# preparation and reshape | |
if self.num_layers > 0: | |
if self.deformable_encoder: | |
reference_points = self.get_reference_points(spatial_shapes, | |
valid_ratios, | |
device=src.device) | |
# import pdb; pdb.set_trace() | |
intermediate_output = [] | |
intermediate_ref = [] | |
if ref_token_index is not None: | |
out_i = torch.gather( | |
output, 1, | |
ref_token_index.unsqueeze(-1).repeat(1, 1, self.d_model)) | |
intermediate_output.append(out_i) | |
intermediate_ref.append(ref_token_coord) | |
# intermediate_coord = [] | |
# main process | |
for layer_id, layer in enumerate(self.layers): | |
# main process | |
dropflag = False | |
if self.enc_layer_dropout_prob is not None: | |
prob = random.random() | |
if prob < self.enc_layer_dropout_prob[layer_id]: | |
dropflag = True | |
if not dropflag: | |
if self.deformable_encoder: | |
output = layer(src=output, | |
pos=pos, | |
reference_points=reference_points, | |
spatial_shapes=spatial_shapes, | |
level_start_index=level_start_index, | |
key_padding_mask=key_padding_mask) | |
else: | |
output = layer( | |
src=output.transpose(0, 1), | |
pos=pos.transpose(0, 1), | |
key_padding_mask=key_padding_mask).transpose(0, 1) | |
if ((layer_id == 0 and self.two_stage_type in ['enceachlayer', 'enclayer1']) \ | |
or (self.two_stage_type == 'enceachlayer')) \ | |
and (layer_id != self.num_layers - 1): | |
output_memory, output_proposals = gen_encoder_output_proposals( | |
output, key_padding_mask, spatial_shapes) | |
output_memory = self.enc_norm[layer_id]( | |
self.enc_proj[layer_id](output_memory)) | |
# gather boxes | |
topk = self.num_queries | |
enc_outputs_class = self.class_embed[layer_id](output_memory) | |
ref_token_index = torch.topk(enc_outputs_class.max(-1)[0], | |
topk, | |
dim=1)[1] # bs, nq | |
ref_token_coord = torch.gather( | |
output_proposals, 1, | |
ref_token_index.unsqueeze(-1).repeat(1, 1, 4)) | |
output = output_memory | |
# aux loss | |
if (layer_id != | |
self.num_layers - 1) and ref_token_index is not None: | |
out_i = torch.gather( | |
output, 1, | |
ref_token_index.unsqueeze(-1).repeat(1, 1, self.d_model)) | |
intermediate_output.append(out_i) | |
intermediate_ref.append(ref_token_coord) | |
if self.norm is not None: | |
output = self.norm(output) | |
if ref_token_index is not None: | |
intermediate_output = torch.stack( | |
intermediate_output) # n_enc/n_enc-1, bs, \sum{hw}, d_model | |
intermediate_ref = torch.stack(intermediate_ref) | |
else: | |
intermediate_output = intermediate_ref = None | |
return output, intermediate_output, intermediate_ref | |
class TransformerDecoder(nn.Module): | |
def __init__( | |
self, | |
decoder_layer, | |
num_layers, | |
norm=None, | |
return_intermediate=False, | |
d_model=256, | |
query_dim=4, | |
modulate_hw_attn=False, | |
num_feature_levels=1, | |
deformable_decoder=False, | |
dec_layer_number=None, # number of queries each layer in decoder | |
dec_layer_share=False, | |
dec_layer_dropout_prob=None, | |
num_box_decoder_layers=2, | |
num_hand_face_decoder_layers=4, | |
num_body_points=17, | |
num_hand_points=10, | |
num_face_points=10, | |
num_dn=100, | |
num_group=100): | |
super().__init__() | |
# pdb.set_trace() | |
if num_layers > 0: | |
self.layers = _get_clones(decoder_layer, | |
num_layers, | |
layer_share=dec_layer_share) | |
else: | |
self.layers = [] | |
self.num_layers = num_layers | |
self.norm = norm | |
self.return_intermediate = return_intermediate # True | |
assert return_intermediate, 'support return_intermediate only' | |
self.query_dim = query_dim # 4 | |
assert query_dim in [ | |
2, 4 | |
], 'query_dim should be 2/4 but {}'.format(query_dim) | |
self.num_feature_levels = num_feature_levels # 4 | |
self.ref_point_head = MLP(query_dim // 2 * d_model, d_model, d_model, | |
2) # 4//2 * 256, 256, 256, 2 | |
if not deformable_decoder: | |
self.query_pos_sine_scale = MLP(d_model, d_model, d_model, 2) | |
else: | |
self.query_pos_sine_scale = None | |
self.num_body_points = num_body_points | |
self.num_hand_points = num_hand_points | |
self.num_face_points = num_face_points | |
self.query_scale = None | |
# aios kp | |
self.bbox_embed = None | |
self.class_embed = None | |
self.pose_embed = None | |
self.pose_hw_embed = None | |
# smpl | |
# self.smpl_pose_embed = None | |
# self.smpl_beta_embed = None | |
# self.smpl_cam_embed = None | |
# smplx | |
# smplx hand kp | |
self.bbox_hand_embed = None | |
self.bbox_hand_hw_embed = None | |
self.pose_hand_embed = None | |
self.pose_hand_hw_embed = None | |
# smplx face kp | |
self.bbox_face_embed = None | |
self.bbox_face_hw_embed = None | |
self.pose_face_embed = None | |
self.pose_face_hw_embed = None | |
# self.smplx_lhand_pose_embed = None | |
# self.smplx_rhand_pose_embed = None | |
# self.smplx_expression_embed = None | |
# self.smplx_jaw_embed = None | |
self.num_box_decoder_layers = num_box_decoder_layers # 2 | |
self.num_hand_face_decoder_layers = num_hand_face_decoder_layers | |
self.d_model = d_model | |
self.modulate_hw_attn = modulate_hw_attn | |
self.deformable_decoder = deformable_decoder | |
if not deformable_decoder and modulate_hw_attn: | |
self.ref_anchor_head = MLP(d_model, d_model, 2, 2) | |
else: | |
self.ref_anchor_head = None | |
self.box_pred_damping = None | |
self.dec_layer_number = dec_layer_number | |
if dec_layer_number is not None: | |
assert isinstance(dec_layer_number, list) | |
assert len(dec_layer_number) == num_layers | |
# assert dec_layer_number[0] == | |
self.dec_layer_dropout_prob = dec_layer_dropout_prob | |
if dec_layer_dropout_prob is not None: | |
raise NotImplementedError | |
assert isinstance(dec_layer_dropout_prob, list) | |
assert len(dec_layer_dropout_prob) == num_layers | |
for i in dec_layer_dropout_prob: | |
assert 0.0 <= i <= 1.0 | |
self.num_group = num_group | |
self.rm_detach = None | |
self.num_dn = num_dn | |
# self.hw_body_kps = nn.Embedding(self.num_body_points, 2) | |
self.hw = nn.Embedding(self.num_body_points, 2) | |
self.keypoint_embed = nn.Embedding(self.num_body_points, d_model) | |
self.body_kpt_index_1 = [ | |
x for x in range(self.num_group*(self.num_body_points+4)) if x%(self.num_body_points+4) not in [0, (1 + self.num_body_points), (2 + self.num_body_points), (3 + self.num_body_points)]] | |
self.whole_body_points = \ | |
self.num_body_points + self.num_hand_points *2 + self.num_face_points | |
self.body_kpt_index_2 = [ | |
x for x in range(self.num_group * (self.whole_body_points + 4)) | |
if (x % (self.whole_body_points + 4) in range(1,self.num_body_points+1)) | |
] | |
# [0-99]: dn bbox; | |
# [0,1]: body box; | |
# [1, 18]: body kps; | |
# [18, 19]: lhand box | |
# [19, 29]: lhand kps | |
# [29, 30]: rhand box | |
# [30, 40]: rhand kps | |
# [40, 41]: face bbox | |
# [41, 51]: face kps | |
self.lhand_kpt_index = [ | |
x for x in range(self.num_group * (self.whole_body_points + 4)) | |
if (x % (self.whole_body_points + 4) in range( | |
self.num_body_points+2, self.num_body_points+self.num_hand_points+2))] | |
self.rhand_kpt_index = [ | |
x for x in range(self.num_group * (self.whole_body_points + 4)) | |
if (x % (self.whole_body_points + 4) in range( | |
self.num_body_points+self.num_hand_points+3, self.num_body_points+self.num_hand_points*2+3)) | |
] | |
self.face_kpt_index = [ | |
x for x in range(self.num_group * (self.whole_body_points + 4)) | |
if (x % (self.whole_body_points + 4) in range( | |
self.num_body_points+self.num_hand_points*2+4, self.num_body_points+self.num_hand_points*2+self.num_face_points+4)) | |
] | |
self.lhand_box_embed = nn.Embedding(1, d_model) | |
self.rhand_box_embed = nn.Embedding(1, d_model) | |
self.face_box_embed = nn.Embedding(1, d_model) | |
self.hw_lhand_bbox = nn.Embedding(1, 2) | |
self.hw_rhand_bbox = nn.Embedding(1, 2) | |
self.hw_face_bbox = nn.Embedding(1, 2) | |
self.hw_lhand_kps = nn.Embedding(self.num_hand_points, 2) | |
self.hw_rhand_kps = nn.Embedding(self.num_hand_points, 2) | |
self.hw_face_kps = nn.Embedding(self.num_face_points, 2) | |
self.lhand_keypoint_embed = nn.Embedding(self.num_hand_points, d_model) | |
self.rhand_keypoint_embed = nn.Embedding(self.num_hand_points, d_model) | |
self.face_keypoint_embed = nn.Embedding(self.num_face_points, d_model) | |
def forward( | |
self, | |
tgt, | |
memory, | |
tgt_mask: Optional[Tensor] = None, | |
tgt_mask2: Optional[Tensor] = None, | |
tgt_mask3: Optional[Tensor] = None, | |
memory_mask: Optional[Tensor] = None, | |
tgt_key_padding_mask: Optional[Tensor] = None, | |
memory_key_padding_mask: Optional[Tensor] = None, | |
pos: Optional[Tensor] = None, | |
refpoints_unsigmoid: Optional[Tensor] = None, # num_queries, bs, 2 | |
# for memory | |
level_start_index: Optional[Tensor] = None, # num_levels | |
spatial_shapes: Optional[Tensor] = None, # bs, num_levels, 2 | |
valid_ratios: Optional[Tensor] = None, | |
): | |
output = tgt | |
intermediate = [] | |
reference_points = refpoints_unsigmoid.sigmoid() | |
ref_points = [reference_points] | |
effect_num_dn = self.num_dn if self.training else 0 | |
inter_select_number = self.num_group | |
for layer_id, layer in enumerate(self.layers): | |
if self.deformable_decoder: | |
if reference_points.shape[-1] == 4: | |
reference_points_input = reference_points[:, :, None] \ | |
* torch.cat([valid_ratios, valid_ratios], -1)[None, :] # nq, bs, nlevel, 4 | |
else: | |
assert reference_points.shape[-1] == 2 | |
reference_points_input = reference_points[:, :, | |
None] * valid_ratios[ | |
None, :] | |
query_sine_embed = gen_sineembed_for_position( | |
reference_points_input[:, :, 0, :] | |
) # convert the position query from bbox to sine/cosin embend | |
else: | |
query_sine_embed = gen_sineembed_for_position( | |
reference_points) # nq, bs, 256*2 | |
reference_points_input = None | |
raw_query_pos = self.ref_point_head( | |
query_sine_embed) # nq, bs, 256 | |
pos_scale = self.query_scale( | |
output) if self.query_scale is not None else 1 # ? | |
query_pos = pos_scale * raw_query_pos | |
if not self.deformable_decoder: | |
query_sine_embed = query_sine_embed[ | |
..., :self.d_model] * self.query_pos_sine_scale(output) | |
# modulated HW attentions | |
if not self.deformable_decoder and self.modulate_hw_attn: | |
refHW_cond = self.ref_anchor_head( | |
output).sigmoid() # nq, bs, 2 | |
query_sine_embed[..., self.d_model // 2:] *= ( | |
refHW_cond[..., 0] / | |
reference_points[..., 2]).unsqueeze(-1) | |
query_sine_embed[..., :self.d_model // | |
2] *= (refHW_cond[..., 1] / | |
reference_points[..., 3]).unsqueeze(-1) | |
dropflag = False | |
if self.dec_layer_dropout_prob is not None: | |
prob = random.random() | |
if prob < self.dec_layer_dropout_prob[layer_id]: | |
dropflag = True | |
if not dropflag: | |
output = layer( | |
tgt=output, | |
tgt_query_pos=query_pos, | |
tgt_query_sine_embed=query_sine_embed, | |
tgt_key_padding_mask=tgt_key_padding_mask, | |
tgt_reference_points=reference_points_input, | |
memory=memory, # encoder output, also known as content query of encoder | |
memory_key_padding_mask=memory_key_padding_mask, | |
memory_level_start_index=level_start_index, | |
memory_spatial_shapes=spatial_shapes, | |
memory_pos=pos, # position query of enconder | |
self_attn_mask=tgt_mask, | |
cross_attn_mask=memory_mask) | |
intermediate.append(self.norm(output)) | |
# human update | |
if layer_id < self.num_box_decoder_layers: | |
# reference_points: [100*(17+20*2+72) 4, 4] | |
reference_before_sigmoid = inverse_sigmoid(reference_points) | |
delta_unsig = self.bbox_embed[layer_id]( | |
output) # delta_x, delta_y, delta_w, delta_h | |
outputs_unsig = delta_unsig + reference_before_sigmoid | |
new_reference_points = outputs_unsig.sigmoid( | |
) # update the positional query by adding the offset delta_unsig | |
# kp query expansion | |
if layer_id == self.num_box_decoder_layers - 1: | |
dn_output = output[:effect_num_dn] # [100,-,256] | |
dn_new_reference_points = new_reference_points[: | |
effect_num_dn] # [100, -, 4] | |
class_unselected = self.class_embed[layer_id](output)[ | |
effect_num_dn:] # [900, -, 2] | |
topk_proposals = torch.topk(class_unselected.max(-1)[0], | |
inter_select_number, | |
dim=0)[1] # 100 | |
# selected position: select 100 query | |
new_reference_points_for_body_box = torch.gather( | |
new_reference_points[effect_num_dn:], 0, | |
topk_proposals.unsqueeze(-1).repeat( | |
1, 1, 4)) # selected position query | |
# selected output features | |
new_output_for_body_box = torch.gather( | |
output[effect_num_dn:], 0, | |
topk_proposals.unsqueeze(-1).repeat( | |
1, 1, self.d_model)) # selected content query | |
bs = new_output_for_body_box.shape[1] | |
# selected content query + keypoint position query, with shape [100, -, 4] | |
# expand per-human query to per-keypoint query | |
new_output_for_body_keypoint = new_output_for_body_box[:, None, :, :] \ | |
+ self.keypoint_embed.weight[None, :, None, :] # keypoint content query | |
if self.num_body_points == 17: | |
delta_xy = self.pose_embed[-1](new_output_for_body_keypoint)[ | |
..., :2] | |
else: | |
delta_xy = self.pose_embed[0](new_output_for_body_keypoint)[ | |
..., :2] | |
body_keypoint_xy = (inverse_sigmoid( | |
new_reference_points_for_body_box[..., :2][:, None]) + | |
delta_xy).sigmoid() # [100, 14, -, 2] | |
num_queries, _, bs, _ = body_keypoint_xy.shape | |
body_keypoint_wh_weight = self.hw.weight.unsqueeze(0).unsqueeze( | |
-2).repeat(num_queries, 1, bs, 1).sigmoid() | |
body_keypoint_wh = body_keypoint_wh_weight * new_reference_points_for_body_box[ | |
..., 2:][:, None] | |
new_reference_points_for_keypoint = torch.cat( | |
(body_keypoint_xy, body_keypoint_wh), dim=-1) | |
# for lhand bbox | |
new_output_for_lhand_box = new_output_for_body_box[:, None, :, :] \ | |
+ self.lhand_box_embed.weight[None, :, None, :] | |
delta_lhand_box_xy = self.bbox_hand_embed[-1](new_output_for_lhand_box)[..., :2] | |
lhand_bbox_xy = (inverse_sigmoid( | |
new_reference_points_for_body_box[..., :2][:, None]) + | |
delta_lhand_box_xy).sigmoid() # [100, 14, -, 2] | |
num_queries, _, bs, _ = lhand_bbox_xy.shape | |
lhand_bbox_wh_weight = self.hw_lhand_bbox.weight.unsqueeze(0).unsqueeze( | |
-2).repeat(num_queries, 1, bs, 1).sigmoid() | |
lhand_bbox_wh = lhand_bbox_wh_weight * new_reference_points_for_body_box[ | |
..., 2:][:, None] | |
new_reference_points_for_lhand_bbox = torch.cat( | |
(lhand_bbox_xy, lhand_bbox_wh), dim=-1) | |
# for rhand bbox | |
new_output_for_rhand_box = new_output_for_body_box[:, None, :, :] \ | |
+ self.rhand_box_embed.weight[None, :, None, :] | |
delta_rhand_box_xy = self.bbox_hand_embed[-1](new_output_for_rhand_box)[..., :2] | |
rhand_bbox_xy = (inverse_sigmoid( | |
new_reference_points_for_body_box[..., :2][:, None]) + | |
delta_rhand_box_xy).sigmoid() # [100, 14, -, 2] | |
num_queries, _, bs, _ = rhand_bbox_xy.shape | |
rhand_bbox_wh_weight = self.hw_rhand_bbox.weight.unsqueeze(0).unsqueeze( | |
-2).repeat(num_queries, 1, bs, 1).sigmoid() | |
rhand_bbox_wh = rhand_bbox_wh_weight * new_reference_points_for_body_box[ | |
..., 2:][:, None] | |
new_reference_points_for_rhand_bbox = torch.cat( | |
(rhand_bbox_xy, rhand_bbox_wh), dim=-1) | |
# for face bbox | |
new_output_for_face_box = new_output_for_body_box[:, None, :, :] \ | |
+ self.face_box_embed.weight[None, :, None, :] | |
delta_face_box_xy = self.bbox_face_embed[-1](new_output_for_face_box)[..., :2] | |
face_bbox_xy = (inverse_sigmoid( | |
new_reference_points_for_body_box[..., :2][:, None]) + | |
delta_face_box_xy).sigmoid() # [100, 14, -, 2] | |
num_queries, _, bs, _ = face_bbox_xy.shape | |
face_bbox_wh_weight = self.hw_face_bbox.weight.unsqueeze(0).unsqueeze( | |
-2).repeat(num_queries, 1, bs, 1).sigmoid() | |
face_bbox_wh = face_bbox_wh_weight * new_reference_points_for_body_box[ | |
..., 2:][:, None] | |
new_reference_points_for_face_box = torch.cat( | |
(face_bbox_xy, face_bbox_wh), dim=-1) | |
output = torch.cat( | |
(new_output_for_body_box.unsqueeze(1), | |
new_output_for_body_keypoint, | |
new_output_for_lhand_box, | |
new_output_for_rhand_box, | |
new_output_for_face_box), | |
dim=1).flatten(0, 1) | |
new_reference_points = torch.cat( | |
(new_reference_points_for_body_box.unsqueeze(1), | |
new_reference_points_for_keypoint, | |
new_reference_points_for_lhand_bbox, | |
new_reference_points_for_rhand_bbox, | |
new_reference_points_for_face_box), dim=1).flatten(0,1) | |
new_reference_points = torch.cat((dn_new_reference_points, new_reference_points),dim=0) | |
output = torch.cat((dn_output, output), dim=0) | |
tgt_mask = tgt_mask2 | |
# human-to-keypoints, human2face, human2hand update # 2 | |
if layer_id >= self.num_box_decoder_layers and layer_id < self.num_box_decoder_layers +2: | |
reference_before_sigmoid = inverse_sigmoid(reference_points) | |
reference_before_sigmoid_body_bbox_dn = \ | |
reference_before_sigmoid[:effect_num_dn] | |
reference_before_sigmoid_bbox_body_norm = \ | |
reference_before_sigmoid[effect_num_dn:][0::(self.num_body_points+4)] | |
output_bbox_body_dn=output[:effect_num_dn] | |
output_bbox_body_norm = output[effect_num_dn:][ | |
0::(self.num_body_points+4)] | |
delta_unsig_bbox_body_dn = self.bbox_embed[ | |
layer_id](output_bbox_body_dn) | |
delta_unsig_bbox_body_norm = self.bbox_embed[ | |
layer_id](output_bbox_body_norm) | |
outputs_unsig_body_bbox_dn = delta_unsig_bbox_body_dn + reference_before_sigmoid_body_bbox_dn | |
outputs_unsig_body_bbox_norm = delta_unsig_bbox_body_norm + reference_before_sigmoid_bbox_body_norm | |
new_reference_points_for_body_box_dn = outputs_unsig_body_bbox_dn.sigmoid() | |
new_reference_points_for_body_box_norm = outputs_unsig_body_bbox_norm.sigmoid() | |
# body kps | |
output_body_kpt=output[effect_num_dn:].index_select( | |
0,torch.tensor(self.body_kpt_index_1,device=output.device)) # select kp center content query | |
delta_xy_body_unsig = self.pose_embed[ | |
layer_id-self.num_box_decoder_layers](output_body_kpt) # offset of kp bbox center | |
outputs_body_kp_unsig = \ | |
reference_before_sigmoid[effect_num_dn:].index_select( | |
0, torch.tensor(self.body_kpt_index_1, device=output.device)).clone() # select kp position query | |
delta_hw_body_kp_unsig = self.pose_hw_embed[ | |
layer_id-self.num_box_decoder_layers](output_body_kpt) | |
outputs_body_kp_unsig[..., :2] += delta_xy_body_unsig[..., :2] | |
outputs_body_kp_unsig[..., 2:] += delta_hw_body_kp_unsig | |
new_reference_points_for_body_keypoint = outputs_body_kp_unsig.sigmoid() | |
bs=new_reference_points_for_body_box_norm.shape[1] | |
# lhand box | |
output_lhand_bbox_query = output[effect_num_dn:][ | |
(self.num_body_points + 1)::(self.num_body_points+4)] | |
delta_xy_lhand_bbox_unsig = self.bbox_hand_embed[ | |
layer_id-self.num_box_decoder_layers](output_lhand_bbox_query) | |
outputs_lhand_bbox_unsig = \ | |
reference_before_sigmoid[effect_num_dn:][ | |
(self.num_body_points + 1)::(self.num_body_points+4)].clone() | |
delta_hw_lhand_bbox_unsig = self.bbox_hand_hw_embed[ | |
layer_id-self.num_box_decoder_layers](output_lhand_bbox_query) | |
outputs_lhand_bbox_unsig[..., :2] +=delta_xy_lhand_bbox_unsig[..., :2] | |
outputs_lhand_bbox_unsig[..., 2:] +=delta_hw_lhand_bbox_unsig | |
new_reference_points_for_lhand_box_norm = outputs_lhand_bbox_unsig.sigmoid() | |
# rhand box | |
output_rhand_bbox_query = output[effect_num_dn:][ | |
(self.num_body_points + 2)::(self.num_body_points+4)] | |
delta_xy_rhand_bbox_unsig = self.bbox_hand_embed[ | |
layer_id-self.num_box_decoder_layers](output_rhand_bbox_query) | |
outputs_rhand_bbox_unsig = \ | |
reference_before_sigmoid[effect_num_dn:][ | |
(self.num_body_points + 2)::(self.num_body_points+4)].clone() | |
delta_hw_rhand_bbox_unsig = self.bbox_hand_hw_embed[ | |
layer_id-self.num_box_decoder_layers](output_rhand_bbox_query) | |
outputs_rhand_bbox_unsig[..., :2] +=delta_xy_rhand_bbox_unsig[..., :2] | |
outputs_rhand_bbox_unsig[..., 2:] +=delta_hw_rhand_bbox_unsig | |
new_reference_points_for_rhand_box_norm = outputs_rhand_bbox_unsig.sigmoid() | |
# face box | |
output_face_bbox_query = output[effect_num_dn:][ | |
(self.num_body_points + 3)::(self.num_body_points+4)] | |
delta_xy_face_bbox_unsig = self.bbox_face_embed[ | |
layer_id-self.num_box_decoder_layers](output_face_bbox_query) | |
outputs_face_bbox_unsig = \ | |
reference_before_sigmoid[effect_num_dn:][ | |
(self.num_body_points + 3)::(self.num_body_points+4)].clone() | |
delta_hw_face_bbox_unsig = self.bbox_face_hw_embed[ | |
layer_id-self.num_box_decoder_layers](output_face_bbox_query) | |
outputs_face_bbox_unsig[..., :2] +=delta_xy_face_bbox_unsig[..., :2] | |
outputs_face_bbox_unsig[..., 2:] +=delta_hw_face_bbox_unsig | |
new_reference_points_for_face_box_norm = outputs_face_bbox_unsig.sigmoid() | |
new_reference_points_norm = torch.cat( | |
(new_reference_points_for_body_box_norm.unsqueeze(1), | |
new_reference_points_for_body_keypoint.view(-1,self.num_body_points,bs,4), | |
new_reference_points_for_lhand_box_norm.unsqueeze(1), | |
new_reference_points_for_rhand_box_norm.unsqueeze(1), | |
new_reference_points_for_face_box_norm.unsqueeze(1)), dim=1).flatten(0,1) | |
new_reference_points = torch.cat(( | |
new_reference_points_for_body_box_dn, | |
new_reference_points_norm), dim=0) | |
# hand, bbox query expansion | |
if layer_id == self.num_hand_face_decoder_layers - 1: | |
dn_body_output = output[:effect_num_dn] | |
dn_reference_points_body = new_reference_points[:effect_num_dn] | |
# body bbox | |
new_reference_points_for_body_box = \ | |
new_reference_points[effect_num_dn:][0::(self.num_body_points + 4)] | |
new_output_for_body_box = output[effect_num_dn:][0:: | |
(self.num_body_points + 4)] | |
# body kp bbox | |
new_output_body_for_body_keypoint = \ | |
output[effect_num_dn:].index_select( | |
0,torch.tensor(self.body_kpt_index_1,device=output.device)).clone() | |
new_output_body_for_body_keypoint = new_output_body_for_body_keypoint.view( | |
self.num_group, self.num_body_points, bs, self.d_model) | |
new_reference_points_for_body_keypoint = new_reference_points[effect_num_dn:].index_select( | |
0,torch.tensor(self.body_kpt_index_1,device=output.device)).clone() | |
new_reference_points_for_body_keypoint = \ | |
new_reference_points_for_body_keypoint.view(self.num_group, self.num_body_points, bs, 4) | |
new_reference_points_body = \ | |
torch.cat((new_reference_points_for_body_box.unsqueeze(1), | |
new_reference_points_for_body_keypoint), dim=1) | |
new_body_output = torch.cat((new_output_for_body_box.unsqueeze(1), | |
new_output_body_for_body_keypoint), dim=1) | |
# lhand bbox content query and position query | |
new_reference_points_for_lhand_box = \ | |
new_reference_points[effect_num_dn:][ | |
(self.num_body_points + 1)::(self.num_body_points + 4)] | |
new_output_for_lhand_box = output[effect_num_dn:][ | |
(self.num_body_points + 1)::(self.num_body_points + 4)] | |
# lhand query expansion | |
new_output_for_lhand_keypoint = new_output_for_lhand_box[:, None, :, :] \ | |
+ self.lhand_keypoint_embed.weight[None, :, None, :] | |
# use the expanded lhand kp query to regress | |
# the center displacement relatived to lhand bbox | |
delta_lhand_kp_xy = self.pose_hand_embed[-1](new_output_for_lhand_keypoint)[..., :2] | |
# get absoulte bbox center for each lhand kps bbox | |
lhand_keypoint_xy = ( | |
inverse_sigmoid(new_reference_points_for_lhand_box[..., :2][:, None]) | |
+ delta_lhand_kp_xy).sigmoid() | |
num_queries,_,bs,_=lhand_keypoint_xy.shape | |
lhand_keypoint_wh_weight = \ | |
self.hw_lhand_kps.weight.unsqueeze(0).unsqueeze(-2).repeat(num_queries,1,bs,1).sigmoid() | |
lhand_keypoint_wh = lhand_keypoint_wh_weight * new_reference_points_for_lhand_box[..., 2:][:, None] | |
new_reference_points_for_lhand_keypoint = torch.cat((lhand_keypoint_xy, lhand_keypoint_wh), dim=-1) | |
new_reference_points_lhand = \ | |
torch.cat((new_reference_points_for_lhand_box.unsqueeze(1), new_reference_points_for_lhand_keypoint), dim=1) | |
new_lhand_output = torch.cat((new_output_for_lhand_box.unsqueeze(1), new_output_for_lhand_keypoint), dim=1) | |
# rhand | |
new_reference_points_for_rhand_box = \ | |
new_reference_points[effect_num_dn:][ | |
(self.num_body_points + 2)::(self.num_body_points + 4)] | |
new_output_for_rhand_box = output[effect_num_dn:][ | |
(self.num_body_points + 2)::(self.num_body_points + 4)] | |
new_output_for_rhand_keypoint = new_output_for_rhand_box[:, None, :, :] \ | |
+ self.rhand_keypoint_embed.weight[None, :, None, :] | |
delta_rhand_kp_xy = self.pose_hand_embed[-1](new_output_for_rhand_keypoint) | |
rhand_keypoint_xy = ( | |
inverse_sigmoid(new_reference_points_for_rhand_box[..., :2][:, None]) | |
+ delta_rhand_kp_xy).sigmoid() | |
num_queries,_,bs,_=rhand_keypoint_xy.shape | |
rhand_keypoint_wh_weight = \ | |
self.hw_rhand_kps.weight.unsqueeze(0).unsqueeze(-2).repeat(num_queries,1,bs,1).sigmoid() | |
rhand_keypoint_wh = rhand_keypoint_wh_weight * new_reference_points_for_rhand_box[..., 2:][:, None] | |
new_reference_points_for_rhand_keypoint = torch.cat((rhand_keypoint_xy, rhand_keypoint_wh), dim=-1) | |
new_reference_points_rhand = \ | |
torch.cat((new_reference_points_for_rhand_box.unsqueeze(1), new_reference_points_for_rhand_keypoint), dim=1) | |
new_rhand_output = torch.cat((new_output_for_rhand_box.unsqueeze(1), new_output_for_rhand_keypoint), dim=1) | |
# face | |
new_reference_points_for_face_box = \ | |
new_reference_points[effect_num_dn:][ | |
(self.num_body_points + 3)::(self.num_body_points + 4)] | |
new_output_for_face_box = output[effect_num_dn:][ | |
(self.num_body_points + 3)::(self.num_body_points + 4)] | |
new_output_for_face_keypoint = new_output_for_face_box[:, None, :, :] \ | |
+ self.face_keypoint_embed.weight[None, :, None, :] | |
delta_face_kp_xy = self.pose_face_embed[-1](new_output_for_face_keypoint)[..., :2] | |
face_keypoint_xy = ( | |
inverse_sigmoid(new_reference_points_for_face_box[..., :2][:, None]) | |
+ delta_face_kp_xy).sigmoid() | |
num_queries,_,bs,_= face_keypoint_xy.shape | |
face_keypoint_wh_weight = \ | |
self.hw_face_kps.weight.unsqueeze(0).unsqueeze(-2).repeat(num_queries,1,bs,1).sigmoid() | |
face_keypoint_wh = face_keypoint_wh_weight * new_reference_points_for_face_box[..., 2:][:, None] | |
new_reference_points_for_face_keypoint = torch.cat((face_keypoint_xy, face_keypoint_wh), dim=-1) | |
new_reference_points_face = torch.cat( | |
(new_reference_points_for_face_box.unsqueeze(1), | |
new_reference_points_for_face_keypoint), dim=1) | |
new_face_output = torch.cat( | |
(new_output_for_face_box.unsqueeze(1), | |
new_output_for_face_keypoint), dim=1) | |
# new_reference_points = torch.cat( | |
# (dn_reference_points_body.unsqueeze(1), | |
# new_reference_points_body, | |
# new_reference_points_lhand, | |
# new_reference_points_rhand, | |
# new_reference_points_face), dim=1).flatten(0,1) | |
new_reference_points = torch.cat( | |
(new_reference_points_body, | |
new_reference_points_lhand, | |
new_reference_points_rhand, | |
new_reference_points_face), dim=1).flatten(0,1) | |
# new_reference_points = torch.cat((dn_reference_points_body,new_reference_points),dim=0) | |
new_reference_points = torch.cat( | |
(dn_reference_points_body, new_reference_points), dim=0 | |
) | |
output = torch.cat( | |
(new_body_output, | |
new_lhand_output, | |
new_rhand_output, | |
new_face_output), dim=1).flatten(0, 1) | |
output = torch.cat( | |
(dn_body_output, output), dim=0 | |
) | |
tgt_mask = tgt_mask3 | |
if layer_id >= self.num_hand_face_decoder_layers: | |
reference_before_sigmoid = inverse_sigmoid(reference_points) | |
# body box | |
reference_before_sigmoid_body_bbox_dn = \ | |
reference_before_sigmoid[:effect_num_dn] | |
reference_before_sigmoid_bbox_body_norm = \ | |
reference_before_sigmoid[effect_num_dn:][ | |
0::(self.num_body_points+2*self.num_hand_points+self.num_face_points+4)] | |
output_bbox_body_dn=output[:effect_num_dn] | |
output_bbox_body_norm = output[effect_num_dn:][ | |
0::(self.num_body_points+2*self.num_hand_points+self.num_face_points+4)] | |
delta_unsig_bbox_body_dn = self.bbox_embed[ | |
layer_id](output_bbox_body_dn) | |
delta_unsig_bbox_body_norm = self.bbox_embed[ | |
layer_id](output_bbox_body_norm) | |
outputs_unsig_body_bbox_dn = \ | |
delta_unsig_bbox_body_dn + reference_before_sigmoid_body_bbox_dn | |
outputs_unsig_body_bbox_norm = \ | |
delta_unsig_bbox_body_norm + reference_before_sigmoid_bbox_body_norm | |
new_reference_points_for_body_box_dn = outputs_unsig_body_bbox_dn.sigmoid() | |
new_reference_points_for_body_box_norm = outputs_unsig_body_bbox_norm.sigmoid() | |
# body kps | |
output_body_kpt=output[effect_num_dn:].index_select( | |
0,torch.tensor(self.body_kpt_index_2,device=output.device)) # select kp center content query | |
delta_xy_body_unsig = self.pose_embed[ | |
layer_id-self.num_box_decoder_layers](output_body_kpt) # offset of kp bbox center | |
outputs_body_kp_unsig = \ | |
reference_before_sigmoid[effect_num_dn:].index_select( | |
0, torch.tensor(self.body_kpt_index_2, device=output.device)).clone() # select kp position query | |
delta_hw_body_kp_unsig = self.pose_hw_embed[ | |
layer_id-self.num_box_decoder_layers](output_body_kpt) | |
outputs_body_kp_unsig[..., :2] += delta_xy_body_unsig[..., :2] | |
outputs_body_kp_unsig[..., 2:] += delta_hw_body_kp_unsig | |
new_reference_points_for_body_keypoint = outputs_body_kp_unsig.sigmoid() | |
bs=new_reference_points_for_body_box_norm.shape[1] | |
new_reference_points_for_body_keypoint = \ | |
new_reference_points_for_body_keypoint.view(-1,self.num_body_points,bs,4) | |
# lhand bbox | |
output_lhand_bbox_query = output[effect_num_dn:][ | |
(self.num_body_points + 1):: | |
(self.num_body_points + 2 * self.num_hand_points + self.num_face_points + 4)] | |
delta_xy_lhand_bbox_unsig = self.bbox_hand_embed[ | |
layer_id-self.num_box_decoder_layers](output_lhand_bbox_query) | |
outputs_lhand_bbox_unsig = \ | |
reference_before_sigmoid[effect_num_dn:][ | |
(self.num_body_points + 1):: | |
(self.num_body_points+2*self.num_hand_points+self.num_face_points+4)].clone() | |
delta_hw_lhand_bbox_unsig = self.bbox_hand_hw_embed[ | |
layer_id-self.num_box_decoder_layers](output_lhand_bbox_query) | |
outputs_lhand_bbox_unsig[..., :2] +=delta_xy_lhand_bbox_unsig[..., :2] | |
outputs_lhand_bbox_unsig[..., 2:] +=delta_hw_lhand_bbox_unsig | |
new_reference_points_for_lhand_box_norm = outputs_lhand_bbox_unsig.sigmoid() | |
# output_bbox_lhand_norm = output[effect_num_dn:][ | |
# (self.num_body_points + 1):: | |
# (self.num_body_points + 2 * self.num_hand_points + self.num_face_points + 4)] | |
# reference_before_sigmoid_bbox_lhand_norm = \ | |
# reference_before_sigmoid[effect_num_dn:][ | |
# (self.num_body_points + 1):: | |
# (self.num_body_points+2*self.num_hand_points+self.num_face_points+4)] | |
# delta_unsig_bbox_lhand_norm = self.bbox_hand_embed[ | |
# layer_id-self.num_box_decoder_layers](output_bbox_lhand_norm) | |
# outputs_unsig_lhand_bbox_norm = \ | |
# delta_unsig_bbox_lhand_norm + reference_before_sigmoid_bbox_lhand_norm | |
# new_reference_points_for_lhand_box_norm = outputs_unsig_lhand_bbox_norm.sigmoid() | |
# lhand kps | |
output_lhand_kpt_query=output[effect_num_dn:].index_select( | |
0,torch.tensor(self.lhand_kpt_index,device=output.device)) # select kp center content query | |
delta_xy_lhand_kpt_unsig = self.pose_hand_embed[ | |
layer_id-self.num_hand_face_decoder_layers](output_lhand_kpt_query) # offset of kp bbox center | |
outputs_lhand_kp_unsig = \ | |
reference_before_sigmoid[effect_num_dn:].index_select( | |
0, torch.tensor(self.lhand_kpt_index, device=output.device)).clone() # select kp position query | |
delta_hw_lhand_kp_unsig = self.pose_hand_hw_embed[ | |
layer_id-self.num_hand_face_decoder_layers](output_lhand_kpt_query) | |
outputs_lhand_kp_unsig[..., :2] += delta_xy_lhand_kpt_unsig[..., :2] | |
outputs_lhand_kp_unsig[..., 2:] += delta_hw_lhand_kp_unsig | |
new_reference_points_for_lhand_keypoint = outputs_lhand_kp_unsig.sigmoid() | |
bs=new_reference_points_for_lhand_box_norm.shape[1] | |
new_reference_points_for_lhand_keypoint = \ | |
new_reference_points_for_lhand_keypoint.view(-1,self.num_hand_points,bs,4) | |
# rhand bbox | |
output_rhand_bbox_query = output[effect_num_dn:][ | |
(self.num_body_points + self.num_hand_points + 2):: | |
(self.num_body_points+2*self.num_hand_points+self.num_face_points+4)] | |
delta_xy_rhand_bbox_unsig = self.bbox_hand_embed[ | |
layer_id-self.num_box_decoder_layers](output_rhand_bbox_query) | |
outputs_rhand_bbox_unsig = \ | |
reference_before_sigmoid[effect_num_dn:][ | |
(self.num_body_points + self.num_hand_points + 2):: | |
(self.num_body_points+2*self.num_hand_points+self.num_face_points+4)].clone() | |
delta_hw_rhand_bbox_unsig = self.bbox_hand_hw_embed[ | |
layer_id-self.num_box_decoder_layers](output_rhand_bbox_query) | |
outputs_rhand_bbox_unsig[..., :2] +=delta_xy_rhand_bbox_unsig[..., :2] | |
outputs_rhand_bbox_unsig[..., 2:] +=delta_hw_rhand_bbox_unsig | |
new_reference_points_for_rhand_box_norm = outputs_rhand_bbox_unsig.sigmoid() | |
# output_bbox_rhand_norm = output[effect_num_dn:][ | |
# (self.num_body_points + self.num_hand_points + 2):: | |
# (self.num_body_points+2*self.num_hand_points+self.num_face_points+4)] | |
# reference_before_sigmoid_bbox_rhand_norm = \ | |
# reference_before_sigmoid[effect_num_dn:][ | |
# (self.num_body_points + self.num_hand_points + 2):: | |
# (self.num_body_points+2*self.num_hand_points+self.num_face_points+4)] | |
# delta_unsig_bbox_rhand_norm = self.bbox_hand_embed[ | |
# layer_id-self.num_box_decoder_layers](output_bbox_rhand_norm) | |
# outputs_unsig_rhand_bbox_norm = \ | |
# delta_unsig_bbox_rhand_norm + reference_before_sigmoid_bbox_rhand_norm | |
# new_reference_points_for_rhand_box_norm = outputs_unsig_rhand_bbox_norm.sigmoid() | |
# rhand kps | |
output_rhand_kpt_query=output[effect_num_dn:].index_select( | |
0,torch.tensor(self.rhand_kpt_index,device=output.device)) # select kp center content query | |
delta_xy_rhand_kpt_unsig = self.pose_hand_embed[ | |
layer_id-self.num_hand_face_decoder_layers](output_rhand_kpt_query) # offset of kp bbox center | |
outputs_rhand_kp_unsig = \ | |
reference_before_sigmoid[effect_num_dn:].index_select( | |
0, torch.tensor(self.rhand_kpt_index, device=output.device)).clone() # select kp position query | |
delta_hw_rhand_kp_unsig = self.pose_hand_hw_embed[ | |
layer_id-self.num_hand_face_decoder_layers](output_rhand_kpt_query) | |
outputs_rhand_kp_unsig[..., :2] += delta_xy_rhand_kpt_unsig[..., :2] | |
outputs_rhand_kp_unsig[..., 2:] += delta_hw_rhand_kp_unsig | |
new_reference_points_for_rhand_keypoint = outputs_rhand_kp_unsig.sigmoid() | |
bs=new_reference_points_for_rhand_box_norm.shape[1] | |
new_reference_points_for_rhand_keypoint = \ | |
new_reference_points_for_rhand_keypoint.view(-1,self.num_hand_points,bs,4) | |
# face bbox | |
output_face_bbox_query = output[effect_num_dn:][ | |
(self.num_body_points + 2 * self.num_hand_points + 3):: | |
(self.num_body_points+2*self.num_hand_points+self.num_face_points+4)] | |
delta_xy_face_bbox_unsig = self.bbox_face_embed[ | |
layer_id-self.num_box_decoder_layers](output_face_bbox_query) | |
outputs_face_bbox_unsig = \ | |
reference_before_sigmoid[effect_num_dn:][ | |
(self.num_body_points + 2 * self.num_hand_points + 3):: | |
(self.num_body_points+2*self.num_hand_points+self.num_face_points+4)].clone() | |
delta_hw_face_bbox_unsig = self.bbox_face_hw_embed[ | |
layer_id-self.num_box_decoder_layers](output_face_bbox_query) | |
outputs_face_bbox_unsig[..., :2] +=delta_xy_face_bbox_unsig[..., :2] | |
outputs_face_bbox_unsig[..., 2:] +=delta_hw_face_bbox_unsig | |
new_reference_points_for_face_box_norm = outputs_face_bbox_unsig.sigmoid() | |
# output_bbox_face_norm = output[effect_num_dn:][ | |
# (self.num_body_points + 2 * self.num_hand_points + 3):: | |
# (self.num_body_points+2*self.num_hand_points+self.num_face_points+4)] | |
# reference_before_sigmoid_bbox_face_norm = \ | |
# reference_before_sigmoid[effect_num_dn:][ | |
# (self.num_body_points + 2 * self.num_hand_points + 3):: | |
# (self.num_body_points+2*self.num_hand_points+self.num_face_points+4)] | |
# delta_unsig_bbox_face_norm = self.bbox_face_embed[ | |
# layer_id-self.num_box_decoder_layers](output_bbox_face_norm) | |
# outputs_unsig_face_bbox_norm = \ | |
# delta_unsig_bbox_face_norm + reference_before_sigmoid_bbox_face_norm | |
# new_reference_points_for_face_box_norm = outputs_unsig_face_bbox_norm.sigmoid() | |
# face kps | |
output_face_kpt_query=output[effect_num_dn:].index_select( | |
0,torch.tensor(self.face_kpt_index,device=output.device)) # select kp center content query | |
delta_xy_face_kpt_unsig = self.pose_face_embed[ | |
layer_id-self.num_hand_face_decoder_layers](output_face_kpt_query) # offset of kp bbox center | |
outputs_face_kp_unsig = \ | |
reference_before_sigmoid[effect_num_dn:].index_select( | |
0, torch.tensor(self.face_kpt_index, device=output.device)).clone() # select kp position query | |
delta_hw_face_kp_unsig = self.pose_face_hw_embed[ | |
layer_id-self.num_hand_face_decoder_layers](output_face_kpt_query) | |
outputs_face_kp_unsig[..., :2] += delta_xy_face_kpt_unsig[..., :2] | |
outputs_face_kp_unsig[..., 2:] += delta_hw_face_kp_unsig | |
new_reference_points_for_face_keypoint = outputs_face_kp_unsig.sigmoid() | |
bs=new_reference_points_for_face_box_norm.shape[1] | |
new_reference_points_for_face_keypoint = \ | |
new_reference_points_for_face_keypoint.view(-1,self.num_face_points,bs,4) | |
new_reference_points_norm = torch.cat( | |
(new_reference_points_for_body_box_norm.unsqueeze(1), | |
new_reference_points_for_body_keypoint, | |
new_reference_points_for_lhand_box_norm.unsqueeze(1), | |
new_reference_points_for_lhand_keypoint, | |
new_reference_points_for_rhand_box_norm.unsqueeze(1), | |
new_reference_points_for_rhand_keypoint, | |
new_reference_points_for_face_box_norm.unsqueeze(1), | |
new_reference_points_for_face_keypoint, | |
), dim=1).flatten(0,1) | |
new_reference_points = torch.cat( | |
(new_reference_points_for_body_box_dn, new_reference_points_norm), dim=0) | |
if self.rm_detach and 'dec' in self.rm_detach: | |
reference_points = new_reference_points | |
else: | |
reference_points = new_reference_points.detach() | |
ref_points.append(new_reference_points) | |
return [[itm_out.transpose(0, 1) for itm_out in intermediate], | |
[itm_refpoint.transpose(0, 1) for itm_refpoint in ref_points]] | |
def _get_clones(module, N, layer_share=False): | |
if layer_share: | |
return nn.ModuleList([module for i in range(N)]) | |
else: | |
return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) | |
def build_transformer(args): | |
if args.modelname == 'aios_smplx_box': | |
return Transformer_Box( | |
d_model=args.hidden_dim, | |
dropout=args.dropout, | |
nhead=args.nheads, | |
num_queries=args.num_queries, | |
dim_feedforward=args.dim_feedforward, | |
num_encoder_layers=args.enc_layers, | |
num_decoder_layers=args.dec_layers, | |
normalize_before=args.pre_norm, | |
return_intermediate_dec=True, | |
query_dim=args.query_dim, | |
activation=args.transformer_activation, | |
num_patterns=args.num_patterns, | |
modulate_hw_attn=True, | |
deformable_encoder=True, | |
deformable_decoder=True, | |
num_feature_levels=args.num_feature_levels, | |
enc_n_points=args.enc_n_points, | |
dec_n_points=args.dec_n_points, | |
learnable_tgt_init=True, | |
random_refpoints_xy=args.random_refpoints_xy, | |
two_stage_type=args.two_stage_type, | |
two_stage_learn_wh=args.two_stage_learn_wh, | |
two_stage_keep_all_tokens=args.two_stage_keep_all_tokens, | |
dec_layer_number=args.dec_layer_number, | |
rm_self_attn_layers=args.rm_self_attn_layers, | |
rm_detach=args.rm_detach, | |
decoder_sa_type=args.decoder_sa_type, | |
module_seq=args.decoder_module_seq, | |
embed_init_tgt=args.embed_init_tgt, | |
num_body_points=args.num_body_points, | |
num_hand_points=args.num_hand_points, | |
num_face_points=args.num_face_points, | |
num_box_decoder_layers=args.num_box_decoder_layers, | |
num_hand_face_decoder_layers=args.num_hand_face_decoder_layers, | |
num_group=args.num_group) | |
elif args.modelname == 'aios_smplx': | |
return Transformer( | |
d_model=args.hidden_dim, | |
dropout=args.dropout, | |
nhead=args.nheads, | |
num_queries=args.num_queries, | |
dim_feedforward=args.dim_feedforward, | |
num_encoder_layers=args.enc_layers, | |
num_decoder_layers=args.dec_layers, | |
normalize_before=args.pre_norm, | |
return_intermediate_dec=True, | |
query_dim=args.query_dim, | |
activation=args.transformer_activation, | |
num_patterns=args.num_patterns, | |
modulate_hw_attn=True, | |
deformable_encoder=True, | |
deformable_decoder=True, | |
num_feature_levels=args.num_feature_levels, | |
enc_n_points=args.enc_n_points, | |
dec_n_points=args.dec_n_points, | |
learnable_tgt_init=True, | |
random_refpoints_xy=args.random_refpoints_xy, | |
two_stage_type=args.two_stage_type, | |
two_stage_learn_wh=args.two_stage_learn_wh, | |
two_stage_keep_all_tokens=args.two_stage_keep_all_tokens, | |
dec_layer_number=args.dec_layer_number, | |
rm_self_attn_layers=args.rm_self_attn_layers, | |
rm_detach=args.rm_detach, | |
decoder_sa_type=args.decoder_sa_type, | |
module_seq=args.decoder_module_seq, | |
embed_init_tgt=args.embed_init_tgt, | |
num_body_points=args.num_body_points, | |
num_hand_points=args.num_hand_points, | |
num_face_points=args.num_face_points, | |
num_box_decoder_layers=args.num_box_decoder_layers, | |
num_hand_face_decoder_layers=args.num_hand_face_decoder_layers, | |
num_group=args.num_group) | |
else: | |
raise ValueError('Wrong Transformer type') | |
class TransformerDecoder_Box(nn.Module): | |
def __init__( | |
self, | |
decoder_layer, | |
num_layers, | |
norm=None, | |
return_intermediate=False, | |
d_model=256, | |
query_dim=4, | |
modulate_hw_attn=False, | |
num_feature_levels=1, | |
deformable_decoder=False, | |
dec_layer_number=None, # number of queries each layer in decoder | |
dec_layer_share=False, | |
dec_layer_dropout_prob=None, | |
num_box_decoder_layers=2, | |
num_hand_face_decoder_layers=4, | |
num_body_points=0, | |
num_hand_points=0, | |
num_face_points=0, | |
num_dn=100, | |
num_group=100): | |
super().__init__() | |
# pdb.set_trace() | |
if num_layers > 0: | |
self.layers = _get_clones(decoder_layer, | |
num_layers, | |
layer_share=dec_layer_share) | |
else: | |
self.layers = [] | |
self.num_layers = num_layers | |
self.norm = norm | |
self.return_intermediate = return_intermediate # True | |
assert return_intermediate, 'support return_intermediate only' | |
self.query_dim = query_dim # 4 | |
assert query_dim in [ | |
2, 4 | |
], 'query_dim should be 2/4 but {}'.format(query_dim) | |
self.num_feature_levels = num_feature_levels # 4 | |
self.ref_point_head = MLP(query_dim // 2 * d_model, d_model, d_model, | |
2) # 4//2 * 256, 256, 256, 2 | |
if not deformable_decoder: | |
self.query_pos_sine_scale = MLP(d_model, d_model, d_model, 2) | |
else: | |
self.query_pos_sine_scale = None | |
self.num_body_points = 0 | |
self.num_hand_points = 0 | |
self.num_face_points = 0 | |
self.query_scale = None | |
# aios kp | |
self.bbox_embed = None | |
self.class_embed = None | |
self.bbox_hand_embed = None | |
self.bbox_hand_hw_embed = None | |
# smplx face kp | |
self.bbox_face_embed = None | |
self.bbox_face_hw_embed = None | |
self.num_box_decoder_layers = num_box_decoder_layers # 2 | |
self.num_hand_face_decoder_layers = num_hand_face_decoder_layers | |
self.d_model = d_model | |
self.modulate_hw_attn = modulate_hw_attn | |
self.deformable_decoder = deformable_decoder | |
if not deformable_decoder and modulate_hw_attn: | |
self.ref_anchor_head = MLP(d_model, d_model, 2, 2) | |
else: | |
self.ref_anchor_head = None | |
self.box_pred_damping = None | |
self.dec_layer_number = dec_layer_number | |
if dec_layer_number is not None: | |
assert isinstance(dec_layer_number, list) | |
assert len(dec_layer_number) == num_layers | |
# assert dec_layer_number[0] == | |
self.dec_layer_dropout_prob = dec_layer_dropout_prob | |
if dec_layer_dropout_prob is not None: | |
raise NotImplementedError | |
assert isinstance(dec_layer_dropout_prob, list) | |
assert len(dec_layer_dropout_prob) == num_layers | |
for i in dec_layer_dropout_prob: | |
assert 0.0 <= i <= 1.0 | |
self.num_group = num_group | |
self.rm_detach = None | |
self.num_dn = num_dn | |
# self.hw_body_kps = nn.Embedding(self.num_body_points, 2) | |
# self.hw = nn.Embedding(self.num_body_points, 2) | |
# self.keypoint_embed = nn.Embedding(self.num_body_points, d_model) | |
# self.body_kpt_index_1 = [ | |
# x for x in range(self.num_group*(self.num_body_points+4)) if x%(self.num_body_points+4) not in [0, (1 + self.num_body_points), (2 + self.num_body_points), (3 + self.num_body_points)]] | |
# self.whole_body_points = \ | |
# self.num_body_points + self.num_hand_points *2 + self.num_face_points | |
# self.body_kpt_index_2 = [ | |
# x for x in range(self.num_group * (self.whole_body_points + 4)) | |
# if (x % (self.whole_body_points + 4) in range(1,self.num_body_points+1)) | |
# ] | |
# [0-99]: dn bbox; | |
# [0,1]: body box; | |
# [1, 18]: body kps; | |
# [18, 19]: lhand box | |
# [19, 29]: lhand kps | |
# [29, 30]: rhand box | |
# [30, 40]: rhand kps | |
# [40, 41]: face bbox | |
# [41, 51]: face kps | |
# self.lhand_kpt_index = [ | |
# x for x in range(self.num_group * (self.whole_body_points + 4)) | |
# if (x % (self.whole_body_points + 4) in range( | |
# self.num_body_points+2, self.num_body_points+self.num_hand_points+2))] | |
# self.rhand_kpt_index = [ | |
# x for x in range(self.num_group * (self.whole_body_points + 4)) | |
# if (x % (self.whole_body_points + 4) in range( | |
# self.num_body_points+self.num_hand_points+3, self.num_body_points+self.num_hand_points*2+3)) | |
# ] | |
# self.face_kpt_index = [ | |
# x for x in range(self.num_group * (self.whole_body_points + 4)) | |
# if (x % (self.whole_body_points + 4) in range( | |
# self.num_body_points+self.num_hand_points*2+4, self.num_body_points+self.num_hand_points*2+self.num_face_points+4)) | |
# ] | |
self.lhand_box_embed = nn.Embedding(1, d_model) | |
self.rhand_box_embed = nn.Embedding(1, d_model) | |
self.face_box_embed = nn.Embedding(1, d_model) | |
self.hw_lhand_bbox = nn.Embedding(1, 2) | |
self.hw_rhand_bbox = nn.Embedding(1, 2) | |
self.hw_face_bbox = nn.Embedding(1, 2) | |
def forward( | |
self, | |
tgt, | |
memory, | |
tgt_mask: Optional[Tensor] = None, | |
tgt_mask2: Optional[Tensor] = None, | |
tgt_mask3: Optional[Tensor] = None, | |
memory_mask: Optional[Tensor] = None, | |
tgt_key_padding_mask: Optional[Tensor] = None, | |
memory_key_padding_mask: Optional[Tensor] = None, | |
pos: Optional[Tensor] = None, | |
refpoints_unsigmoid: Optional[Tensor] = None, # num_queries, bs, 2 | |
# for memory | |
level_start_index: Optional[Tensor] = None, # num_levels | |
spatial_shapes: Optional[Tensor] = None, # bs, num_levels, 2 | |
valid_ratios: Optional[Tensor] = None, | |
): | |
output = tgt | |
intermediate = [] | |
reference_points = refpoints_unsigmoid.sigmoid() | |
ref_points = [reference_points] | |
effect_num_dn = self.num_dn if self.training else 0 | |
inter_select_number = self.num_group | |
for layer_id, layer in enumerate(self.layers): | |
if self.deformable_decoder: | |
if reference_points.shape[-1] == 4: | |
reference_points_input = reference_points[:, :, None] \ | |
* torch.cat([valid_ratios, valid_ratios], -1)[None, :] # nq, bs, nlevel, 4 | |
else: | |
assert reference_points.shape[-1] == 2 | |
reference_points_input = reference_points[:, :, | |
None] * valid_ratios[ | |
None, :] | |
query_sine_embed = gen_sineembed_for_position( | |
reference_points_input[:, :, 0, :] | |
) # convert the position query from bbox to sine/cosin embend | |
else: | |
query_sine_embed = gen_sineembed_for_position( | |
reference_points) # nq, bs, 256*2 | |
reference_points_input = None | |
raw_query_pos = self.ref_point_head( | |
query_sine_embed) # nq, bs, 256 | |
pos_scale = self.query_scale( | |
output) if self.query_scale is not None else 1 # ? | |
query_pos = pos_scale * raw_query_pos | |
if not self.deformable_decoder: | |
query_sine_embed = query_sine_embed[ | |
..., :self.d_model] * self.query_pos_sine_scale(output) | |
# modulated HW attentions | |
if not self.deformable_decoder and self.modulate_hw_attn: | |
refHW_cond = self.ref_anchor_head( | |
output).sigmoid() # nq, bs, 2 | |
query_sine_embed[..., self.d_model // 2:] *= ( | |
refHW_cond[..., 0] / | |
reference_points[..., 2]).unsqueeze(-1) | |
query_sine_embed[..., :self.d_model // | |
2] *= (refHW_cond[..., 1] / | |
reference_points[..., 3]).unsqueeze(-1) | |
dropflag = False | |
if self.dec_layer_dropout_prob is not None: | |
prob = random.random() | |
if prob < self.dec_layer_dropout_prob[layer_id]: | |
dropflag = True | |
if not dropflag: | |
output = layer( | |
tgt=output, | |
tgt_query_pos=query_pos, | |
tgt_query_sine_embed=query_sine_embed, | |
tgt_key_padding_mask=tgt_key_padding_mask, | |
tgt_reference_points=reference_points_input, | |
memory=memory, # encoder output, also known as content query of encoder | |
memory_key_padding_mask=memory_key_padding_mask, | |
memory_level_start_index=level_start_index, | |
memory_spatial_shapes=spatial_shapes, | |
memory_pos=pos, # position query of enconder | |
self_attn_mask=tgt_mask, | |
cross_attn_mask=memory_mask) | |
intermediate.append(self.norm(output)) | |
# human update | |
if layer_id < self.num_box_decoder_layers: | |
# reference_points: [100*(17+20*2+72) 4, 4] | |
reference_before_sigmoid = inverse_sigmoid(reference_points) | |
delta_unsig = self.bbox_embed[layer_id]( | |
output) # delta_x, delta_y, delta_w, delta_h | |
outputs_unsig = delta_unsig + reference_before_sigmoid | |
new_reference_points = outputs_unsig.sigmoid( | |
) # update the positional query by adding the offset delta_unsig | |
# kp query expansion | |
if layer_id == self.num_box_decoder_layers - 1: | |
dn_output = output[:effect_num_dn] # [100,-,256] | |
dn_new_reference_points = new_reference_points[:effect_num_dn] # [100, -, 4] | |
class_unselected = self.class_embed[layer_id](output)[ | |
effect_num_dn:] # [900, -, 2] | |
topk_proposals = torch.topk(class_unselected.max(-1)[0], | |
inter_select_number, | |
dim=0)[1] # 100 | |
# selected position: select 100 query | |
new_reference_points_for_body_box = torch.gather( | |
new_reference_points[effect_num_dn:], 0, | |
topk_proposals.unsqueeze(-1).repeat( | |
1, 1, 4)) # selected position query | |
# selected output features | |
new_output_for_body_box = torch.gather( | |
output[effect_num_dn:], 0, | |
topk_proposals.unsqueeze(-1).repeat( | |
1, 1, self.d_model)) # selected content query | |
bs = new_output_for_body_box.shape[1] | |
# for lhand bbox | |
new_output_for_lhand_box = new_output_for_body_box[:, None, :, :] \ | |
+ self.lhand_box_embed.weight[None, :, None, :] | |
delta_lhand_box_xy = self.bbox_hand_embed[-1](new_output_for_lhand_box)[..., :2] | |
lhand_bbox_xy = (inverse_sigmoid( | |
new_reference_points_for_body_box[..., :2][:, None]) + | |
delta_lhand_box_xy).sigmoid() # [100, 14, -, 2] | |
num_queries, _, bs, _ = lhand_bbox_xy.shape | |
lhand_bbox_wh_weight = self.hw_lhand_bbox.weight.unsqueeze(0).unsqueeze( | |
-2).repeat(num_queries, 1, bs, 1).sigmoid() | |
lhand_bbox_wh = lhand_bbox_wh_weight * new_reference_points_for_body_box[ | |
..., 2:][:, None] | |
new_reference_points_for_lhand_bbox = torch.cat( | |
(lhand_bbox_xy, lhand_bbox_wh), dim=-1) | |
# for rhand bbox | |
new_output_for_rhand_box = new_output_for_body_box[:, None, :, :] \ | |
+ self.rhand_box_embed.weight[None, :, None, :] | |
delta_rhand_box_xy = self.bbox_hand_embed[-1](new_output_for_rhand_box)[..., :2] | |
rhand_bbox_xy = (inverse_sigmoid( | |
new_reference_points_for_body_box[..., :2][:, None]) + | |
delta_rhand_box_xy).sigmoid() # [100, 14, -, 2] | |
num_queries, _, bs, _ = rhand_bbox_xy.shape | |
rhand_bbox_wh_weight = self.hw_rhand_bbox.weight.unsqueeze(0).unsqueeze( | |
-2).repeat(num_queries, 1, bs, 1).sigmoid() | |
rhand_bbox_wh = rhand_bbox_wh_weight * new_reference_points_for_body_box[ | |
..., 2:][:, None] | |
new_reference_points_for_rhand_bbox = torch.cat( | |
(rhand_bbox_xy, rhand_bbox_wh), dim=-1) | |
# for face bbox | |
new_output_for_face_box = new_output_for_body_box[:, None, :, :] \ | |
+ self.face_box_embed.weight[None, :, None, :] | |
delta_face_box_xy = self.bbox_face_embed[-1](new_output_for_face_box)[..., :2] | |
face_bbox_xy = (inverse_sigmoid( | |
new_reference_points_for_body_box[..., :2][:, None]) + | |
delta_face_box_xy).sigmoid() # [100, 14, -, 2] | |
num_queries, _, bs, _ = face_bbox_xy.shape | |
face_bbox_wh_weight = self.hw_face_bbox.weight.unsqueeze(0).unsqueeze( | |
-2).repeat(num_queries, 1, bs, 1).sigmoid() | |
face_bbox_wh = face_bbox_wh_weight * new_reference_points_for_body_box[ | |
..., 2:][:, None] | |
new_reference_points_for_face_box = torch.cat( | |
(face_bbox_xy, face_bbox_wh), dim=-1) | |
output = torch.cat( | |
(new_output_for_body_box.unsqueeze(1), | |
new_output_for_lhand_box, | |
new_output_for_rhand_box, | |
new_output_for_face_box), | |
dim=1).flatten(0, 1) | |
new_reference_points = torch.cat( | |
(new_reference_points_for_body_box.unsqueeze(1), | |
new_reference_points_for_lhand_bbox, | |
new_reference_points_for_rhand_bbox, | |
new_reference_points_for_face_box), dim=1).flatten(0,1) | |
new_reference_points = torch.cat((dn_new_reference_points, new_reference_points),dim=0) | |
output = torch.cat((dn_output, output), dim=0) | |
tgt_mask = tgt_mask2 | |
# human-to-keypoints, human2face, human2hand update # 2 | |
if layer_id >= self.num_box_decoder_layers: | |
reference_before_sigmoid = inverse_sigmoid(reference_points) | |
reference_before_sigmoid_body_bbox_dn = reference_before_sigmoid[:effect_num_dn] | |
reference_before_sigmoid_bbox_body_norm = \ | |
reference_before_sigmoid[effect_num_dn:][0::(self.num_body_points+4)] | |
output_bbox_body_dn=output[:effect_num_dn] | |
output_bbox_body_norm = output[effect_num_dn:][ | |
0::(self.num_body_points+4)] | |
delta_unsig_bbox_body_dn = self.bbox_embed[ | |
layer_id](output_bbox_body_dn) | |
delta_unsig_bbox_body_norm = self.bbox_embed[ | |
layer_id](output_bbox_body_norm) | |
outputs_unsig_body_bbox_dn = delta_unsig_bbox_body_dn + reference_before_sigmoid_body_bbox_dn | |
outputs_unsig_body_bbox_norm = delta_unsig_bbox_body_norm + reference_before_sigmoid_bbox_body_norm | |
new_reference_points_for_body_box_dn = outputs_unsig_body_bbox_dn.sigmoid() | |
new_reference_points_for_body_box_norm = outputs_unsig_body_bbox_norm.sigmoid() | |
# lhand box | |
output_lhand_bbox_query = output[effect_num_dn:][ | |
(self.num_body_points + 1)::(self.num_body_points+4)] | |
delta_xy_lhand_bbox_unsig = self.bbox_hand_embed[ | |
layer_id-self.num_box_decoder_layers](output_lhand_bbox_query) | |
outputs_lhand_bbox_unsig = \ | |
reference_before_sigmoid[effect_num_dn:][ | |
(self.num_body_points + 1)::(self.num_body_points+4)].clone() | |
delta_hw_lhand_bbox_unsig = self.bbox_hand_hw_embed[ | |
layer_id-self.num_box_decoder_layers](output_lhand_bbox_query) | |
outputs_lhand_bbox_unsig[..., :2] +=delta_xy_lhand_bbox_unsig[..., :2] | |
outputs_lhand_bbox_unsig[..., 2:] +=delta_hw_lhand_bbox_unsig | |
new_reference_points_for_lhand_box_norm = outputs_lhand_bbox_unsig.sigmoid() | |
# rhand box | |
output_rhand_bbox_query = output[effect_num_dn:][ | |
(self.num_body_points + 2)::(self.num_body_points+4)] | |
delta_xy_rhand_bbox_unsig = self.bbox_hand_embed[ | |
layer_id-self.num_box_decoder_layers](output_rhand_bbox_query) | |
outputs_rhand_bbox_unsig = \ | |
reference_before_sigmoid[effect_num_dn:][ | |
(self.num_body_points + 2)::(self.num_body_points+4)].clone() | |
delta_hw_rhand_bbox_unsig = self.bbox_hand_hw_embed[ | |
layer_id-self.num_box_decoder_layers](output_rhand_bbox_query) | |
outputs_rhand_bbox_unsig[..., :2] +=delta_xy_rhand_bbox_unsig[..., :2] | |
outputs_rhand_bbox_unsig[..., 2:] +=delta_hw_rhand_bbox_unsig | |
new_reference_points_for_rhand_box_norm = outputs_rhand_bbox_unsig.sigmoid() | |
# face box | |
output_face_bbox_query = output[effect_num_dn:][ | |
(self.num_body_points + 3)::(self.num_body_points+4)] | |
delta_xy_face_bbox_unsig = self.bbox_face_embed[ | |
layer_id-self.num_box_decoder_layers](output_face_bbox_query) | |
outputs_face_bbox_unsig = \ | |
reference_before_sigmoid[effect_num_dn:][ | |
(self.num_body_points + 3)::(self.num_body_points+4)].clone() | |
delta_hw_face_bbox_unsig = self.bbox_face_hw_embed[ | |
layer_id-self.num_box_decoder_layers](output_face_bbox_query) | |
outputs_face_bbox_unsig[..., :2] +=delta_xy_face_bbox_unsig[..., :2] | |
outputs_face_bbox_unsig[..., 2:] +=delta_hw_face_bbox_unsig | |
new_reference_points_for_face_box_norm = outputs_face_bbox_unsig.sigmoid() | |
new_reference_points_norm = torch.cat( | |
(new_reference_points_for_body_box_norm.unsqueeze(1), | |
new_reference_points_for_lhand_box_norm.unsqueeze(1), | |
new_reference_points_for_rhand_box_norm.unsqueeze(1), | |
new_reference_points_for_face_box_norm.unsqueeze(1)), dim=1).flatten(0,1) | |
new_reference_points = torch.cat(( | |
new_reference_points_for_body_box_dn, | |
new_reference_points_norm), dim=0) | |
if self.rm_detach and 'dec' in self.rm_detach: | |
reference_points = new_reference_points | |
else: | |
reference_points = new_reference_points.detach() | |
ref_points.append(new_reference_points) | |
return [[itm_out.transpose(0, 1) for itm_out in intermediate], | |
[itm_refpoint.transpose(0, 1) for itm_refpoint in ref_points]] | |
class Transformer_Box(nn.Module): | |
def __init__( | |
self, | |
d_model=256, | |
nhead=8, | |
num_queries=300, | |
num_encoder_layers=6, | |
num_decoder_layers=6, | |
dim_feedforward=2048, | |
dropout=0.0, | |
activation='relu', | |
normalize_before=False, | |
return_intermediate_dec=False, | |
query_dim=4, | |
num_patterns=0, | |
modulate_hw_attn=False, | |
# for deformable encoder | |
deformable_encoder=False, | |
deformable_decoder=False, | |
num_feature_levels=1, | |
enc_n_points=4, | |
dec_n_points=4, | |
# init query | |
learnable_tgt_init=False, | |
random_refpoints_xy=False, | |
# two stage | |
two_stage_type='no', | |
two_stage_learn_wh=False, | |
two_stage_keep_all_tokens=False, | |
# evo of #anchors | |
dec_layer_number=None, | |
rm_self_attn_layers=None, | |
# for detach | |
rm_detach=None, | |
decoder_sa_type='sa', | |
module_seq=['sa', 'ca', 'ffn'], | |
# for pose | |
embed_init_tgt=False, | |
num_body_points=0, | |
num_hand_points=0, | |
num_face_points=0, | |
num_box_decoder_layers=2, | |
num_hand_face_decoder_layers=4, | |
num_group=100): | |
super().__init__() | |
# pdb.set_trace() | |
self.num_feature_levels = num_feature_levels # 4 | |
self.num_encoder_layers = num_encoder_layers # 6 | |
self.num_decoder_layers = num_decoder_layers # 6 | |
self.deformable_encoder = deformable_encoder | |
self.deformable_decoder = deformable_decoder | |
self.two_stage_keep_all_tokens = two_stage_keep_all_tokens # False | |
self.num_queries = num_queries # 900 | |
self.random_refpoints_xy = random_refpoints_xy # False | |
assert query_dim == 4 | |
if num_feature_levels > 1: | |
assert deformable_encoder, 'only support deformable_encoder for num_feature_levels > 1' | |
self.decoder_sa_type = decoder_sa_type # sa | |
assert decoder_sa_type in ['sa', 'ca_label', 'ca_content'] | |
# choose encoder layer type | |
if deformable_encoder: | |
encoder_layer = DeformableTransformerEncoderLayer( | |
d_model, dim_feedforward, dropout, activation, | |
num_feature_levels, nhead, enc_n_points) | |
else: | |
raise NotImplementedError | |
encoder_layer = TransformerEncoderLayer(d_model, nhead, | |
dim_feedforward, dropout, | |
activation, | |
normalize_before) | |
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None | |
self.encoder = TransformerEncoder( | |
encoder_layer, | |
num_encoder_layers, | |
encoder_norm, | |
d_model=d_model, | |
num_queries=num_queries, | |
deformable_encoder=deformable_encoder, | |
two_stage_type=two_stage_type) | |
# choose decoder layer type | |
if deformable_decoder: | |
decoder_layer = DeformableTransformerDecoderLayer( | |
d_model, | |
dim_feedforward, | |
dropout, | |
activation, | |
num_feature_levels, | |
nhead, | |
dec_n_points, | |
decoder_sa_type=decoder_sa_type, | |
module_seq=module_seq) | |
else: | |
raise NotImplementedError | |
decoder_layer = TransformerDecoderLayer( | |
d_model, | |
nhead, | |
dim_feedforward, | |
dropout, | |
activation, | |
normalize_before, | |
num_feature_levels=num_feature_levels) | |
decoder_norm = nn.LayerNorm(d_model) | |
self.decoder = TransformerDecoder_Box( | |
decoder_layer, | |
num_decoder_layers, | |
decoder_norm, | |
return_intermediate=return_intermediate_dec, | |
d_model=d_model, | |
query_dim=query_dim, | |
modulate_hw_attn=modulate_hw_attn, | |
num_feature_levels=num_feature_levels, | |
deformable_decoder=deformable_decoder, | |
dec_layer_number=dec_layer_number, | |
num_body_points=num_body_points, | |
num_hand_points=num_hand_points, | |
num_face_points=num_face_points, | |
num_box_decoder_layers=num_box_decoder_layers, | |
num_hand_face_decoder_layers=num_hand_face_decoder_layers, | |
num_group=num_group, | |
num_dn=num_group, | |
) | |
self.d_model = d_model | |
self.nhead = nhead # 8 | |
self.dec_layers = num_decoder_layers # 6 | |
self.num_queries = num_queries # useful for single stage model only | |
self.num_patterns = num_patterns # 0 | |
if not isinstance(num_patterns, int): | |
Warning('num_patterns should be int but {}'.format( | |
type(num_patterns))) | |
self.num_patterns = 0 | |
if self.num_patterns > 0: | |
assert two_stage_type == 'no' | |
self.patterns = nn.Embedding(self.num_patterns, d_model) | |
if num_feature_levels > 1: | |
if self.num_encoder_layers > 0: | |
self.level_embed = nn.Parameter( | |
torch.Tensor(num_feature_levels, d_model)) | |
else: | |
self.level_embed = None | |
self.learnable_tgt_init = learnable_tgt_init # true | |
assert learnable_tgt_init, 'why not learnable_tgt_init' | |
self.embed_init_tgt = embed_init_tgt # false | |
if (two_stage_type != 'no' and embed_init_tgt) or (two_stage_type | |
== 'no'): | |
self.tgt_embed = nn.Embedding(self.num_queries, d_model) | |
nn.init.normal_(self.tgt_embed.weight.data) | |
else: | |
self.tgt_embed = None | |
# for two stage | |
self.two_stage_type = two_stage_type | |
self.two_stage_learn_wh = two_stage_learn_wh | |
assert two_stage_type in [ | |
'no', 'standard', 'early', 'combine', 'enceachlayer', 'enclayer1' | |
], 'unknown param {} of two_stage_type'.format(two_stage_type) | |
if two_stage_type in [ | |
'standard', 'combine', 'enceachlayer', 'enclayer1' | |
]: | |
# anchor selection at the output of encoder | |
self.enc_output = nn.Linear(d_model, d_model) | |
self.enc_output_norm = nn.LayerNorm(d_model) | |
if two_stage_learn_wh: | |
# import pdb; pdb.set_trace() | |
self.two_stage_wh_embedding = nn.Embedding(1, 2) | |
else: | |
self.two_stage_wh_embedding = None | |
if two_stage_type in ['early', 'combine']: | |
# anchor selection at the output of backbone | |
self.enc_output_backbone = nn.Linear(d_model, d_model) | |
self.enc_output_norm_backbone = nn.LayerNorm(d_model) | |
if two_stage_type == 'no': | |
self.init_ref_points(num_queries) # init self.refpoint_embed | |
self.enc_out_class_embed = None | |
self.enc_out_bbox_embed = None | |
self.enc_out_pose_embed = None | |
# evolution of anchors | |
self.dec_layer_number = dec_layer_number | |
if dec_layer_number is not None: | |
if self.two_stage_type != 'no' or num_patterns == 0: | |
assert dec_layer_number[ | |
0] == num_queries, f'dec_layer_number[0]({dec_layer_number[0]}) != num_queries({num_queries})' | |
else: | |
assert dec_layer_number[ | |
0] == num_queries * num_patterns, f'dec_layer_number[0]({dec_layer_number[0]}) != num_queries({num_queries}) * num_patterns({num_patterns})' | |
self._reset_parameters() | |
self.rm_self_attn_layers = rm_self_attn_layers | |
if rm_self_attn_layers is not None: | |
# assert len(rm_self_attn_layers) == num_decoder_layers | |
print('Removing the self-attn in {} decoder layers'.format( | |
rm_self_attn_layers)) | |
for lid, dec_layer in enumerate(self.decoder.layers): | |
if lid in rm_self_attn_layers: | |
dec_layer.rm_self_attn_modules() | |
self.rm_detach = rm_detach | |
if self.rm_detach: | |
assert isinstance(rm_detach, list) | |
assert any([i in ['enc_ref', 'enc_tgt', 'dec'] for i in rm_detach]) | |
self.decoder.rm_detach = rm_detach | |
def _reset_parameters(self): | |
for p in self.parameters(): | |
if p.dim() > 1: | |
nn.init.xavier_uniform_(p) | |
for m in self.modules(): | |
if isinstance(m, MSDeformAttn): | |
m._reset_parameters() | |
if self.num_feature_levels > 1 and self.level_embed is not None: | |
nn.init.normal_(self.level_embed) | |
if self.two_stage_learn_wh: | |
nn.init.constant_(self.two_stage_wh_embedding.weight, | |
math.log(0.05 / (1 - 0.05))) | |
def get_valid_ratio(self, mask): | |
_, H, W = mask.shape | |
valid_H = torch.sum(~mask[:, :, 0], 1) | |
valid_W = torch.sum(~mask[:, 0, :], 1) | |
valid_ratio_h = valid_H.float() / H | |
valid_ratio_w = valid_W.float() / W | |
valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1) | |
return valid_ratio | |
def init_ref_points(self, use_num_queries): | |
self.refpoint_embed = nn.Embedding(use_num_queries, 4) | |
if self.random_refpoints_xy: | |
# import pdb; pdb.set_trace() | |
self.refpoint_embed.weight.data[:, :2].uniform_(0, 1) | |
self.refpoint_embed.weight.data[:, :2] = inverse_sigmoid( | |
self.refpoint_embed.weight.data[:, :2]) | |
self.refpoint_embed.weight.data[:, :2].requires_grad = False | |
# srcs: features; refpoint_embed: | |
def forward(self, | |
srcs, | |
masks, | |
refpoint_embed, | |
pos_embeds, | |
tgt, | |
attn_mask=None, | |
attn_mask2=None, | |
attn_mask3=None): | |
# pdb.set_trace() | |
# prepare input for encoder | |
src_flatten = [] | |
mask_flatten = [] | |
lvl_pos_embed_flatten = [] | |
spatial_shapes = [] | |
for lvl, (src, mask, pos_embed) in enumerate( | |
zip(srcs, masks, pos_embeds)): # for feature level | |
bs, c, h, w = src.shape | |
spatial_shape = (h, w) | |
spatial_shapes.append(spatial_shape) | |
src = src.flatten(2).transpose(1, 2) # bs, hw, c | |
mask = mask.flatten(1) # bs, hw | |
pos_embed = pos_embed.flatten(2).transpose(1, 2) # bs, hw, c | |
if self.num_feature_levels > 1 and self.level_embed is not None: | |
lvl_pos_embed = pos_embed + self.level_embed[lvl].view( | |
1, 1, -1) # level_embed[lvl]: [256] | |
else: | |
lvl_pos_embed = pos_embed | |
lvl_pos_embed_flatten.append(lvl_pos_embed) | |
src_flatten.append(src) | |
mask_flatten.append(mask) | |
src_flatten = torch.cat(src_flatten, 1) # bs, \sum{hxw}, c | |
mask_flatten = torch.cat(mask_flatten, 1) # bs, \sum{hxw} | |
lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, | |
1) # bs, \sum{hxw}, c | |
spatial_shapes = torch.as_tensor(spatial_shapes, | |
dtype=torch.long, | |
device=src_flatten.device) | |
level_start_index = torch.cat((spatial_shapes.new_zeros( | |
(1, )), spatial_shapes.prod(1).cumsum(0)[:-1])) | |
valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1) | |
# two stage | |
if self.two_stage_type in ['early', 'combine']: | |
output_memory, output_proposals = gen_encoder_output_proposals( | |
src_flatten, mask_flatten, spatial_shapes) | |
output_memory = self.enc_output_norm_backbone( | |
self.enc_output_backbone(output_memory)) | |
# gather boxes | |
topk = self.num_queries | |
enc_outputs_class = self.encoder.class_embed[0](output_memory) | |
enc_topk_proposals = torch.topk(enc_outputs_class.max(-1)[0], | |
topk, | |
dim=1)[1] # bs, nq | |
enc_refpoint_embed = torch.gather( | |
output_proposals, 1, | |
enc_topk_proposals.unsqueeze(-1).repeat(1, 1, 4)) | |
src_flatten = output_memory | |
else: | |
enc_topk_proposals = enc_refpoint_embed = None | |
######################################################### | |
# Begin Encoder | |
######################################################### | |
memory, enc_intermediate_output, enc_intermediate_refpoints = self.encoder( | |
src_flatten, | |
pos=lvl_pos_embed_flatten, | |
level_start_index=level_start_index, | |
spatial_shapes=spatial_shapes, | |
valid_ratios=valid_ratios, | |
key_padding_mask=mask_flatten, | |
ref_token_index=enc_topk_proposals, # bs, nq | |
ref_token_coord=enc_refpoint_embed, # bs, nq, 4 | |
) | |
######################################################### | |
# End Encoder | |
# - memory: bs, \sum{hw}, c | |
# - mask_flatten: bs, \sum{hw} | |
# - lvl_pos_embed_flatten: bs, \sum{hw}, c | |
# - enc_intermediate_output: None or (nenc+1, bs, nq, c) or (nenc, bs, nq, c) | |
# - enc_intermediate_refpoints: None or (nenc+1, bs, nq, c) or (nenc, bs, nq, c) | |
######################################################### | |
if self.two_stage_type in [ | |
'standard', 'combine', 'enceachlayer', 'enclayer1' | |
]: | |
if self.two_stage_learn_wh: | |
# import pdb; pdb.set_trace() | |
input_hw = self.two_stage_wh_embedding.weight[0] | |
else: | |
input_hw = None | |
output_memory, output_proposals = gen_encoder_output_proposals( | |
memory, mask_flatten, spatial_shapes, input_hw) | |
output_memory = self.enc_output_norm( | |
self.enc_output(output_memory)) | |
enc_outputs_class_unselected = self.enc_out_class_embed( | |
output_memory) # [11531, 2] for swin | |
enc_outputs_coord_unselected = self.enc_out_bbox_embed( | |
output_memory | |
) + output_proposals # (bs, \sum{hw}, 4) unsigmoid | |
topk = self.num_queries | |
topk_proposals = torch.topk( | |
enc_outputs_class_unselected.max(-1)[0], topk, | |
dim=1)[1] # bs, nq coarse human query selection | |
# gather boxes | |
refpoint_embed_undetach = torch.gather( | |
enc_outputs_coord_unselected, 1, | |
topk_proposals.unsqueeze(-1).repeat(1, 1, 4)) # unsigmoid | |
refpoint_embed_ = refpoint_embed_undetach.detach() | |
init_box_proposal = torch.gather( | |
output_proposals, 1, | |
topk_proposals.unsqueeze(-1).repeat(1, 1, | |
4)).sigmoid() # sigmoid | |
# gather tgt | |
tgt_undetach = torch.gather( | |
output_memory, 1, | |
topk_proposals.unsqueeze(-1).repeat( | |
1, 1, self.d_model)) # selected content query | |
if self.embed_init_tgt: | |
tgt_ = self.tgt_embed.weight[:, None, :].repeat( | |
1, bs, 1).transpose(0, 1) # nq, bs, d_model | |
else: | |
tgt_ = tgt_undetach.detach() | |
if refpoint_embed is not None: | |
# import pdb; pdb.set_trace() | |
refpoint_embed = torch.cat([refpoint_embed, refpoint_embed_], | |
dim=1) # [1000, 4] | |
tgt = torch.cat([tgt, tgt_], dim=1) | |
else: | |
refpoint_embed, tgt = refpoint_embed_, tgt_ | |
elif self.two_stage_type == 'early': | |
refpoint_embed_undetach = self.enc_out_bbox_embed( | |
enc_intermediate_output[-1] | |
) + enc_refpoint_embed # unsigmoid, (bs, nq, 4) | |
refpoint_embed = refpoint_embed_undetach.detach() # | |
tgt_undetach = enc_intermediate_output[-1] # bs, nq, d_model | |
tgt = tgt_undetach.detach() | |
elif self.two_stage_type == 'no': | |
tgt_ = self.tgt_embed.weight[:, | |
None, :].repeat(1, bs, 1).transpose( | |
0, 1) # nq, bs, d_model | |
refpoint_embed_ = self.refpoint_embed.weight[:, None, :].repeat( | |
1, bs, 1).transpose(0, 1) # nq, bs, 4 | |
if refpoint_embed is not None: | |
# import pdb; pdb.set_trace() | |
refpoint_embed = torch.cat([refpoint_embed, refpoint_embed_], | |
dim=1) | |
tgt = torch.cat([tgt, tgt_], dim=1) | |
else: | |
refpoint_embed, tgt = refpoint_embed_, tgt_ | |
# pat embed | |
if self.num_patterns > 0: | |
tgt_embed = tgt.repeat(1, self.num_patterns, 1) | |
refpoint_embed = refpoint_embed.repeat(1, self.num_patterns, 1) | |
tgt_pat = self.patterns.weight[None, :, :].repeat_interleave( | |
self.num_queries, 1) # 1, n_q*n_pat, d_model | |
tgt = tgt_embed + tgt_pat | |
init_box_proposal = refpoint_embed_.sigmoid() | |
else: | |
raise NotImplementedError('unknown two_stage_type {}'.format( | |
self.two_stage_type)) | |
######################################################### | |
# Begin Decoder | |
######################################################### | |
hs, references = self.decoder( | |
tgt=tgt.transpose(0, 1), | |
memory=memory.transpose(0, 1), | |
memory_key_padding_mask=mask_flatten, | |
pos=lvl_pos_embed_flatten.transpose(0, 1), | |
refpoints_unsigmoid=refpoint_embed.transpose(0, 1), | |
level_start_index=level_start_index, | |
spatial_shapes=spatial_shapes, | |
valid_ratios=valid_ratios, | |
tgt_mask=attn_mask, | |
tgt_mask2=attn_mask2, | |
tgt_mask3=attn_mask3) | |
######################################################### | |
# End Decoder | |
# hs: n_dec, bs, nq, d_model | |
# references: n_dec+1, bs, nq, query_dim | |
######################################################### | |
######################################################### | |
# Begin postprocess | |
######################################################### | |
if self.two_stage_type == 'standard': | |
if self.two_stage_keep_all_tokens: | |
hs_enc = output_memory.unsqueeze(0) | |
ref_enc = enc_outputs_coord_unselected.unsqueeze(0) | |
init_box_proposal = output_proposals | |
# import pdb; pdb.set_trace() | |
else: | |
hs_enc = tgt_undetach.unsqueeze(0) | |
ref_enc = refpoint_embed_undetach.sigmoid().unsqueeze(0) | |
elif self.two_stage_type in ['combine', 'early']: | |
hs_enc = enc_intermediate_output | |
hs_enc = torch.cat((hs_enc, tgt_undetach.unsqueeze(0)), | |
dim=0) # nenc+1, bs, nq, c | |
n_layer_hs_enc = hs_enc.shape[0] | |
assert n_layer_hs_enc == self.num_encoder_layers + 1 | |
ref_enc = enc_intermediate_refpoints | |
ref_enc = torch.cat( | |
(ref_enc, refpoint_embed_undetach.sigmoid().unsqueeze(0)), | |
dim=0) # nenc+1, bs, nq, 4 | |
elif self.two_stage_type in ['enceachlayer', 'enclayer1']: | |
hs_enc = enc_intermediate_output | |
hs_enc = torch.cat((hs_enc, tgt_undetach.unsqueeze(0)), | |
dim=0) # nenc, bs, nq, c | |
n_layer_hs_enc = hs_enc.shape[0] | |
assert n_layer_hs_enc == self.num_encoder_layers | |
ref_enc = enc_intermediate_refpoints | |
ref_enc = torch.cat( | |
(ref_enc, refpoint_embed_undetach.sigmoid().unsqueeze(0)), | |
dim=0) # nenc, bs, nq, 4 | |
else: | |
hs_enc = ref_enc = None | |
return hs, references, hs_enc, ref_enc, init_box_proposal | |