KyanChen's picture
init
f549064
raw
history blame
No virus
4.89 kB
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABCMeta, abstractmethod
from typing import Tuple
from mmengine.model import BaseModule
from torch import Tensor
from mmdet.registry import MODELS
from mmdet.structures import SampleList
from mmdet.utils import InstanceList, OptConfigType, OptMultiConfig
class BaseRoIHead(BaseModule, metaclass=ABCMeta):
"""Base class for RoIHeads."""
def __init__(self,
bbox_roi_extractor: OptMultiConfig = None,
bbox_head: OptMultiConfig = None,
mask_roi_extractor: OptMultiConfig = None,
mask_head: OptMultiConfig = None,
shared_head: OptConfigType = None,
train_cfg: OptConfigType = None,
test_cfg: OptConfigType = None,
init_cfg: OptMultiConfig = None) -> None:
super().__init__(init_cfg=init_cfg)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
if shared_head is not None:
self.shared_head = MODELS.build(shared_head)
if bbox_head is not None:
self.init_bbox_head(bbox_roi_extractor, bbox_head)
if mask_head is not None:
self.init_mask_head(mask_roi_extractor, mask_head)
self.init_assigner_sampler()
@property
def with_bbox(self) -> bool:
"""bool: whether the RoI head contains a `bbox_head`"""
return hasattr(self, 'bbox_head') and self.bbox_head is not None
@property
def with_mask(self) -> bool:
"""bool: whether the RoI head contains a `mask_head`"""
return hasattr(self, 'mask_head') and self.mask_head is not None
@property
def with_shared_head(self) -> bool:
"""bool: whether the RoI head contains a `shared_head`"""
return hasattr(self, 'shared_head') and self.shared_head is not None
@abstractmethod
def init_bbox_head(self, *args, **kwargs):
"""Initialize ``bbox_head``"""
pass
@abstractmethod
def init_mask_head(self, *args, **kwargs):
"""Initialize ``mask_head``"""
pass
@abstractmethod
def init_assigner_sampler(self, *args, **kwargs):
"""Initialize assigner and sampler."""
pass
@abstractmethod
def loss(self, x: Tuple[Tensor], rpn_results_list: InstanceList,
batch_data_samples: SampleList):
"""Perform forward propagation and loss calculation of the roi head on
the features of the upstream network."""
def predict(self,
x: Tuple[Tensor],
rpn_results_list: InstanceList,
batch_data_samples: SampleList,
rescale: bool = False) -> InstanceList:
"""Perform forward propagation of the roi head and predict detection
results on the features of the upstream network.
Args:
x (tuple[Tensor]): Features from upstream network. Each
has shape (N, C, H, W).
rpn_results_list (list[:obj:`InstanceData`]): list of region
proposals.
batch_data_samples (List[:obj:`DetDataSample`]): The Data
Samples. It usually includes information such as
`gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
rescale (bool): Whether to rescale the results to
the original image. Defaults to True.
Returns:
list[obj:`InstanceData`]: Detection results of each image.
Each item usually contains following keys.
- scores (Tensor): Classification scores, has a shape
(num_instance, )
- labels (Tensor): Labels of bboxes, has a shape
(num_instances, ).
- bboxes (Tensor): Has a shape (num_instances, 4),
the last dimension 4 arrange as (x1, y1, x2, y2).
- masks (Tensor): Has a shape (num_instances, H, W).
"""
assert self.with_bbox, 'Bbox head must be implemented.'
batch_img_metas = [
data_samples.metainfo for data_samples in batch_data_samples
]
# TODO: nms_op in mmcv need be enhanced, the bbox result may get
# difference when not rescale in bbox_head
# If it has the mask branch, the bbox branch does not need
# to be scaled to the original image scale, because the mask
# branch will scale both bbox and mask at the same time.
bbox_rescale = rescale if not self.with_mask else False
results_list = self.predict_bbox(
x,
batch_img_metas,
rpn_results_list,
rcnn_test_cfg=self.test_cfg,
rescale=bbox_rescale)
if self.with_mask:
results_list = self.predict_mask(
x, batch_img_metas, results_list, rescale=rescale)
return results_list