|
import torch.nn as nn |
|
|
|
from networks.encoders import build_encoder |
|
from networks.layers.transformer import LongShortTermTransformer |
|
from networks.decoders import build_decoder |
|
from networks.layers.position import PositionEmbeddingSine |
|
|
|
|
|
class AOT(nn.Module): |
|
def __init__(self, cfg, encoder='mobilenetv2', decoder='fpn'): |
|
super().__init__() |
|
self.cfg = cfg |
|
self.max_obj_num = cfg.MODEL_MAX_OBJ_NUM |
|
self.epsilon = cfg.MODEL_EPSILON |
|
|
|
self.encoder = build_encoder(encoder, |
|
frozen_bn=cfg.MODEL_FREEZE_BN, |
|
freeze_at=cfg.TRAIN_ENCODER_FREEZE_AT) |
|
self.encoder_projector = nn.Conv2d(cfg.MODEL_ENCODER_DIM[-1], |
|
cfg.MODEL_ENCODER_EMBEDDING_DIM, |
|
kernel_size=1) |
|
|
|
self.LSTT = LongShortTermTransformer( |
|
cfg.MODEL_LSTT_NUM, |
|
cfg.MODEL_ENCODER_EMBEDDING_DIM, |
|
cfg.MODEL_SELF_HEADS, |
|
cfg.MODEL_ATT_HEADS, |
|
emb_dropout=cfg.TRAIN_LSTT_EMB_DROPOUT, |
|
droppath=cfg.TRAIN_LSTT_DROPPATH, |
|
lt_dropout=cfg.TRAIN_LSTT_LT_DROPOUT, |
|
st_dropout=cfg.TRAIN_LSTT_ST_DROPOUT, |
|
droppath_lst=cfg.TRAIN_LSTT_DROPPATH_LST, |
|
droppath_scaling=cfg.TRAIN_LSTT_DROPPATH_SCALING, |
|
intermediate_norm=cfg.MODEL_DECODER_INTERMEDIATE_LSTT, |
|
return_intermediate=True) |
|
|
|
decoder_indim = cfg.MODEL_ENCODER_EMBEDDING_DIM * \ |
|
(cfg.MODEL_LSTT_NUM + |
|
1) if cfg.MODEL_DECODER_INTERMEDIATE_LSTT else cfg.MODEL_ENCODER_EMBEDDING_DIM |
|
|
|
self.decoder = build_decoder( |
|
decoder, |
|
in_dim=decoder_indim, |
|
out_dim=cfg.MODEL_MAX_OBJ_NUM + 1, |
|
decode_intermediate_input=cfg.MODEL_DECODER_INTERMEDIATE_LSTT, |
|
hidden_dim=cfg.MODEL_ENCODER_EMBEDDING_DIM, |
|
shortcut_dims=cfg.MODEL_ENCODER_DIM, |
|
align_corners=cfg.MODEL_ALIGN_CORNERS) |
|
|
|
if cfg.MODEL_ALIGN_CORNERS: |
|
self.patch_wise_id_bank = nn.Conv2d( |
|
cfg.MODEL_MAX_OBJ_NUM + 1, |
|
cfg.MODEL_ENCODER_EMBEDDING_DIM, |
|
kernel_size=17, |
|
stride=16, |
|
padding=8) |
|
else: |
|
self.patch_wise_id_bank = nn.Conv2d( |
|
cfg.MODEL_MAX_OBJ_NUM + 1, |
|
cfg.MODEL_ENCODER_EMBEDDING_DIM, |
|
kernel_size=16, |
|
stride=16, |
|
padding=0) |
|
|
|
self.id_dropout = nn.Dropout(cfg.TRAIN_LSTT_ID_DROPOUT, True) |
|
|
|
self.pos_generator = PositionEmbeddingSine( |
|
cfg.MODEL_ENCODER_EMBEDDING_DIM // 2, normalize=True) |
|
|
|
self._init_weight() |
|
|
|
def get_pos_emb(self, x): |
|
pos_emb = self.pos_generator(x) |
|
return pos_emb |
|
|
|
def get_id_emb(self, x): |
|
id_emb = self.patch_wise_id_bank(x) |
|
id_emb = self.id_dropout(id_emb) |
|
return id_emb |
|
|
|
def encode_image(self, img): |
|
xs = self.encoder(img) |
|
xs[-1] = self.encoder_projector(xs[-1]) |
|
return xs |
|
|
|
def decode_id_logits(self, lstt_emb, shortcuts): |
|
n, c, h, w = shortcuts[-1].size() |
|
decoder_inputs = [shortcuts[-1]] |
|
for emb in lstt_emb: |
|
decoder_inputs.append(emb.view(h, w, n, c).permute(2, 3, 0, 1)) |
|
pred_logit = self.decoder(decoder_inputs, shortcuts) |
|
return pred_logit |
|
|
|
def LSTT_forward(self, |
|
curr_embs, |
|
long_term_memories, |
|
short_term_memories, |
|
curr_id_emb=None, |
|
pos_emb=None, |
|
size_2d=(30, 30)): |
|
n, c, h, w = curr_embs[-1].size() |
|
curr_emb = curr_embs[-1].view(n, c, h * w).permute(2, 0, 1) |
|
lstt_embs, lstt_memories = self.LSTT(curr_emb, long_term_memories, |
|
short_term_memories, curr_id_emb, |
|
pos_emb, size_2d) |
|
lstt_curr_memories, lstt_long_memories, lstt_short_memories = zip( |
|
*lstt_memories) |
|
return lstt_embs, lstt_curr_memories, lstt_long_memories, lstt_short_memories |
|
|
|
def _init_weight(self): |
|
nn.init.xavier_uniform_(self.encoder_projector.weight) |
|
nn.init.orthogonal_( |
|
self.patch_wise_id_bank.weight.view( |
|
self.cfg.MODEL_ENCODER_EMBEDDING_DIM, -1).permute(0, 1), |
|
gain=17**-2 if self.cfg.MODEL_ALIGN_CORNERS else 16**-2) |
|
|