# Copyright (c) OpenMMLab. All rights reserved. from typing import Dict, Optional, Tuple import torch from torch import Tensor, nn from torch.nn.init import normal_ from mmdet.registry import MODELS from mmdet.structures import OptSampleList from mmdet.utils import OptConfigType from ..layers import (CdnQueryGenerator, DeformableDetrTransformerEncoder, DinoTransformerDecoder, SinePositionalEncoding) from .deformable_detr import DeformableDETR, MultiScaleDeformableAttention @MODELS.register_module() class DINO(DeformableDETR): r"""Implementation of `DINO: DETR with Improved DeNoising Anchor Boxes for End-to-End Object Detection `_ Code is modified from the `official github repo `_. Args: dn_cfg (:obj:`ConfigDict` or dict, optional): Config of denoising query generator. Defaults to `None`. """ def __init__(self, *args, dn_cfg: OptConfigType = None, **kwargs) -> None: super().__init__(*args, **kwargs) assert self.as_two_stage, 'as_two_stage must be True for DINO' assert self.with_box_refine, 'with_box_refine must be True for DINO' if dn_cfg is not None: assert 'num_classes' not in dn_cfg and \ 'num_queries' not in dn_cfg and \ 'hidden_dim' not in dn_cfg, \ 'The three keyword args `num_classes`, `embed_dims`, and ' \ '`num_matching_queries` are set in `detector.__init__()`, ' \ 'users should not set them in `dn_cfg` config.' dn_cfg['num_classes'] = self.bbox_head.num_classes dn_cfg['embed_dims'] = self.embed_dims dn_cfg['num_matching_queries'] = self.num_queries self.dn_query_generator = CdnQueryGenerator(**dn_cfg) def _init_layers(self) -> None: """Initialize layers except for backbone, neck and bbox_head.""" self.positional_encoding = SinePositionalEncoding( **self.positional_encoding) self.encoder = DeformableDetrTransformerEncoder(**self.encoder) self.decoder = DinoTransformerDecoder(**self.decoder) self.embed_dims = self.encoder.embed_dims self.query_embedding = nn.Embedding(self.num_queries, self.embed_dims) # NOTE In DINO, the query_embedding only contains content # queries, while in Deformable DETR, the query_embedding # contains both content and spatial queries, and in DETR, # it only contains spatial queries. num_feats = self.positional_encoding.num_feats assert num_feats * 2 == self.embed_dims, \ f'embed_dims should be exactly 2 times of num_feats. ' \ f'Found {self.embed_dims} and {num_feats}.' self.level_embed = nn.Parameter( torch.Tensor(self.num_feature_levels, self.embed_dims)) self.memory_trans_fc = nn.Linear(self.embed_dims, self.embed_dims) self.memory_trans_norm = nn.LayerNorm(self.embed_dims) def init_weights(self) -> None: """Initialize weights for Transformer and other components.""" super(DeformableDETR, self).init_weights() for coder in self.encoder, self.decoder: for p in coder.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) for m in self.modules(): if isinstance(m, MultiScaleDeformableAttention): m.init_weights() nn.init.xavier_uniform_(self.memory_trans_fc.weight) nn.init.xavier_uniform_(self.query_embedding.weight) normal_(self.level_embed) def forward_transformer( self, img_feats: Tuple[Tensor], batch_data_samples: OptSampleList = None, ) -> Dict: """Forward process of Transformer. The forward procedure of the transformer is defined as: 'pre_transformer' -> 'encoder' -> 'pre_decoder' -> 'decoder' More details can be found at `TransformerDetector.forward_transformer` in `mmdet/detector/base_detr.py`. The difference is that the ground truth in `batch_data_samples` is required for the `pre_decoder` to prepare the query of DINO. Additionally, DINO inherits the `pre_transformer` method and the `forward_encoder` method of DeformableDETR. More details about the two methods can be found in `mmdet/detector/deformable_detr.py`. Args: img_feats (tuple[Tensor]): Tuple of feature maps from neck. Each feature map has shape (bs, dim, H, W). batch_data_samples (list[:obj:`DetDataSample`]): The batch data samples. It usually includes information such as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. Defaults to None. Returns: dict: The dictionary of bbox_head function inputs, which always includes the `hidden_states` of the decoder output and may contain `references` including the initial and intermediate references. """ encoder_inputs_dict, decoder_inputs_dict = self.pre_transformer( img_feats, batch_data_samples) encoder_outputs_dict = self.forward_encoder(**encoder_inputs_dict) tmp_dec_in, head_inputs_dict = self.pre_decoder( **encoder_outputs_dict, batch_data_samples=batch_data_samples) decoder_inputs_dict.update(tmp_dec_in) decoder_outputs_dict = self.forward_decoder(**decoder_inputs_dict) head_inputs_dict.update(decoder_outputs_dict) return head_inputs_dict def pre_decoder( self, memory: Tensor, memory_mask: Tensor, spatial_shapes: Tensor, batch_data_samples: OptSampleList = None, ) -> Tuple[Dict]: """Prepare intermediate variables before entering Transformer decoder, such as `query`, `query_pos`, and `reference_points`. Args: memory (Tensor): The output embeddings of the Transformer encoder, has shape (bs, num_feat_points, dim). memory_mask (Tensor): ByteTensor, the padding mask of the memory, has shape (bs, num_feat_points). Will only be used when `as_two_stage` is `True`. spatial_shapes (Tensor): Spatial shapes of features in all levels. With shape (num_levels, 2), last dimension represents (h, w). Will only be used when `as_two_stage` is `True`. batch_data_samples (list[:obj:`DetDataSample`]): The batch data samples. It usually includes information such as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. Defaults to None. Returns: tuple[dict]: The decoder_inputs_dict and head_inputs_dict. - decoder_inputs_dict (dict): The keyword dictionary args of `self.forward_decoder()`, which includes 'query', 'memory', `reference_points`, and `dn_mask`. The reference points of decoder input here are 4D boxes, although it has `points` in its name. - head_inputs_dict (dict): The keyword dictionary args of the bbox_head functions, which includes `topk_score`, `topk_coords`, and `dn_meta` when `self.training` is `True`, else is empty. """ bs, _, c = memory.shape cls_out_features = self.bbox_head.cls_branches[ self.decoder.num_layers].out_features output_memory, output_proposals = self.gen_encoder_output_proposals( memory, memory_mask, spatial_shapes) enc_outputs_class = self.bbox_head.cls_branches[ self.decoder.num_layers]( output_memory) enc_outputs_coord_unact = self.bbox_head.reg_branches[ self.decoder.num_layers](output_memory) + output_proposals # NOTE The DINO selects top-k proposals according to scores of # multi-class classification, while DeformDETR, where the input # is `enc_outputs_class[..., 0]` selects according to scores of # binary classification. topk_indices = torch.topk( enc_outputs_class.max(-1)[0], k=self.num_queries, dim=1)[1] topk_score = torch.gather( enc_outputs_class, 1, topk_indices.unsqueeze(-1).repeat(1, 1, cls_out_features)) topk_coords_unact = torch.gather( enc_outputs_coord_unact, 1, topk_indices.unsqueeze(-1).repeat(1, 1, 4)) topk_coords = topk_coords_unact.sigmoid() topk_coords_unact = topk_coords_unact.detach() query = self.query_embedding.weight[:, None, :] query = query.repeat(1, bs, 1).transpose(0, 1) if self.training: dn_label_query, dn_bbox_query, dn_mask, dn_meta = \ self.dn_query_generator(batch_data_samples) query = torch.cat([dn_label_query, query], dim=1) reference_points = torch.cat([dn_bbox_query, topk_coords_unact], dim=1) else: reference_points = topk_coords_unact dn_mask, dn_meta = None, None reference_points = reference_points.sigmoid() decoder_inputs_dict = dict( query=query, memory=memory, reference_points=reference_points, dn_mask=dn_mask) # NOTE DINO calculates encoder losses on scores and coordinates # of selected top-k encoder queries, while DeformDETR is of all # encoder queries. head_inputs_dict = dict( enc_outputs_class=topk_score, enc_outputs_coord=topk_coords, dn_meta=dn_meta) if self.training else dict() return decoder_inputs_dict, head_inputs_dict def forward_decoder(self, query: Tensor, memory: Tensor, memory_mask: Tensor, reference_points: Tensor, spatial_shapes: Tensor, level_start_index: Tensor, valid_ratios: Tensor, dn_mask: Optional[Tensor] = None) -> Dict: """Forward with Transformer decoder. The forward procedure of the transformer is defined as: 'pre_transformer' -> 'encoder' -> 'pre_decoder' -> 'decoder' More details can be found at `TransformerDetector.forward_transformer` in `mmdet/detector/base_detr.py`. Args: query (Tensor): The queries of decoder inputs, has shape (bs, num_queries_total, dim), where `num_queries_total` is the sum of `num_denoising_queries` and `num_matching_queries` when `self.training` is `True`, else `num_matching_queries`. memory (Tensor): The output embeddings of the Transformer encoder, has shape (bs, num_feat_points, dim). memory_mask (Tensor): ByteTensor, the padding mask of the memory, has shape (bs, num_feat_points). reference_points (Tensor): The initial reference, has shape (bs, num_queries_total, 4) with the last dimension arranged as (cx, cy, w, h). spatial_shapes (Tensor): Spatial shapes of features in all levels, has shape (num_levels, 2), last dimension represents (h, w). level_start_index (Tensor): The start index of each level. A tensor has shape (num_levels, ) and can be represented as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...]. valid_ratios (Tensor): The ratios of the valid width and the valid height relative to the width and the height of features in all levels, has shape (bs, num_levels, 2). dn_mask (Tensor, optional): The attention mask to prevent information leakage from different denoising groups and matching parts, will be used as `self_attn_mask` of the `self.decoder`, has shape (num_queries_total, num_queries_total). It is `None` when `self.training` is `False`. Returns: dict: The dictionary of decoder outputs, which includes the `hidden_states` of the decoder output and `references` including the initial and intermediate reference_points. """ inter_states, references = self.decoder( query=query, value=memory, key_padding_mask=memory_mask, self_attn_mask=dn_mask, reference_points=reference_points, spatial_shapes=spatial_shapes, level_start_index=level_start_index, valid_ratios=valid_ratios, reg_branches=self.bbox_head.reg_branches) if len(query) == self.num_queries: # NOTE: This is to make sure label_embeding can be involved to # produce loss even if there is no denoising query (no ground truth # target in this GPU), otherwise, this will raise runtime error in # distributed training. inter_states[0] += \ self.dn_query_generator.label_embedding.weight[0, 0] * 0.0 decoder_outputs_dict = dict( hidden_states=inter_states, references=list(references)) return decoder_outputs_dict