ai-photo-gallery / mmdet /models /dense_heads /embedding_rpn_head.py
KyanChen's picture
init
f549064
raw
history blame
No virus
5.7 kB
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List
import torch
import torch.nn as nn
from mmengine.model import BaseModule
from mmengine.structures import InstanceData
from torch import Tensor
from mmdet.registry import MODELS
from mmdet.structures.bbox import bbox_cxcywh_to_xyxy
from mmdet.structures.det_data_sample import SampleList
from mmdet.utils import InstanceList, OptConfigType
@MODELS.register_module()
class EmbeddingRPNHead(BaseModule):
"""RPNHead in the `Sparse R-CNN <https://arxiv.org/abs/2011.12450>`_ .
Unlike traditional RPNHead, this module does not need FPN input, but just
decode `init_proposal_bboxes` and expand the first dimension of
`init_proposal_bboxes` and `init_proposal_features` to the batch_size.
Args:
num_proposals (int): Number of init_proposals. Defaults to 100.
proposal_feature_channel (int): Channel number of
init_proposal_feature. Defaults to 256.
init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \
dict]): Initialization config dict. Defaults to None.
"""
def __init__(self,
num_proposals: int = 100,
proposal_feature_channel: int = 256,
init_cfg: OptConfigType = None,
**kwargs) -> None:
# `**kwargs` is necessary to avoid some potential error.
assert init_cfg is None, 'To prevent abnormal initialization ' \
'behavior, init_cfg is not allowed to be set'
super().__init__(init_cfg=init_cfg)
self.num_proposals = num_proposals
self.proposal_feature_channel = proposal_feature_channel
self._init_layers()
def _init_layers(self) -> None:
"""Initialize a sparse set of proposal boxes and proposal features."""
self.init_proposal_bboxes = nn.Embedding(self.num_proposals, 4)
self.init_proposal_features = nn.Embedding(
self.num_proposals, self.proposal_feature_channel)
def init_weights(self) -> None:
"""Initialize the init_proposal_bboxes as normalized.
[c_x, c_y, w, h], and we initialize it to the size of the entire
image.
"""
super().init_weights()
nn.init.constant_(self.init_proposal_bboxes.weight[:, :2], 0.5)
nn.init.constant_(self.init_proposal_bboxes.weight[:, 2:], 1)
def _decode_init_proposals(self, x: List[Tensor],
batch_data_samples: SampleList) -> InstanceList:
"""Decode init_proposal_bboxes according to the size of images and
expand dimension of init_proposal_features to batch_size.
Args:
x (list[Tensor]): List of FPN features.
batch_data_samples (List[:obj:`DetDataSample`]): The Data
Samples. It usually includes information such as
`gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
Returns:
List[:obj:`InstanceData`:] Detection results of each image.
Each item usually contains following keys.
- proposals: Decoded proposal bboxes,
has shape (num_proposals, 4).
- features: init_proposal_features, expanded proposal
features, has shape
(num_proposals, proposal_feature_channel).
- imgs_whwh: Tensor with shape
(num_proposals, 4), the dimension means
[img_width, img_height, img_width, img_height].
"""
batch_img_metas = []
for data_sample in batch_data_samples:
batch_img_metas.append(data_sample.metainfo)
proposals = self.init_proposal_bboxes.weight.clone()
proposals = bbox_cxcywh_to_xyxy(proposals)
imgs_whwh = []
for meta in batch_img_metas:
h, w = meta['img_shape'][:2]
imgs_whwh.append(x[0].new_tensor([[w, h, w, h]]))
imgs_whwh = torch.cat(imgs_whwh, dim=0)
imgs_whwh = imgs_whwh[:, None, :]
proposals = proposals * imgs_whwh
rpn_results_list = []
for idx in range(len(batch_img_metas)):
rpn_results = InstanceData()
rpn_results.bboxes = proposals[idx]
rpn_results.imgs_whwh = imgs_whwh[idx].repeat(
self.num_proposals, 1)
rpn_results.features = self.init_proposal_features.weight.clone()
rpn_results_list.append(rpn_results)
return rpn_results_list
def loss(self, *args, **kwargs):
"""Perform forward propagation and loss calculation of the detection
head on the features of the upstream network."""
raise NotImplementedError(
'EmbeddingRPNHead does not have `loss`, please use '
'`predict` or `loss_and_predict` instead.')
def predict(self, x: List[Tensor], batch_data_samples: SampleList,
**kwargs) -> InstanceList:
"""Perform forward propagation of the detection head and predict
detection results on the features of the upstream network."""
# `**kwargs` is necessary to avoid some potential error.
return self._decode_init_proposals(
x=x, batch_data_samples=batch_data_samples)
def loss_and_predict(self, x: List[Tensor], batch_data_samples: SampleList,
**kwargs) -> tuple:
"""Perform forward propagation of the head, then calculate loss and
predictions from the features and data samples."""
# `**kwargs` is necessary to avoid some potential error.
predictions = self._decode_init_proposals(
x=x, batch_data_samples=batch_data_samples)
return dict(), predictions