KyanChen's picture
init
f549064
raw
history blame
No virus
4.55 kB
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple
import torch.nn as nn
from mmcv.cnn import Linear
from mmengine.model import bias_init_with_prob, constant_init
from torch import Tensor
from mmdet.registry import MODELS
from mmdet.structures import SampleList
from mmdet.utils import InstanceList
from ..layers import MLP, inverse_sigmoid
from .conditional_detr_head import ConditionalDETRHead
@MODELS.register_module()
class DABDETRHead(ConditionalDETRHead):
"""Head of DAB-DETR. DAB-DETR: Dynamic Anchor Boxes are Better Queries for
DETR.
More details can be found in the `paper
<https://arxiv.org/abs/2201.12329>`_ .
"""
def _init_layers(self) -> None:
"""Initialize layers of the transformer head."""
# cls branch
self.fc_cls = Linear(self.embed_dims, self.cls_out_channels)
# reg branch
self.fc_reg = MLP(self.embed_dims, self.embed_dims, 4, 3)
def init_weights(self) -> None:
"""initialize weights."""
if self.loss_cls.use_sigmoid:
bias_init = bias_init_with_prob(0.01)
nn.init.constant_(self.fc_cls.bias, bias_init)
constant_init(self.fc_reg.layers[-1], 0., bias=0.)
def forward(self, hidden_states: Tensor,
references: Tensor) -> Tuple[Tensor, Tensor]:
""""Forward function.
Args:
hidden_states (Tensor): Features from transformer decoder. If
`return_intermediate_dec` is True output has shape
(num_decoder_layers, bs, num_queries, dim), else has shape (1,
bs, num_queries, dim) which only contains the last layer
outputs.
references (Tensor): References from transformer decoder. If
`return_intermediate_dec` is True output has shape
(num_decoder_layers, bs, num_queries, 2/4), else has shape (1,
bs, num_queries, 2/4)
which only contains the last layer reference.
Returns:
tuple[Tensor]: results of head containing the following tensor.
- layers_cls_scores (Tensor): Outputs from the classification head,
shape (num_decoder_layers, bs, num_queries, cls_out_channels).
Note cls_out_channels should include background.
- layers_bbox_preds (Tensor): Sigmoid outputs from the regression
head with normalized coordinate format (cx, cy, w, h), has shape
(num_decoder_layers, bs, num_queries, 4).
"""
layers_cls_scores = self.fc_cls(hidden_states)
references_before_sigmoid = inverse_sigmoid(references, eps=1e-3)
tmp_reg_preds = self.fc_reg(hidden_states)
tmp_reg_preds[..., :references_before_sigmoid.
size(-1)] += references_before_sigmoid
layers_bbox_preds = tmp_reg_preds.sigmoid()
return layers_cls_scores, layers_bbox_preds
def predict(self,
hidden_states: Tensor,
references: Tensor,
batch_data_samples: SampleList,
rescale: bool = True) -> InstanceList:
"""Perform forward propagation of the detection head and predict
detection results on the features of the upstream network. Over-write
because img_metas are needed as inputs for bbox_head.
Args:
hidden_states (Tensor): Feature from the transformer decoder, has
shape (num_decoder_layers, bs, num_queries, dim).
references (Tensor): references from the transformer decoder, has
shape (num_decoder_layers, bs, num_queries, 2/4).
batch_data_samples (List[:obj:`DetDataSample`]): The Data
Samples. It usually includes information such as
`gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
rescale (bool, optional): Whether to rescale the results.
Defaults to True.
Returns:
list[obj:`InstanceData`]: Detection results of each image
after the post process.
"""
batch_img_metas = [
data_samples.metainfo for data_samples in batch_data_samples
]
last_layer_hidden_state = hidden_states[-1].unsqueeze(0)
last_layer_reference = references[-1].unsqueeze(0)
outs = self(last_layer_hidden_state, last_layer_reference)
predictions = self.predict_by_feat(
*outs, batch_img_metas=batch_img_metas, rescale=rescale)
return predictions