KyanChen's picture
init
f549064
raw
history blame contribute delete
No virus
6.26 kB
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABCMeta, abstractmethod
from typing import Dict, List, Tuple, Union
import torch
from mmengine.model import BaseModel
from torch import Tensor
from mmdet.structures import DetDataSample, OptSampleList, SampleList
from mmdet.utils import InstanceList, OptConfigType, OptMultiConfig
from ..utils import samplelist_boxtype2tensor
ForwardResults = Union[Dict[str, torch.Tensor], List[DetDataSample],
Tuple[torch.Tensor], torch.Tensor]
class BaseDetector(BaseModel, metaclass=ABCMeta):
"""Base class for detectors.
Args:
data_preprocessor (dict or ConfigDict, optional): The pre-process
config of :class:`BaseDataPreprocessor`. it usually includes,
``pad_size_divisor``, ``pad_value``, ``mean`` and ``std``.
init_cfg (dict or ConfigDict, optional): the config to control the
initialization. Defaults to None.
"""
def __init__(self,
data_preprocessor: OptConfigType = None,
init_cfg: OptMultiConfig = None):
super().__init__(
data_preprocessor=data_preprocessor, init_cfg=init_cfg)
@property
def with_neck(self) -> bool:
"""bool: whether the detector has a neck"""
return hasattr(self, 'neck') and self.neck is not None
# TODO: these properties need to be carefully handled
# for both single stage & two stage detectors
@property
def with_shared_head(self) -> bool:
"""bool: whether the detector has a shared head in the RoI Head"""
return hasattr(self, 'roi_head') and self.roi_head.with_shared_head
@property
def with_bbox(self) -> bool:
"""bool: whether the detector has a bbox head"""
return ((hasattr(self, 'roi_head') and self.roi_head.with_bbox)
or (hasattr(self, 'bbox_head') and self.bbox_head is not None))
@property
def with_mask(self) -> bool:
"""bool: whether the detector has a mask head"""
return ((hasattr(self, 'roi_head') and self.roi_head.with_mask)
or (hasattr(self, 'mask_head') and self.mask_head is not None))
def forward(self,
inputs: torch.Tensor,
data_samples: OptSampleList = None,
mode: str = 'tensor') -> ForwardResults:
"""The unified entry for a forward process in both training and test.
The method should accept three modes: "tensor", "predict" and "loss":
- "tensor": Forward the whole network and return tensor or tuple of
tensor without any post-processing, same as a common nn.Module.
- "predict": Forward and return the predictions, which are fully
processed to a list of :obj:`DetDataSample`.
- "loss": Forward and return a dict of losses according to the given
inputs and data samples.
Note that this method doesn't handle either back propagation or
parameter update, which are supposed to be done in :meth:`train_step`.
Args:
inputs (torch.Tensor): The input tensor with shape
(N, C, ...) in general.
data_samples (list[:obj:`DetDataSample`], optional): A batch of
data samples that contain annotations and predictions.
Defaults to None.
mode (str): Return what kind of value. Defaults to 'tensor'.
Returns:
The return type depends on ``mode``.
- If ``mode="tensor"``, return a tensor or a tuple of tensor.
- If ``mode="predict"``, return a list of :obj:`DetDataSample`.
- If ``mode="loss"``, return a dict of tensor.
"""
if mode == 'loss':
return self.loss(inputs, data_samples)
elif mode == 'predict':
return self.predict(inputs, data_samples)
elif mode == 'tensor':
return self._forward(inputs, data_samples)
else:
raise RuntimeError(f'Invalid mode "{mode}". '
'Only supports loss, predict and tensor mode')
@abstractmethod
def loss(self, batch_inputs: Tensor,
batch_data_samples: SampleList) -> Union[dict, tuple]:
"""Calculate losses from a batch of inputs and data samples."""
pass
@abstractmethod
def predict(self, batch_inputs: Tensor,
batch_data_samples: SampleList) -> SampleList:
"""Predict results from a batch of inputs and data samples with post-
processing."""
pass
@abstractmethod
def _forward(self,
batch_inputs: Tensor,
batch_data_samples: OptSampleList = None):
"""Network forward process.
Usually includes backbone, neck and head forward without any post-
processing.
"""
pass
@abstractmethod
def extract_feat(self, batch_inputs: Tensor):
"""Extract features from images."""
pass
def add_pred_to_datasample(self, data_samples: SampleList,
results_list: InstanceList) -> SampleList:
"""Add predictions to `DetDataSample`.
Args:
data_samples (list[:obj:`DetDataSample`], optional): A batch of
data samples that contain annotations and predictions.
results_list (list[:obj:`InstanceData`]): Detection results of
each image.
Returns:
list[:obj:`DetDataSample`]: Detection results of the
input images. Each DetDataSample usually contain
'pred_instances'. And the ``pred_instances`` usually
contains following keys.
- scores (Tensor): Classification scores, has a shape
(num_instance, )
- labels (Tensor): Labels of bboxes, has a shape
(num_instances, ).
- bboxes (Tensor): Has a shape (num_instances, 4),
the last dimension 4 arrange as (x1, y1, x2, y2).
"""
for data_sample, pred_instances in zip(data_samples, results_list):
data_sample.pred_instances = pred_instances
samplelist_boxtype2tensor(data_samples)
return data_samples