ttxskk
update
d7e58f0
raw
history blame
31.7 kB
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import build_conv_layer, build_norm_layer
from mmcv.runner import BaseModule, ModuleList, Sequential
from torch.nn.modules.batchnorm import _BatchNorm
from .resnet import BasicBlock, Bottleneck
class HRModule(BaseModule):
"""High-Resolution Module for HRNet.
In this module, every branch has 4 BasicBlocks/Bottlenecks. Fusion/Exchange
is in this module.
"""
def __init__(self,
num_branches,
blocks,
num_blocks,
in_channels,
num_channels,
multiscale_output=True,
with_cp=False,
conv_cfg=None,
norm_cfg=dict(type='BN'),
block_init_cfg=None,
init_cfg=None):
super(HRModule, self).__init__(init_cfg)
self.block_init_cfg = block_init_cfg
self._check_branches(num_branches, num_blocks, in_channels,
num_channels)
self.in_channels = in_channels
self.num_branches = num_branches
self.multiscale_output = multiscale_output
self.norm_cfg = norm_cfg
self.conv_cfg = conv_cfg
self.with_cp = with_cp
self.branches = self._make_branches(num_branches, blocks, num_blocks,
num_channels)
self.fuse_layers = self._make_fuse_layers()
self.relu = nn.ReLU(inplace=False)
def _check_branches(self, num_branches, num_blocks, in_channels,
num_channels):
if num_branches != len(num_blocks):
error_msg = f'NUM_BRANCHES({num_branches}) ' \
f'!= NUM_BLOCKS({len(num_blocks)})'
raise ValueError(error_msg)
if num_branches != len(num_channels):
error_msg = f'NUM_BRANCHES({num_branches}) ' \
f'!= NUM_CHANNELS({len(num_channels)})'
raise ValueError(error_msg)
if num_branches != len(in_channels):
error_msg = f'NUM_BRANCHES({num_branches}) ' \
f'!= NUM_INCHANNELS({len(in_channels)})'
raise ValueError(error_msg)
def _make_one_branch(self,
branch_index,
block,
num_blocks,
num_channels,
stride=1):
downsample = None
if stride != 1 or \
self.in_channels[branch_index] != \
num_channels[branch_index] * block.expansion:
downsample = nn.Sequential(
build_conv_layer(self.conv_cfg,
self.in_channels[branch_index],
num_channels[branch_index] * block.expansion,
kernel_size=1,
stride=stride,
bias=False),
build_norm_layer(self.norm_cfg, num_channels[branch_index] *
block.expansion)[1])
layers = []
layers.append(
block(self.in_channels[branch_index],
num_channels[branch_index],
stride,
downsample=downsample,
with_cp=self.with_cp,
norm_cfg=self.norm_cfg,
conv_cfg=self.conv_cfg,
init_cfg=self.block_init_cfg))
self.in_channels[branch_index] = \
num_channels[branch_index] * block.expansion
for i in range(1, num_blocks[branch_index]):
layers.append(
block(self.in_channels[branch_index],
num_channels[branch_index],
with_cp=self.with_cp,
norm_cfg=self.norm_cfg,
conv_cfg=self.conv_cfg,
init_cfg=self.block_init_cfg))
return Sequential(*layers)
def _make_branches(self, num_branches, block, num_blocks, num_channels):
branches = []
for i in range(num_branches):
branches.append(
self._make_one_branch(i, block, num_blocks, num_channels))
return ModuleList(branches)
def _make_fuse_layers(self):
if self.num_branches == 1:
return None
num_branches = self.num_branches
in_channels = self.in_channels
fuse_layers = []
num_out_branches = num_branches if self.multiscale_output else 1
for i in range(num_out_branches):
fuse_layer = []
for j in range(num_branches):
if j > i:
fuse_layer.append(
nn.Sequential(
build_conv_layer(self.conv_cfg,
in_channels[j],
in_channels[i],
kernel_size=1,
stride=1,
padding=0,
bias=False),
build_norm_layer(self.norm_cfg, in_channels[i])[1],
nn.Upsample(scale_factor=2**(j - i),
mode='nearest')))
elif j == i:
fuse_layer.append(None)
else:
conv_downsamples = []
for k in range(i - j):
if k == i - j - 1:
conv_downsamples.append(
nn.Sequential(
build_conv_layer(self.conv_cfg,
in_channels[j],
in_channels[i],
kernel_size=3,
stride=2,
padding=1,
bias=False),
build_norm_layer(self.norm_cfg,
in_channels[i])[1]))
else:
conv_downsamples.append(
nn.Sequential(
build_conv_layer(self.conv_cfg,
in_channels[j],
in_channels[j],
kernel_size=3,
stride=2,
padding=1,
bias=False),
build_norm_layer(self.norm_cfg,
in_channels[j])[1],
nn.ReLU(inplace=False)))
fuse_layer.append(nn.Sequential(*conv_downsamples))
fuse_layers.append(nn.ModuleList(fuse_layer))
return nn.ModuleList(fuse_layers)
def forward(self, x):
"""Forward function."""
if self.num_branches == 1:
return [self.branches[0](x[0])]
for i in range(self.num_branches):
x[i] = self.branches[i](x[i])
x_fuse = []
for i in range(len(self.fuse_layers)):
y = 0
for j in range(self.num_branches):
if i == j:
y += x[j]
else:
y += self.fuse_layers[i][j](x[j])
x_fuse.append(self.relu(y))
return x_fuse
class PoseHighResolutionNet(BaseModule):
"""HRNet backbone.
`High-Resolution Representations for Labeling Pixels and Regions
arXiv: <https://arxiv.org/abs/1904.04514>`_.
Args:
extra (dict): Detailed configuration for each stage of HRNet.
There must be 4 stages, the configuration for each stage must have
5 keys:
- num_modules(int): The number of HRModule in this stage.
- num_branches(int): The number of branches in the HRModule.
- block(str): The type of convolution block.
- num_blocks(tuple): The number of blocks in each branch.
The length must be equal to num_branches.
- num_channels(tuple): The number of channels in each branch.
The length must be equal to num_branches.
in_channels (int): Number of input image channels. Default: 3.
conv_cfg (dict): Dictionary to construct and config conv layer.
norm_cfg (dict): Dictionary to construct and config norm layer.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Default: True.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
zero_init_residual (bool): Whether to use zero init for last norm layer
in resblocks to let them behave as identity. Default: False.
multiscale_output (bool): Whether to output multi-level features
produced by multiple branches. If False, only the first level
feature will be output. Default: True.
num_joints(int): the number of output for the final layer. Default: 24.
pretrained (str, optional): Model pretrained path. Default: None.
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
"""
blocks_dict = {'BASIC': BasicBlock, 'BOTTLENECK': Bottleneck}
def __init__(self,
extra,
in_channels=3,
conv_cfg=None,
norm_cfg=dict(type='BN'),
norm_eval=True,
with_cp=False,
num_joints=24,
zero_init_residual=False,
multiscale_output=True,
pretrained=None,
init_cfg=None):
super(PoseHighResolutionNet, self).__init__(init_cfg)
self.pretrained = pretrained
assert not (init_cfg and pretrained), \
'init_cfg and pretrained cannot be specified at the same time'
if isinstance(pretrained, str):
warnings.warn('DeprecationWarning: pretrained is deprecated, '
'please use "init_cfg" instead')
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
elif pretrained is None:
if init_cfg is None:
self.init_cfg = [
dict(type='Kaiming', layer='Conv2d'),
dict(type='Constant',
val=1,
layer=['_BatchNorm', 'GroupNorm'])
]
else:
raise TypeError('pretrained must be a str or None')
# Assert configurations of 4 stages are in extra
assert 'stage1' in extra and 'stage2' in extra \
and 'stage3' in extra and 'stage4' in extra
# Assert whether the length of `num_blocks` and `num_channels` are
# equal to `num_branches`
for i in range(4):
cfg = extra[f'stage{i + 1}']
assert len(cfg['num_blocks']) == cfg['num_branches'] and \
len(cfg['num_channels']) == cfg['num_branches']
self.extra = extra
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.norm_eval = norm_eval
self.with_cp = with_cp
self.zero_init_residual = zero_init_residual
# stem net
self.norm1_name, norm1 = build_norm_layer(self.norm_cfg, 64, postfix=1)
self.norm2_name, norm2 = build_norm_layer(self.norm_cfg, 64, postfix=2)
self.conv1 = build_conv_layer(self.conv_cfg,
in_channels,
64,
kernel_size=3,
stride=2,
padding=1,
bias=False)
self.add_module(self.norm1_name, norm1)
self.conv2 = build_conv_layer(self.conv_cfg,
64,
64,
kernel_size=3,
stride=2,
padding=1,
bias=False)
self.add_module(self.norm2_name, norm2)
self.relu = nn.ReLU(inplace=True)
# stage 1
self.stage1_cfg = self.extra['stage1']
num_channels = self.stage1_cfg['num_channels'][0]
block_type = self.stage1_cfg['block']
num_blocks = self.stage1_cfg['num_blocks'][0]
block = self.blocks_dict[block_type]
stage1_out_channels = num_channels * block.expansion
self.layer1 = self._make_layer(block, 64, num_channels, num_blocks)
# stage 2
self.stage2_cfg = self.extra['stage2']
num_channels = self.stage2_cfg['num_channels']
block_type = self.stage2_cfg['block']
block = self.blocks_dict[block_type]
num_channels = [channel * block.expansion for channel in num_channels]
self.transition1 = self._make_transition_layer([stage1_out_channels],
num_channels)
self.stage2, pre_stage_channels = self._make_stage(
self.stage2_cfg, num_channels)
# stage 3
self.stage3_cfg = self.extra['stage3']
num_channels = self.stage3_cfg['num_channels']
block_type = self.stage3_cfg['block']
block = self.blocks_dict[block_type]
num_channels = [channel * block.expansion for channel in num_channels]
self.transition2 = self._make_transition_layer(pre_stage_channels,
num_channels)
self.stage3, pre_stage_channels = self._make_stage(
self.stage3_cfg, num_channels)
# stage 4
self.stage4_cfg = self.extra['stage4']
num_channels = self.stage4_cfg['num_channels']
block_type = self.stage4_cfg['block']
block = self.blocks_dict[block_type]
num_channels = [channel * block.expansion for channel in num_channels]
self.transition3 = self._make_transition_layer(pre_stage_channels,
num_channels)
self.stage4, pre_stage_channels = self._make_stage(
self.stage4_cfg, num_channels, multiscale_output=multiscale_output)
# self.pretrained_layers = extra['pretrained_layers']
self.final_layer = build_conv_layer(
cfg=self.conv_cfg,
in_channels=pre_stage_channels[0],
out_channels=num_joints,
kernel_size=extra['final_conv_kernel'],
stride=1,
padding=1 if extra['final_conv_kernel'] == 3 else 0)
if extra['downsample'] and extra['use_conv']:
self.downsample_stage_1 = self._make_downsample_layer(
3, num_channel=self.stage2_cfg['num_channels'][0])
self.downsample_stage_2 = self._make_downsample_layer(
2, num_channel=self.stage2_cfg['num_channels'][-1])
self.downsample_stage_3 = self._make_downsample_layer(
1, num_channel=self.stage3_cfg['num_channels'][-1])
elif not extra['downsample'] and extra['use_conv']:
self.upsample_stage_2 = self._make_upsample_layer(
1, num_channel=self.stage2_cfg['num_channels'][-1])
self.upsample_stage_3 = self._make_upsample_layer(
2, num_channel=self.stage3_cfg['num_channels'][-1])
self.upsample_stage_4 = self._make_upsample_layer(
3, num_channel=self.stage4_cfg['num_channels'][-1])
@property
def norm1(self):
"""nn.Module: the normalization layer named "norm1" """
return getattr(self, self.norm1_name)
@property
def norm2(self):
"""nn.Module: the normalization layer named "norm2" """
return getattr(self, self.norm2_name)
def _make_transition_layer(self, num_channels_pre_layer,
num_channels_cur_layer):
num_branches_cur = len(num_channels_cur_layer)
num_branches_pre = len(num_channels_pre_layer)
transition_layers = []
for i in range(num_branches_cur):
if i < num_branches_pre:
if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
transition_layers.append(
nn.Sequential(
build_conv_layer(self.conv_cfg,
num_channels_pre_layer[i],
num_channels_cur_layer[i],
kernel_size=3,
stride=1,
padding=1,
bias=False),
build_norm_layer(self.norm_cfg,
num_channels_cur_layer[i])[1],
nn.ReLU(inplace=True)))
else:
transition_layers.append(None)
else:
conv_downsamples = []
for j in range(i + 1 - num_branches_pre):
in_channels = num_channels_pre_layer[-1]
out_channels = num_channels_cur_layer[i] \
if j == i - num_branches_pre else in_channels
conv_downsamples.append(
nn.Sequential(
build_conv_layer(self.conv_cfg,
in_channels,
out_channels,
kernel_size=3,
stride=2,
padding=1,
bias=False),
build_norm_layer(self.norm_cfg, out_channels)[1],
nn.ReLU(inplace=True)))
transition_layers.append(nn.Sequential(*conv_downsamples))
return nn.ModuleList(transition_layers)
def _make_layer(self, block, inplanes, planes, blocks, stride=1):
downsample = None
if stride != 1 or inplanes != planes * block.expansion:
downsample = nn.Sequential(
build_conv_layer(self.conv_cfg,
inplanes,
planes * block.expansion,
kernel_size=1,
stride=stride,
bias=False),
build_norm_layer(self.norm_cfg, planes * block.expansion)[1])
layers = []
block_init_cfg = None
if self.pretrained is None and not hasattr(
self, 'init_cfg') and self.zero_init_residual:
if block is BasicBlock:
block_init_cfg = dict(type='Constant',
val=0,
override=dict(name='norm2'))
elif block is Bottleneck:
block_init_cfg = dict(type='Constant',
val=0,
override=dict(name='norm3'))
layers.append(
block(
inplanes,
planes,
stride,
downsample=downsample,
with_cp=self.with_cp,
norm_cfg=self.norm_cfg,
conv_cfg=self.conv_cfg,
init_cfg=block_init_cfg,
))
inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(
block(inplanes,
planes,
with_cp=self.with_cp,
norm_cfg=self.norm_cfg,
conv_cfg=self.conv_cfg,
init_cfg=block_init_cfg))
return Sequential(*layers)
def _make_stage(self, layer_config, in_channels, multiscale_output=True):
num_modules = layer_config['num_modules']
num_branches = layer_config['num_branches']
num_blocks = layer_config['num_blocks']
num_channels = layer_config['num_channels']
block = self.blocks_dict[layer_config['block']]
hr_modules = []
block_init_cfg = None
if self.pretrained is None and not hasattr(
self, 'init_cfg') and self.zero_init_residual:
if block is BasicBlock:
block_init_cfg = dict(type='Constant',
val=0,
override=dict(name='norm2'))
elif block is Bottleneck:
block_init_cfg = dict(type='Constant',
val=0,
override=dict(name='norm3'))
for i in range(num_modules):
# multi_scale_output is only used for the last module
if not multiscale_output and i == num_modules - 1:
reset_multiscale_output = False
else:
reset_multiscale_output = True
hr_modules.append(
HRModule(num_branches,
block,
num_blocks,
in_channels,
num_channels,
reset_multiscale_output,
with_cp=self.with_cp,
norm_cfg=self.norm_cfg,
conv_cfg=self.conv_cfg,
block_init_cfg=block_init_cfg))
return Sequential(*hr_modules), in_channels
def _make_upsample_layer(self, num_layers, num_channel, kernel_size=3):
layers = []
for i in range(num_layers):
layers.append(
nn.Upsample(scale_factor=2,
mode='bilinear',
align_corners=True))
layers.append(
build_conv_layer(
cfg=self.conv_cfg,
in_channels=num_channel,
out_channels=num_channel,
kernel_size=kernel_size,
stride=1,
padding=1,
bias=False,
))
layers.append(build_norm_layer(self.norm_cfg, num_channel)[1])
layers.append(nn.ReLU(inplace=True))
return nn.Sequential(*layers)
def _make_downsample_layer(self, num_layers, num_channel, kernel_size=3):
layers = []
for i in range(num_layers):
layers.append(
build_conv_layer(
cfg=self.conv_cfg,
in_channels=num_channel,
out_channels=num_channel,
kernel_size=kernel_size,
stride=2,
padding=1,
bias=False,
))
layers.append(build_norm_layer(self.norm_cfg, num_channel)[1])
layers.append(nn.ReLU(inplace=True))
return nn.Sequential(*layers)
def forward(self, x):
"""Forward function."""
x = self.conv1(x)
x = self.norm1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.norm2(x)
x = self.relu(x)
x = self.layer1(x)
x_list = []
for i in range(self.stage2_cfg['num_branches']):
if self.transition1[i] is not None:
x_list.append(self.transition1[i](x))
else:
x_list.append(x)
y_list = self.stage2(x_list)
x_list = []
for i in range(self.stage3_cfg['num_branches']):
if self.transition2[i] is not None:
x_list.append(self.transition2[i](y_list[-1]))
else:
x_list.append(y_list[i])
y_list = self.stage3(x_list)
x_list = []
for i in range(self.stage4_cfg['num_branches']):
if self.transition3[i] is not None:
x_list.append(self.transition3[i](y_list[-1]))
else:
x_list.append(y_list[i])
y_list = self.stage4(x_list)
if self.extra['return_list']:
return y_list
elif self.extra['downsample']:
if self.extra['use_conv']:
# Downsampling with strided convolutions
x1 = self.downsample_stage_1(y_list[0])
x2 = self.downsample_stage_2(y_list[1])
x3 = self.downsample_stage_3(y_list[2])
x = torch.cat([x1, x2, x3, y_list[3]], 1)
else:
# Downsampling with interpolation
x0_h, x0_w = y_list[3].size(2), y_list[3].size(3)
x1 = F.interpolate(y_list[0],
size=(x0_h, x0_w),
mode='bilinear',
align_corners=True)
x2 = F.interpolate(y_list[1],
size=(x0_h, x0_w),
mode='bilinear',
align_corners=True)
x3 = F.interpolate(y_list[2],
size=(x0_h, x0_w),
mode='bilinear',
align_corners=True)
x = torch.cat([x1, x2, x3, y_list[3]], 1)
else:
if self.extra['use_conv']:
# Upsampling with interpolations + convolutions
x1 = self.upsample_stage_2(y_list[1])
x2 = self.upsample_stage_3(y_list[2])
x3 = self.upsample_stage_4(y_list[3])
x = torch.cat([y_list[0], x1, x2, x3], 1)
else:
# Upsampling with interpolation
x0_h, x0_w = y_list[0].size(2), y_list[0].size(3)
x1 = F.interpolate(y_list[1],
size=(x0_h, x0_w),
mode='bilinear',
align_corners=True)
x2 = F.interpolate(y_list[2],
size=(x0_h, x0_w),
mode='bilinear',
align_corners=True)
x3 = F.interpolate(y_list[3],
size=(x0_h, x0_w),
mode='bilinear',
align_corners=True)
x = torch.cat([y_list[0], x1, x2, x3], 1)
return x
def train(self, mode=True):
"""Convert the model into training mode will keeping the normalization
layer freezed."""
super(PoseHighResolutionNet, self).train(mode)
if mode and self.norm_eval:
for m in self.modules():
# trick: eval have effect on BatchNorm only
if isinstance(m, _BatchNorm):
m.eval()
class PoseHighResolutionNetExpose(PoseHighResolutionNet):
"""HRNet backbone for expose."""
def __init__(self,
extra,
in_channels=3,
conv_cfg=None,
norm_cfg=dict(type='BN'),
norm_eval=True,
with_cp=False,
num_joints=24,
zero_init_residual=False,
multiscale_output=True,
pretrained=None,
init_cfg=None):
super().__init__(extra, in_channels, conv_cfg, norm_cfg, norm_eval,
with_cp, num_joints, zero_init_residual,
multiscale_output, pretrained, init_cfg)
in_dims = (2**2 * self.stage2_cfg['num_channels'][-1] +
2**1 * self.stage3_cfg['num_channels'][-1] +
self.stage4_cfg['num_channels'][-1])
self.conv_layers = self._make_conv_layer(in_channels=in_dims,
num_layers=5)
self.subsample_3 = self._make_subsample_layer(
in_channels=self.stage2_cfg['num_channels'][-1], num_layers=2)
self.subsample_2 = self._make_subsample_layer(
in_channels=self.stage3_cfg['num_channels'][-1], num_layers=1)
def _make_conv_layer(self,
in_channels=2048,
num_layers=3,
num_filters=2048,
stride=1):
layers = []
for i in range(num_layers):
downsample = nn.Conv2d(in_channels,
num_filters,
stride=1,
kernel_size=1,
bias=False)
layers.append(
Bottleneck(in_channels,
num_filters // 4,
downsample=downsample))
in_channels = num_filters
return nn.Sequential(*layers)
def _make_subsample_layer(self, in_channels=96, num_layers=3, stride=2):
layers = []
for i in range(num_layers):
layers.append(
nn.Conv2d(in_channels=in_channels,
out_channels=2 * in_channels,
kernel_size=3,
stride=stride,
padding=1))
in_channels = 2 * in_channels
layers.append(nn.BatchNorm2d(in_channels, momentum=0.1))
layers.append(nn.ReLU(inplace=True))
return nn.Sequential(*layers)
def forward(self, x):
"""Forward function."""
x = self.conv1(x)
x = self.norm1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.norm2(x)
x = self.relu(x)
x = self.layer1(x)
x_list = []
for i in range(self.stage2_cfg['num_branches']):
if self.transition1[i] is not None:
x_list.append(self.transition1[i](x))
else:
x_list.append(x)
y_list = self.stage2(x_list)
x_list = []
for i in range(self.stage3_cfg['num_branches']):
if self.transition2[i] is not None:
x_list.append(self.transition2[i](y_list[-1]))
else:
x_list.append(y_list[i])
y_list = self.stage3(x_list)
x_list = []
for i in range(self.stage4_cfg['num_branches']):
if self.transition3[i] is not None:
x_list.append(self.transition3[i](y_list[-1]))
else:
x_list.append(y_list[i])
x3 = self.subsample_3(x_list[1])
x2 = self.subsample_2(x_list[2])
x1 = x_list[3]
xf = self.conv_layers(torch.cat([x3, x2, x1], dim=1))
xf = xf.mean(dim=(2, 3))
xf = xf.view(xf.size(0), -1)
return xf