Spaces:
Running
on
L40S
Running
on
L40S
from mmcv.cnn import build_conv_layer, build_norm_layer | |
from mmcv.runner import BaseModule, Sequential | |
from torch import nn as nn | |
class ResLayer(Sequential): | |
"""ResLayer to build ResNet style backbone. | |
Args: | |
block (nn.Module): block used to build ResLayer. | |
inplanes (int): inplanes of block. | |
planes (int): planes of block. | |
num_blocks (int): number of blocks. | |
stride (int): stride of the first block. Default: 1 | |
avg_down (bool): Use AvgPool instead of stride conv when | |
downsampling in the bottleneck. Default: False | |
conv_cfg (dict): dictionary to construct and config conv layer. | |
Default: None | |
norm_cfg (dict): dictionary to construct and config norm layer. | |
Default: dict(type='BN') | |
downsample_first (bool): Downsample at the first block or last block. | |
False for Hourglass, True for ResNet. Default: True | |
""" | |
def __init__(self, | |
block, | |
inplanes, | |
planes, | |
num_blocks, | |
stride=1, | |
avg_down=False, | |
conv_cfg=None, | |
norm_cfg=dict(type='BN'), | |
downsample_first=True, | |
**kwargs): | |
self.block = block | |
downsample = None | |
if stride != 1 or inplanes != planes * block.expansion: | |
downsample = [] | |
conv_stride = stride | |
if avg_down: | |
conv_stride = 1 | |
downsample.append( | |
nn.AvgPool2d(kernel_size=stride, | |
stride=stride, | |
ceil_mode=True, | |
count_include_pad=False)) | |
downsample.extend([ | |
build_conv_layer(conv_cfg, | |
inplanes, | |
planes * block.expansion, | |
kernel_size=1, | |
stride=conv_stride, | |
bias=False), | |
build_norm_layer(norm_cfg, planes * block.expansion)[1] | |
]) | |
downsample = nn.Sequential(*downsample) | |
layers = [] | |
if downsample_first: | |
layers.append( | |
block(inplanes=inplanes, | |
planes=planes, | |
stride=stride, | |
downsample=downsample, | |
conv_cfg=conv_cfg, | |
norm_cfg=norm_cfg, | |
**kwargs)) | |
inplanes = planes * block.expansion | |
for _ in range(1, num_blocks): | |
layers.append( | |
block(inplanes=inplanes, | |
planes=planes, | |
stride=1, | |
conv_cfg=conv_cfg, | |
norm_cfg=norm_cfg, | |
**kwargs)) | |
else: # downsample_first=False is for HourglassModule | |
for _ in range(num_blocks - 1): | |
layers.append( | |
block(inplanes=inplanes, | |
planes=inplanes, | |
stride=1, | |
conv_cfg=conv_cfg, | |
norm_cfg=norm_cfg, | |
**kwargs)) | |
layers.append( | |
block(inplanes=inplanes, | |
planes=planes, | |
stride=stride, | |
downsample=downsample, | |
conv_cfg=conv_cfg, | |
norm_cfg=norm_cfg, | |
**kwargs)) | |
super(ResLayer, self).__init__(*layers) | |
class SimplifiedBasicBlock(BaseModule): | |
"""Simplified version of original basic residual block. This is used in | |
`SCNet <https://arxiv.org/abs/2012.10150>`_. | |
- Norm layer is now optional | |
- Last ReLU in forward function is removed | |
""" | |
expansion = 1 | |
def __init__(self, | |
inplanes, | |
planes, | |
stride=1, | |
dilation=1, | |
downsample=None, | |
style='pytorch', | |
with_cp=False, | |
conv_cfg=None, | |
norm_cfg=dict(type='BN'), | |
dcn=None, | |
plugins=None, | |
init_fg=None): | |
super(SimplifiedBasicBlock, self).__init__(init_fg) | |
assert dcn is None, 'Not implemented yet.' | |
assert plugins is None, 'Not implemented yet.' | |
assert not with_cp, 'Not implemented yet.' | |
self.with_norm = norm_cfg is not None | |
with_bias = True if norm_cfg is None else False | |
self.conv1 = build_conv_layer(conv_cfg, | |
inplanes, | |
planes, | |
3, | |
stride=stride, | |
padding=dilation, | |
dilation=dilation, | |
bias=with_bias) | |
if self.with_norm: | |
self.norm1_name, norm1 = build_norm_layer(norm_cfg, | |
planes, | |
postfix=1) | |
self.add_module(self.norm1_name, norm1) | |
self.conv2 = build_conv_layer(conv_cfg, | |
planes, | |
planes, | |
3, | |
padding=1, | |
bias=with_bias) | |
if self.with_norm: | |
self.norm2_name, norm2 = build_norm_layer(norm_cfg, | |
planes, | |
postfix=2) | |
self.add_module(self.norm2_name, norm2) | |
self.relu = nn.ReLU(inplace=True) | |
self.downsample = downsample | |
self.stride = stride | |
self.dilation = dilation | |
self.with_cp = with_cp | |
def norm1(self): | |
"""nn.Module: normalization layer after the first convolution layer""" | |
return getattr(self, self.norm1_name) if self.with_norm else None | |
def norm2(self): | |
"""nn.Module: normalization layer after the second convolution layer""" | |
return getattr(self, self.norm2_name) if self.with_norm else None | |
def forward(self, x): | |
"""Forward function.""" | |
identity = x | |
out = self.conv1(x) | |
if self.with_norm: | |
out = self.norm1(out) | |
out = self.relu(out) | |
out = self.conv2(out) | |
if self.with_norm: | |
out = self.norm2(out) | |
if self.downsample is not None: | |
identity = self.downsample(x) | |
out += identity | |
return out | |