# Copyright (c) OpenMMLab. All rights reserved. from typing import Tuple import torch.nn.functional as F from mmcv.cnn import ConvModule from mmcv.cnn.bricks import NonLocal2d from mmengine.model import BaseModule from torch import Tensor from mmdet.registry import MODELS from mmdet.utils import OptConfigType, OptMultiConfig @MODELS.register_module() class BFP(BaseModule): """BFP (Balanced Feature Pyramids) BFP takes multi-level features as inputs and gather them into a single one, then refine the gathered feature and scatter the refined results to multi-level features. This module is used in Libra R-CNN (CVPR 2019), see the paper `Libra R-CNN: Towards Balanced Learning for Object Detection `_ for details. Args: in_channels (int): Number of input channels (feature maps of all levels should have the same channels). num_levels (int): Number of input feature levels. refine_level (int): Index of integration and refine level of BSF in multi-level features from bottom to top. refine_type (str): Type of the refine op, currently support [None, 'conv', 'non_local']. conv_cfg (:obj:`ConfigDict` or dict, optional): The config dict for convolution layers. norm_cfg (:obj:`ConfigDict` or dict, optional): The config dict for normalization layers. init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or dict], optional): Initialization config dict. """ def __init__( self, in_channels: int, num_levels: int, refine_level: int = 2, refine_type: str = None, conv_cfg: OptConfigType = None, norm_cfg: OptConfigType = None, init_cfg: OptMultiConfig = dict( type='Xavier', layer='Conv2d', distribution='uniform') ) -> None: super().__init__(init_cfg=init_cfg) assert refine_type in [None, 'conv', 'non_local'] self.in_channels = in_channels self.num_levels = num_levels self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg self.refine_level = refine_level self.refine_type = refine_type assert 0 <= self.refine_level < self.num_levels if self.refine_type == 'conv': self.refine = ConvModule( self.in_channels, self.in_channels, 3, padding=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg) elif self.refine_type == 'non_local': self.refine = NonLocal2d( self.in_channels, reduction=1, use_scale=False, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg) def forward(self, inputs: Tuple[Tensor]) -> Tuple[Tensor]: """Forward function.""" assert len(inputs) == self.num_levels # step 1: gather multi-level features by resize and average feats = [] gather_size = inputs[self.refine_level].size()[2:] for i in range(self.num_levels): if i < self.refine_level: gathered = F.adaptive_max_pool2d( inputs[i], output_size=gather_size) else: gathered = F.interpolate( inputs[i], size=gather_size, mode='nearest') feats.append(gathered) bsf = sum(feats) / len(feats) # step 2: refine gathered features if self.refine_type is not None: bsf = self.refine(bsf) # step 3: scatter refined features to multi-levels by a residual path outs = [] for i in range(self.num_levels): out_size = inputs[i].size()[2:] if i < self.refine_level: residual = F.interpolate(bsf, size=out_size, mode='nearest') else: residual = F.adaptive_max_pool2d(bsf, output_size=out_size) outs.append(residual + inputs[i]) return tuple(outs)