Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from abc import abstractmethod | |
from typing import Any, List, Sequence, Tuple, Union | |
import torch.nn as nn | |
from mmcv.cnn import ConvModule | |
from numpy import ndarray | |
from torch import Tensor | |
from mmdet.registry import MODELS, TASK_UTILS | |
from mmdet.utils import (ConfigType, InstanceList, MultiConfig, OptConfigType, | |
OptInstanceList) | |
from ..task_modules.prior_generators import MlvlPointGenerator | |
from ..utils import multi_apply | |
from .base_dense_head import BaseDenseHead | |
StrideType = Union[Sequence[int], Sequence[Tuple[int, int]]] | |
class AnchorFreeHead(BaseDenseHead): | |
"""Anchor-free head (FCOS, Fovea, RepPoints, etc.). | |
Args: | |
num_classes (int): Number of categories excluding the background | |
category. | |
in_channels (int): Number of channels in the input feature map. | |
feat_channels (int): Number of hidden channels. Used in child classes. | |
stacked_convs (int): Number of stacking convs of the head. | |
strides (Sequence[int] or Sequence[Tuple[int, int]]): Downsample | |
factor of each feature map. | |
dcn_on_last_conv (bool): If true, use dcn in the last layer of | |
towers. Defaults to False. | |
conv_bias (bool or str): If specified as `auto`, it will be decided by | |
the norm_cfg. Bias of conv will be set as True if `norm_cfg` is | |
None, otherwise False. Default: "auto". | |
loss_cls (:obj:`ConfigDict` or dict): Config of classification loss. | |
loss_bbox (:obj:`ConfigDict` or dict): Config of localization loss. | |
bbox_coder (:obj:`ConfigDict` or dict): Config of bbox coder. Defaults | |
'DistancePointBBoxCoder'. | |
conv_cfg (:obj:`ConfigDict` or dict, Optional): Config dict for | |
convolution layer. Defaults to None. | |
norm_cfg (:obj:`ConfigDict` or dict, Optional): Config dict for | |
normalization layer. Defaults to None. | |
train_cfg (:obj:`ConfigDict` or dict, Optional): Training config of | |
anchor-free head. | |
test_cfg (:obj:`ConfigDict` or dict, Optional): Testing config of | |
anchor-free head. | |
init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \ | |
dict]): Initialization config dict. | |
""" # noqa: W605 | |
_version = 1 | |
def __init__( | |
self, | |
num_classes: int, | |
in_channels: int, | |
feat_channels: int = 256, | |
stacked_convs: int = 4, | |
strides: StrideType = (4, 8, 16, 32, 64), | |
dcn_on_last_conv: bool = False, | |
conv_bias: Union[bool, str] = 'auto', | |
loss_cls: ConfigType = dict( | |
type='FocalLoss', | |
use_sigmoid=True, | |
gamma=2.0, | |
alpha=0.25, | |
loss_weight=1.0), | |
loss_bbox: ConfigType = dict(type='IoULoss', loss_weight=1.0), | |
bbox_coder: ConfigType = dict(type='DistancePointBBoxCoder'), | |
conv_cfg: OptConfigType = None, | |
norm_cfg: OptConfigType = None, | |
train_cfg: OptConfigType = None, | |
test_cfg: OptConfigType = None, | |
init_cfg: MultiConfig = dict( | |
type='Normal', | |
layer='Conv2d', | |
std=0.01, | |
override=dict( | |
type='Normal', name='conv_cls', std=0.01, bias_prob=0.01)) | |
) -> None: | |
super().__init__(init_cfg=init_cfg) | |
self.num_classes = num_classes | |
self.use_sigmoid_cls = loss_cls.get('use_sigmoid', False) | |
if self.use_sigmoid_cls: | |
self.cls_out_channels = num_classes | |
else: | |
self.cls_out_channels = num_classes + 1 | |
self.in_channels = in_channels | |
self.feat_channels = feat_channels | |
self.stacked_convs = stacked_convs | |
self.strides = strides | |
self.dcn_on_last_conv = dcn_on_last_conv | |
assert conv_bias == 'auto' or isinstance(conv_bias, bool) | |
self.conv_bias = conv_bias | |
self.loss_cls = MODELS.build(loss_cls) | |
self.loss_bbox = MODELS.build(loss_bbox) | |
self.bbox_coder = TASK_UTILS.build(bbox_coder) | |
self.prior_generator = MlvlPointGenerator(strides) | |
# In order to keep a more general interface and be consistent with | |
# anchor_head. We can think of point like one anchor | |
self.num_base_priors = self.prior_generator.num_base_priors[0] | |
self.train_cfg = train_cfg | |
self.test_cfg = test_cfg | |
self.conv_cfg = conv_cfg | |
self.norm_cfg = norm_cfg | |
self.fp16_enabled = False | |
self._init_layers() | |
def _init_layers(self) -> None: | |
"""Initialize layers of the head.""" | |
self._init_cls_convs() | |
self._init_reg_convs() | |
self._init_predictor() | |
def _init_cls_convs(self) -> None: | |
"""Initialize classification conv layers of the head.""" | |
self.cls_convs = nn.ModuleList() | |
for i in range(self.stacked_convs): | |
chn = self.in_channels if i == 0 else self.feat_channels | |
if self.dcn_on_last_conv and i == self.stacked_convs - 1: | |
conv_cfg = dict(type='DCNv2') | |
else: | |
conv_cfg = self.conv_cfg | |
self.cls_convs.append( | |
ConvModule( | |
chn, | |
self.feat_channels, | |
3, | |
stride=1, | |
padding=1, | |
conv_cfg=conv_cfg, | |
norm_cfg=self.norm_cfg, | |
bias=self.conv_bias)) | |
def _init_reg_convs(self) -> None: | |
"""Initialize bbox regression conv layers of the head.""" | |
self.reg_convs = nn.ModuleList() | |
for i in range(self.stacked_convs): | |
chn = self.in_channels if i == 0 else self.feat_channels | |
if self.dcn_on_last_conv and i == self.stacked_convs - 1: | |
conv_cfg = dict(type='DCNv2') | |
else: | |
conv_cfg = self.conv_cfg | |
self.reg_convs.append( | |
ConvModule( | |
chn, | |
self.feat_channels, | |
3, | |
stride=1, | |
padding=1, | |
conv_cfg=conv_cfg, | |
norm_cfg=self.norm_cfg, | |
bias=self.conv_bias)) | |
def _init_predictor(self) -> None: | |
"""Initialize predictor layers of the head.""" | |
self.conv_cls = nn.Conv2d( | |
self.feat_channels, self.cls_out_channels, 3, padding=1) | |
self.conv_reg = nn.Conv2d(self.feat_channels, 4, 3, padding=1) | |
def _load_from_state_dict(self, state_dict: dict, prefix: str, | |
local_metadata: dict, strict: bool, | |
missing_keys: Union[List[str], str], | |
unexpected_keys: Union[List[str], str], | |
error_msgs: Union[List[str], str]) -> None: | |
"""Hack some keys of the model state dict so that can load checkpoints | |
of previous version.""" | |
version = local_metadata.get('version', None) | |
if version is None: | |
# the key is different in early versions | |
# for example, 'fcos_cls' become 'conv_cls' now | |
bbox_head_keys = [ | |
k for k in state_dict.keys() if k.startswith(prefix) | |
] | |
ori_predictor_keys = [] | |
new_predictor_keys = [] | |
# e.g. 'fcos_cls' or 'fcos_reg' | |
for key in bbox_head_keys: | |
ori_predictor_keys.append(key) | |
key = key.split('.') | |
if len(key) < 2: | |
conv_name = None | |
elif key[1].endswith('cls'): | |
conv_name = 'conv_cls' | |
elif key[1].endswith('reg'): | |
conv_name = 'conv_reg' | |
elif key[1].endswith('centerness'): | |
conv_name = 'conv_centerness' | |
else: | |
conv_name = None | |
if conv_name is not None: | |
key[1] = conv_name | |
new_predictor_keys.append('.'.join(key)) | |
else: | |
ori_predictor_keys.pop(-1) | |
for i in range(len(new_predictor_keys)): | |
state_dict[new_predictor_keys[i]] = state_dict.pop( | |
ori_predictor_keys[i]) | |
super()._load_from_state_dict(state_dict, prefix, local_metadata, | |
strict, missing_keys, unexpected_keys, | |
error_msgs) | |
def forward(self, x: Tuple[Tensor]) -> Tuple[List[Tensor], List[Tensor]]: | |
"""Forward features from the upstream network. | |
Args: | |
feats (tuple[Tensor]): Features from the upstream network, each is | |
a 4D-tensor. | |
Returns: | |
tuple: Usually contain classification scores and bbox predictions. | |
- cls_scores (list[Tensor]): Box scores for each scale level, \ | |
each is a 4D-tensor, the channel number is \ | |
num_points * num_classes. | |
- bbox_preds (list[Tensor]): Box energies / deltas for each scale \ | |
level, each is a 4D-tensor, the channel number is num_points * 4. | |
""" | |
return multi_apply(self.forward_single, x)[:2] | |
def forward_single(self, x: Tensor) -> Tuple[Tensor, ...]: | |
"""Forward features of a single scale level. | |
Args: | |
x (Tensor): FPN feature maps of the specified stride. | |
Returns: | |
tuple: Scores for each class, bbox predictions, features | |
after classification and regression conv layers, some | |
models needs these features like FCOS. | |
""" | |
cls_feat = x | |
reg_feat = x | |
for cls_layer in self.cls_convs: | |
cls_feat = cls_layer(cls_feat) | |
cls_score = self.conv_cls(cls_feat) | |
for reg_layer in self.reg_convs: | |
reg_feat = reg_layer(reg_feat) | |
bbox_pred = self.conv_reg(reg_feat) | |
return cls_score, bbox_pred, cls_feat, reg_feat | |
def loss_by_feat( | |
self, | |
cls_scores: List[Tensor], | |
bbox_preds: List[Tensor], | |
batch_gt_instances: InstanceList, | |
batch_img_metas: List[dict], | |
batch_gt_instances_ignore: OptInstanceList = None) -> dict: | |
"""Calculate the loss based on the features extracted by the detection | |
head. | |
Args: | |
cls_scores (list[Tensor]): Box scores for each scale level, | |
each is a 4D-tensor, the channel number is | |
num_points * num_classes. | |
bbox_preds (list[Tensor]): Box energies / deltas for each scale | |
level, each is a 4D-tensor, the channel number is | |
num_points * 4. | |
batch_gt_instances (list[:obj:`InstanceData`]): Batch of | |
gt_instance. It usually includes ``bboxes`` and ``labels`` | |
attributes. | |
batch_img_metas (list[dict]): Meta information of each image, e.g., | |
image size, scaling factor, etc. | |
batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional): | |
Batch of gt_instances_ignore. It includes ``bboxes`` attribute | |
data that is ignored during training and testing. | |
Defaults to None. | |
""" | |
raise NotImplementedError | |
def get_targets(self, points: List[Tensor], | |
batch_gt_instances: InstanceList) -> Any: | |
"""Compute regression, classification and centerness targets for points | |
in multiple images. | |
Args: | |
points (list[Tensor]): Points of each fpn level, each has shape | |
(num_points, 2). | |
batch_gt_instances (list[:obj:`InstanceData`]): Batch of | |
gt_instance. It usually includes ``bboxes`` and ``labels`` | |
attributes. | |
""" | |
raise NotImplementedError | |
# TODO refactor aug_test | |
def aug_test(self, | |
aug_batch_feats: List[Tensor], | |
aug_batch_img_metas: List[List[Tensor]], | |
rescale: bool = False) -> List[ndarray]: | |
"""Test function with test time augmentation. | |
Args: | |
aug_batch_feats (list[Tensor]): the outer list indicates test-time | |
augmentations and inner Tensor should have a shape NxCxHxW, | |
which contains features for all images in the batch. | |
aug_batch_img_metas (list[list[dict]]): the outer list indicates | |
test-time augs (multiscale, flip, etc.) and the inner list | |
indicates images in a batch. each dict has image information. | |
rescale (bool, optional): Whether to rescale the results. | |
Defaults to False. | |
Returns: | |
list[ndarray]: bbox results of each class | |
""" | |
return self.aug_test_bboxes( | |
aug_batch_feats, aug_batch_img_metas, rescale=rescale) | |