Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import torch | |
import torch.nn.functional as F | |
import torch.utils.checkpoint as cp | |
from mmcv.cnn import (ConvModule, build_activation_layer, build_conv_layer, | |
build_norm_layer) | |
from mmengine.model import BaseModule, Sequential | |
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm | |
from torch import nn | |
from mmcls.registry import MODELS | |
from ..utils.se_layer import SELayer | |
from .base_backbone import BaseBackbone | |
class RepVGGBlock(BaseModule): | |
"""RepVGG block for RepVGG backbone. | |
Args: | |
in_channels (int): The input channels of the block. | |
out_channels (int): The output channels of the block. | |
stride (int): Stride of the 3x3 and 1x1 convolution layer. Default: 1. | |
padding (int): Padding of the 3x3 convolution layer. | |
dilation (int): Dilation of the 3x3 convolution layer. | |
groups (int): Groups of the 3x3 and 1x1 convolution layer. Default: 1. | |
padding_mode (str): Padding mode of the 3x3 convolution layer. | |
Default: 'zeros'. | |
se_cfg (None or dict): The configuration of the se module. | |
Default: None. | |
with_cp (bool): Use checkpoint or not. Using checkpoint will save some | |
memory while slowing down the training speed. Default: False. | |
conv_cfg (dict, optional): Config dict for convolution layer. | |
Default: None, which means using conv2d. | |
norm_cfg (dict): dictionary to construct and config norm layer. | |
Default: dict(type='BN', requires_grad=True). | |
act_cfg (dict): Config dict for activation layer. | |
Default: dict(type='ReLU'). | |
deploy (bool): Whether to switch the model structure to | |
deployment mode. Default: False. | |
init_cfg (dict or list[dict], optional): Initialization config dict. | |
Default: None | |
""" | |
def __init__(self, | |
in_channels, | |
out_channels, | |
stride=1, | |
padding=1, | |
dilation=1, | |
groups=1, | |
padding_mode='zeros', | |
se_cfg=None, | |
with_cp=False, | |
conv_cfg=None, | |
norm_cfg=dict(type='BN'), | |
act_cfg=dict(type='ReLU'), | |
deploy=False, | |
init_cfg=None): | |
super(RepVGGBlock, self).__init__(init_cfg) | |
assert se_cfg is None or isinstance(se_cfg, dict) | |
self.in_channels = in_channels | |
self.out_channels = out_channels | |
self.stride = stride | |
self.padding = padding | |
self.dilation = dilation | |
self.groups = groups | |
self.se_cfg = se_cfg | |
self.with_cp = with_cp | |
self.conv_cfg = conv_cfg | |
self.norm_cfg = norm_cfg | |
self.act_cfg = act_cfg | |
self.deploy = deploy | |
if deploy: | |
self.branch_reparam = build_conv_layer( | |
conv_cfg, | |
in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_size=3, | |
stride=stride, | |
padding=padding, | |
dilation=dilation, | |
groups=groups, | |
bias=True, | |
padding_mode=padding_mode) | |
else: | |
# judge if input shape and output shape are the same. | |
# If true, add a normalized identity shortcut. | |
if out_channels == in_channels and stride == 1 and \ | |
padding == dilation: | |
self.branch_norm = build_norm_layer(norm_cfg, in_channels)[1] | |
else: | |
self.branch_norm = None | |
self.branch_3x3 = self.create_conv_bn( | |
kernel_size=3, | |
dilation=dilation, | |
padding=padding, | |
) | |
self.branch_1x1 = self.create_conv_bn(kernel_size=1) | |
if se_cfg is not None: | |
self.se_layer = SELayer(channels=out_channels, **se_cfg) | |
else: | |
self.se_layer = None | |
self.act = build_activation_layer(act_cfg) | |
def create_conv_bn(self, kernel_size, dilation=1, padding=0): | |
conv_bn = Sequential() | |
conv_bn.add_module( | |
'conv', | |
build_conv_layer( | |
self.conv_cfg, | |
in_channels=self.in_channels, | |
out_channels=self.out_channels, | |
kernel_size=kernel_size, | |
stride=self.stride, | |
dilation=dilation, | |
padding=padding, | |
groups=self.groups, | |
bias=False)) | |
conv_bn.add_module( | |
'norm', | |
build_norm_layer(self.norm_cfg, num_features=self.out_channels)[1]) | |
return conv_bn | |
def forward(self, x): | |
def _inner_forward(inputs): | |
if self.deploy: | |
return self.branch_reparam(inputs) | |
if self.branch_norm is None: | |
branch_norm_out = 0 | |
else: | |
branch_norm_out = self.branch_norm(inputs) | |
inner_out = self.branch_3x3(inputs) + self.branch_1x1( | |
inputs) + branch_norm_out | |
if self.se_cfg is not None: | |
inner_out = self.se_layer(inner_out) | |
return inner_out | |
if self.with_cp and x.requires_grad: | |
out = cp.checkpoint(_inner_forward, x) | |
else: | |
out = _inner_forward(x) | |
out = self.act(out) | |
return out | |
def switch_to_deploy(self): | |
"""Switch the model structure from training mode to deployment mode.""" | |
if self.deploy: | |
return | |
assert self.norm_cfg['type'] == 'BN', \ | |
"Switch is not allowed when norm_cfg['type'] != 'BN'." | |
reparam_weight, reparam_bias = self.reparameterize() | |
self.branch_reparam = build_conv_layer( | |
self.conv_cfg, | |
self.in_channels, | |
self.out_channels, | |
kernel_size=3, | |
stride=self.stride, | |
padding=self.padding, | |
dilation=self.dilation, | |
groups=self.groups, | |
bias=True) | |
self.branch_reparam.weight.data = reparam_weight | |
self.branch_reparam.bias.data = reparam_bias | |
for param in self.parameters(): | |
param.detach_() | |
delattr(self, 'branch_3x3') | |
delattr(self, 'branch_1x1') | |
delattr(self, 'branch_norm') | |
self.deploy = True | |
def reparameterize(self): | |
"""Fuse all the parameters of all branches. | |
Returns: | |
tuple[torch.Tensor, torch.Tensor]: Parameters after fusion of all | |
branches. the first element is the weights and the second is | |
the bias. | |
""" | |
weight_3x3, bias_3x3 = self._fuse_conv_bn(self.branch_3x3) | |
weight_1x1, bias_1x1 = self._fuse_conv_bn(self.branch_1x1) | |
# pad a conv1x1 weight to a conv3x3 weight | |
weight_1x1 = F.pad(weight_1x1, [1, 1, 1, 1], value=0) | |
weight_norm, bias_norm = 0, 0 | |
if self.branch_norm: | |
tmp_conv_bn = self._norm_to_conv3x3(self.branch_norm) | |
weight_norm, bias_norm = self._fuse_conv_bn(tmp_conv_bn) | |
return (weight_3x3 + weight_1x1 + weight_norm, | |
bias_3x3 + bias_1x1 + bias_norm) | |
def _fuse_conv_bn(self, branch): | |
"""Fuse the parameters in a branch with a conv and bn. | |
Args: | |
branch (mmcv.runner.Sequential): A branch with conv and bn. | |
Returns: | |
tuple[torch.Tensor, torch.Tensor]: The parameters obtained after | |
fusing the parameters of conv and bn in one branch. | |
The first element is the weight and the second is the bias. | |
""" | |
if branch is None: | |
return 0, 0 | |
conv_weight = branch.conv.weight | |
running_mean = branch.norm.running_mean | |
running_var = branch.norm.running_var | |
gamma = branch.norm.weight | |
beta = branch.norm.bias | |
eps = branch.norm.eps | |
std = (running_var + eps).sqrt() | |
fused_weight = (gamma / std).reshape(-1, 1, 1, 1) * conv_weight | |
fused_bias = -running_mean * gamma / std + beta | |
return fused_weight, fused_bias | |
def _norm_to_conv3x3(self, branch_nrom): | |
"""Convert a norm layer to a conv3x3-bn sequence. | |
Args: | |
branch (nn.BatchNorm2d): A branch only with bn in the block. | |
Returns: | |
tmp_conv3x3 (mmcv.runner.Sequential): a sequential with conv3x3 and | |
bn. | |
""" | |
input_dim = self.in_channels // self.groups | |
conv_weight = torch.zeros((self.in_channels, input_dim, 3, 3), | |
dtype=branch_nrom.weight.dtype) | |
for i in range(self.in_channels): | |
conv_weight[i, i % input_dim, 1, 1] = 1 | |
conv_weight = conv_weight.to(branch_nrom.weight.device) | |
tmp_conv3x3 = self.create_conv_bn(kernel_size=3) | |
tmp_conv3x3.conv.weight.data = conv_weight | |
tmp_conv3x3.norm = branch_nrom | |
return tmp_conv3x3 | |
class MTSPPF(BaseModule): | |
"""MTSPPF block for YOLOX-PAI RepVGG backbone. | |
Args: | |
in_channels (int): The input channels of the block. | |
out_channels (int): The output channels of the block. | |
norm_cfg (dict): dictionary to construct and config norm layer. | |
Default: dict(type='BN'). | |
act_cfg (dict): Config dict for activation layer. | |
Default: dict(type='ReLU'). | |
kernel_size (int): Kernel size of pooling. Default: 5. | |
""" | |
def __init__(self, | |
in_channels, | |
out_channels, | |
norm_cfg=dict(type='BN'), | |
act_cfg=dict(type='ReLU'), | |
kernel_size=5): | |
super().__init__() | |
hidden_features = in_channels // 2 # hidden channels | |
self.conv1 = ConvModule( | |
in_channels, | |
hidden_features, | |
1, | |
stride=1, | |
norm_cfg=norm_cfg, | |
act_cfg=act_cfg) | |
self.conv2 = ConvModule( | |
hidden_features * 4, | |
out_channels, | |
1, | |
stride=1, | |
norm_cfg=norm_cfg, | |
act_cfg=act_cfg) | |
self.maxpool = nn.MaxPool2d( | |
kernel_size=kernel_size, stride=1, padding=kernel_size // 2) | |
def forward(self, x): | |
x = self.conv1(x) | |
y1 = self.maxpool(x) | |
y2 = self.maxpool(y1) | |
return self.conv2(torch.cat([x, y1, y2, self.maxpool(y2)], 1)) | |
class RepVGG(BaseBackbone): | |
"""RepVGG backbone. | |
A PyTorch impl of : `RepVGG: Making VGG-style ConvNets Great Again | |
<https://arxiv.org/abs/2101.03697>`_ | |
Args: | |
arch (str | dict): RepVGG architecture. If use string, choose from | |
'A0', 'A1`', 'A2', 'B0', 'B1', 'B1g2', 'B1g4', 'B2', 'B2g2', | |
'B2g4', 'B3', 'B3g2', 'B3g4' or 'D2se'. If use dict, it should | |
have below keys: | |
- **num_blocks** (Sequence[int]): Number of blocks in each stage. | |
- **width_factor** (Sequence[float]): Width deflator in each stage. | |
- **group_layer_map** (dict | None): RepVGG Block that declares | |
the need to apply group convolution. | |
- **se_cfg** (dict | None): SE Layer config. | |
- **stem_channels** (int, optional): The stem channels, the final | |
stem channels will be | |
``min(stem_channels, base_channels*width_factor[0])``. | |
If not set here, 64 is used by default in the code. | |
in_channels (int): Number of input image channels. Defaults to 3. | |
base_channels (int): Base channels of RepVGG backbone, work with | |
width_factor together. Defaults to 64. | |
out_indices (Sequence[int]): Output from which stages. | |
Defaults to ``(3, )``. | |
strides (Sequence[int]): Strides of the first block of each stage. | |
Defaults to ``(2, 2, 2, 2)``. | |
dilations (Sequence[int]): Dilation of each stage. | |
Defaults to ``(1, 1, 1, 1)``. | |
frozen_stages (int): Stages to be frozen (all param fixed). -1 means | |
not freezing any parameters. Defaults to -1. | |
conv_cfg (dict | None): The config dict for conv layers. | |
Defaults to None. | |
norm_cfg (dict): The config dict for norm layers. | |
Defaults to ``dict(type='BN')``. | |
act_cfg (dict): Config dict for activation layer. | |
Defaults to ``dict(type='ReLU')``. | |
with_cp (bool): Use checkpoint or not. Using checkpoint will save some | |
memory while slowing down the training speed. Defaults to False. | |
deploy (bool): Whether to switch the model structure to deployment | |
mode. Defaults to False. | |
norm_eval (bool): Whether to set norm layers to eval mode, namely, | |
freeze running stats (mean and var). Note: Effect on Batch Norm | |
and its variants only. Defaults to False. | |
add_ppf (bool): Whether to use the MTSPPF block. Defaults to False. | |
init_cfg (dict or list[dict], optional): Initialization config dict. | |
Defaults to None. | |
""" | |
groupwise_layers = [2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26] | |
g2_layer_map = {layer: 2 for layer in groupwise_layers} | |
g4_layer_map = {layer: 4 for layer in groupwise_layers} | |
arch_settings = { | |
'A0': | |
dict( | |
num_blocks=[2, 4, 14, 1], | |
width_factor=[0.75, 0.75, 0.75, 2.5], | |
group_layer_map=None, | |
se_cfg=None), | |
'A1': | |
dict( | |
num_blocks=[2, 4, 14, 1], | |
width_factor=[1, 1, 1, 2.5], | |
group_layer_map=None, | |
se_cfg=None), | |
'A2': | |
dict( | |
num_blocks=[2, 4, 14, 1], | |
width_factor=[1.5, 1.5, 1.5, 2.75], | |
group_layer_map=None, | |
se_cfg=None), | |
'B0': | |
dict( | |
num_blocks=[4, 6, 16, 1], | |
width_factor=[1, 1, 1, 2.5], | |
group_layer_map=None, | |
se_cfg=None, | |
stem_channels=64), | |
'B1': | |
dict( | |
num_blocks=[4, 6, 16, 1], | |
width_factor=[2, 2, 2, 4], | |
group_layer_map=None, | |
se_cfg=None), | |
'B1g2': | |
dict( | |
num_blocks=[4, 6, 16, 1], | |
width_factor=[2, 2, 2, 4], | |
group_layer_map=g2_layer_map, | |
se_cfg=None), | |
'B1g4': | |
dict( | |
num_blocks=[4, 6, 16, 1], | |
width_factor=[2, 2, 2, 4], | |
group_layer_map=g4_layer_map, | |
se_cfg=None), | |
'B2': | |
dict( | |
num_blocks=[4, 6, 16, 1], | |
width_factor=[2.5, 2.5, 2.5, 5], | |
group_layer_map=None, | |
se_cfg=None), | |
'B2g2': | |
dict( | |
num_blocks=[4, 6, 16, 1], | |
width_factor=[2.5, 2.5, 2.5, 5], | |
group_layer_map=g2_layer_map, | |
se_cfg=None), | |
'B2g4': | |
dict( | |
num_blocks=[4, 6, 16, 1], | |
width_factor=[2.5, 2.5, 2.5, 5], | |
group_layer_map=g4_layer_map, | |
se_cfg=None), | |
'B3': | |
dict( | |
num_blocks=[4, 6, 16, 1], | |
width_factor=[3, 3, 3, 5], | |
group_layer_map=None, | |
se_cfg=None), | |
'B3g2': | |
dict( | |
num_blocks=[4, 6, 16, 1], | |
width_factor=[3, 3, 3, 5], | |
group_layer_map=g2_layer_map, | |
se_cfg=None), | |
'B3g4': | |
dict( | |
num_blocks=[4, 6, 16, 1], | |
width_factor=[3, 3, 3, 5], | |
group_layer_map=g4_layer_map, | |
se_cfg=None), | |
'D2se': | |
dict( | |
num_blocks=[8, 14, 24, 1], | |
width_factor=[2.5, 2.5, 2.5, 5], | |
group_layer_map=None, | |
se_cfg=dict(ratio=16, divisor=1)), | |
'yolox-pai-small': | |
dict( | |
num_blocks=[3, 5, 7, 3], | |
width_factor=[1, 1, 1, 1], | |
group_layer_map=None, | |
se_cfg=None, | |
stem_channels=32), | |
} | |
def __init__(self, | |
arch, | |
in_channels=3, | |
base_channels=64, | |
out_indices=(3, ), | |
strides=(2, 2, 2, 2), | |
dilations=(1, 1, 1, 1), | |
frozen_stages=-1, | |
conv_cfg=None, | |
norm_cfg=dict(type='BN'), | |
act_cfg=dict(type='ReLU'), | |
with_cp=False, | |
deploy=False, | |
norm_eval=False, | |
add_ppf=False, | |
init_cfg=[ | |
dict(type='Kaiming', layer=['Conv2d']), | |
dict( | |
type='Constant', | |
val=1, | |
layer=['_BatchNorm', 'GroupNorm']) | |
]): | |
super(RepVGG, self).__init__(init_cfg) | |
if isinstance(arch, str): | |
assert arch in self.arch_settings, \ | |
f'"arch": "{arch}" is not one of the arch_settings' | |
arch = self.arch_settings[arch] | |
elif not isinstance(arch, dict): | |
raise TypeError('Expect "arch" to be either a string ' | |
f'or a dict, got {type(arch)}') | |
assert len(arch['num_blocks']) == len( | |
arch['width_factor']) == len(strides) == len(dilations) | |
assert max(out_indices) < len(arch['num_blocks']) | |
if arch['group_layer_map'] is not None: | |
assert max(arch['group_layer_map'].keys()) <= sum( | |
arch['num_blocks']) | |
if arch['se_cfg'] is not None: | |
assert isinstance(arch['se_cfg'], dict) | |
self.base_channels = base_channels | |
self.arch = arch | |
self.in_channels = in_channels | |
self.out_indices = out_indices | |
self.strides = strides | |
self.dilations = dilations | |
self.deploy = deploy | |
self.frozen_stages = frozen_stages | |
self.conv_cfg = conv_cfg | |
self.norm_cfg = norm_cfg | |
self.act_cfg = act_cfg | |
self.with_cp = with_cp | |
self.norm_eval = norm_eval | |
# defaults to 64 to prevert BC-breaking if stem_channels | |
# not in arch dict; | |
# the stem channels should not be larger than that of stage1. | |
channels = min( | |
arch.get('stem_channels', 64), | |
int(self.base_channels * self.arch['width_factor'][0])) | |
self.stem = RepVGGBlock( | |
self.in_channels, | |
channels, | |
stride=2, | |
se_cfg=arch['se_cfg'], | |
with_cp=with_cp, | |
conv_cfg=conv_cfg, | |
norm_cfg=norm_cfg, | |
act_cfg=act_cfg, | |
deploy=deploy) | |
next_create_block_idx = 1 | |
self.stages = [] | |
for i in range(len(arch['num_blocks'])): | |
num_blocks = self.arch['num_blocks'][i] | |
stride = self.strides[i] | |
dilation = self.dilations[i] | |
out_channels = int(self.base_channels * 2**i * | |
self.arch['width_factor'][i]) | |
stage, next_create_block_idx = self._make_stage( | |
channels, out_channels, num_blocks, stride, dilation, | |
next_create_block_idx, init_cfg) | |
stage_name = f'stage_{i + 1}' | |
self.add_module(stage_name, stage) | |
self.stages.append(stage_name) | |
channels = out_channels | |
if add_ppf: | |
self.ppf = MTSPPF( | |
out_channels, | |
out_channels, | |
norm_cfg=norm_cfg, | |
act_cfg=act_cfg, | |
kernel_size=5) | |
else: | |
self.ppf = nn.Identity() | |
def _make_stage(self, in_channels, out_channels, num_blocks, stride, | |
dilation, next_create_block_idx, init_cfg): | |
strides = [stride] + [1] * (num_blocks - 1) | |
dilations = [dilation] * num_blocks | |
blocks = [] | |
for i in range(num_blocks): | |
groups = self.arch['group_layer_map'].get( | |
next_create_block_idx, | |
1) if self.arch['group_layer_map'] is not None else 1 | |
blocks.append( | |
RepVGGBlock( | |
in_channels, | |
out_channels, | |
stride=strides[i], | |
padding=dilations[i], | |
dilation=dilations[i], | |
groups=groups, | |
se_cfg=self.arch['se_cfg'], | |
with_cp=self.with_cp, | |
conv_cfg=self.conv_cfg, | |
norm_cfg=self.norm_cfg, | |
act_cfg=self.act_cfg, | |
deploy=self.deploy, | |
init_cfg=init_cfg)) | |
in_channels = out_channels | |
next_create_block_idx += 1 | |
return Sequential(*blocks), next_create_block_idx | |
def forward(self, x): | |
x = self.stem(x) | |
outs = [] | |
for i, stage_name in enumerate(self.stages): | |
stage = getattr(self, stage_name) | |
x = stage(x) | |
if i + 1 == len(self.stages): | |
x = self.ppf(x) | |
if i in self.out_indices: | |
outs.append(x) | |
return tuple(outs) | |
def _freeze_stages(self): | |
if self.frozen_stages >= 0: | |
self.stem.eval() | |
for param in self.stem.parameters(): | |
param.requires_grad = False | |
for i in range(self.frozen_stages): | |
stage = getattr(self, f'stage_{i+1}') | |
stage.eval() | |
for param in stage.parameters(): | |
param.requires_grad = False | |
def train(self, mode=True): | |
super(RepVGG, self).train(mode) | |
self._freeze_stages() | |
if mode and self.norm_eval: | |
for m in self.modules(): | |
if isinstance(m, _BatchNorm): | |
m.eval() | |
def switch_to_deploy(self): | |
for m in self.modules(): | |
if isinstance(m, RepVGGBlock): | |
m.switch_to_deploy() | |
self.deploy = True | |