# Copyright (c) OpenMMLab. All rights reserved. import torch.nn as nn from mmcv.cnn.bricks import ConvModule from mmengine.model import BaseModule from mmcls.registry import MODELS from ..backbones.resnet import Bottleneck, ResLayer @MODELS.register_module() class HRFuseScales(BaseModule): """Fuse feature map of multiple scales in HRNet. Args: in_channels (list[int]): The input channels of all scales. out_channels (int): The channels of fused feature map. Defaults to 2048. norm_cfg (dict): dictionary to construct norm layers. Defaults to ``dict(type='BN', momentum=0.1)``. init_cfg (dict | list[dict], optional): Initialization config dict. Defaults to ``dict(type='Normal', layer='Linear', std=0.01))``. """ def __init__(self, in_channels, out_channels=2048, norm_cfg=dict(type='BN', momentum=0.1), init_cfg=dict(type='Normal', layer='Linear', std=0.01)): super(HRFuseScales, self).__init__(init_cfg=init_cfg) self.in_channels = in_channels self.out_channels = out_channels self.norm_cfg = norm_cfg block_type = Bottleneck out_channels = [128, 256, 512, 1024] # Increase the channels on each resolution # from C, 2C, 4C, 8C to 128, 256, 512, 1024 increase_layers = [] for i in range(len(in_channels)): increase_layers.append( ResLayer( block_type, in_channels=in_channels[i], out_channels=out_channels[i], num_blocks=1, stride=1, )) self.increase_layers = nn.ModuleList(increase_layers) # Downsample feature maps in each scale. downsample_layers = [] for i in range(len(in_channels) - 1): downsample_layers.append( ConvModule( in_channels=out_channels[i], out_channels=out_channels[i + 1], kernel_size=3, stride=2, padding=1, norm_cfg=self.norm_cfg, bias=False, )) self.downsample_layers = nn.ModuleList(downsample_layers) # The final conv block before final classifier linear layer. self.final_layer = ConvModule( in_channels=out_channels[3], out_channels=self.out_channels, kernel_size=1, norm_cfg=self.norm_cfg, bias=False, ) def forward(self, x): assert isinstance(x, tuple) and len(x) == len(self.in_channels) feat = self.increase_layers[0](x[0]) for i in range(len(self.downsample_layers)): feat = self.downsample_layers[i](feat) + \ self.increase_layers[i + 1](x[i + 1]) return (self.final_layer(feat), )