AiOS / models /aios /transformer_deformable.py
ttxskk
update
d7e58f0
raw
history blame
5.98 kB
import copy
import os
from typing import Optional, List
import math
import torch
import torch.nn.functional as F
from torch import nn, Tensor
from util.misc import inverse_sigmoid
from .ops.modules import MSDeformAttn
from .utils import sigmoid_focal_loss, MLP, _get_activation_fn, gen_sineembed_for_position
import pdb
class DeformableTransformerEncoderLayer(nn.Module):
def __init__(
self,
d_model=256,
d_ffn=1024,
dropout=0.1,
activation='relu',
n_levels=4,
n_heads=8,
n_points=4,
):
super().__init__()
# pdb.set_trace()
# self attention
self.self_attn = MSDeformAttn(d_model, n_levels, n_heads,
n_points) # 256 4 8 4
self.dropout1 = nn.Dropout(dropout)
self.norm1 = nn.LayerNorm(d_model)
# ffn
self.linear1 = nn.Linear(d_model, d_ffn)
self.activation = _get_activation_fn(activation, d_model=d_ffn)
self.dropout2 = nn.Dropout(dropout)
self.linear2 = nn.Linear(d_ffn, d_model)
self.dropout3 = nn.Dropout(dropout)
self.norm2 = nn.LayerNorm(d_model)
@staticmethod
def with_pos_embed(tensor, pos):
return tensor if pos is None else tensor + pos
def forward_ffn(self, src):
src2 = self.linear2(self.dropout2(self.activation(self.linear1(src))))
src = src + self.dropout3(src2)
src = self.norm2(src)
return src
def forward(self,
src,
pos,
reference_points,
spatial_shapes,
level_start_index,
key_padding_mask=None):
# pdb.set_trace()
src2 = self.self_attn(self.with_pos_embed(src, pos), reference_points,
src, spatial_shapes, level_start_index,
key_padding_mask)
src = src + self.dropout1(src2)
src = self.norm1(src)
src = self.forward_ffn(src)
return src
class DeformableTransformerDecoderLayer(nn.Module):
def __init__(
self,
d_model=256,
d_ffn=1024,
dropout=0.1,
activation='relu',
n_levels=4,
n_heads=8,
n_points=4,
decoder_sa_type='ca',
module_seq=['sa', 'ca', 'ffn'],
):
super().__init__()
# pdb.set_trace()
self.module_seq = module_seq
assert sorted(module_seq) == ['ca', 'ffn', 'sa']
# cross attention
self.cross_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)
self.dropout1 = nn.Dropout(dropout)
self.norm1 = nn.LayerNorm(d_model)
# self attention
self.self_attn = nn.MultiheadAttention(d_model,
n_heads,
dropout=dropout)
self.dropout2 = nn.Dropout(dropout)
self.norm2 = nn.LayerNorm(d_model)
# ffn
self.linear1 = nn.Linear(d_model, d_ffn)
self.activation = _get_activation_fn(activation,
d_model=d_ffn,
batch_dim=1)
self.dropout3 = nn.Dropout(dropout)
self.linear2 = nn.Linear(d_ffn, d_model)
self.dropout4 = nn.Dropout(dropout)
self.norm3 = nn.LayerNorm(d_model)
self.key_aware_proj = None
self.decoder_sa_type = decoder_sa_type
assert decoder_sa_type in ['sa']
def rm_self_attn_modules(self):
self.self_attn = None
self.dropout2 = None
self.norm2 = None
@staticmethod
def with_pos_embed(tensor, pos):
return tensor if pos is None else tensor + pos
def forward_ffn(self, tgt):
tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt))))
tgt = tgt + self.dropout4(tgt2)
tgt = self.norm3(tgt)
return tgt
def forward(
self,
# for tgt
tgt: Optional[Tensor], # nq, bs, d_model
tgt_query_pos: Optional[
Tensor] = None, # pos for query. MLP(Sine(pos))
tgt_query_sine_embed: Optional[
Tensor] = None, # pos for query. Sine(pos)
tgt_key_padding_mask: Optional[Tensor] = None,
tgt_reference_points: Optional[Tensor] = None, # nq, bs, 4
# for memory
memory: Optional[Tensor] = None, # hw, bs, d_model
memory_key_padding_mask: Optional[Tensor] = None,
memory_level_start_index: Optional[Tensor] = None, # num_levels
memory_spatial_shapes: Optional[
Tensor] = None, # bs, num_levels, 2
memory_pos: Optional[Tensor] = None, # pos for memory
# sa
self_attn_mask: Optional[
Tensor] = None, # mask used for self-attention
cross_attn_mask: Optional[
Tensor] = None, # mask used for cross-attention
):
"""
Input:
- tgt/tgt_query_pos: nq, bs, d_model
-
"""
# pdb.set_trace()
assert cross_attn_mask is None
if self.self_attn is not None:
q = k = self.with_pos_embed(tgt, tgt_query_pos)
tgt2 = self.self_attn(q, k, tgt, attn_mask=self_attn_mask)[0]
tgt = tgt + self.dropout2(tgt2)
tgt = self.norm2(tgt)
tgt2 = self.cross_attn(
self.with_pos_embed(tgt, tgt_query_pos).transpose(0, 1),
tgt_reference_points.transpose(0, 1).contiguous(),
memory.transpose(0, 1), memory_spatial_shapes,
memory_level_start_index, memory_key_padding_mask).transpose(0, 1)
tgt = tgt + self.dropout1(tgt2)
tgt = self.norm1(tgt)
# ffn
tgt = self.forward_ffn(tgt)
return tgt
def _get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])