Spaces:
Sleeping
Sleeping
import copy | |
import os | |
from typing import Optional, List | |
import math | |
import torch | |
import torch.nn.functional as F | |
from torch import nn, Tensor | |
from util.misc import inverse_sigmoid | |
from .ops.modules import MSDeformAttn | |
from .utils import sigmoid_focal_loss, MLP, _get_activation_fn, gen_sineembed_for_position | |
import pdb | |
class DeformableTransformerEncoderLayer(nn.Module): | |
def __init__( | |
self, | |
d_model=256, | |
d_ffn=1024, | |
dropout=0.1, | |
activation='relu', | |
n_levels=4, | |
n_heads=8, | |
n_points=4, | |
): | |
super().__init__() | |
# pdb.set_trace() | |
# self attention | |
self.self_attn = MSDeformAttn(d_model, n_levels, n_heads, | |
n_points) # 256 4 8 4 | |
self.dropout1 = nn.Dropout(dropout) | |
self.norm1 = nn.LayerNorm(d_model) | |
# ffn | |
self.linear1 = nn.Linear(d_model, d_ffn) | |
self.activation = _get_activation_fn(activation, d_model=d_ffn) | |
self.dropout2 = nn.Dropout(dropout) | |
self.linear2 = nn.Linear(d_ffn, d_model) | |
self.dropout3 = nn.Dropout(dropout) | |
self.norm2 = nn.LayerNorm(d_model) | |
def with_pos_embed(tensor, pos): | |
return tensor if pos is None else tensor + pos | |
def forward_ffn(self, src): | |
src2 = self.linear2(self.dropout2(self.activation(self.linear1(src)))) | |
src = src + self.dropout3(src2) | |
src = self.norm2(src) | |
return src | |
def forward(self, | |
src, | |
pos, | |
reference_points, | |
spatial_shapes, | |
level_start_index, | |
key_padding_mask=None): | |
# pdb.set_trace() | |
src2 = self.self_attn(self.with_pos_embed(src, pos), reference_points, | |
src, spatial_shapes, level_start_index, | |
key_padding_mask) | |
src = src + self.dropout1(src2) | |
src = self.norm1(src) | |
src = self.forward_ffn(src) | |
return src | |
class DeformableTransformerDecoderLayer(nn.Module): | |
def __init__( | |
self, | |
d_model=256, | |
d_ffn=1024, | |
dropout=0.1, | |
activation='relu', | |
n_levels=4, | |
n_heads=8, | |
n_points=4, | |
decoder_sa_type='ca', | |
module_seq=['sa', 'ca', 'ffn'], | |
): | |
super().__init__() | |
# pdb.set_trace() | |
self.module_seq = module_seq | |
assert sorted(module_seq) == ['ca', 'ffn', 'sa'] | |
# cross attention | |
self.cross_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points) | |
self.dropout1 = nn.Dropout(dropout) | |
self.norm1 = nn.LayerNorm(d_model) | |
# self attention | |
self.self_attn = nn.MultiheadAttention(d_model, | |
n_heads, | |
dropout=dropout) | |
self.dropout2 = nn.Dropout(dropout) | |
self.norm2 = nn.LayerNorm(d_model) | |
# ffn | |
self.linear1 = nn.Linear(d_model, d_ffn) | |
self.activation = _get_activation_fn(activation, | |
d_model=d_ffn, | |
batch_dim=1) | |
self.dropout3 = nn.Dropout(dropout) | |
self.linear2 = nn.Linear(d_ffn, d_model) | |
self.dropout4 = nn.Dropout(dropout) | |
self.norm3 = nn.LayerNorm(d_model) | |
self.key_aware_proj = None | |
self.decoder_sa_type = decoder_sa_type | |
assert decoder_sa_type in ['sa'] | |
def rm_self_attn_modules(self): | |
self.self_attn = None | |
self.dropout2 = None | |
self.norm2 = None | |
def with_pos_embed(tensor, pos): | |
return tensor if pos is None else tensor + pos | |
def forward_ffn(self, tgt): | |
tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt)))) | |
tgt = tgt + self.dropout4(tgt2) | |
tgt = self.norm3(tgt) | |
return tgt | |
def forward( | |
self, | |
# for tgt | |
tgt: Optional[Tensor], # nq, bs, d_model | |
tgt_query_pos: Optional[ | |
Tensor] = None, # pos for query. MLP(Sine(pos)) | |
tgt_query_sine_embed: Optional[ | |
Tensor] = None, # pos for query. Sine(pos) | |
tgt_key_padding_mask: Optional[Tensor] = None, | |
tgt_reference_points: Optional[Tensor] = None, # nq, bs, 4 | |
# for memory | |
memory: Optional[Tensor] = None, # hw, bs, d_model | |
memory_key_padding_mask: Optional[Tensor] = None, | |
memory_level_start_index: Optional[Tensor] = None, # num_levels | |
memory_spatial_shapes: Optional[ | |
Tensor] = None, # bs, num_levels, 2 | |
memory_pos: Optional[Tensor] = None, # pos for memory | |
# sa | |
self_attn_mask: Optional[ | |
Tensor] = None, # mask used for self-attention | |
cross_attn_mask: Optional[ | |
Tensor] = None, # mask used for cross-attention | |
): | |
""" | |
Input: | |
- tgt/tgt_query_pos: nq, bs, d_model | |
- | |
""" | |
# pdb.set_trace() | |
assert cross_attn_mask is None | |
if self.self_attn is not None: | |
q = k = self.with_pos_embed(tgt, tgt_query_pos) | |
tgt2 = self.self_attn(q, k, tgt, attn_mask=self_attn_mask)[0] | |
tgt = tgt + self.dropout2(tgt2) | |
tgt = self.norm2(tgt) | |
tgt2 = self.cross_attn( | |
self.with_pos_embed(tgt, tgt_query_pos).transpose(0, 1), | |
tgt_reference_points.transpose(0, 1).contiguous(), | |
memory.transpose(0, 1), memory_spatial_shapes, | |
memory_level_start_index, memory_key_padding_mask).transpose(0, 1) | |
tgt = tgt + self.dropout1(tgt2) | |
tgt = self.norm1(tgt) | |
# ffn | |
tgt = self.forward_ffn(tgt) | |
return tgt | |
def _get_clones(module, N): | |
return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) | |