Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from typing import Tuple | |
import torch.nn.functional as F | |
from mmcv.cnn import ConvModule | |
from mmcv.cnn.bricks import NonLocal2d | |
from mmengine.model import BaseModule | |
from torch import Tensor | |
from mmdet.registry import MODELS | |
from mmdet.utils import OptConfigType, OptMultiConfig | |
class BFP(BaseModule): | |
"""BFP (Balanced Feature Pyramids) | |
BFP takes multi-level features as inputs and gather them into a single one, | |
then refine the gathered feature and scatter the refined results to | |
multi-level features. This module is used in Libra R-CNN (CVPR 2019), see | |
the paper `Libra R-CNN: Towards Balanced Learning for Object Detection | |
<https://arxiv.org/abs/1904.02701>`_ for details. | |
Args: | |
in_channels (int): Number of input channels (feature maps of all levels | |
should have the same channels). | |
num_levels (int): Number of input feature levels. | |
refine_level (int): Index of integration and refine level of BSF in | |
multi-level features from bottom to top. | |
refine_type (str): Type of the refine op, currently support | |
[None, 'conv', 'non_local']. | |
conv_cfg (:obj:`ConfigDict` or dict, optional): The config dict for | |
convolution layers. | |
norm_cfg (:obj:`ConfigDict` or dict, optional): The config dict for | |
normalization layers. | |
init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or | |
dict], optional): Initialization config dict. | |
""" | |
def __init__( | |
self, | |
in_channels: int, | |
num_levels: int, | |
refine_level: int = 2, | |
refine_type: str = None, | |
conv_cfg: OptConfigType = None, | |
norm_cfg: OptConfigType = None, | |
init_cfg: OptMultiConfig = dict( | |
type='Xavier', layer='Conv2d', distribution='uniform') | |
) -> None: | |
super().__init__(init_cfg=init_cfg) | |
assert refine_type in [None, 'conv', 'non_local'] | |
self.in_channels = in_channels | |
self.num_levels = num_levels | |
self.conv_cfg = conv_cfg | |
self.norm_cfg = norm_cfg | |
self.refine_level = refine_level | |
self.refine_type = refine_type | |
assert 0 <= self.refine_level < self.num_levels | |
if self.refine_type == 'conv': | |
self.refine = ConvModule( | |
self.in_channels, | |
self.in_channels, | |
3, | |
padding=1, | |
conv_cfg=self.conv_cfg, | |
norm_cfg=self.norm_cfg) | |
elif self.refine_type == 'non_local': | |
self.refine = NonLocal2d( | |
self.in_channels, | |
reduction=1, | |
use_scale=False, | |
conv_cfg=self.conv_cfg, | |
norm_cfg=self.norm_cfg) | |
def forward(self, inputs: Tuple[Tensor]) -> Tuple[Tensor]: | |
"""Forward function.""" | |
assert len(inputs) == self.num_levels | |
# step 1: gather multi-level features by resize and average | |
feats = [] | |
gather_size = inputs[self.refine_level].size()[2:] | |
for i in range(self.num_levels): | |
if i < self.refine_level: | |
gathered = F.adaptive_max_pool2d( | |
inputs[i], output_size=gather_size) | |
else: | |
gathered = F.interpolate( | |
inputs[i], size=gather_size, mode='nearest') | |
feats.append(gathered) | |
bsf = sum(feats) / len(feats) | |
# step 2: refine gathered features | |
if self.refine_type is not None: | |
bsf = self.refine(bsf) | |
# step 3: scatter refined features to multi-levels by a residual path | |
outs = [] | |
for i in range(self.num_levels): | |
out_size = inputs[i].size()[2:] | |
if i < self.refine_level: | |
residual = F.interpolate(bsf, size=out_size, mode='nearest') | |
else: | |
residual = F.adaptive_max_pool2d(bsf, output_size=out_size) | |
outs.append(residual + inputs[i]) | |
return tuple(outs) | |