import torch.nn as nn from networks.layers.transformer import DualBranchGPM from networks.models.aot import AOT from networks.decoders import build_decoder class DeAOT(AOT): def __init__(self, cfg, encoder='mobilenetv2', decoder='fpn'): super().__init__(cfg, encoder, decoder) self.LSTT = DualBranchGPM( cfg.MODEL_LSTT_NUM, cfg.MODEL_ENCODER_EMBEDDING_DIM, cfg.MODEL_SELF_HEADS, cfg.MODEL_ATT_HEADS, emb_dropout=cfg.TRAIN_LSTT_EMB_DROPOUT, droppath=cfg.TRAIN_LSTT_DROPPATH, lt_dropout=cfg.TRAIN_LSTT_LT_DROPOUT, st_dropout=cfg.TRAIN_LSTT_ST_DROPOUT, droppath_lst=cfg.TRAIN_LSTT_DROPPATH_LST, droppath_scaling=cfg.TRAIN_LSTT_DROPPATH_SCALING, intermediate_norm=cfg.MODEL_DECODER_INTERMEDIATE_LSTT, return_intermediate=True) decoder_indim = cfg.MODEL_ENCODER_EMBEDDING_DIM * \ (cfg.MODEL_LSTT_NUM * 2 + 1) if cfg.MODEL_DECODER_INTERMEDIATE_LSTT else cfg.MODEL_ENCODER_EMBEDDING_DIM * 2 self.decoder = build_decoder( decoder, in_dim=decoder_indim, out_dim=cfg.MODEL_MAX_OBJ_NUM + 1, decode_intermediate_input=cfg.MODEL_DECODER_INTERMEDIATE_LSTT, hidden_dim=cfg.MODEL_ENCODER_EMBEDDING_DIM, shortcut_dims=cfg.MODEL_ENCODER_DIM, align_corners=cfg.MODEL_ALIGN_CORNERS) self.id_norm = nn.LayerNorm(cfg.MODEL_ENCODER_EMBEDDING_DIM) self._init_weight() def decode_id_logits(self, lstt_emb, shortcuts): n, c, h, w = shortcuts[-1].size() decoder_inputs = [shortcuts[-1]] for emb in lstt_emb: decoder_inputs.append(emb.view(h, w, n, -1).permute(2, 3, 0, 1)) pred_logit = self.decoder(decoder_inputs, shortcuts) return pred_logit def get_id_emb(self, x): id_emb = self.patch_wise_id_bank(x) id_emb = self.id_norm(id_emb.permute(2, 3, 0, 1)).permute(2, 3, 0, 1) id_emb = self.id_dropout(id_emb) return id_emb