AiOS / models /aios /transformer.py
ttxskk
update
d7e58f0
raw
history blame
120 kB
import math, random
import copy
import os
from typing import Optional, List, Union
import warnings
from util.misc import inverse_sigmoid
import torch
import torch.nn.functional as F
from torch import nn, Tensor
from .transformer_deformable import DeformableTransformerEncoderLayer, DeformableTransformerDecoderLayer
from .utils import gen_encoder_output_proposals, sigmoid_focal_loss, MLP, _get_activation_fn, gen_sineembed_for_position
from .ops.modules.ms_deform_attn import MSDeformAttn
import pdb
class Transformer(nn.Module):
def __init__(
self,
d_model=256,
nhead=8,
num_queries=300,
num_encoder_layers=6,
num_decoder_layers=6,
dim_feedforward=2048,
dropout=0.0,
activation='relu',
normalize_before=False,
return_intermediate_dec=False,
query_dim=4,
num_patterns=0,
modulate_hw_attn=False,
# for deformable encoder
deformable_encoder=False,
deformable_decoder=False,
num_feature_levels=1,
enc_n_points=4,
dec_n_points=4,
# init query
learnable_tgt_init=False,
random_refpoints_xy=False,
# two stage
two_stage_type='no',
two_stage_learn_wh=False,
two_stage_keep_all_tokens=False,
# evo of #anchors
dec_layer_number=None,
rm_self_attn_layers=None,
# for detach
rm_detach=None,
decoder_sa_type='sa',
module_seq=['sa', 'ca', 'ffn'],
# for pose
embed_init_tgt=False,
num_body_points=17,
num_hand_points=10,
num_face_points=10,
num_box_decoder_layers=2,
num_hand_face_decoder_layers=4,
num_group=100):
super().__init__()
# pdb.set_trace()
self.num_feature_levels = num_feature_levels # 4
self.num_encoder_layers = num_encoder_layers # 6
self.num_decoder_layers = num_decoder_layers # 6
self.deformable_encoder = deformable_encoder
self.deformable_decoder = deformable_decoder
self.two_stage_keep_all_tokens = two_stage_keep_all_tokens # False
self.num_queries = num_queries # 900
self.random_refpoints_xy = random_refpoints_xy # False
assert query_dim == 4
if num_feature_levels > 1:
assert deformable_encoder, 'only support deformable_encoder for num_feature_levels > 1'
self.decoder_sa_type = decoder_sa_type # sa
assert decoder_sa_type in ['sa', 'ca_label', 'ca_content']
# choose encoder layer type
if deformable_encoder:
encoder_layer = DeformableTransformerEncoderLayer(
d_model, dim_feedforward, dropout, activation,
num_feature_levels, nhead, enc_n_points)
else:
raise NotImplementedError
encoder_layer = TransformerEncoderLayer(d_model, nhead,
dim_feedforward, dropout,
activation,
normalize_before)
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
self.encoder = TransformerEncoder(
encoder_layer,
num_encoder_layers,
encoder_norm,
d_model=d_model,
num_queries=num_queries,
deformable_encoder=deformable_encoder,
two_stage_type=two_stage_type)
# choose decoder layer type
if deformable_decoder:
decoder_layer = DeformableTransformerDecoderLayer(
d_model,
dim_feedforward,
dropout,
activation,
num_feature_levels,
nhead,
dec_n_points,
decoder_sa_type=decoder_sa_type,
module_seq=module_seq)
else:
raise NotImplementedError
decoder_layer = TransformerDecoderLayer(
d_model,
nhead,
dim_feedforward,
dropout,
activation,
normalize_before,
num_feature_levels=num_feature_levels)
decoder_norm = nn.LayerNorm(d_model)
self.decoder = TransformerDecoder(
decoder_layer,
num_decoder_layers,
decoder_norm,
return_intermediate=return_intermediate_dec,
d_model=d_model,
query_dim=query_dim,
modulate_hw_attn=modulate_hw_attn,
num_feature_levels=num_feature_levels,
deformable_decoder=deformable_decoder,
dec_layer_number=dec_layer_number,
num_body_points=num_body_points,
num_hand_points=num_hand_points,
num_face_points=num_face_points,
num_box_decoder_layers=num_box_decoder_layers,
num_hand_face_decoder_layers=num_hand_face_decoder_layers,
num_group=num_group,
num_dn=num_group,
)
self.d_model = d_model
self.nhead = nhead # 8
self.dec_layers = num_decoder_layers # 6
self.num_queries = num_queries # useful for single stage model only
self.num_patterns = num_patterns # 0
if not isinstance(num_patterns, int):
Warning('num_patterns should be int but {}'.format(
type(num_patterns)))
self.num_patterns = 0
if self.num_patterns > 0:
assert two_stage_type == 'no'
self.patterns = nn.Embedding(self.num_patterns, d_model)
if num_feature_levels > 1:
if self.num_encoder_layers > 0:
self.level_embed = nn.Parameter(
torch.Tensor(num_feature_levels, d_model))
else:
self.level_embed = None
self.learnable_tgt_init = learnable_tgt_init # true
assert learnable_tgt_init, 'why not learnable_tgt_init'
self.embed_init_tgt = embed_init_tgt # false
if (two_stage_type != 'no' and embed_init_tgt) or (two_stage_type
== 'no'):
self.tgt_embed = nn.Embedding(self.num_queries, d_model)
nn.init.normal_(self.tgt_embed.weight.data)
else:
self.tgt_embed = None
# for two stage
self.two_stage_type = two_stage_type
self.two_stage_learn_wh = two_stage_learn_wh
assert two_stage_type in [
'no', 'standard', 'early', 'combine', 'enceachlayer', 'enclayer1'
], 'unknown param {} of two_stage_type'.format(two_stage_type)
if two_stage_type in [
'standard', 'combine', 'enceachlayer', 'enclayer1'
]:
# anchor selection at the output of encoder
self.enc_output = nn.Linear(d_model, d_model)
self.enc_output_norm = nn.LayerNorm(d_model)
if two_stage_learn_wh:
# import pdb; pdb.set_trace()
self.two_stage_wh_embedding = nn.Embedding(1, 2)
else:
self.two_stage_wh_embedding = None
if two_stage_type in ['early', 'combine']:
# anchor selection at the output of backbone
self.enc_output_backbone = nn.Linear(d_model, d_model)
self.enc_output_norm_backbone = nn.LayerNorm(d_model)
if two_stage_type == 'no':
self.init_ref_points(num_queries) # init self.refpoint_embed
self.enc_out_class_embed = None
self.enc_out_bbox_embed = None
self.enc_out_pose_embed = None
# evolution of anchors
self.dec_layer_number = dec_layer_number
if dec_layer_number is not None:
if self.two_stage_type != 'no' or num_patterns == 0:
assert dec_layer_number[
0] == num_queries, f'dec_layer_number[0]({dec_layer_number[0]}) != num_queries({num_queries})'
else:
assert dec_layer_number[
0] == num_queries * num_patterns, f'dec_layer_number[0]({dec_layer_number[0]}) != num_queries({num_queries}) * num_patterns({num_patterns})'
self._reset_parameters()
self.rm_self_attn_layers = rm_self_attn_layers
if rm_self_attn_layers is not None:
# assert len(rm_self_attn_layers) == num_decoder_layers
print('Removing the self-attn in {} decoder layers'.format(
rm_self_attn_layers))
for lid, dec_layer in enumerate(self.decoder.layers):
if lid in rm_self_attn_layers:
dec_layer.rm_self_attn_modules()
self.rm_detach = rm_detach
if self.rm_detach:
assert isinstance(rm_detach, list)
assert any([i in ['enc_ref', 'enc_tgt', 'dec'] for i in rm_detach])
self.decoder.rm_detach = rm_detach
def _reset_parameters(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
for m in self.modules():
if isinstance(m, MSDeformAttn):
m._reset_parameters()
if self.num_feature_levels > 1 and self.level_embed is not None:
nn.init.normal_(self.level_embed)
if self.two_stage_learn_wh:
nn.init.constant_(self.two_stage_wh_embedding.weight,
math.log(0.05 / (1 - 0.05)))
def get_valid_ratio(self, mask):
_, H, W = mask.shape
valid_H = torch.sum(~mask[:, :, 0], 1)
valid_W = torch.sum(~mask[:, 0, :], 1)
valid_ratio_h = valid_H.float() / H
valid_ratio_w = valid_W.float() / W
valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
return valid_ratio
def init_ref_points(self, use_num_queries):
self.refpoint_embed = nn.Embedding(use_num_queries, 4)
if self.random_refpoints_xy:
# import pdb; pdb.set_trace()
self.refpoint_embed.weight.data[:, :2].uniform_(0, 1)
self.refpoint_embed.weight.data[:, :2] = inverse_sigmoid(
self.refpoint_embed.weight.data[:, :2])
self.refpoint_embed.weight.data[:, :2].requires_grad = False
# srcs: features; refpoint_embed:
def forward(self,
srcs,
masks,
refpoint_embed,
pos_embeds,
tgt,
attn_mask=None,
attn_mask2=None,
attn_mask3=None):
# pdb.set_trace()
# prepare input for encoder
src_flatten = []
mask_flatten = []
lvl_pos_embed_flatten = []
spatial_shapes = []
for lvl, (src, mask, pos_embed) in enumerate(
zip(srcs, masks, pos_embeds)): # for feature level
bs, c, h, w = src.shape
spatial_shape = (h, w)
spatial_shapes.append(spatial_shape)
src = src.flatten(2).transpose(1, 2) # bs, hw, c
mask = mask.flatten(1) # bs, hw
pos_embed = pos_embed.flatten(2).transpose(1, 2) # bs, hw, c
if self.num_feature_levels > 1 and self.level_embed is not None:
lvl_pos_embed = pos_embed + self.level_embed[lvl].view(
1, 1, -1) # level_embed[lvl]: [256]
else:
lvl_pos_embed = pos_embed
lvl_pos_embed_flatten.append(lvl_pos_embed)
src_flatten.append(src)
mask_flatten.append(mask)
src_flatten = torch.cat(src_flatten, 1) # bs, \sum{hxw}, c
mask_flatten = torch.cat(mask_flatten, 1) # bs, \sum{hxw}
lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten,
1) # bs, \sum{hxw}, c
spatial_shapes = torch.as_tensor(spatial_shapes,
dtype=torch.long,
device=src_flatten.device)
level_start_index = torch.cat((spatial_shapes.new_zeros(
(1, )), spatial_shapes.prod(1).cumsum(0)[:-1]))
valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1)
# two stage
if self.two_stage_type in ['early', 'combine']:
output_memory, output_proposals = gen_encoder_output_proposals(
src_flatten, mask_flatten, spatial_shapes)
output_memory = self.enc_output_norm_backbone(
self.enc_output_backbone(output_memory))
# gather boxes
topk = self.num_queries
enc_outputs_class = self.encoder.class_embed[0](output_memory)
enc_topk_proposals = torch.topk(enc_outputs_class.max(-1)[0],
topk,
dim=1)[1] # bs, nq
enc_refpoint_embed = torch.gather(
output_proposals, 1,
enc_topk_proposals.unsqueeze(-1).repeat(1, 1, 4))
src_flatten = output_memory
else:
enc_topk_proposals = enc_refpoint_embed = None
#########################################################
# Begin Encoder
#########################################################
memory, enc_intermediate_output, enc_intermediate_refpoints = self.encoder(
src_flatten,
pos=lvl_pos_embed_flatten,
level_start_index=level_start_index,
spatial_shapes=spatial_shapes,
valid_ratios=valid_ratios,
key_padding_mask=mask_flatten,
ref_token_index=enc_topk_proposals, # bs, nq
ref_token_coord=enc_refpoint_embed, # bs, nq, 4
)
#########################################################
# End Encoder
# - memory: bs, \sum{hw}, c
# - mask_flatten: bs, \sum{hw}
# - lvl_pos_embed_flatten: bs, \sum{hw}, c
# - enc_intermediate_output: None or (nenc+1, bs, nq, c) or (nenc, bs, nq, c)
# - enc_intermediate_refpoints: None or (nenc+1, bs, nq, c) or (nenc, bs, nq, c)
#########################################################
if self.two_stage_type in [
'standard', 'combine', 'enceachlayer', 'enclayer1'
]:
if self.two_stage_learn_wh:
# import pdb; pdb.set_trace()
input_hw = self.two_stage_wh_embedding.weight[0]
else:
input_hw = None
output_memory, output_proposals = gen_encoder_output_proposals(
memory, mask_flatten, spatial_shapes, input_hw)
output_memory = self.enc_output_norm(
self.enc_output(output_memory))
enc_outputs_class_unselected = self.enc_out_class_embed(
output_memory) # [11531, 2] for swin
enc_outputs_coord_unselected = self.enc_out_bbox_embed(
output_memory
) + output_proposals # (bs, \sum{hw}, 4) unsigmoid
topk = self.num_queries
topk_proposals = torch.topk(
enc_outputs_class_unselected.max(-1)[0], topk,
dim=1)[1] # bs, nq coarse human query selection
# gather boxes
refpoint_embed_undetach = torch.gather(
enc_outputs_coord_unselected, 1,
topk_proposals.unsqueeze(-1).repeat(1, 1, 4)) # unsigmoid
refpoint_embed_ = refpoint_embed_undetach.detach()
init_box_proposal = torch.gather(
output_proposals, 1,
topk_proposals.unsqueeze(-1).repeat(1, 1,
4)).sigmoid() # sigmoid
# gather tgt
tgt_undetach = torch.gather(
output_memory, 1,
topk_proposals.unsqueeze(-1).repeat(
1, 1, self.d_model)) # selected content query
if self.embed_init_tgt:
tgt_ = self.tgt_embed.weight[:, None, :].repeat(
1, bs, 1).transpose(0, 1) # nq, bs, d_model
else:
tgt_ = tgt_undetach.detach()
if refpoint_embed is not None:
# import pdb; pdb.set_trace()
refpoint_embed = torch.cat([refpoint_embed, refpoint_embed_],
dim=1) # [1000, 4]
tgt = torch.cat([tgt, tgt_], dim=1)
else:
refpoint_embed, tgt = refpoint_embed_, tgt_
elif self.two_stage_type == 'early':
refpoint_embed_undetach = self.enc_out_bbox_embed(
enc_intermediate_output[-1]
) + enc_refpoint_embed # unsigmoid, (bs, nq, 4)
refpoint_embed = refpoint_embed_undetach.detach() #
tgt_undetach = enc_intermediate_output[-1] # bs, nq, d_model
tgt = tgt_undetach.detach()
elif self.two_stage_type == 'no':
tgt_ = self.tgt_embed.weight[:,
None, :].repeat(1, bs, 1).transpose(
0, 1) # nq, bs, d_model
refpoint_embed_ = self.refpoint_embed.weight[:, None, :].repeat(
1, bs, 1).transpose(0, 1) # nq, bs, 4
if refpoint_embed is not None:
# import pdb; pdb.set_trace()
refpoint_embed = torch.cat([refpoint_embed, refpoint_embed_],
dim=1)
tgt = torch.cat([tgt, tgt_], dim=1)
else:
refpoint_embed, tgt = refpoint_embed_, tgt_
# pat embed
if self.num_patterns > 0:
tgt_embed = tgt.repeat(1, self.num_patterns, 1)
refpoint_embed = refpoint_embed.repeat(1, self.num_patterns, 1)
tgt_pat = self.patterns.weight[None, :, :].repeat_interleave(
self.num_queries, 1) # 1, n_q*n_pat, d_model
tgt = tgt_embed + tgt_pat
init_box_proposal = refpoint_embed_.sigmoid()
else:
raise NotImplementedError('unknown two_stage_type {}'.format(
self.two_stage_type))
#########################################################
# Begin Decoder
#########################################################
hs, references = self.decoder(
tgt=tgt.transpose(0, 1),
memory=memory.transpose(0, 1),
memory_key_padding_mask=mask_flatten,
pos=lvl_pos_embed_flatten.transpose(0, 1),
refpoints_unsigmoid=refpoint_embed.transpose(0, 1),
level_start_index=level_start_index,
spatial_shapes=spatial_shapes,
valid_ratios=valid_ratios,
tgt_mask=attn_mask,
tgt_mask2=attn_mask2,
tgt_mask3=attn_mask3)
#########################################################
# End Decoder
# hs: n_dec, bs, nq, d_model
# references: n_dec+1, bs, nq, query_dim
#########################################################
#########################################################
# Begin postprocess
#########################################################
if self.two_stage_type == 'standard':
if self.two_stage_keep_all_tokens:
hs_enc = output_memory.unsqueeze(0)
ref_enc = enc_outputs_coord_unselected.unsqueeze(0)
init_box_proposal = output_proposals
# import pdb; pdb.set_trace()
else:
hs_enc = tgt_undetach.unsqueeze(0)
ref_enc = refpoint_embed_undetach.sigmoid().unsqueeze(0)
elif self.two_stage_type in ['combine', 'early']:
hs_enc = enc_intermediate_output
hs_enc = torch.cat((hs_enc, tgt_undetach.unsqueeze(0)),
dim=0) # nenc+1, bs, nq, c
n_layer_hs_enc = hs_enc.shape[0]
assert n_layer_hs_enc == self.num_encoder_layers + 1
ref_enc = enc_intermediate_refpoints
ref_enc = torch.cat(
(ref_enc, refpoint_embed_undetach.sigmoid().unsqueeze(0)),
dim=0) # nenc+1, bs, nq, 4
elif self.two_stage_type in ['enceachlayer', 'enclayer1']:
hs_enc = enc_intermediate_output
hs_enc = torch.cat((hs_enc, tgt_undetach.unsqueeze(0)),
dim=0) # nenc, bs, nq, c
n_layer_hs_enc = hs_enc.shape[0]
assert n_layer_hs_enc == self.num_encoder_layers
ref_enc = enc_intermediate_refpoints
ref_enc = torch.cat(
(ref_enc, refpoint_embed_undetach.sigmoid().unsqueeze(0)),
dim=0) # nenc, bs, nq, 4
else:
hs_enc = ref_enc = None
return hs, references, hs_enc, ref_enc, init_box_proposal
class TransformerEncoder(nn.Module):
def __init__(
self,
encoder_layer,
num_layers,
norm=None,
d_model=256,
num_queries=300,
deformable_encoder=False,
enc_layer_share=False,
enc_layer_dropout_prob=None,
two_stage_type='no',
):
super().__init__()
# pdb.set_trace()
# prepare layers
if num_layers > 0: # 6
self.layers = _get_clones(
encoder_layer, num_layers,
layer_share=enc_layer_share) # enc_layer_share false
else:
self.layers = []
del encoder_layer
self.query_scale = None
self.num_queries = num_queries # 900
self.deformable_encoder = deformable_encoder
self.num_layers = num_layers # 6
self.norm = norm
self.d_model = d_model
self.enc_layer_dropout_prob = enc_layer_dropout_prob
if enc_layer_dropout_prob is not None:
assert isinstance(enc_layer_dropout_prob, list)
assert len(enc_layer_dropout_prob) == num_layers
for i in enc_layer_dropout_prob:
assert 0.0 <= i <= 1.0
self.two_stage_type = two_stage_type
if two_stage_type in ['enceachlayer', 'enclayer1']:
_proj_layer = nn.Linear(d_model, d_model)
_norm_layer = nn.LayerNorm(d_model)
if two_stage_type == 'enclayer1':
self.enc_norm = nn.ModuleList([_norm_layer])
self.enc_proj = nn.ModuleList([_proj_layer])
else:
self.enc_norm = nn.ModuleList([
copy.deepcopy(_norm_layer) for i in range(num_layers - 1)
])
self.enc_proj = nn.ModuleList([
copy.deepcopy(_proj_layer) for i in range(num_layers - 1)
])
@staticmethod
def get_reference_points(spatial_shapes, valid_ratios, device):
reference_points_list = []
for lvl, (H_, W_) in enumerate(spatial_shapes):
ref_y, ref_x = torch.meshgrid(
torch.linspace(0.5,
H_ - 0.5,
H_,
dtype=torch.float32,
device=device),
torch.linspace(0.5,
W_ - 0.5,
W_,
dtype=torch.float32,
device=device))
ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] *
H_)
ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] *
W_)
ref = torch.stack((ref_x, ref_y), -1)
reference_points_list.append(ref)
reference_points = torch.cat(reference_points_list, 1)
reference_points = reference_points[:, :, None] * valid_ratios[:, None]
return reference_points
def forward(self,
src: Tensor,
pos: Tensor,
spatial_shapes: Tensor,
level_start_index: Tensor,
valid_ratios: Tensor,
key_padding_mask: Tensor,
ref_token_index: Optional[Tensor] = None,
ref_token_coord: Optional[Tensor] = None):
"""
Input:
- src: [bs, sum(hi*wi), 256]
- pos: pos embed for src. [bs, sum(hi*wi), 256]
- spatial_shapes: h,w of each level [num_level, 2]
- level_start_index: [num_level] start point of level in sum(hi*wi).
- valid_ratios: [bs, num_level, 2]
- key_padding_mask: [bs, sum(hi*wi)]
- ref_token_index: bs, nq
- ref_token_coord: bs, nq, 4
Intermedia:
- reference_points: [bs, sum(hi*wi), num_level, 2]
Outpus:
- output: [bs, sum(hi*wi), 256]
"""
# pdb.set_trace()
if self.two_stage_type in [
'no', 'standard', 'enceachlayer', 'enclayer1'
]:
assert ref_token_index is None
output = src
# preparation and reshape
if self.num_layers > 0:
if self.deformable_encoder:
reference_points = self.get_reference_points(spatial_shapes,
valid_ratios,
device=src.device)
# import pdb; pdb.set_trace()
intermediate_output = []
intermediate_ref = []
if ref_token_index is not None:
out_i = torch.gather(
output, 1,
ref_token_index.unsqueeze(-1).repeat(1, 1, self.d_model))
intermediate_output.append(out_i)
intermediate_ref.append(ref_token_coord)
# intermediate_coord = []
# main process
for layer_id, layer in enumerate(self.layers):
# main process
dropflag = False
if self.enc_layer_dropout_prob is not None:
prob = random.random()
if prob < self.enc_layer_dropout_prob[layer_id]:
dropflag = True
if not dropflag:
if self.deformable_encoder:
output = layer(src=output,
pos=pos,
reference_points=reference_points,
spatial_shapes=spatial_shapes,
level_start_index=level_start_index,
key_padding_mask=key_padding_mask)
else:
output = layer(
src=output.transpose(0, 1),
pos=pos.transpose(0, 1),
key_padding_mask=key_padding_mask).transpose(0, 1)
if ((layer_id == 0 and self.two_stage_type in ['enceachlayer', 'enclayer1']) \
or (self.two_stage_type == 'enceachlayer')) \
and (layer_id != self.num_layers - 1):
output_memory, output_proposals = gen_encoder_output_proposals(
output, key_padding_mask, spatial_shapes)
output_memory = self.enc_norm[layer_id](
self.enc_proj[layer_id](output_memory))
# gather boxes
topk = self.num_queries
enc_outputs_class = self.class_embed[layer_id](output_memory)
ref_token_index = torch.topk(enc_outputs_class.max(-1)[0],
topk,
dim=1)[1] # bs, nq
ref_token_coord = torch.gather(
output_proposals, 1,
ref_token_index.unsqueeze(-1).repeat(1, 1, 4))
output = output_memory
# aux loss
if (layer_id !=
self.num_layers - 1) and ref_token_index is not None:
out_i = torch.gather(
output, 1,
ref_token_index.unsqueeze(-1).repeat(1, 1, self.d_model))
intermediate_output.append(out_i)
intermediate_ref.append(ref_token_coord)
if self.norm is not None:
output = self.norm(output)
if ref_token_index is not None:
intermediate_output = torch.stack(
intermediate_output) # n_enc/n_enc-1, bs, \sum{hw}, d_model
intermediate_ref = torch.stack(intermediate_ref)
else:
intermediate_output = intermediate_ref = None
return output, intermediate_output, intermediate_ref
class TransformerDecoder(nn.Module):
def __init__(
self,
decoder_layer,
num_layers,
norm=None,
return_intermediate=False,
d_model=256,
query_dim=4,
modulate_hw_attn=False,
num_feature_levels=1,
deformable_decoder=False,
dec_layer_number=None, # number of queries each layer in decoder
dec_layer_share=False,
dec_layer_dropout_prob=None,
num_box_decoder_layers=2,
num_hand_face_decoder_layers=4,
num_body_points=17,
num_hand_points=10,
num_face_points=10,
num_dn=100,
num_group=100):
super().__init__()
# pdb.set_trace()
if num_layers > 0:
self.layers = _get_clones(decoder_layer,
num_layers,
layer_share=dec_layer_share)
else:
self.layers = []
self.num_layers = num_layers
self.norm = norm
self.return_intermediate = return_intermediate # True
assert return_intermediate, 'support return_intermediate only'
self.query_dim = query_dim # 4
assert query_dim in [
2, 4
], 'query_dim should be 2/4 but {}'.format(query_dim)
self.num_feature_levels = num_feature_levels # 4
self.ref_point_head = MLP(query_dim // 2 * d_model, d_model, d_model,
2) # 4//2 * 256, 256, 256, 2
if not deformable_decoder:
self.query_pos_sine_scale = MLP(d_model, d_model, d_model, 2)
else:
self.query_pos_sine_scale = None
self.num_body_points = num_body_points
self.num_hand_points = num_hand_points
self.num_face_points = num_face_points
self.query_scale = None
# aios kp
self.bbox_embed = None
self.class_embed = None
self.pose_embed = None
self.pose_hw_embed = None
# smpl
# self.smpl_pose_embed = None
# self.smpl_beta_embed = None
# self.smpl_cam_embed = None
# smplx
# smplx hand kp
self.bbox_hand_embed = None
self.bbox_hand_hw_embed = None
self.pose_hand_embed = None
self.pose_hand_hw_embed = None
# smplx face kp
self.bbox_face_embed = None
self.bbox_face_hw_embed = None
self.pose_face_embed = None
self.pose_face_hw_embed = None
# self.smplx_lhand_pose_embed = None
# self.smplx_rhand_pose_embed = None
# self.smplx_expression_embed = None
# self.smplx_jaw_embed = None
self.num_box_decoder_layers = num_box_decoder_layers # 2
self.num_hand_face_decoder_layers = num_hand_face_decoder_layers
self.d_model = d_model
self.modulate_hw_attn = modulate_hw_attn
self.deformable_decoder = deformable_decoder
if not deformable_decoder and modulate_hw_attn:
self.ref_anchor_head = MLP(d_model, d_model, 2, 2)
else:
self.ref_anchor_head = None
self.box_pred_damping = None
self.dec_layer_number = dec_layer_number
if dec_layer_number is not None:
assert isinstance(dec_layer_number, list)
assert len(dec_layer_number) == num_layers
# assert dec_layer_number[0] ==
self.dec_layer_dropout_prob = dec_layer_dropout_prob
if dec_layer_dropout_prob is not None:
raise NotImplementedError
assert isinstance(dec_layer_dropout_prob, list)
assert len(dec_layer_dropout_prob) == num_layers
for i in dec_layer_dropout_prob:
assert 0.0 <= i <= 1.0
self.num_group = num_group
self.rm_detach = None
self.num_dn = num_dn
# self.hw_body_kps = nn.Embedding(self.num_body_points, 2)
self.hw = nn.Embedding(self.num_body_points, 2)
self.keypoint_embed = nn.Embedding(self.num_body_points, d_model)
self.body_kpt_index_1 = [
x for x in range(self.num_group*(self.num_body_points+4)) if x%(self.num_body_points+4) not in [0, (1 + self.num_body_points), (2 + self.num_body_points), (3 + self.num_body_points)]]
self.whole_body_points = \
self.num_body_points + self.num_hand_points *2 + self.num_face_points
self.body_kpt_index_2 = [
x for x in range(self.num_group * (self.whole_body_points + 4))
if (x % (self.whole_body_points + 4) in range(1,self.num_body_points+1))
]
# [0-99]: dn bbox;
# [0,1]: body box;
# [1, 18]: body kps;
# [18, 19]: lhand box
# [19, 29]: lhand kps
# [29, 30]: rhand box
# [30, 40]: rhand kps
# [40, 41]: face bbox
# [41, 51]: face kps
self.lhand_kpt_index = [
x for x in range(self.num_group * (self.whole_body_points + 4))
if (x % (self.whole_body_points + 4) in range(
self.num_body_points+2, self.num_body_points+self.num_hand_points+2))]
self.rhand_kpt_index = [
x for x in range(self.num_group * (self.whole_body_points + 4))
if (x % (self.whole_body_points + 4) in range(
self.num_body_points+self.num_hand_points+3, self.num_body_points+self.num_hand_points*2+3))
]
self.face_kpt_index = [
x for x in range(self.num_group * (self.whole_body_points + 4))
if (x % (self.whole_body_points + 4) in range(
self.num_body_points+self.num_hand_points*2+4, self.num_body_points+self.num_hand_points*2+self.num_face_points+4))
]
self.lhand_box_embed = nn.Embedding(1, d_model)
self.rhand_box_embed = nn.Embedding(1, d_model)
self.face_box_embed = nn.Embedding(1, d_model)
self.hw_lhand_bbox = nn.Embedding(1, 2)
self.hw_rhand_bbox = nn.Embedding(1, 2)
self.hw_face_bbox = nn.Embedding(1, 2)
self.hw_lhand_kps = nn.Embedding(self.num_hand_points, 2)
self.hw_rhand_kps = nn.Embedding(self.num_hand_points, 2)
self.hw_face_kps = nn.Embedding(self.num_face_points, 2)
self.lhand_keypoint_embed = nn.Embedding(self.num_hand_points, d_model)
self.rhand_keypoint_embed = nn.Embedding(self.num_hand_points, d_model)
self.face_keypoint_embed = nn.Embedding(self.num_face_points, d_model)
def forward(
self,
tgt,
memory,
tgt_mask: Optional[Tensor] = None,
tgt_mask2: Optional[Tensor] = None,
tgt_mask3: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
refpoints_unsigmoid: Optional[Tensor] = None, # num_queries, bs, 2
# for memory
level_start_index: Optional[Tensor] = None, # num_levels
spatial_shapes: Optional[Tensor] = None, # bs, num_levels, 2
valid_ratios: Optional[Tensor] = None,
):
output = tgt
intermediate = []
reference_points = refpoints_unsigmoid.sigmoid()
ref_points = [reference_points]
effect_num_dn = self.num_dn if self.training else 0
inter_select_number = self.num_group
for layer_id, layer in enumerate(self.layers):
if self.deformable_decoder:
if reference_points.shape[-1] == 4:
reference_points_input = reference_points[:, :, None] \
* torch.cat([valid_ratios, valid_ratios], -1)[None, :] # nq, bs, nlevel, 4
else:
assert reference_points.shape[-1] == 2
reference_points_input = reference_points[:, :,
None] * valid_ratios[
None, :]
query_sine_embed = gen_sineembed_for_position(
reference_points_input[:, :, 0, :]
) # convert the position query from bbox to sine/cosin embend
else:
query_sine_embed = gen_sineembed_for_position(
reference_points) # nq, bs, 256*2
reference_points_input = None
raw_query_pos = self.ref_point_head(
query_sine_embed) # nq, bs, 256
pos_scale = self.query_scale(
output) if self.query_scale is not None else 1 # ?
query_pos = pos_scale * raw_query_pos
if not self.deformable_decoder:
query_sine_embed = query_sine_embed[
..., :self.d_model] * self.query_pos_sine_scale(output)
# modulated HW attentions
if not self.deformable_decoder and self.modulate_hw_attn:
refHW_cond = self.ref_anchor_head(
output).sigmoid() # nq, bs, 2
query_sine_embed[..., self.d_model // 2:] *= (
refHW_cond[..., 0] /
reference_points[..., 2]).unsqueeze(-1)
query_sine_embed[..., :self.d_model //
2] *= (refHW_cond[..., 1] /
reference_points[..., 3]).unsqueeze(-1)
dropflag = False
if self.dec_layer_dropout_prob is not None:
prob = random.random()
if prob < self.dec_layer_dropout_prob[layer_id]:
dropflag = True
if not dropflag:
output = layer(
tgt=output,
tgt_query_pos=query_pos,
tgt_query_sine_embed=query_sine_embed,
tgt_key_padding_mask=tgt_key_padding_mask,
tgt_reference_points=reference_points_input,
memory=memory, # encoder output, also known as content query of encoder
memory_key_padding_mask=memory_key_padding_mask,
memory_level_start_index=level_start_index,
memory_spatial_shapes=spatial_shapes,
memory_pos=pos, # position query of enconder
self_attn_mask=tgt_mask,
cross_attn_mask=memory_mask)
intermediate.append(self.norm(output))
# human update
if layer_id < self.num_box_decoder_layers:
# reference_points: [100*(17+20*2+72) 4, 4]
reference_before_sigmoid = inverse_sigmoid(reference_points)
delta_unsig = self.bbox_embed[layer_id](
output) # delta_x, delta_y, delta_w, delta_h
outputs_unsig = delta_unsig + reference_before_sigmoid
new_reference_points = outputs_unsig.sigmoid(
) # update the positional query by adding the offset delta_unsig
# kp query expansion
if layer_id == self.num_box_decoder_layers - 1:
dn_output = output[:effect_num_dn] # [100,-,256]
dn_new_reference_points = new_reference_points[:
effect_num_dn] # [100, -, 4]
class_unselected = self.class_embed[layer_id](output)[
effect_num_dn:] # [900, -, 2]
topk_proposals = torch.topk(class_unselected.max(-1)[0],
inter_select_number,
dim=0)[1] # 100
# selected position: select 100 query
new_reference_points_for_body_box = torch.gather(
new_reference_points[effect_num_dn:], 0,
topk_proposals.unsqueeze(-1).repeat(
1, 1, 4)) # selected position query
# selected output features
new_output_for_body_box = torch.gather(
output[effect_num_dn:], 0,
topk_proposals.unsqueeze(-1).repeat(
1, 1, self.d_model)) # selected content query
bs = new_output_for_body_box.shape[1]
# selected content query + keypoint position query, with shape [100, -, 4]
# expand per-human query to per-keypoint query
new_output_for_body_keypoint = new_output_for_body_box[:, None, :, :] \
+ self.keypoint_embed.weight[None, :, None, :] # keypoint content query
if self.num_body_points == 17:
delta_xy = self.pose_embed[-1](new_output_for_body_keypoint)[
..., :2]
else:
delta_xy = self.pose_embed[0](new_output_for_body_keypoint)[
..., :2]
body_keypoint_xy = (inverse_sigmoid(
new_reference_points_for_body_box[..., :2][:, None]) +
delta_xy).sigmoid() # [100, 14, -, 2]
num_queries, _, bs, _ = body_keypoint_xy.shape
body_keypoint_wh_weight = self.hw.weight.unsqueeze(0).unsqueeze(
-2).repeat(num_queries, 1, bs, 1).sigmoid()
body_keypoint_wh = body_keypoint_wh_weight * new_reference_points_for_body_box[
..., 2:][:, None]
new_reference_points_for_keypoint = torch.cat(
(body_keypoint_xy, body_keypoint_wh), dim=-1)
# for lhand bbox
new_output_for_lhand_box = new_output_for_body_box[:, None, :, :] \
+ self.lhand_box_embed.weight[None, :, None, :]
delta_lhand_box_xy = self.bbox_hand_embed[-1](new_output_for_lhand_box)[..., :2]
lhand_bbox_xy = (inverse_sigmoid(
new_reference_points_for_body_box[..., :2][:, None]) +
delta_lhand_box_xy).sigmoid() # [100, 14, -, 2]
num_queries, _, bs, _ = lhand_bbox_xy.shape
lhand_bbox_wh_weight = self.hw_lhand_bbox.weight.unsqueeze(0).unsqueeze(
-2).repeat(num_queries, 1, bs, 1).sigmoid()
lhand_bbox_wh = lhand_bbox_wh_weight * new_reference_points_for_body_box[
..., 2:][:, None]
new_reference_points_for_lhand_bbox = torch.cat(
(lhand_bbox_xy, lhand_bbox_wh), dim=-1)
# for rhand bbox
new_output_for_rhand_box = new_output_for_body_box[:, None, :, :] \
+ self.rhand_box_embed.weight[None, :, None, :]
delta_rhand_box_xy = self.bbox_hand_embed[-1](new_output_for_rhand_box)[..., :2]
rhand_bbox_xy = (inverse_sigmoid(
new_reference_points_for_body_box[..., :2][:, None]) +
delta_rhand_box_xy).sigmoid() # [100, 14, -, 2]
num_queries, _, bs, _ = rhand_bbox_xy.shape
rhand_bbox_wh_weight = self.hw_rhand_bbox.weight.unsqueeze(0).unsqueeze(
-2).repeat(num_queries, 1, bs, 1).sigmoid()
rhand_bbox_wh = rhand_bbox_wh_weight * new_reference_points_for_body_box[
..., 2:][:, None]
new_reference_points_for_rhand_bbox = torch.cat(
(rhand_bbox_xy, rhand_bbox_wh), dim=-1)
# for face bbox
new_output_for_face_box = new_output_for_body_box[:, None, :, :] \
+ self.face_box_embed.weight[None, :, None, :]
delta_face_box_xy = self.bbox_face_embed[-1](new_output_for_face_box)[..., :2]
face_bbox_xy = (inverse_sigmoid(
new_reference_points_for_body_box[..., :2][:, None]) +
delta_face_box_xy).sigmoid() # [100, 14, -, 2]
num_queries, _, bs, _ = face_bbox_xy.shape
face_bbox_wh_weight = self.hw_face_bbox.weight.unsqueeze(0).unsqueeze(
-2).repeat(num_queries, 1, bs, 1).sigmoid()
face_bbox_wh = face_bbox_wh_weight * new_reference_points_for_body_box[
..., 2:][:, None]
new_reference_points_for_face_box = torch.cat(
(face_bbox_xy, face_bbox_wh), dim=-1)
output = torch.cat(
(new_output_for_body_box.unsqueeze(1),
new_output_for_body_keypoint,
new_output_for_lhand_box,
new_output_for_rhand_box,
new_output_for_face_box),
dim=1).flatten(0, 1)
new_reference_points = torch.cat(
(new_reference_points_for_body_box.unsqueeze(1),
new_reference_points_for_keypoint,
new_reference_points_for_lhand_bbox,
new_reference_points_for_rhand_bbox,
new_reference_points_for_face_box), dim=1).flatten(0,1)
new_reference_points = torch.cat((dn_new_reference_points, new_reference_points),dim=0)
output = torch.cat((dn_output, output), dim=0)
tgt_mask = tgt_mask2
# human-to-keypoints, human2face, human2hand update # 2
if layer_id >= self.num_box_decoder_layers and layer_id < self.num_box_decoder_layers +2:
reference_before_sigmoid = inverse_sigmoid(reference_points)
reference_before_sigmoid_body_bbox_dn = \
reference_before_sigmoid[:effect_num_dn]
reference_before_sigmoid_bbox_body_norm = \
reference_before_sigmoid[effect_num_dn:][0::(self.num_body_points+4)]
output_bbox_body_dn=output[:effect_num_dn]
output_bbox_body_norm = output[effect_num_dn:][
0::(self.num_body_points+4)]
delta_unsig_bbox_body_dn = self.bbox_embed[
layer_id](output_bbox_body_dn)
delta_unsig_bbox_body_norm = self.bbox_embed[
layer_id](output_bbox_body_norm)
outputs_unsig_body_bbox_dn = delta_unsig_bbox_body_dn + reference_before_sigmoid_body_bbox_dn
outputs_unsig_body_bbox_norm = delta_unsig_bbox_body_norm + reference_before_sigmoid_bbox_body_norm
new_reference_points_for_body_box_dn = outputs_unsig_body_bbox_dn.sigmoid()
new_reference_points_for_body_box_norm = outputs_unsig_body_bbox_norm.sigmoid()
# body kps
output_body_kpt=output[effect_num_dn:].index_select(
0,torch.tensor(self.body_kpt_index_1,device=output.device)) # select kp center content query
delta_xy_body_unsig = self.pose_embed[
layer_id-self.num_box_decoder_layers](output_body_kpt) # offset of kp bbox center
outputs_body_kp_unsig = \
reference_before_sigmoid[effect_num_dn:].index_select(
0, torch.tensor(self.body_kpt_index_1, device=output.device)).clone() # select kp position query
delta_hw_body_kp_unsig = self.pose_hw_embed[
layer_id-self.num_box_decoder_layers](output_body_kpt)
outputs_body_kp_unsig[..., :2] += delta_xy_body_unsig[..., :2]
outputs_body_kp_unsig[..., 2:] += delta_hw_body_kp_unsig
new_reference_points_for_body_keypoint = outputs_body_kp_unsig.sigmoid()
bs=new_reference_points_for_body_box_norm.shape[1]
# lhand box
output_lhand_bbox_query = output[effect_num_dn:][
(self.num_body_points + 1)::(self.num_body_points+4)]
delta_xy_lhand_bbox_unsig = self.bbox_hand_embed[
layer_id-self.num_box_decoder_layers](output_lhand_bbox_query)
outputs_lhand_bbox_unsig = \
reference_before_sigmoid[effect_num_dn:][
(self.num_body_points + 1)::(self.num_body_points+4)].clone()
delta_hw_lhand_bbox_unsig = self.bbox_hand_hw_embed[
layer_id-self.num_box_decoder_layers](output_lhand_bbox_query)
outputs_lhand_bbox_unsig[..., :2] +=delta_xy_lhand_bbox_unsig[..., :2]
outputs_lhand_bbox_unsig[..., 2:] +=delta_hw_lhand_bbox_unsig
new_reference_points_for_lhand_box_norm = outputs_lhand_bbox_unsig.sigmoid()
# rhand box
output_rhand_bbox_query = output[effect_num_dn:][
(self.num_body_points + 2)::(self.num_body_points+4)]
delta_xy_rhand_bbox_unsig = self.bbox_hand_embed[
layer_id-self.num_box_decoder_layers](output_rhand_bbox_query)
outputs_rhand_bbox_unsig = \
reference_before_sigmoid[effect_num_dn:][
(self.num_body_points + 2)::(self.num_body_points+4)].clone()
delta_hw_rhand_bbox_unsig = self.bbox_hand_hw_embed[
layer_id-self.num_box_decoder_layers](output_rhand_bbox_query)
outputs_rhand_bbox_unsig[..., :2] +=delta_xy_rhand_bbox_unsig[..., :2]
outputs_rhand_bbox_unsig[..., 2:] +=delta_hw_rhand_bbox_unsig
new_reference_points_for_rhand_box_norm = outputs_rhand_bbox_unsig.sigmoid()
# face box
output_face_bbox_query = output[effect_num_dn:][
(self.num_body_points + 3)::(self.num_body_points+4)]
delta_xy_face_bbox_unsig = self.bbox_face_embed[
layer_id-self.num_box_decoder_layers](output_face_bbox_query)
outputs_face_bbox_unsig = \
reference_before_sigmoid[effect_num_dn:][
(self.num_body_points + 3)::(self.num_body_points+4)].clone()
delta_hw_face_bbox_unsig = self.bbox_face_hw_embed[
layer_id-self.num_box_decoder_layers](output_face_bbox_query)
outputs_face_bbox_unsig[..., :2] +=delta_xy_face_bbox_unsig[..., :2]
outputs_face_bbox_unsig[..., 2:] +=delta_hw_face_bbox_unsig
new_reference_points_for_face_box_norm = outputs_face_bbox_unsig.sigmoid()
new_reference_points_norm = torch.cat(
(new_reference_points_for_body_box_norm.unsqueeze(1),
new_reference_points_for_body_keypoint.view(-1,self.num_body_points,bs,4),
new_reference_points_for_lhand_box_norm.unsqueeze(1),
new_reference_points_for_rhand_box_norm.unsqueeze(1),
new_reference_points_for_face_box_norm.unsqueeze(1)), dim=1).flatten(0,1)
new_reference_points = torch.cat((
new_reference_points_for_body_box_dn,
new_reference_points_norm), dim=0)
# hand, bbox query expansion
if layer_id == self.num_hand_face_decoder_layers - 1:
dn_body_output = output[:effect_num_dn]
dn_reference_points_body = new_reference_points[:effect_num_dn]
# body bbox
new_reference_points_for_body_box = \
new_reference_points[effect_num_dn:][0::(self.num_body_points + 4)]
new_output_for_body_box = output[effect_num_dn:][0::
(self.num_body_points + 4)]
# body kp bbox
new_output_body_for_body_keypoint = \
output[effect_num_dn:].index_select(
0,torch.tensor(self.body_kpt_index_1,device=output.device)).clone()
new_output_body_for_body_keypoint = new_output_body_for_body_keypoint.view(
self.num_group, self.num_body_points, bs, self.d_model)
new_reference_points_for_body_keypoint = new_reference_points[effect_num_dn:].index_select(
0,torch.tensor(self.body_kpt_index_1,device=output.device)).clone()
new_reference_points_for_body_keypoint = \
new_reference_points_for_body_keypoint.view(self.num_group, self.num_body_points, bs, 4)
new_reference_points_body = \
torch.cat((new_reference_points_for_body_box.unsqueeze(1),
new_reference_points_for_body_keypoint), dim=1)
new_body_output = torch.cat((new_output_for_body_box.unsqueeze(1),
new_output_body_for_body_keypoint), dim=1)
# lhand bbox content query and position query
new_reference_points_for_lhand_box = \
new_reference_points[effect_num_dn:][
(self.num_body_points + 1)::(self.num_body_points + 4)]
new_output_for_lhand_box = output[effect_num_dn:][
(self.num_body_points + 1)::(self.num_body_points + 4)]
# lhand query expansion
new_output_for_lhand_keypoint = new_output_for_lhand_box[:, None, :, :] \
+ self.lhand_keypoint_embed.weight[None, :, None, :]
# use the expanded lhand kp query to regress
# the center displacement relatived to lhand bbox
delta_lhand_kp_xy = self.pose_hand_embed[-1](new_output_for_lhand_keypoint)[..., :2]
# get absoulte bbox center for each lhand kps bbox
lhand_keypoint_xy = (
inverse_sigmoid(new_reference_points_for_lhand_box[..., :2][:, None])
+ delta_lhand_kp_xy).sigmoid()
num_queries,_,bs,_=lhand_keypoint_xy.shape
lhand_keypoint_wh_weight = \
self.hw_lhand_kps.weight.unsqueeze(0).unsqueeze(-2).repeat(num_queries,1,bs,1).sigmoid()
lhand_keypoint_wh = lhand_keypoint_wh_weight * new_reference_points_for_lhand_box[..., 2:][:, None]
new_reference_points_for_lhand_keypoint = torch.cat((lhand_keypoint_xy, lhand_keypoint_wh), dim=-1)
new_reference_points_lhand = \
torch.cat((new_reference_points_for_lhand_box.unsqueeze(1), new_reference_points_for_lhand_keypoint), dim=1)
new_lhand_output = torch.cat((new_output_for_lhand_box.unsqueeze(1), new_output_for_lhand_keypoint), dim=1)
# rhand
new_reference_points_for_rhand_box = \
new_reference_points[effect_num_dn:][
(self.num_body_points + 2)::(self.num_body_points + 4)]
new_output_for_rhand_box = output[effect_num_dn:][
(self.num_body_points + 2)::(self.num_body_points + 4)]
new_output_for_rhand_keypoint = new_output_for_rhand_box[:, None, :, :] \
+ self.rhand_keypoint_embed.weight[None, :, None, :]
delta_rhand_kp_xy = self.pose_hand_embed[-1](new_output_for_rhand_keypoint)
rhand_keypoint_xy = (
inverse_sigmoid(new_reference_points_for_rhand_box[..., :2][:, None])
+ delta_rhand_kp_xy).sigmoid()
num_queries,_,bs,_=rhand_keypoint_xy.shape
rhand_keypoint_wh_weight = \
self.hw_rhand_kps.weight.unsqueeze(0).unsqueeze(-2).repeat(num_queries,1,bs,1).sigmoid()
rhand_keypoint_wh = rhand_keypoint_wh_weight * new_reference_points_for_rhand_box[..., 2:][:, None]
new_reference_points_for_rhand_keypoint = torch.cat((rhand_keypoint_xy, rhand_keypoint_wh), dim=-1)
new_reference_points_rhand = \
torch.cat((new_reference_points_for_rhand_box.unsqueeze(1), new_reference_points_for_rhand_keypoint), dim=1)
new_rhand_output = torch.cat((new_output_for_rhand_box.unsqueeze(1), new_output_for_rhand_keypoint), dim=1)
# face
new_reference_points_for_face_box = \
new_reference_points[effect_num_dn:][
(self.num_body_points + 3)::(self.num_body_points + 4)]
new_output_for_face_box = output[effect_num_dn:][
(self.num_body_points + 3)::(self.num_body_points + 4)]
new_output_for_face_keypoint = new_output_for_face_box[:, None, :, :] \
+ self.face_keypoint_embed.weight[None, :, None, :]
delta_face_kp_xy = self.pose_face_embed[-1](new_output_for_face_keypoint)[..., :2]
face_keypoint_xy = (
inverse_sigmoid(new_reference_points_for_face_box[..., :2][:, None])
+ delta_face_kp_xy).sigmoid()
num_queries,_,bs,_= face_keypoint_xy.shape
face_keypoint_wh_weight = \
self.hw_face_kps.weight.unsqueeze(0).unsqueeze(-2).repeat(num_queries,1,bs,1).sigmoid()
face_keypoint_wh = face_keypoint_wh_weight * new_reference_points_for_face_box[..., 2:][:, None]
new_reference_points_for_face_keypoint = torch.cat((face_keypoint_xy, face_keypoint_wh), dim=-1)
new_reference_points_face = torch.cat(
(new_reference_points_for_face_box.unsqueeze(1),
new_reference_points_for_face_keypoint), dim=1)
new_face_output = torch.cat(
(new_output_for_face_box.unsqueeze(1),
new_output_for_face_keypoint), dim=1)
# new_reference_points = torch.cat(
# (dn_reference_points_body.unsqueeze(1),
# new_reference_points_body,
# new_reference_points_lhand,
# new_reference_points_rhand,
# new_reference_points_face), dim=1).flatten(0,1)
new_reference_points = torch.cat(
(new_reference_points_body,
new_reference_points_lhand,
new_reference_points_rhand,
new_reference_points_face), dim=1).flatten(0,1)
# new_reference_points = torch.cat((dn_reference_points_body,new_reference_points),dim=0)
new_reference_points = torch.cat(
(dn_reference_points_body, new_reference_points), dim=0
)
output = torch.cat(
(new_body_output,
new_lhand_output,
new_rhand_output,
new_face_output), dim=1).flatten(0, 1)
output = torch.cat(
(dn_body_output, output), dim=0
)
tgt_mask = tgt_mask3
if layer_id >= self.num_hand_face_decoder_layers:
reference_before_sigmoid = inverse_sigmoid(reference_points)
# body box
reference_before_sigmoid_body_bbox_dn = \
reference_before_sigmoid[:effect_num_dn]
reference_before_sigmoid_bbox_body_norm = \
reference_before_sigmoid[effect_num_dn:][
0::(self.num_body_points+2*self.num_hand_points+self.num_face_points+4)]
output_bbox_body_dn=output[:effect_num_dn]
output_bbox_body_norm = output[effect_num_dn:][
0::(self.num_body_points+2*self.num_hand_points+self.num_face_points+4)]
delta_unsig_bbox_body_dn = self.bbox_embed[
layer_id](output_bbox_body_dn)
delta_unsig_bbox_body_norm = self.bbox_embed[
layer_id](output_bbox_body_norm)
outputs_unsig_body_bbox_dn = \
delta_unsig_bbox_body_dn + reference_before_sigmoid_body_bbox_dn
outputs_unsig_body_bbox_norm = \
delta_unsig_bbox_body_norm + reference_before_sigmoid_bbox_body_norm
new_reference_points_for_body_box_dn = outputs_unsig_body_bbox_dn.sigmoid()
new_reference_points_for_body_box_norm = outputs_unsig_body_bbox_norm.sigmoid()
# body kps
output_body_kpt=output[effect_num_dn:].index_select(
0,torch.tensor(self.body_kpt_index_2,device=output.device)) # select kp center content query
delta_xy_body_unsig = self.pose_embed[
layer_id-self.num_box_decoder_layers](output_body_kpt) # offset of kp bbox center
outputs_body_kp_unsig = \
reference_before_sigmoid[effect_num_dn:].index_select(
0, torch.tensor(self.body_kpt_index_2, device=output.device)).clone() # select kp position query
delta_hw_body_kp_unsig = self.pose_hw_embed[
layer_id-self.num_box_decoder_layers](output_body_kpt)
outputs_body_kp_unsig[..., :2] += delta_xy_body_unsig[..., :2]
outputs_body_kp_unsig[..., 2:] += delta_hw_body_kp_unsig
new_reference_points_for_body_keypoint = outputs_body_kp_unsig.sigmoid()
bs=new_reference_points_for_body_box_norm.shape[1]
new_reference_points_for_body_keypoint = \
new_reference_points_for_body_keypoint.view(-1,self.num_body_points,bs,4)
# lhand bbox
output_lhand_bbox_query = output[effect_num_dn:][
(self.num_body_points + 1)::
(self.num_body_points + 2 * self.num_hand_points + self.num_face_points + 4)]
delta_xy_lhand_bbox_unsig = self.bbox_hand_embed[
layer_id-self.num_box_decoder_layers](output_lhand_bbox_query)
outputs_lhand_bbox_unsig = \
reference_before_sigmoid[effect_num_dn:][
(self.num_body_points + 1)::
(self.num_body_points+2*self.num_hand_points+self.num_face_points+4)].clone()
delta_hw_lhand_bbox_unsig = self.bbox_hand_hw_embed[
layer_id-self.num_box_decoder_layers](output_lhand_bbox_query)
outputs_lhand_bbox_unsig[..., :2] +=delta_xy_lhand_bbox_unsig[..., :2]
outputs_lhand_bbox_unsig[..., 2:] +=delta_hw_lhand_bbox_unsig
new_reference_points_for_lhand_box_norm = outputs_lhand_bbox_unsig.sigmoid()
# output_bbox_lhand_norm = output[effect_num_dn:][
# (self.num_body_points + 1)::
# (self.num_body_points + 2 * self.num_hand_points + self.num_face_points + 4)]
# reference_before_sigmoid_bbox_lhand_norm = \
# reference_before_sigmoid[effect_num_dn:][
# (self.num_body_points + 1)::
# (self.num_body_points+2*self.num_hand_points+self.num_face_points+4)]
# delta_unsig_bbox_lhand_norm = self.bbox_hand_embed[
# layer_id-self.num_box_decoder_layers](output_bbox_lhand_norm)
# outputs_unsig_lhand_bbox_norm = \
# delta_unsig_bbox_lhand_norm + reference_before_sigmoid_bbox_lhand_norm
# new_reference_points_for_lhand_box_norm = outputs_unsig_lhand_bbox_norm.sigmoid()
# lhand kps
output_lhand_kpt_query=output[effect_num_dn:].index_select(
0,torch.tensor(self.lhand_kpt_index,device=output.device)) # select kp center content query
delta_xy_lhand_kpt_unsig = self.pose_hand_embed[
layer_id-self.num_hand_face_decoder_layers](output_lhand_kpt_query) # offset of kp bbox center
outputs_lhand_kp_unsig = \
reference_before_sigmoid[effect_num_dn:].index_select(
0, torch.tensor(self.lhand_kpt_index, device=output.device)).clone() # select kp position query
delta_hw_lhand_kp_unsig = self.pose_hand_hw_embed[
layer_id-self.num_hand_face_decoder_layers](output_lhand_kpt_query)
outputs_lhand_kp_unsig[..., :2] += delta_xy_lhand_kpt_unsig[..., :2]
outputs_lhand_kp_unsig[..., 2:] += delta_hw_lhand_kp_unsig
new_reference_points_for_lhand_keypoint = outputs_lhand_kp_unsig.sigmoid()
bs=new_reference_points_for_lhand_box_norm.shape[1]
new_reference_points_for_lhand_keypoint = \
new_reference_points_for_lhand_keypoint.view(-1,self.num_hand_points,bs,4)
# rhand bbox
output_rhand_bbox_query = output[effect_num_dn:][
(self.num_body_points + self.num_hand_points + 2)::
(self.num_body_points+2*self.num_hand_points+self.num_face_points+4)]
delta_xy_rhand_bbox_unsig = self.bbox_hand_embed[
layer_id-self.num_box_decoder_layers](output_rhand_bbox_query)
outputs_rhand_bbox_unsig = \
reference_before_sigmoid[effect_num_dn:][
(self.num_body_points + self.num_hand_points + 2)::
(self.num_body_points+2*self.num_hand_points+self.num_face_points+4)].clone()
delta_hw_rhand_bbox_unsig = self.bbox_hand_hw_embed[
layer_id-self.num_box_decoder_layers](output_rhand_bbox_query)
outputs_rhand_bbox_unsig[..., :2] +=delta_xy_rhand_bbox_unsig[..., :2]
outputs_rhand_bbox_unsig[..., 2:] +=delta_hw_rhand_bbox_unsig
new_reference_points_for_rhand_box_norm = outputs_rhand_bbox_unsig.sigmoid()
# output_bbox_rhand_norm = output[effect_num_dn:][
# (self.num_body_points + self.num_hand_points + 2)::
# (self.num_body_points+2*self.num_hand_points+self.num_face_points+4)]
# reference_before_sigmoid_bbox_rhand_norm = \
# reference_before_sigmoid[effect_num_dn:][
# (self.num_body_points + self.num_hand_points + 2)::
# (self.num_body_points+2*self.num_hand_points+self.num_face_points+4)]
# delta_unsig_bbox_rhand_norm = self.bbox_hand_embed[
# layer_id-self.num_box_decoder_layers](output_bbox_rhand_norm)
# outputs_unsig_rhand_bbox_norm = \
# delta_unsig_bbox_rhand_norm + reference_before_sigmoid_bbox_rhand_norm
# new_reference_points_for_rhand_box_norm = outputs_unsig_rhand_bbox_norm.sigmoid()
# rhand kps
output_rhand_kpt_query=output[effect_num_dn:].index_select(
0,torch.tensor(self.rhand_kpt_index,device=output.device)) # select kp center content query
delta_xy_rhand_kpt_unsig = self.pose_hand_embed[
layer_id-self.num_hand_face_decoder_layers](output_rhand_kpt_query) # offset of kp bbox center
outputs_rhand_kp_unsig = \
reference_before_sigmoid[effect_num_dn:].index_select(
0, torch.tensor(self.rhand_kpt_index, device=output.device)).clone() # select kp position query
delta_hw_rhand_kp_unsig = self.pose_hand_hw_embed[
layer_id-self.num_hand_face_decoder_layers](output_rhand_kpt_query)
outputs_rhand_kp_unsig[..., :2] += delta_xy_rhand_kpt_unsig[..., :2]
outputs_rhand_kp_unsig[..., 2:] += delta_hw_rhand_kp_unsig
new_reference_points_for_rhand_keypoint = outputs_rhand_kp_unsig.sigmoid()
bs=new_reference_points_for_rhand_box_norm.shape[1]
new_reference_points_for_rhand_keypoint = \
new_reference_points_for_rhand_keypoint.view(-1,self.num_hand_points,bs,4)
# face bbox
output_face_bbox_query = output[effect_num_dn:][
(self.num_body_points + 2 * self.num_hand_points + 3)::
(self.num_body_points+2*self.num_hand_points+self.num_face_points+4)]
delta_xy_face_bbox_unsig = self.bbox_face_embed[
layer_id-self.num_box_decoder_layers](output_face_bbox_query)
outputs_face_bbox_unsig = \
reference_before_sigmoid[effect_num_dn:][
(self.num_body_points + 2 * self.num_hand_points + 3)::
(self.num_body_points+2*self.num_hand_points+self.num_face_points+4)].clone()
delta_hw_face_bbox_unsig = self.bbox_face_hw_embed[
layer_id-self.num_box_decoder_layers](output_face_bbox_query)
outputs_face_bbox_unsig[..., :2] +=delta_xy_face_bbox_unsig[..., :2]
outputs_face_bbox_unsig[..., 2:] +=delta_hw_face_bbox_unsig
new_reference_points_for_face_box_norm = outputs_face_bbox_unsig.sigmoid()
# output_bbox_face_norm = output[effect_num_dn:][
# (self.num_body_points + 2 * self.num_hand_points + 3)::
# (self.num_body_points+2*self.num_hand_points+self.num_face_points+4)]
# reference_before_sigmoid_bbox_face_norm = \
# reference_before_sigmoid[effect_num_dn:][
# (self.num_body_points + 2 * self.num_hand_points + 3)::
# (self.num_body_points+2*self.num_hand_points+self.num_face_points+4)]
# delta_unsig_bbox_face_norm = self.bbox_face_embed[
# layer_id-self.num_box_decoder_layers](output_bbox_face_norm)
# outputs_unsig_face_bbox_norm = \
# delta_unsig_bbox_face_norm + reference_before_sigmoid_bbox_face_norm
# new_reference_points_for_face_box_norm = outputs_unsig_face_bbox_norm.sigmoid()
# face kps
output_face_kpt_query=output[effect_num_dn:].index_select(
0,torch.tensor(self.face_kpt_index,device=output.device)) # select kp center content query
delta_xy_face_kpt_unsig = self.pose_face_embed[
layer_id-self.num_hand_face_decoder_layers](output_face_kpt_query) # offset of kp bbox center
outputs_face_kp_unsig = \
reference_before_sigmoid[effect_num_dn:].index_select(
0, torch.tensor(self.face_kpt_index, device=output.device)).clone() # select kp position query
delta_hw_face_kp_unsig = self.pose_face_hw_embed[
layer_id-self.num_hand_face_decoder_layers](output_face_kpt_query)
outputs_face_kp_unsig[..., :2] += delta_xy_face_kpt_unsig[..., :2]
outputs_face_kp_unsig[..., 2:] += delta_hw_face_kp_unsig
new_reference_points_for_face_keypoint = outputs_face_kp_unsig.sigmoid()
bs=new_reference_points_for_face_box_norm.shape[1]
new_reference_points_for_face_keypoint = \
new_reference_points_for_face_keypoint.view(-1,self.num_face_points,bs,4)
new_reference_points_norm = torch.cat(
(new_reference_points_for_body_box_norm.unsqueeze(1),
new_reference_points_for_body_keypoint,
new_reference_points_for_lhand_box_norm.unsqueeze(1),
new_reference_points_for_lhand_keypoint,
new_reference_points_for_rhand_box_norm.unsqueeze(1),
new_reference_points_for_rhand_keypoint,
new_reference_points_for_face_box_norm.unsqueeze(1),
new_reference_points_for_face_keypoint,
), dim=1).flatten(0,1)
new_reference_points = torch.cat(
(new_reference_points_for_body_box_dn, new_reference_points_norm), dim=0)
if self.rm_detach and 'dec' in self.rm_detach:
reference_points = new_reference_points
else:
reference_points = new_reference_points.detach()
ref_points.append(new_reference_points)
return [[itm_out.transpose(0, 1) for itm_out in intermediate],
[itm_refpoint.transpose(0, 1) for itm_refpoint in ref_points]]
def _get_clones(module, N, layer_share=False):
if layer_share:
return nn.ModuleList([module for i in range(N)])
else:
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
def build_transformer(args):
if args.modelname == 'aios_smplx_box':
return Transformer_Box(
d_model=args.hidden_dim,
dropout=args.dropout,
nhead=args.nheads,
num_queries=args.num_queries,
dim_feedforward=args.dim_feedforward,
num_encoder_layers=args.enc_layers,
num_decoder_layers=args.dec_layers,
normalize_before=args.pre_norm,
return_intermediate_dec=True,
query_dim=args.query_dim,
activation=args.transformer_activation,
num_patterns=args.num_patterns,
modulate_hw_attn=True,
deformable_encoder=True,
deformable_decoder=True,
num_feature_levels=args.num_feature_levels,
enc_n_points=args.enc_n_points,
dec_n_points=args.dec_n_points,
learnable_tgt_init=True,
random_refpoints_xy=args.random_refpoints_xy,
two_stage_type=args.two_stage_type,
two_stage_learn_wh=args.two_stage_learn_wh,
two_stage_keep_all_tokens=args.two_stage_keep_all_tokens,
dec_layer_number=args.dec_layer_number,
rm_self_attn_layers=args.rm_self_attn_layers,
rm_detach=args.rm_detach,
decoder_sa_type=args.decoder_sa_type,
module_seq=args.decoder_module_seq,
embed_init_tgt=args.embed_init_tgt,
num_body_points=args.num_body_points,
num_hand_points=args.num_hand_points,
num_face_points=args.num_face_points,
num_box_decoder_layers=args.num_box_decoder_layers,
num_hand_face_decoder_layers=args.num_hand_face_decoder_layers,
num_group=args.num_group)
elif args.modelname == 'aios_smplx':
return Transformer(
d_model=args.hidden_dim,
dropout=args.dropout,
nhead=args.nheads,
num_queries=args.num_queries,
dim_feedforward=args.dim_feedforward,
num_encoder_layers=args.enc_layers,
num_decoder_layers=args.dec_layers,
normalize_before=args.pre_norm,
return_intermediate_dec=True,
query_dim=args.query_dim,
activation=args.transformer_activation,
num_patterns=args.num_patterns,
modulate_hw_attn=True,
deformable_encoder=True,
deformable_decoder=True,
num_feature_levels=args.num_feature_levels,
enc_n_points=args.enc_n_points,
dec_n_points=args.dec_n_points,
learnable_tgt_init=True,
random_refpoints_xy=args.random_refpoints_xy,
two_stage_type=args.two_stage_type,
two_stage_learn_wh=args.two_stage_learn_wh,
two_stage_keep_all_tokens=args.two_stage_keep_all_tokens,
dec_layer_number=args.dec_layer_number,
rm_self_attn_layers=args.rm_self_attn_layers,
rm_detach=args.rm_detach,
decoder_sa_type=args.decoder_sa_type,
module_seq=args.decoder_module_seq,
embed_init_tgt=args.embed_init_tgt,
num_body_points=args.num_body_points,
num_hand_points=args.num_hand_points,
num_face_points=args.num_face_points,
num_box_decoder_layers=args.num_box_decoder_layers,
num_hand_face_decoder_layers=args.num_hand_face_decoder_layers,
num_group=args.num_group)
else:
raise ValueError('Wrong Transformer type')
class TransformerDecoder_Box(nn.Module):
def __init__(
self,
decoder_layer,
num_layers,
norm=None,
return_intermediate=False,
d_model=256,
query_dim=4,
modulate_hw_attn=False,
num_feature_levels=1,
deformable_decoder=False,
dec_layer_number=None, # number of queries each layer in decoder
dec_layer_share=False,
dec_layer_dropout_prob=None,
num_box_decoder_layers=2,
num_hand_face_decoder_layers=4,
num_body_points=0,
num_hand_points=0,
num_face_points=0,
num_dn=100,
num_group=100):
super().__init__()
# pdb.set_trace()
if num_layers > 0:
self.layers = _get_clones(decoder_layer,
num_layers,
layer_share=dec_layer_share)
else:
self.layers = []
self.num_layers = num_layers
self.norm = norm
self.return_intermediate = return_intermediate # True
assert return_intermediate, 'support return_intermediate only'
self.query_dim = query_dim # 4
assert query_dim in [
2, 4
], 'query_dim should be 2/4 but {}'.format(query_dim)
self.num_feature_levels = num_feature_levels # 4
self.ref_point_head = MLP(query_dim // 2 * d_model, d_model, d_model,
2) # 4//2 * 256, 256, 256, 2
if not deformable_decoder:
self.query_pos_sine_scale = MLP(d_model, d_model, d_model, 2)
else:
self.query_pos_sine_scale = None
self.num_body_points = 0
self.num_hand_points = 0
self.num_face_points = 0
self.query_scale = None
# aios kp
self.bbox_embed = None
self.class_embed = None
self.bbox_hand_embed = None
self.bbox_hand_hw_embed = None
# smplx face kp
self.bbox_face_embed = None
self.bbox_face_hw_embed = None
self.num_box_decoder_layers = num_box_decoder_layers # 2
self.num_hand_face_decoder_layers = num_hand_face_decoder_layers
self.d_model = d_model
self.modulate_hw_attn = modulate_hw_attn
self.deformable_decoder = deformable_decoder
if not deformable_decoder and modulate_hw_attn:
self.ref_anchor_head = MLP(d_model, d_model, 2, 2)
else:
self.ref_anchor_head = None
self.box_pred_damping = None
self.dec_layer_number = dec_layer_number
if dec_layer_number is not None:
assert isinstance(dec_layer_number, list)
assert len(dec_layer_number) == num_layers
# assert dec_layer_number[0] ==
self.dec_layer_dropout_prob = dec_layer_dropout_prob
if dec_layer_dropout_prob is not None:
raise NotImplementedError
assert isinstance(dec_layer_dropout_prob, list)
assert len(dec_layer_dropout_prob) == num_layers
for i in dec_layer_dropout_prob:
assert 0.0 <= i <= 1.0
self.num_group = num_group
self.rm_detach = None
self.num_dn = num_dn
# self.hw_body_kps = nn.Embedding(self.num_body_points, 2)
# self.hw = nn.Embedding(self.num_body_points, 2)
# self.keypoint_embed = nn.Embedding(self.num_body_points, d_model)
# self.body_kpt_index_1 = [
# x for x in range(self.num_group*(self.num_body_points+4)) if x%(self.num_body_points+4) not in [0, (1 + self.num_body_points), (2 + self.num_body_points), (3 + self.num_body_points)]]
# self.whole_body_points = \
# self.num_body_points + self.num_hand_points *2 + self.num_face_points
# self.body_kpt_index_2 = [
# x for x in range(self.num_group * (self.whole_body_points + 4))
# if (x % (self.whole_body_points + 4) in range(1,self.num_body_points+1))
# ]
# [0-99]: dn bbox;
# [0,1]: body box;
# [1, 18]: body kps;
# [18, 19]: lhand box
# [19, 29]: lhand kps
# [29, 30]: rhand box
# [30, 40]: rhand kps
# [40, 41]: face bbox
# [41, 51]: face kps
# self.lhand_kpt_index = [
# x for x in range(self.num_group * (self.whole_body_points + 4))
# if (x % (self.whole_body_points + 4) in range(
# self.num_body_points+2, self.num_body_points+self.num_hand_points+2))]
# self.rhand_kpt_index = [
# x for x in range(self.num_group * (self.whole_body_points + 4))
# if (x % (self.whole_body_points + 4) in range(
# self.num_body_points+self.num_hand_points+3, self.num_body_points+self.num_hand_points*2+3))
# ]
# self.face_kpt_index = [
# x for x in range(self.num_group * (self.whole_body_points + 4))
# if (x % (self.whole_body_points + 4) in range(
# self.num_body_points+self.num_hand_points*2+4, self.num_body_points+self.num_hand_points*2+self.num_face_points+4))
# ]
self.lhand_box_embed = nn.Embedding(1, d_model)
self.rhand_box_embed = nn.Embedding(1, d_model)
self.face_box_embed = nn.Embedding(1, d_model)
self.hw_lhand_bbox = nn.Embedding(1, 2)
self.hw_rhand_bbox = nn.Embedding(1, 2)
self.hw_face_bbox = nn.Embedding(1, 2)
def forward(
self,
tgt,
memory,
tgt_mask: Optional[Tensor] = None,
tgt_mask2: Optional[Tensor] = None,
tgt_mask3: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
refpoints_unsigmoid: Optional[Tensor] = None, # num_queries, bs, 2
# for memory
level_start_index: Optional[Tensor] = None, # num_levels
spatial_shapes: Optional[Tensor] = None, # bs, num_levels, 2
valid_ratios: Optional[Tensor] = None,
):
output = tgt
intermediate = []
reference_points = refpoints_unsigmoid.sigmoid()
ref_points = [reference_points]
effect_num_dn = self.num_dn if self.training else 0
inter_select_number = self.num_group
for layer_id, layer in enumerate(self.layers):
if self.deformable_decoder:
if reference_points.shape[-1] == 4:
reference_points_input = reference_points[:, :, None] \
* torch.cat([valid_ratios, valid_ratios], -1)[None, :] # nq, bs, nlevel, 4
else:
assert reference_points.shape[-1] == 2
reference_points_input = reference_points[:, :,
None] * valid_ratios[
None, :]
query_sine_embed = gen_sineembed_for_position(
reference_points_input[:, :, 0, :]
) # convert the position query from bbox to sine/cosin embend
else:
query_sine_embed = gen_sineembed_for_position(
reference_points) # nq, bs, 256*2
reference_points_input = None
raw_query_pos = self.ref_point_head(
query_sine_embed) # nq, bs, 256
pos_scale = self.query_scale(
output) if self.query_scale is not None else 1 # ?
query_pos = pos_scale * raw_query_pos
if not self.deformable_decoder:
query_sine_embed = query_sine_embed[
..., :self.d_model] * self.query_pos_sine_scale(output)
# modulated HW attentions
if not self.deformable_decoder and self.modulate_hw_attn:
refHW_cond = self.ref_anchor_head(
output).sigmoid() # nq, bs, 2
query_sine_embed[..., self.d_model // 2:] *= (
refHW_cond[..., 0] /
reference_points[..., 2]).unsqueeze(-1)
query_sine_embed[..., :self.d_model //
2] *= (refHW_cond[..., 1] /
reference_points[..., 3]).unsqueeze(-1)
dropflag = False
if self.dec_layer_dropout_prob is not None:
prob = random.random()
if prob < self.dec_layer_dropout_prob[layer_id]:
dropflag = True
if not dropflag:
output = layer(
tgt=output,
tgt_query_pos=query_pos,
tgt_query_sine_embed=query_sine_embed,
tgt_key_padding_mask=tgt_key_padding_mask,
tgt_reference_points=reference_points_input,
memory=memory, # encoder output, also known as content query of encoder
memory_key_padding_mask=memory_key_padding_mask,
memory_level_start_index=level_start_index,
memory_spatial_shapes=spatial_shapes,
memory_pos=pos, # position query of enconder
self_attn_mask=tgt_mask,
cross_attn_mask=memory_mask)
intermediate.append(self.norm(output))
# human update
if layer_id < self.num_box_decoder_layers:
# reference_points: [100*(17+20*2+72) 4, 4]
reference_before_sigmoid = inverse_sigmoid(reference_points)
delta_unsig = self.bbox_embed[layer_id](
output) # delta_x, delta_y, delta_w, delta_h
outputs_unsig = delta_unsig + reference_before_sigmoid
new_reference_points = outputs_unsig.sigmoid(
) # update the positional query by adding the offset delta_unsig
# kp query expansion
if layer_id == self.num_box_decoder_layers - 1:
dn_output = output[:effect_num_dn] # [100,-,256]
dn_new_reference_points = new_reference_points[:effect_num_dn] # [100, -, 4]
class_unselected = self.class_embed[layer_id](output)[
effect_num_dn:] # [900, -, 2]
topk_proposals = torch.topk(class_unselected.max(-1)[0],
inter_select_number,
dim=0)[1] # 100
# selected position: select 100 query
new_reference_points_for_body_box = torch.gather(
new_reference_points[effect_num_dn:], 0,
topk_proposals.unsqueeze(-1).repeat(
1, 1, 4)) # selected position query
# selected output features
new_output_for_body_box = torch.gather(
output[effect_num_dn:], 0,
topk_proposals.unsqueeze(-1).repeat(
1, 1, self.d_model)) # selected content query
bs = new_output_for_body_box.shape[1]
# for lhand bbox
new_output_for_lhand_box = new_output_for_body_box[:, None, :, :] \
+ self.lhand_box_embed.weight[None, :, None, :]
delta_lhand_box_xy = self.bbox_hand_embed[-1](new_output_for_lhand_box)[..., :2]
lhand_bbox_xy = (inverse_sigmoid(
new_reference_points_for_body_box[..., :2][:, None]) +
delta_lhand_box_xy).sigmoid() # [100, 14, -, 2]
num_queries, _, bs, _ = lhand_bbox_xy.shape
lhand_bbox_wh_weight = self.hw_lhand_bbox.weight.unsqueeze(0).unsqueeze(
-2).repeat(num_queries, 1, bs, 1).sigmoid()
lhand_bbox_wh = lhand_bbox_wh_weight * new_reference_points_for_body_box[
..., 2:][:, None]
new_reference_points_for_lhand_bbox = torch.cat(
(lhand_bbox_xy, lhand_bbox_wh), dim=-1)
# for rhand bbox
new_output_for_rhand_box = new_output_for_body_box[:, None, :, :] \
+ self.rhand_box_embed.weight[None, :, None, :]
delta_rhand_box_xy = self.bbox_hand_embed[-1](new_output_for_rhand_box)[..., :2]
rhand_bbox_xy = (inverse_sigmoid(
new_reference_points_for_body_box[..., :2][:, None]) +
delta_rhand_box_xy).sigmoid() # [100, 14, -, 2]
num_queries, _, bs, _ = rhand_bbox_xy.shape
rhand_bbox_wh_weight = self.hw_rhand_bbox.weight.unsqueeze(0).unsqueeze(
-2).repeat(num_queries, 1, bs, 1).sigmoid()
rhand_bbox_wh = rhand_bbox_wh_weight * new_reference_points_for_body_box[
..., 2:][:, None]
new_reference_points_for_rhand_bbox = torch.cat(
(rhand_bbox_xy, rhand_bbox_wh), dim=-1)
# for face bbox
new_output_for_face_box = new_output_for_body_box[:, None, :, :] \
+ self.face_box_embed.weight[None, :, None, :]
delta_face_box_xy = self.bbox_face_embed[-1](new_output_for_face_box)[..., :2]
face_bbox_xy = (inverse_sigmoid(
new_reference_points_for_body_box[..., :2][:, None]) +
delta_face_box_xy).sigmoid() # [100, 14, -, 2]
num_queries, _, bs, _ = face_bbox_xy.shape
face_bbox_wh_weight = self.hw_face_bbox.weight.unsqueeze(0).unsqueeze(
-2).repeat(num_queries, 1, bs, 1).sigmoid()
face_bbox_wh = face_bbox_wh_weight * new_reference_points_for_body_box[
..., 2:][:, None]
new_reference_points_for_face_box = torch.cat(
(face_bbox_xy, face_bbox_wh), dim=-1)
output = torch.cat(
(new_output_for_body_box.unsqueeze(1),
new_output_for_lhand_box,
new_output_for_rhand_box,
new_output_for_face_box),
dim=1).flatten(0, 1)
new_reference_points = torch.cat(
(new_reference_points_for_body_box.unsqueeze(1),
new_reference_points_for_lhand_bbox,
new_reference_points_for_rhand_bbox,
new_reference_points_for_face_box), dim=1).flatten(0,1)
new_reference_points = torch.cat((dn_new_reference_points, new_reference_points),dim=0)
output = torch.cat((dn_output, output), dim=0)
tgt_mask = tgt_mask2
# human-to-keypoints, human2face, human2hand update # 2
if layer_id >= self.num_box_decoder_layers:
reference_before_sigmoid = inverse_sigmoid(reference_points)
reference_before_sigmoid_body_bbox_dn = reference_before_sigmoid[:effect_num_dn]
reference_before_sigmoid_bbox_body_norm = \
reference_before_sigmoid[effect_num_dn:][0::(self.num_body_points+4)]
output_bbox_body_dn=output[:effect_num_dn]
output_bbox_body_norm = output[effect_num_dn:][
0::(self.num_body_points+4)]
delta_unsig_bbox_body_dn = self.bbox_embed[
layer_id](output_bbox_body_dn)
delta_unsig_bbox_body_norm = self.bbox_embed[
layer_id](output_bbox_body_norm)
outputs_unsig_body_bbox_dn = delta_unsig_bbox_body_dn + reference_before_sigmoid_body_bbox_dn
outputs_unsig_body_bbox_norm = delta_unsig_bbox_body_norm + reference_before_sigmoid_bbox_body_norm
new_reference_points_for_body_box_dn = outputs_unsig_body_bbox_dn.sigmoid()
new_reference_points_for_body_box_norm = outputs_unsig_body_bbox_norm.sigmoid()
# lhand box
output_lhand_bbox_query = output[effect_num_dn:][
(self.num_body_points + 1)::(self.num_body_points+4)]
delta_xy_lhand_bbox_unsig = self.bbox_hand_embed[
layer_id-self.num_box_decoder_layers](output_lhand_bbox_query)
outputs_lhand_bbox_unsig = \
reference_before_sigmoid[effect_num_dn:][
(self.num_body_points + 1)::(self.num_body_points+4)].clone()
delta_hw_lhand_bbox_unsig = self.bbox_hand_hw_embed[
layer_id-self.num_box_decoder_layers](output_lhand_bbox_query)
outputs_lhand_bbox_unsig[..., :2] +=delta_xy_lhand_bbox_unsig[..., :2]
outputs_lhand_bbox_unsig[..., 2:] +=delta_hw_lhand_bbox_unsig
new_reference_points_for_lhand_box_norm = outputs_lhand_bbox_unsig.sigmoid()
# rhand box
output_rhand_bbox_query = output[effect_num_dn:][
(self.num_body_points + 2)::(self.num_body_points+4)]
delta_xy_rhand_bbox_unsig = self.bbox_hand_embed[
layer_id-self.num_box_decoder_layers](output_rhand_bbox_query)
outputs_rhand_bbox_unsig = \
reference_before_sigmoid[effect_num_dn:][
(self.num_body_points + 2)::(self.num_body_points+4)].clone()
delta_hw_rhand_bbox_unsig = self.bbox_hand_hw_embed[
layer_id-self.num_box_decoder_layers](output_rhand_bbox_query)
outputs_rhand_bbox_unsig[..., :2] +=delta_xy_rhand_bbox_unsig[..., :2]
outputs_rhand_bbox_unsig[..., 2:] +=delta_hw_rhand_bbox_unsig
new_reference_points_for_rhand_box_norm = outputs_rhand_bbox_unsig.sigmoid()
# face box
output_face_bbox_query = output[effect_num_dn:][
(self.num_body_points + 3)::(self.num_body_points+4)]
delta_xy_face_bbox_unsig = self.bbox_face_embed[
layer_id-self.num_box_decoder_layers](output_face_bbox_query)
outputs_face_bbox_unsig = \
reference_before_sigmoid[effect_num_dn:][
(self.num_body_points + 3)::(self.num_body_points+4)].clone()
delta_hw_face_bbox_unsig = self.bbox_face_hw_embed[
layer_id-self.num_box_decoder_layers](output_face_bbox_query)
outputs_face_bbox_unsig[..., :2] +=delta_xy_face_bbox_unsig[..., :2]
outputs_face_bbox_unsig[..., 2:] +=delta_hw_face_bbox_unsig
new_reference_points_for_face_box_norm = outputs_face_bbox_unsig.sigmoid()
new_reference_points_norm = torch.cat(
(new_reference_points_for_body_box_norm.unsqueeze(1),
new_reference_points_for_lhand_box_norm.unsqueeze(1),
new_reference_points_for_rhand_box_norm.unsqueeze(1),
new_reference_points_for_face_box_norm.unsqueeze(1)), dim=1).flatten(0,1)
new_reference_points = torch.cat((
new_reference_points_for_body_box_dn,
new_reference_points_norm), dim=0)
if self.rm_detach and 'dec' in self.rm_detach:
reference_points = new_reference_points
else:
reference_points = new_reference_points.detach()
ref_points.append(new_reference_points)
return [[itm_out.transpose(0, 1) for itm_out in intermediate],
[itm_refpoint.transpose(0, 1) for itm_refpoint in ref_points]]
class Transformer_Box(nn.Module):
def __init__(
self,
d_model=256,
nhead=8,
num_queries=300,
num_encoder_layers=6,
num_decoder_layers=6,
dim_feedforward=2048,
dropout=0.0,
activation='relu',
normalize_before=False,
return_intermediate_dec=False,
query_dim=4,
num_patterns=0,
modulate_hw_attn=False,
# for deformable encoder
deformable_encoder=False,
deformable_decoder=False,
num_feature_levels=1,
enc_n_points=4,
dec_n_points=4,
# init query
learnable_tgt_init=False,
random_refpoints_xy=False,
# two stage
two_stage_type='no',
two_stage_learn_wh=False,
two_stage_keep_all_tokens=False,
# evo of #anchors
dec_layer_number=None,
rm_self_attn_layers=None,
# for detach
rm_detach=None,
decoder_sa_type='sa',
module_seq=['sa', 'ca', 'ffn'],
# for pose
embed_init_tgt=False,
num_body_points=0,
num_hand_points=0,
num_face_points=0,
num_box_decoder_layers=2,
num_hand_face_decoder_layers=4,
num_group=100):
super().__init__()
# pdb.set_trace()
self.num_feature_levels = num_feature_levels # 4
self.num_encoder_layers = num_encoder_layers # 6
self.num_decoder_layers = num_decoder_layers # 6
self.deformable_encoder = deformable_encoder
self.deformable_decoder = deformable_decoder
self.two_stage_keep_all_tokens = two_stage_keep_all_tokens # False
self.num_queries = num_queries # 900
self.random_refpoints_xy = random_refpoints_xy # False
assert query_dim == 4
if num_feature_levels > 1:
assert deformable_encoder, 'only support deformable_encoder for num_feature_levels > 1'
self.decoder_sa_type = decoder_sa_type # sa
assert decoder_sa_type in ['sa', 'ca_label', 'ca_content']
# choose encoder layer type
if deformable_encoder:
encoder_layer = DeformableTransformerEncoderLayer(
d_model, dim_feedforward, dropout, activation,
num_feature_levels, nhead, enc_n_points)
else:
raise NotImplementedError
encoder_layer = TransformerEncoderLayer(d_model, nhead,
dim_feedforward, dropout,
activation,
normalize_before)
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
self.encoder = TransformerEncoder(
encoder_layer,
num_encoder_layers,
encoder_norm,
d_model=d_model,
num_queries=num_queries,
deformable_encoder=deformable_encoder,
two_stage_type=two_stage_type)
# choose decoder layer type
if deformable_decoder:
decoder_layer = DeformableTransformerDecoderLayer(
d_model,
dim_feedforward,
dropout,
activation,
num_feature_levels,
nhead,
dec_n_points,
decoder_sa_type=decoder_sa_type,
module_seq=module_seq)
else:
raise NotImplementedError
decoder_layer = TransformerDecoderLayer(
d_model,
nhead,
dim_feedforward,
dropout,
activation,
normalize_before,
num_feature_levels=num_feature_levels)
decoder_norm = nn.LayerNorm(d_model)
self.decoder = TransformerDecoder_Box(
decoder_layer,
num_decoder_layers,
decoder_norm,
return_intermediate=return_intermediate_dec,
d_model=d_model,
query_dim=query_dim,
modulate_hw_attn=modulate_hw_attn,
num_feature_levels=num_feature_levels,
deformable_decoder=deformable_decoder,
dec_layer_number=dec_layer_number,
num_body_points=num_body_points,
num_hand_points=num_hand_points,
num_face_points=num_face_points,
num_box_decoder_layers=num_box_decoder_layers,
num_hand_face_decoder_layers=num_hand_face_decoder_layers,
num_group=num_group,
num_dn=num_group,
)
self.d_model = d_model
self.nhead = nhead # 8
self.dec_layers = num_decoder_layers # 6
self.num_queries = num_queries # useful for single stage model only
self.num_patterns = num_patterns # 0
if not isinstance(num_patterns, int):
Warning('num_patterns should be int but {}'.format(
type(num_patterns)))
self.num_patterns = 0
if self.num_patterns > 0:
assert two_stage_type == 'no'
self.patterns = nn.Embedding(self.num_patterns, d_model)
if num_feature_levels > 1:
if self.num_encoder_layers > 0:
self.level_embed = nn.Parameter(
torch.Tensor(num_feature_levels, d_model))
else:
self.level_embed = None
self.learnable_tgt_init = learnable_tgt_init # true
assert learnable_tgt_init, 'why not learnable_tgt_init'
self.embed_init_tgt = embed_init_tgt # false
if (two_stage_type != 'no' and embed_init_tgt) or (two_stage_type
== 'no'):
self.tgt_embed = nn.Embedding(self.num_queries, d_model)
nn.init.normal_(self.tgt_embed.weight.data)
else:
self.tgt_embed = None
# for two stage
self.two_stage_type = two_stage_type
self.two_stage_learn_wh = two_stage_learn_wh
assert two_stage_type in [
'no', 'standard', 'early', 'combine', 'enceachlayer', 'enclayer1'
], 'unknown param {} of two_stage_type'.format(two_stage_type)
if two_stage_type in [
'standard', 'combine', 'enceachlayer', 'enclayer1'
]:
# anchor selection at the output of encoder
self.enc_output = nn.Linear(d_model, d_model)
self.enc_output_norm = nn.LayerNorm(d_model)
if two_stage_learn_wh:
# import pdb; pdb.set_trace()
self.two_stage_wh_embedding = nn.Embedding(1, 2)
else:
self.two_stage_wh_embedding = None
if two_stage_type in ['early', 'combine']:
# anchor selection at the output of backbone
self.enc_output_backbone = nn.Linear(d_model, d_model)
self.enc_output_norm_backbone = nn.LayerNorm(d_model)
if two_stage_type == 'no':
self.init_ref_points(num_queries) # init self.refpoint_embed
self.enc_out_class_embed = None
self.enc_out_bbox_embed = None
self.enc_out_pose_embed = None
# evolution of anchors
self.dec_layer_number = dec_layer_number
if dec_layer_number is not None:
if self.two_stage_type != 'no' or num_patterns == 0:
assert dec_layer_number[
0] == num_queries, f'dec_layer_number[0]({dec_layer_number[0]}) != num_queries({num_queries})'
else:
assert dec_layer_number[
0] == num_queries * num_patterns, f'dec_layer_number[0]({dec_layer_number[0]}) != num_queries({num_queries}) * num_patterns({num_patterns})'
self._reset_parameters()
self.rm_self_attn_layers = rm_self_attn_layers
if rm_self_attn_layers is not None:
# assert len(rm_self_attn_layers) == num_decoder_layers
print('Removing the self-attn in {} decoder layers'.format(
rm_self_attn_layers))
for lid, dec_layer in enumerate(self.decoder.layers):
if lid in rm_self_attn_layers:
dec_layer.rm_self_attn_modules()
self.rm_detach = rm_detach
if self.rm_detach:
assert isinstance(rm_detach, list)
assert any([i in ['enc_ref', 'enc_tgt', 'dec'] for i in rm_detach])
self.decoder.rm_detach = rm_detach
def _reset_parameters(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
for m in self.modules():
if isinstance(m, MSDeformAttn):
m._reset_parameters()
if self.num_feature_levels > 1 and self.level_embed is not None:
nn.init.normal_(self.level_embed)
if self.two_stage_learn_wh:
nn.init.constant_(self.two_stage_wh_embedding.weight,
math.log(0.05 / (1 - 0.05)))
def get_valid_ratio(self, mask):
_, H, W = mask.shape
valid_H = torch.sum(~mask[:, :, 0], 1)
valid_W = torch.sum(~mask[:, 0, :], 1)
valid_ratio_h = valid_H.float() / H
valid_ratio_w = valid_W.float() / W
valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
return valid_ratio
def init_ref_points(self, use_num_queries):
self.refpoint_embed = nn.Embedding(use_num_queries, 4)
if self.random_refpoints_xy:
# import pdb; pdb.set_trace()
self.refpoint_embed.weight.data[:, :2].uniform_(0, 1)
self.refpoint_embed.weight.data[:, :2] = inverse_sigmoid(
self.refpoint_embed.weight.data[:, :2])
self.refpoint_embed.weight.data[:, :2].requires_grad = False
# srcs: features; refpoint_embed:
def forward(self,
srcs,
masks,
refpoint_embed,
pos_embeds,
tgt,
attn_mask=None,
attn_mask2=None,
attn_mask3=None):
# pdb.set_trace()
# prepare input for encoder
src_flatten = []
mask_flatten = []
lvl_pos_embed_flatten = []
spatial_shapes = []
for lvl, (src, mask, pos_embed) in enumerate(
zip(srcs, masks, pos_embeds)): # for feature level
bs, c, h, w = src.shape
spatial_shape = (h, w)
spatial_shapes.append(spatial_shape)
src = src.flatten(2).transpose(1, 2) # bs, hw, c
mask = mask.flatten(1) # bs, hw
pos_embed = pos_embed.flatten(2).transpose(1, 2) # bs, hw, c
if self.num_feature_levels > 1 and self.level_embed is not None:
lvl_pos_embed = pos_embed + self.level_embed[lvl].view(
1, 1, -1) # level_embed[lvl]: [256]
else:
lvl_pos_embed = pos_embed
lvl_pos_embed_flatten.append(lvl_pos_embed)
src_flatten.append(src)
mask_flatten.append(mask)
src_flatten = torch.cat(src_flatten, 1) # bs, \sum{hxw}, c
mask_flatten = torch.cat(mask_flatten, 1) # bs, \sum{hxw}
lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten,
1) # bs, \sum{hxw}, c
spatial_shapes = torch.as_tensor(spatial_shapes,
dtype=torch.long,
device=src_flatten.device)
level_start_index = torch.cat((spatial_shapes.new_zeros(
(1, )), spatial_shapes.prod(1).cumsum(0)[:-1]))
valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1)
# two stage
if self.two_stage_type in ['early', 'combine']:
output_memory, output_proposals = gen_encoder_output_proposals(
src_flatten, mask_flatten, spatial_shapes)
output_memory = self.enc_output_norm_backbone(
self.enc_output_backbone(output_memory))
# gather boxes
topk = self.num_queries
enc_outputs_class = self.encoder.class_embed[0](output_memory)
enc_topk_proposals = torch.topk(enc_outputs_class.max(-1)[0],
topk,
dim=1)[1] # bs, nq
enc_refpoint_embed = torch.gather(
output_proposals, 1,
enc_topk_proposals.unsqueeze(-1).repeat(1, 1, 4))
src_flatten = output_memory
else:
enc_topk_proposals = enc_refpoint_embed = None
#########################################################
# Begin Encoder
#########################################################
memory, enc_intermediate_output, enc_intermediate_refpoints = self.encoder(
src_flatten,
pos=lvl_pos_embed_flatten,
level_start_index=level_start_index,
spatial_shapes=spatial_shapes,
valid_ratios=valid_ratios,
key_padding_mask=mask_flatten,
ref_token_index=enc_topk_proposals, # bs, nq
ref_token_coord=enc_refpoint_embed, # bs, nq, 4
)
#########################################################
# End Encoder
# - memory: bs, \sum{hw}, c
# - mask_flatten: bs, \sum{hw}
# - lvl_pos_embed_flatten: bs, \sum{hw}, c
# - enc_intermediate_output: None or (nenc+1, bs, nq, c) or (nenc, bs, nq, c)
# - enc_intermediate_refpoints: None or (nenc+1, bs, nq, c) or (nenc, bs, nq, c)
#########################################################
if self.two_stage_type in [
'standard', 'combine', 'enceachlayer', 'enclayer1'
]:
if self.two_stage_learn_wh:
# import pdb; pdb.set_trace()
input_hw = self.two_stage_wh_embedding.weight[0]
else:
input_hw = None
output_memory, output_proposals = gen_encoder_output_proposals(
memory, mask_flatten, spatial_shapes, input_hw)
output_memory = self.enc_output_norm(
self.enc_output(output_memory))
enc_outputs_class_unselected = self.enc_out_class_embed(
output_memory) # [11531, 2] for swin
enc_outputs_coord_unselected = self.enc_out_bbox_embed(
output_memory
) + output_proposals # (bs, \sum{hw}, 4) unsigmoid
topk = self.num_queries
topk_proposals = torch.topk(
enc_outputs_class_unselected.max(-1)[0], topk,
dim=1)[1] # bs, nq coarse human query selection
# gather boxes
refpoint_embed_undetach = torch.gather(
enc_outputs_coord_unselected, 1,
topk_proposals.unsqueeze(-1).repeat(1, 1, 4)) # unsigmoid
refpoint_embed_ = refpoint_embed_undetach.detach()
init_box_proposal = torch.gather(
output_proposals, 1,
topk_proposals.unsqueeze(-1).repeat(1, 1,
4)).sigmoid() # sigmoid
# gather tgt
tgt_undetach = torch.gather(
output_memory, 1,
topk_proposals.unsqueeze(-1).repeat(
1, 1, self.d_model)) # selected content query
if self.embed_init_tgt:
tgt_ = self.tgt_embed.weight[:, None, :].repeat(
1, bs, 1).transpose(0, 1) # nq, bs, d_model
else:
tgt_ = tgt_undetach.detach()
if refpoint_embed is not None:
# import pdb; pdb.set_trace()
refpoint_embed = torch.cat([refpoint_embed, refpoint_embed_],
dim=1) # [1000, 4]
tgt = torch.cat([tgt, tgt_], dim=1)
else:
refpoint_embed, tgt = refpoint_embed_, tgt_
elif self.two_stage_type == 'early':
refpoint_embed_undetach = self.enc_out_bbox_embed(
enc_intermediate_output[-1]
) + enc_refpoint_embed # unsigmoid, (bs, nq, 4)
refpoint_embed = refpoint_embed_undetach.detach() #
tgt_undetach = enc_intermediate_output[-1] # bs, nq, d_model
tgt = tgt_undetach.detach()
elif self.two_stage_type == 'no':
tgt_ = self.tgt_embed.weight[:,
None, :].repeat(1, bs, 1).transpose(
0, 1) # nq, bs, d_model
refpoint_embed_ = self.refpoint_embed.weight[:, None, :].repeat(
1, bs, 1).transpose(0, 1) # nq, bs, 4
if refpoint_embed is not None:
# import pdb; pdb.set_trace()
refpoint_embed = torch.cat([refpoint_embed, refpoint_embed_],
dim=1)
tgt = torch.cat([tgt, tgt_], dim=1)
else:
refpoint_embed, tgt = refpoint_embed_, tgt_
# pat embed
if self.num_patterns > 0:
tgt_embed = tgt.repeat(1, self.num_patterns, 1)
refpoint_embed = refpoint_embed.repeat(1, self.num_patterns, 1)
tgt_pat = self.patterns.weight[None, :, :].repeat_interleave(
self.num_queries, 1) # 1, n_q*n_pat, d_model
tgt = tgt_embed + tgt_pat
init_box_proposal = refpoint_embed_.sigmoid()
else:
raise NotImplementedError('unknown two_stage_type {}'.format(
self.two_stage_type))
#########################################################
# Begin Decoder
#########################################################
hs, references = self.decoder(
tgt=tgt.transpose(0, 1),
memory=memory.transpose(0, 1),
memory_key_padding_mask=mask_flatten,
pos=lvl_pos_embed_flatten.transpose(0, 1),
refpoints_unsigmoid=refpoint_embed.transpose(0, 1),
level_start_index=level_start_index,
spatial_shapes=spatial_shapes,
valid_ratios=valid_ratios,
tgt_mask=attn_mask,
tgt_mask2=attn_mask2,
tgt_mask3=attn_mask3)
#########################################################
# End Decoder
# hs: n_dec, bs, nq, d_model
# references: n_dec+1, bs, nq, query_dim
#########################################################
#########################################################
# Begin postprocess
#########################################################
if self.two_stage_type == 'standard':
if self.two_stage_keep_all_tokens:
hs_enc = output_memory.unsqueeze(0)
ref_enc = enc_outputs_coord_unselected.unsqueeze(0)
init_box_proposal = output_proposals
# import pdb; pdb.set_trace()
else:
hs_enc = tgt_undetach.unsqueeze(0)
ref_enc = refpoint_embed_undetach.sigmoid().unsqueeze(0)
elif self.two_stage_type in ['combine', 'early']:
hs_enc = enc_intermediate_output
hs_enc = torch.cat((hs_enc, tgt_undetach.unsqueeze(0)),
dim=0) # nenc+1, bs, nq, c
n_layer_hs_enc = hs_enc.shape[0]
assert n_layer_hs_enc == self.num_encoder_layers + 1
ref_enc = enc_intermediate_refpoints
ref_enc = torch.cat(
(ref_enc, refpoint_embed_undetach.sigmoid().unsqueeze(0)),
dim=0) # nenc+1, bs, nq, 4
elif self.two_stage_type in ['enceachlayer', 'enclayer1']:
hs_enc = enc_intermediate_output
hs_enc = torch.cat((hs_enc, tgt_undetach.unsqueeze(0)),
dim=0) # nenc, bs, nq, c
n_layer_hs_enc = hs_enc.shape[0]
assert n_layer_hs_enc == self.num_encoder_layers
ref_enc = enc_intermediate_refpoints
ref_enc = torch.cat(
(ref_enc, refpoint_embed_undetach.sigmoid().unsqueeze(0)),
dim=0) # nenc, bs, nq, 4
else:
hs_enc = ref_enc = None
return hs, references, hs_enc, ref_enc, init_box_proposal