# Copyright (c) OpenMMLab. All rights reserved. import copy from typing import Dict, List, Tuple import torch import torch.nn as nn from mmcv.cnn import Linear from mmengine.model import bias_init_with_prob, constant_init from torch import Tensor from mmdet.registry import MODELS from mmdet.structures import SampleList from mmdet.utils import InstanceList, OptInstanceList from ..layers import inverse_sigmoid from .detr_head import DETRHead @MODELS.register_module() class DeformableDETRHead(DETRHead): r"""Head of DeformDETR: Deformable DETR: Deformable Transformers for End-to-End Object Detection. Code is modified from the `official github repo `_. More details can be found in the `paper `_ . Args: share_pred_layer (bool): Whether to share parameters for all the prediction layers. Defaults to `False`. num_pred_layer (int): The number of the prediction layers. Defaults to 6. as_two_stage (bool, optional): Whether to generate the proposal from the outputs of encoder. Defaults to `False`. """ def __init__(self, *args, share_pred_layer: bool = False, num_pred_layer: int = 6, as_two_stage: bool = False, **kwargs) -> None: self.share_pred_layer = share_pred_layer self.num_pred_layer = num_pred_layer self.as_two_stage = as_two_stage super().__init__(*args, **kwargs) def _init_layers(self) -> None: """Initialize classification branch and regression branch of head.""" fc_cls = Linear(self.embed_dims, self.cls_out_channels) reg_branch = [] for _ in range(self.num_reg_fcs): reg_branch.append(Linear(self.embed_dims, self.embed_dims)) reg_branch.append(nn.ReLU()) reg_branch.append(Linear(self.embed_dims, 4)) reg_branch = nn.Sequential(*reg_branch) if self.share_pred_layer: self.cls_branches = nn.ModuleList( [fc_cls for _ in range(self.num_pred_layer)]) self.reg_branches = nn.ModuleList( [reg_branch for _ in range(self.num_pred_layer)]) else: self.cls_branches = nn.ModuleList( [copy.deepcopy(fc_cls) for _ in range(self.num_pred_layer)]) self.reg_branches = nn.ModuleList([ copy.deepcopy(reg_branch) for _ in range(self.num_pred_layer) ]) def init_weights(self) -> None: """Initialize weights of the Deformable DETR head.""" if self.loss_cls.use_sigmoid: bias_init = bias_init_with_prob(0.01) for m in self.cls_branches: nn.init.constant_(m.bias, bias_init) for m in self.reg_branches: constant_init(m[-1], 0, bias=0) nn.init.constant_(self.reg_branches[0][-1].bias.data[2:], -2.0) if self.as_two_stage: for m in self.reg_branches: nn.init.constant_(m[-1].bias.data[2:], 0.0) def forward(self, hidden_states: Tensor, references: List[Tensor]) -> Tuple[Tensor]: """Forward function. Args: hidden_states (Tensor): Hidden states output from each decoder layer, has shape (num_decoder_layers, bs, num_queries, dim). references (list[Tensor]): List of the reference from the decoder. The first reference is the `init_reference` (initial) and the other num_decoder_layers(6) references are `inter_references` (intermediate). The `init_reference` has shape (bs, num_queries, 4) when `as_two_stage` of the detector is `True`, otherwise (bs, num_queries, 2). Each `inter_reference` has shape (bs, num_queries, 4) when `with_box_refine` of the detector is `True`, otherwise (bs, num_queries, 2). The coordinates are arranged as (cx, cy) when the last dimension is 2, and (cx, cy, w, h) when it is 4. Returns: tuple[Tensor]: results of head containing the following tensor. - all_layers_outputs_classes (Tensor): Outputs from the classification head, has shape (num_decoder_layers, bs, num_queries, cls_out_channels). - all_layers_outputs_coords (Tensor): Sigmoid outputs from the regression head with normalized coordinate format (cx, cy, w, h), has shape (num_decoder_layers, bs, num_queries, 4) with the last dimension arranged as (cx, cy, w, h). """ all_layers_outputs_classes = [] all_layers_outputs_coords = [] for layer_id in range(hidden_states.shape[0]): reference = inverse_sigmoid(references[layer_id]) # NOTE The last reference will not be used. hidden_state = hidden_states[layer_id] outputs_class = self.cls_branches[layer_id](hidden_state) tmp_reg_preds = self.reg_branches[layer_id](hidden_state) if reference.shape[-1] == 4: # When `layer` is 0 and `as_two_stage` of the detector # is `True`, or when `layer` is greater than 0 and # `with_box_refine` of the detector is `True`. tmp_reg_preds += reference else: # When `layer` is 0 and `as_two_stage` of the detector # is `False`, or when `layer` is greater than 0 and # `with_box_refine` of the detector is `False`. assert reference.shape[-1] == 2 tmp_reg_preds[..., :2] += reference outputs_coord = tmp_reg_preds.sigmoid() all_layers_outputs_classes.append(outputs_class) all_layers_outputs_coords.append(outputs_coord) all_layers_outputs_classes = torch.stack(all_layers_outputs_classes) all_layers_outputs_coords = torch.stack(all_layers_outputs_coords) return all_layers_outputs_classes, all_layers_outputs_coords def loss(self, hidden_states: Tensor, references: List[Tensor], enc_outputs_class: Tensor, enc_outputs_coord: Tensor, batch_data_samples: SampleList) -> dict: """Perform forward propagation and loss calculation of the detection head on the queries of the upstream network. Args: hidden_states (Tensor): Hidden states output from each decoder layer, has shape (num_decoder_layers, num_queries, bs, dim). references (list[Tensor]): List of the reference from the decoder. The first reference is the `init_reference` (initial) and the other num_decoder_layers(6) references are `inter_references` (intermediate). The `init_reference` has shape (bs, num_queries, 4) when `as_two_stage` of the detector is `True`, otherwise (bs, num_queries, 2). Each `inter_reference` has shape (bs, num_queries, 4) when `with_box_refine` of the detector is `True`, otherwise (bs, num_queries, 2). The coordinates are arranged as (cx, cy) when the last dimension is 2, and (cx, cy, w, h) when it is 4. enc_outputs_class (Tensor): The score of each point on encode feature map, has shape (bs, num_feat_points, cls_out_channels). Only when `as_two_stage` is `True` it would be passed in, otherwise it would be `None`. enc_outputs_coord (Tensor): The proposal generate from the encode feature map, has shape (bs, num_feat_points, 4) with the last dimension arranged as (cx, cy, w, h). Only when `as_two_stage` is `True` it would be passed in, otherwise it would be `None`. batch_data_samples (list[:obj:`DetDataSample`]): The Data Samples. It usually includes information such as `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. Returns: dict: A dictionary of loss components. """ batch_gt_instances = [] batch_img_metas = [] for data_sample in batch_data_samples: batch_img_metas.append(data_sample.metainfo) batch_gt_instances.append(data_sample.gt_instances) outs = self(hidden_states, references) loss_inputs = outs + (enc_outputs_class, enc_outputs_coord, batch_gt_instances, batch_img_metas) losses = self.loss_by_feat(*loss_inputs) return losses def loss_by_feat( self, all_layers_cls_scores: Tensor, all_layers_bbox_preds: Tensor, enc_cls_scores: Tensor, enc_bbox_preds: Tensor, batch_gt_instances: InstanceList, batch_img_metas: List[dict], batch_gt_instances_ignore: OptInstanceList = None ) -> Dict[str, Tensor]: """Loss function. Args: all_layers_cls_scores (Tensor): Classification scores of all decoder layers, has shape (num_decoder_layers, bs, num_queries, cls_out_channels). all_layers_bbox_preds (Tensor): Regression outputs of all decoder layers. Each is a 4D-tensor with normalized coordinate format (cx, cy, w, h) and has shape (num_decoder_layers, bs, num_queries, 4) with the last dimension arranged as (cx, cy, w, h). enc_cls_scores (Tensor): The score of each point on encode feature map, has shape (bs, num_feat_points, cls_out_channels). Only when `as_two_stage` is `True` it would be passes in, otherwise, it would be `None`. enc_bbox_preds (Tensor): The proposal generate from the encode feature map, has shape (bs, num_feat_points, 4) with the last dimension arranged as (cx, cy, w, h). Only when `as_two_stage` is `True` it would be passed in, otherwise it would be `None`. batch_gt_instances (list[:obj:`InstanceData`]): Batch of gt_instance. It usually includes ``bboxes`` and ``labels`` attributes. batch_img_metas (list[dict]): Meta information of each image, e.g., image size, scaling factor, etc. batch_gt_instances_ignore (list[:obj:`InstanceData`], optional): Batch of gt_instances_ignore. It includes ``bboxes`` attribute data that is ignored during training and testing. Defaults to None. Returns: dict[str, Tensor]: A dictionary of loss components. """ loss_dict = super().loss_by_feat(all_layers_cls_scores, all_layers_bbox_preds, batch_gt_instances, batch_img_metas, batch_gt_instances_ignore) # loss of proposal generated from encode feature map. if enc_cls_scores is not None: proposal_gt_instances = copy.deepcopy(batch_gt_instances) for i in range(len(proposal_gt_instances)): proposal_gt_instances[i].labels = torch.zeros_like( proposal_gt_instances[i].labels) enc_loss_cls, enc_losses_bbox, enc_losses_iou = \ self.loss_by_feat_single( enc_cls_scores, enc_bbox_preds, batch_gt_instances=proposal_gt_instances, batch_img_metas=batch_img_metas) loss_dict['enc_loss_cls'] = enc_loss_cls loss_dict['enc_loss_bbox'] = enc_losses_bbox loss_dict['enc_loss_iou'] = enc_losses_iou return loss_dict def predict(self, hidden_states: Tensor, references: List[Tensor], batch_data_samples: SampleList, rescale: bool = True) -> InstanceList: """Perform forward propagation and loss calculation of the detection head on the queries of the upstream network. Args: hidden_states (Tensor): Hidden states output from each decoder layer, has shape (num_decoder_layers, num_queries, bs, dim). references (list[Tensor]): List of the reference from the decoder. The first reference is the `init_reference` (initial) and the other num_decoder_layers(6) references are `inter_references` (intermediate). The `init_reference` has shape (bs, num_queries, 4) when `as_two_stage` of the detector is `True`, otherwise (bs, num_queries, 2). Each `inter_reference` has shape (bs, num_queries, 4) when `with_box_refine` of the detector is `True`, otherwise (bs, num_queries, 2). The coordinates are arranged as (cx, cy) when the last dimension is 2, and (cx, cy, w, h) when it is 4. batch_data_samples (list[:obj:`DetDataSample`]): The Data Samples. It usually includes information such as `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. rescale (bool, optional): If `True`, return boxes in original image space. Defaults to `True`. Returns: list[obj:`InstanceData`]: Detection results of each image after the post process. """ batch_img_metas = [ data_samples.metainfo for data_samples in batch_data_samples ] outs = self(hidden_states, references) predictions = self.predict_by_feat( *outs, batch_img_metas=batch_img_metas, rescale=rescale) return predictions def predict_by_feat(self, all_layers_cls_scores: Tensor, all_layers_bbox_preds: Tensor, batch_img_metas: List[Dict], rescale: bool = False) -> InstanceList: """Transform a batch of output features extracted from the head into bbox results. Args: all_layers_cls_scores (Tensor): Classification scores of all decoder layers, has shape (num_decoder_layers, bs, num_queries, cls_out_channels). all_layers_bbox_preds (Tensor): Regression outputs of all decoder layers. Each is a 4D-tensor with normalized coordinate format (cx, cy, w, h) and shape (num_decoder_layers, bs, num_queries, 4) with the last dimension arranged as (cx, cy, w, h). batch_img_metas (list[dict]): Meta information of each image. rescale (bool, optional): If `True`, return boxes in original image space. Default `False`. Returns: list[obj:`InstanceData`]: Detection results of each image after the post process. """ cls_scores = all_layers_cls_scores[-1] bbox_preds = all_layers_bbox_preds[-1] result_list = [] for img_id in range(len(batch_img_metas)): cls_score = cls_scores[img_id] bbox_pred = bbox_preds[img_id] img_meta = batch_img_metas[img_id] results = self._predict_by_feat_single(cls_score, bbox_pred, img_meta, rescale) result_list.append(results) return result_list