Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from typing import List, Sequence | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from mmcv.cnn import ConvModule | |
from mmengine.model import BaseModule | |
from mmdet.registry import MODELS | |
from mmdet.utils import ConfigType, OptMultiConfig | |
from ..layers import ResLayer | |
from .resnet import BasicBlock | |
class HourglassModule(BaseModule): | |
"""Hourglass Module for HourglassNet backbone. | |
Generate module recursively and use BasicBlock as the base unit. | |
Args: | |
depth (int): Depth of current HourglassModule. | |
stage_channels (list[int]): Feature channels of sub-modules in current | |
and follow-up HourglassModule. | |
stage_blocks (list[int]): Number of sub-modules stacked in current and | |
follow-up HourglassModule. | |
norm_cfg (ConfigType): Dictionary to construct and config norm layer. | |
Defaults to `dict(type='BN', requires_grad=True)` | |
upsample_cfg (ConfigType): Config dict for interpolate layer. | |
Defaults to `dict(mode='nearest')` | |
init_cfg (dict or ConfigDict, optional): the config to control the | |
initialization. | |
""" | |
def __init__(self, | |
depth: int, | |
stage_channels: List[int], | |
stage_blocks: List[int], | |
norm_cfg: ConfigType = dict(type='BN', requires_grad=True), | |
upsample_cfg: ConfigType = dict(mode='nearest'), | |
init_cfg: OptMultiConfig = None) -> None: | |
super().__init__(init_cfg) | |
self.depth = depth | |
cur_block = stage_blocks[0] | |
next_block = stage_blocks[1] | |
cur_channel = stage_channels[0] | |
next_channel = stage_channels[1] | |
self.up1 = ResLayer( | |
BasicBlock, cur_channel, cur_channel, cur_block, norm_cfg=norm_cfg) | |
self.low1 = ResLayer( | |
BasicBlock, | |
cur_channel, | |
next_channel, | |
cur_block, | |
stride=2, | |
norm_cfg=norm_cfg) | |
if self.depth > 1: | |
self.low2 = HourglassModule(depth - 1, stage_channels[1:], | |
stage_blocks[1:]) | |
else: | |
self.low2 = ResLayer( | |
BasicBlock, | |
next_channel, | |
next_channel, | |
next_block, | |
norm_cfg=norm_cfg) | |
self.low3 = ResLayer( | |
BasicBlock, | |
next_channel, | |
cur_channel, | |
cur_block, | |
norm_cfg=norm_cfg, | |
downsample_first=False) | |
self.up2 = F.interpolate | |
self.upsample_cfg = upsample_cfg | |
def forward(self, x: torch.Tensor) -> nn.Module: | |
"""Forward function.""" | |
up1 = self.up1(x) | |
low1 = self.low1(x) | |
low2 = self.low2(low1) | |
low3 = self.low3(low2) | |
# Fixing `scale factor` (e.g. 2) is common for upsampling, but | |
# in some cases the spatial size is mismatched and error will arise. | |
if 'scale_factor' in self.upsample_cfg: | |
up2 = self.up2(low3, **self.upsample_cfg) | |
else: | |
shape = up1.shape[2:] | |
up2 = self.up2(low3, size=shape, **self.upsample_cfg) | |
return up1 + up2 | |
class HourglassNet(BaseModule): | |
"""HourglassNet backbone. | |
Stacked Hourglass Networks for Human Pose Estimation. | |
More details can be found in the `paper | |
<https://arxiv.org/abs/1603.06937>`_ . | |
Args: | |
downsample_times (int): Downsample times in a HourglassModule. | |
num_stacks (int): Number of HourglassModule modules stacked, | |
1 for Hourglass-52, 2 for Hourglass-104. | |
stage_channels (Sequence[int]): Feature channel of each sub-module in a | |
HourglassModule. | |
stage_blocks (Sequence[int]): Number of sub-modules stacked in a | |
HourglassModule. | |
feat_channel (int): Feature channel of conv after a HourglassModule. | |
norm_cfg (norm_cfg): Dictionary to construct and config norm layer. | |
init_cfg (dict or ConfigDict, optional): the config to control the | |
initialization. | |
Example: | |
>>> from mmdet.models import HourglassNet | |
>>> import torch | |
>>> self = HourglassNet() | |
>>> self.eval() | |
>>> inputs = torch.rand(1, 3, 511, 511) | |
>>> level_outputs = self.forward(inputs) | |
>>> for level_output in level_outputs: | |
... print(tuple(level_output.shape)) | |
(1, 256, 128, 128) | |
(1, 256, 128, 128) | |
""" | |
def __init__(self, | |
downsample_times: int = 5, | |
num_stacks: int = 2, | |
stage_channels: Sequence = (256, 256, 384, 384, 384, 512), | |
stage_blocks: Sequence = (2, 2, 2, 2, 2, 4), | |
feat_channel: int = 256, | |
norm_cfg: ConfigType = dict(type='BN', requires_grad=True), | |
init_cfg: OptMultiConfig = None) -> None: | |
assert init_cfg is None, 'To prevent abnormal initialization ' \ | |
'behavior, init_cfg is not allowed to be set' | |
super().__init__(init_cfg) | |
self.num_stacks = num_stacks | |
assert self.num_stacks >= 1 | |
assert len(stage_channels) == len(stage_blocks) | |
assert len(stage_channels) > downsample_times | |
cur_channel = stage_channels[0] | |
self.stem = nn.Sequential( | |
ConvModule( | |
3, cur_channel // 2, 7, padding=3, stride=2, | |
norm_cfg=norm_cfg), | |
ResLayer( | |
BasicBlock, | |
cur_channel // 2, | |
cur_channel, | |
1, | |
stride=2, | |
norm_cfg=norm_cfg)) | |
self.hourglass_modules = nn.ModuleList([ | |
HourglassModule(downsample_times, stage_channels, stage_blocks) | |
for _ in range(num_stacks) | |
]) | |
self.inters = ResLayer( | |
BasicBlock, | |
cur_channel, | |
cur_channel, | |
num_stacks - 1, | |
norm_cfg=norm_cfg) | |
self.conv1x1s = nn.ModuleList([ | |
ConvModule( | |
cur_channel, cur_channel, 1, norm_cfg=norm_cfg, act_cfg=None) | |
for _ in range(num_stacks - 1) | |
]) | |
self.out_convs = nn.ModuleList([ | |
ConvModule( | |
cur_channel, feat_channel, 3, padding=1, norm_cfg=norm_cfg) | |
for _ in range(num_stacks) | |
]) | |
self.remap_convs = nn.ModuleList([ | |
ConvModule( | |
feat_channel, cur_channel, 1, norm_cfg=norm_cfg, act_cfg=None) | |
for _ in range(num_stacks - 1) | |
]) | |
self.relu = nn.ReLU(inplace=True) | |
def init_weights(self) -> None: | |
"""Init module weights.""" | |
# Training Centripetal Model needs to reset parameters for Conv2d | |
super().init_weights() | |
for m in self.modules(): | |
if isinstance(m, nn.Conv2d): | |
m.reset_parameters() | |
def forward(self, x: torch.Tensor) -> List[torch.Tensor]: | |
"""Forward function.""" | |
inter_feat = self.stem(x) | |
out_feats = [] | |
for ind in range(self.num_stacks): | |
single_hourglass = self.hourglass_modules[ind] | |
out_conv = self.out_convs[ind] | |
hourglass_feat = single_hourglass(inter_feat) | |
out_feat = out_conv(hourglass_feat) | |
out_feats.append(out_feat) | |
if ind < self.num_stacks - 1: | |
inter_feat = self.conv1x1s[ind]( | |
inter_feat) + self.remap_convs[ind]( | |
out_feat) | |
inter_feat = self.inters[ind](self.relu(inter_feat)) | |
return out_feats | |