Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from typing import List, Tuple | |
import torch.nn as nn | |
from mmcv.cnn import ConvModule | |
from mmengine.model import BaseModule | |
from torch import Tensor | |
from mmdet.registry import MODELS | |
from mmdet.utils import OptConfigType, OptMultiConfig | |
class ChannelMapper(BaseModule): | |
"""Channel Mapper to reduce/increase channels of backbone features. | |
This is used to reduce/increase channels of backbone features. | |
Args: | |
in_channels (List[int]): Number of input channels per scale. | |
out_channels (int): Number of output channels (used at each scale). | |
kernel_size (int, optional): kernel_size for reducing channels (used | |
at each scale). Default: 3. | |
conv_cfg (:obj:`ConfigDict` or dict, optional): Config dict for | |
convolution layer. Default: None. | |
norm_cfg (:obj:`ConfigDict` or dict, optional): Config dict for | |
normalization layer. Default: None. | |
act_cfg (:obj:`ConfigDict` or dict, optional): Config dict for | |
activation layer in ConvModule. Default: dict(type='ReLU'). | |
num_outs (int, optional): Number of output feature maps. There would | |
be extra_convs when num_outs larger than the length of in_channels. | |
init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or dict], | |
optional): Initialization config dict. | |
Example: | |
>>> import torch | |
>>> in_channels = [2, 3, 5, 7] | |
>>> scales = [340, 170, 84, 43] | |
>>> inputs = [torch.rand(1, c, s, s) | |
... for c, s in zip(in_channels, scales)] | |
>>> self = ChannelMapper(in_channels, 11, 3).eval() | |
>>> outputs = self.forward(inputs) | |
>>> for i in range(len(outputs)): | |
... print(f'outputs[{i}].shape = {outputs[i].shape}') | |
outputs[0].shape = torch.Size([1, 11, 340, 340]) | |
outputs[1].shape = torch.Size([1, 11, 170, 170]) | |
outputs[2].shape = torch.Size([1, 11, 84, 84]) | |
outputs[3].shape = torch.Size([1, 11, 43, 43]) | |
""" | |
def __init__( | |
self, | |
in_channels: List[int], | |
out_channels: int, | |
kernel_size: int = 3, | |
conv_cfg: OptConfigType = None, | |
norm_cfg: OptConfigType = None, | |
act_cfg: OptConfigType = dict(type='ReLU'), | |
num_outs: int = None, | |
init_cfg: OptMultiConfig = dict( | |
type='Xavier', layer='Conv2d', distribution='uniform') | |
) -> None: | |
super().__init__(init_cfg=init_cfg) | |
assert isinstance(in_channels, list) | |
self.extra_convs = None | |
if num_outs is None: | |
num_outs = len(in_channels) | |
self.convs = nn.ModuleList() | |
for in_channel in in_channels: | |
self.convs.append( | |
ConvModule( | |
in_channel, | |
out_channels, | |
kernel_size, | |
padding=(kernel_size - 1) // 2, | |
conv_cfg=conv_cfg, | |
norm_cfg=norm_cfg, | |
act_cfg=act_cfg)) | |
if num_outs > len(in_channels): | |
self.extra_convs = nn.ModuleList() | |
for i in range(len(in_channels), num_outs): | |
if i == len(in_channels): | |
in_channel = in_channels[-1] | |
else: | |
in_channel = out_channels | |
self.extra_convs.append( | |
ConvModule( | |
in_channel, | |
out_channels, | |
3, | |
stride=2, | |
padding=1, | |
conv_cfg=conv_cfg, | |
norm_cfg=norm_cfg, | |
act_cfg=act_cfg)) | |
def forward(self, inputs: Tuple[Tensor]) -> Tuple[Tensor]: | |
"""Forward function.""" | |
assert len(inputs) == len(self.convs) | |
outs = [self.convs[i](inputs[i]) for i in range(len(inputs))] | |
if self.extra_convs: | |
for i in range(len(self.extra_convs)): | |
if i == 0: | |
outs.append(self.extra_convs[0](inputs[-1])) | |
else: | |
outs.append(self.extra_convs[i](outs[-1])) | |
return tuple(outs) | |