from mmcv.cnn import build_conv_layer, build_norm_layer from mmcv.runner import BaseModule, Sequential from torch import nn as nn class ResLayer(Sequential): """ResLayer to build ResNet style backbone. Args: block (nn.Module): block used to build ResLayer. inplanes (int): inplanes of block. planes (int): planes of block. num_blocks (int): number of blocks. stride (int): stride of the first block. Default: 1 avg_down (bool): Use AvgPool instead of stride conv when downsampling in the bottleneck. Default: False conv_cfg (dict): dictionary to construct and config conv layer. Default: None norm_cfg (dict): dictionary to construct and config norm layer. Default: dict(type='BN') downsample_first (bool): Downsample at the first block or last block. False for Hourglass, True for ResNet. Default: True """ def __init__(self, block, inplanes, planes, num_blocks, stride=1, avg_down=False, conv_cfg=None, norm_cfg=dict(type='BN'), downsample_first=True, **kwargs): self.block = block downsample = None if stride != 1 or inplanes != planes * block.expansion: downsample = [] conv_stride = stride if avg_down: conv_stride = 1 downsample.append( nn.AvgPool2d(kernel_size=stride, stride=stride, ceil_mode=True, count_include_pad=False)) downsample.extend([ build_conv_layer(conv_cfg, inplanes, planes * block.expansion, kernel_size=1, stride=conv_stride, bias=False), build_norm_layer(norm_cfg, planes * block.expansion)[1] ]) downsample = nn.Sequential(*downsample) layers = [] if downsample_first: layers.append( block(inplanes=inplanes, planes=planes, stride=stride, downsample=downsample, conv_cfg=conv_cfg, norm_cfg=norm_cfg, **kwargs)) inplanes = planes * block.expansion for _ in range(1, num_blocks): layers.append( block(inplanes=inplanes, planes=planes, stride=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, **kwargs)) else: # downsample_first=False is for HourglassModule for _ in range(num_blocks - 1): layers.append( block(inplanes=inplanes, planes=inplanes, stride=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, **kwargs)) layers.append( block(inplanes=inplanes, planes=planes, stride=stride, downsample=downsample, conv_cfg=conv_cfg, norm_cfg=norm_cfg, **kwargs)) super(ResLayer, self).__init__(*layers) class SimplifiedBasicBlock(BaseModule): """Simplified version of original basic residual block. This is used in `SCNet `_. - Norm layer is now optional - Last ReLU in forward function is removed """ expansion = 1 def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, style='pytorch', with_cp=False, conv_cfg=None, norm_cfg=dict(type='BN'), dcn=None, plugins=None, init_fg=None): super(SimplifiedBasicBlock, self).__init__(init_fg) assert dcn is None, 'Not implemented yet.' assert plugins is None, 'Not implemented yet.' assert not with_cp, 'Not implemented yet.' self.with_norm = norm_cfg is not None with_bias = True if norm_cfg is None else False self.conv1 = build_conv_layer(conv_cfg, inplanes, planes, 3, stride=stride, padding=dilation, dilation=dilation, bias=with_bias) if self.with_norm: self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1) self.add_module(self.norm1_name, norm1) self.conv2 = build_conv_layer(conv_cfg, planes, planes, 3, padding=1, bias=with_bias) if self.with_norm: self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2) self.add_module(self.norm2_name, norm2) self.relu = nn.ReLU(inplace=True) self.downsample = downsample self.stride = stride self.dilation = dilation self.with_cp = with_cp @property def norm1(self): """nn.Module: normalization layer after the first convolution layer""" return getattr(self, self.norm1_name) if self.with_norm else None @property def norm2(self): """nn.Module: normalization layer after the second convolution layer""" return getattr(self, self.norm2_name) if self.with_norm else None def forward(self, x): """Forward function.""" identity = x out = self.conv1(x) if self.with_norm: out = self.norm1(out) out = self.relu(out) out = self.conv2(out) if self.with_norm: out = self.norm2(out) if self.downsample is not None: identity = self.downsample(x) out += identity return out