Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import copy | |
from typing import Dict, List, Tuple | |
import torch | |
import torch.nn as nn | |
from mmcv.cnn import Linear | |
from mmengine.model import bias_init_with_prob, constant_init | |
from torch import Tensor | |
from mmdet.registry import MODELS | |
from mmdet.structures import SampleList | |
from mmdet.utils import InstanceList, OptInstanceList | |
from ..layers import inverse_sigmoid | |
from .detr_head import DETRHead | |
class DeformableDETRHead(DETRHead): | |
r"""Head of DeformDETR: Deformable DETR: Deformable Transformers for | |
End-to-End Object Detection. | |
Code is modified from the `official github repo | |
<https://github.com/fundamentalvision/Deformable-DETR>`_. | |
More details can be found in the `paper | |
<https://arxiv.org/abs/2010.04159>`_ . | |
Args: | |
share_pred_layer (bool): Whether to share parameters for all the | |
prediction layers. Defaults to `False`. | |
num_pred_layer (int): The number of the prediction layers. | |
Defaults to 6. | |
as_two_stage (bool, optional): Whether to generate the proposal | |
from the outputs of encoder. Defaults to `False`. | |
""" | |
def __init__(self, | |
*args, | |
share_pred_layer: bool = False, | |
num_pred_layer: int = 6, | |
as_two_stage: bool = False, | |
**kwargs) -> None: | |
self.share_pred_layer = share_pred_layer | |
self.num_pred_layer = num_pred_layer | |
self.as_two_stage = as_two_stage | |
super().__init__(*args, **kwargs) | |
def _init_layers(self) -> None: | |
"""Initialize classification branch and regression branch of head.""" | |
fc_cls = Linear(self.embed_dims, self.cls_out_channels) | |
reg_branch = [] | |
for _ in range(self.num_reg_fcs): | |
reg_branch.append(Linear(self.embed_dims, self.embed_dims)) | |
reg_branch.append(nn.ReLU()) | |
reg_branch.append(Linear(self.embed_dims, 4)) | |
reg_branch = nn.Sequential(*reg_branch) | |
if self.share_pred_layer: | |
self.cls_branches = nn.ModuleList( | |
[fc_cls for _ in range(self.num_pred_layer)]) | |
self.reg_branches = nn.ModuleList( | |
[reg_branch for _ in range(self.num_pred_layer)]) | |
else: | |
self.cls_branches = nn.ModuleList( | |
[copy.deepcopy(fc_cls) for _ in range(self.num_pred_layer)]) | |
self.reg_branches = nn.ModuleList([ | |
copy.deepcopy(reg_branch) for _ in range(self.num_pred_layer) | |
]) | |
def init_weights(self) -> None: | |
"""Initialize weights of the Deformable DETR head.""" | |
if self.loss_cls.use_sigmoid: | |
bias_init = bias_init_with_prob(0.01) | |
for m in self.cls_branches: | |
nn.init.constant_(m.bias, bias_init) | |
for m in self.reg_branches: | |
constant_init(m[-1], 0, bias=0) | |
nn.init.constant_(self.reg_branches[0][-1].bias.data[2:], -2.0) | |
if self.as_two_stage: | |
for m in self.reg_branches: | |
nn.init.constant_(m[-1].bias.data[2:], 0.0) | |
def forward(self, hidden_states: Tensor, | |
references: List[Tensor]) -> Tuple[Tensor]: | |
"""Forward function. | |
Args: | |
hidden_states (Tensor): Hidden states output from each decoder | |
layer, has shape (num_decoder_layers, bs, num_queries, dim). | |
references (list[Tensor]): List of the reference from the decoder. | |
The first reference is the `init_reference` (initial) and the | |
other num_decoder_layers(6) references are `inter_references` | |
(intermediate). The `init_reference` has shape (bs, | |
num_queries, 4) when `as_two_stage` of the detector is `True`, | |
otherwise (bs, num_queries, 2). Each `inter_reference` has | |
shape (bs, num_queries, 4) when `with_box_refine` of the | |
detector is `True`, otherwise (bs, num_queries, 2). The | |
coordinates are arranged as (cx, cy) when the last dimension is | |
2, and (cx, cy, w, h) when it is 4. | |
Returns: | |
tuple[Tensor]: results of head containing the following tensor. | |
- all_layers_outputs_classes (Tensor): Outputs from the | |
classification head, has shape (num_decoder_layers, bs, | |
num_queries, cls_out_channels). | |
- all_layers_outputs_coords (Tensor): Sigmoid outputs from the | |
regression head with normalized coordinate format (cx, cy, w, | |
h), has shape (num_decoder_layers, bs, num_queries, 4) with the | |
last dimension arranged as (cx, cy, w, h). | |
""" | |
all_layers_outputs_classes = [] | |
all_layers_outputs_coords = [] | |
for layer_id in range(hidden_states.shape[0]): | |
reference = inverse_sigmoid(references[layer_id]) | |
# NOTE The last reference will not be used. | |
hidden_state = hidden_states[layer_id] | |
outputs_class = self.cls_branches[layer_id](hidden_state) | |
tmp_reg_preds = self.reg_branches[layer_id](hidden_state) | |
if reference.shape[-1] == 4: | |
# When `layer` is 0 and `as_two_stage` of the detector | |
# is `True`, or when `layer` is greater than 0 and | |
# `with_box_refine` of the detector is `True`. | |
tmp_reg_preds += reference | |
else: | |
# When `layer` is 0 and `as_two_stage` of the detector | |
# is `False`, or when `layer` is greater than 0 and | |
# `with_box_refine` of the detector is `False`. | |
assert reference.shape[-1] == 2 | |
tmp_reg_preds[..., :2] += reference | |
outputs_coord = tmp_reg_preds.sigmoid() | |
all_layers_outputs_classes.append(outputs_class) | |
all_layers_outputs_coords.append(outputs_coord) | |
all_layers_outputs_classes = torch.stack(all_layers_outputs_classes) | |
all_layers_outputs_coords = torch.stack(all_layers_outputs_coords) | |
return all_layers_outputs_classes, all_layers_outputs_coords | |
def loss(self, hidden_states: Tensor, references: List[Tensor], | |
enc_outputs_class: Tensor, enc_outputs_coord: Tensor, | |
batch_data_samples: SampleList) -> dict: | |
"""Perform forward propagation and loss calculation of the detection | |
head on the queries of the upstream network. | |
Args: | |
hidden_states (Tensor): Hidden states output from each decoder | |
layer, has shape (num_decoder_layers, num_queries, bs, dim). | |
references (list[Tensor]): List of the reference from the decoder. | |
The first reference is the `init_reference` (initial) and the | |
other num_decoder_layers(6) references are `inter_references` | |
(intermediate). The `init_reference` has shape (bs, | |
num_queries, 4) when `as_two_stage` of the detector is `True`, | |
otherwise (bs, num_queries, 2). Each `inter_reference` has | |
shape (bs, num_queries, 4) when `with_box_refine` of the | |
detector is `True`, otherwise (bs, num_queries, 2). The | |
coordinates are arranged as (cx, cy) when the last dimension is | |
2, and (cx, cy, w, h) when it is 4. | |
enc_outputs_class (Tensor): The score of each point on encode | |
feature map, has shape (bs, num_feat_points, cls_out_channels). | |
Only when `as_two_stage` is `True` it would be passed in, | |
otherwise it would be `None`. | |
enc_outputs_coord (Tensor): The proposal generate from the encode | |
feature map, has shape (bs, num_feat_points, 4) with the last | |
dimension arranged as (cx, cy, w, h). Only when `as_two_stage` | |
is `True` it would be passed in, otherwise it would be `None`. | |
batch_data_samples (list[:obj:`DetDataSample`]): The Data | |
Samples. It usually includes information such as | |
`gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. | |
Returns: | |
dict: A dictionary of loss components. | |
""" | |
batch_gt_instances = [] | |
batch_img_metas = [] | |
for data_sample in batch_data_samples: | |
batch_img_metas.append(data_sample.metainfo) | |
batch_gt_instances.append(data_sample.gt_instances) | |
outs = self(hidden_states, references) | |
loss_inputs = outs + (enc_outputs_class, enc_outputs_coord, | |
batch_gt_instances, batch_img_metas) | |
losses = self.loss_by_feat(*loss_inputs) | |
return losses | |
def loss_by_feat( | |
self, | |
all_layers_cls_scores: Tensor, | |
all_layers_bbox_preds: Tensor, | |
enc_cls_scores: Tensor, | |
enc_bbox_preds: Tensor, | |
batch_gt_instances: InstanceList, | |
batch_img_metas: List[dict], | |
batch_gt_instances_ignore: OptInstanceList = None | |
) -> Dict[str, Tensor]: | |
"""Loss function. | |
Args: | |
all_layers_cls_scores (Tensor): Classification scores of all | |
decoder layers, has shape (num_decoder_layers, bs, num_queries, | |
cls_out_channels). | |
all_layers_bbox_preds (Tensor): Regression outputs of all decoder | |
layers. Each is a 4D-tensor with normalized coordinate format | |
(cx, cy, w, h) and has shape (num_decoder_layers, bs, | |
num_queries, 4) with the last dimension arranged as | |
(cx, cy, w, h). | |
enc_cls_scores (Tensor): The score of each point on encode | |
feature map, has shape (bs, num_feat_points, cls_out_channels). | |
Only when `as_two_stage` is `True` it would be passes in, | |
otherwise, it would be `None`. | |
enc_bbox_preds (Tensor): The proposal generate from the encode | |
feature map, has shape (bs, num_feat_points, 4) with the last | |
dimension arranged as (cx, cy, w, h). Only when `as_two_stage` | |
is `True` it would be passed in, otherwise it would be `None`. | |
batch_gt_instances (list[:obj:`InstanceData`]): Batch of | |
gt_instance. It usually includes ``bboxes`` and ``labels`` | |
attributes. | |
batch_img_metas (list[dict]): Meta information of each image, e.g., | |
image size, scaling factor, etc. | |
batch_gt_instances_ignore (list[:obj:`InstanceData`], optional): | |
Batch of gt_instances_ignore. It includes ``bboxes`` attribute | |
data that is ignored during training and testing. | |
Defaults to None. | |
Returns: | |
dict[str, Tensor]: A dictionary of loss components. | |
""" | |
loss_dict = super().loss_by_feat(all_layers_cls_scores, | |
all_layers_bbox_preds, | |
batch_gt_instances, batch_img_metas, | |
batch_gt_instances_ignore) | |
# loss of proposal generated from encode feature map. | |
if enc_cls_scores is not None: | |
proposal_gt_instances = copy.deepcopy(batch_gt_instances) | |
for i in range(len(proposal_gt_instances)): | |
proposal_gt_instances[i].labels = torch.zeros_like( | |
proposal_gt_instances[i].labels) | |
enc_loss_cls, enc_losses_bbox, enc_losses_iou = \ | |
self.loss_by_feat_single( | |
enc_cls_scores, enc_bbox_preds, | |
batch_gt_instances=proposal_gt_instances, | |
batch_img_metas=batch_img_metas) | |
loss_dict['enc_loss_cls'] = enc_loss_cls | |
loss_dict['enc_loss_bbox'] = enc_losses_bbox | |
loss_dict['enc_loss_iou'] = enc_losses_iou | |
return loss_dict | |
def predict(self, | |
hidden_states: Tensor, | |
references: List[Tensor], | |
batch_data_samples: SampleList, | |
rescale: bool = True) -> InstanceList: | |
"""Perform forward propagation and loss calculation of the detection | |
head on the queries of the upstream network. | |
Args: | |
hidden_states (Tensor): Hidden states output from each decoder | |
layer, has shape (num_decoder_layers, num_queries, bs, dim). | |
references (list[Tensor]): List of the reference from the decoder. | |
The first reference is the `init_reference` (initial) and the | |
other num_decoder_layers(6) references are `inter_references` | |
(intermediate). The `init_reference` has shape (bs, | |
num_queries, 4) when `as_two_stage` of the detector is `True`, | |
otherwise (bs, num_queries, 2). Each `inter_reference` has | |
shape (bs, num_queries, 4) when `with_box_refine` of the | |
detector is `True`, otherwise (bs, num_queries, 2). The | |
coordinates are arranged as (cx, cy) when the last dimension is | |
2, and (cx, cy, w, h) when it is 4. | |
batch_data_samples (list[:obj:`DetDataSample`]): The Data | |
Samples. It usually includes information such as | |
`gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. | |
rescale (bool, optional): If `True`, return boxes in original | |
image space. Defaults to `True`. | |
Returns: | |
list[obj:`InstanceData`]: Detection results of each image | |
after the post process. | |
""" | |
batch_img_metas = [ | |
data_samples.metainfo for data_samples in batch_data_samples | |
] | |
outs = self(hidden_states, references) | |
predictions = self.predict_by_feat( | |
*outs, batch_img_metas=batch_img_metas, rescale=rescale) | |
return predictions | |
def predict_by_feat(self, | |
all_layers_cls_scores: Tensor, | |
all_layers_bbox_preds: Tensor, | |
batch_img_metas: List[Dict], | |
rescale: bool = False) -> InstanceList: | |
"""Transform a batch of output features extracted from the head into | |
bbox results. | |
Args: | |
all_layers_cls_scores (Tensor): Classification scores of all | |
decoder layers, has shape (num_decoder_layers, bs, num_queries, | |
cls_out_channels). | |
all_layers_bbox_preds (Tensor): Regression outputs of all decoder | |
layers. Each is a 4D-tensor with normalized coordinate format | |
(cx, cy, w, h) and shape (num_decoder_layers, bs, num_queries, | |
4) with the last dimension arranged as (cx, cy, w, h). | |
batch_img_metas (list[dict]): Meta information of each image. | |
rescale (bool, optional): If `True`, return boxes in original | |
image space. Default `False`. | |
Returns: | |
list[obj:`InstanceData`]: Detection results of each image | |
after the post process. | |
""" | |
cls_scores = all_layers_cls_scores[-1] | |
bbox_preds = all_layers_bbox_preds[-1] | |
result_list = [] | |
for img_id in range(len(batch_img_metas)): | |
cls_score = cls_scores[img_id] | |
bbox_pred = bbox_preds[img_id] | |
img_meta = batch_img_metas[img_id] | |
results = self._predict_by_feat_single(cls_score, bbox_pred, | |
img_meta, rescale) | |
result_list.append(results) | |
return result_list | |