KyanChen's picture
init
f549064
raw
history blame
25.3 kB
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint
from mmcv.cnn import build_activation_layer, build_norm_layer
from mmcv.cnn.bricks import DropPath
from mmengine.model import BaseModule
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
from mmcls.registry import MODELS
from .base_backbone import BaseBackbone
def conv_bn(in_channels,
out_channels,
kernel_size,
stride,
padding,
groups,
dilation=1,
norm_cfg=dict(type='BN')):
"""Construct a sequential conv and bn.
Args:
in_channels (int): Dimension of input features.
out_channels (int): Dimension of output features.
kernel_size (int): kernel_size of the convolution.
stride (int): stride of the convolution.
padding (int): stride of the convolution.
groups (int): groups of the convolution.
dilation (int): dilation of the convolution. Default to 1.
norm_cfg (dict): dictionary to construct and config norm layer.
Default to ``dict(type='BN', requires_grad=True)``.
Returns:
nn.Sequential(): A conv layer and a batch norm layer.
"""
if padding is None:
padding = kernel_size // 2
result = nn.Sequential()
result.add_module(
'conv',
nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=False))
result.add_module('bn', build_norm_layer(norm_cfg, out_channels)[1])
return result
def conv_bn_relu(in_channels,
out_channels,
kernel_size,
stride,
padding,
groups,
dilation=1):
"""Construct a sequential conv, bn and relu.
Args:
in_channels (int): Dimension of input features.
out_channels (int): Dimension of output features.
kernel_size (int): kernel_size of the convolution.
stride (int): stride of the convolution.
padding (int): stride of the convolution.
groups (int): groups of the convolution.
dilation (int): dilation of the convolution. Default to 1.
Returns:
nn.Sequential(): A conv layer, batch norm layer and a relu function.
"""
if padding is None:
padding = kernel_size // 2
result = conv_bn(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
groups=groups,
dilation=dilation)
result.add_module('nonlinear', nn.ReLU())
return result
def fuse_bn(conv, bn):
"""Fuse the parameters in a branch with a conv and bn.
Args:
conv (nn.Conv2d): The convolution module to fuse.
bn (nn.BatchNorm2d): The batch normalization to fuse.
Returns:
tuple[torch.Tensor, torch.Tensor]: The parameters obtained after
fusing the parameters of conv and bn in one branch.
The first element is the weight and the second is the bias.
"""
kernel = conv.weight
running_mean = bn.running_mean
running_var = bn.running_var
gamma = bn.weight
beta = bn.bias
eps = bn.eps
std = (running_var + eps).sqrt()
t = (gamma / std).reshape(-1, 1, 1, 1)
return kernel * t, beta - running_mean * gamma / std
class ReparamLargeKernelConv(BaseModule):
"""Super large kernel implemented by with large convolutions.
Input: Tensor with shape [B, C, H, W].
Output: Tensor with shape [B, C, H, W].
Args:
in_channels (int): Dimension of input features.
out_channels (int): Dimension of output features.
kernel_size (int): kernel_size of the large convolution.
stride (int): stride of the large convolution.
groups (int): groups of the large convolution.
small_kernel (int): kernel_size of the small convolution.
small_kernel_merged (bool): Whether to switch the model structure to
deployment mode (merge the small kernel to the large kernel).
Default to False.
init_cfg (dict or list[dict], optional): Initialization config dict.
Defaults to None
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride,
groups,
small_kernel,
small_kernel_merged=False,
init_cfg=None):
super(ReparamLargeKernelConv, self).__init__(init_cfg)
self.kernel_size = kernel_size
self.small_kernel = small_kernel
self.small_kernel_merged = small_kernel_merged
# We assume the conv does not change the feature map size,
# so padding = k//2.
# Otherwise, you may configure padding as you wish,
# and change the padding of small_conv accordingly.
padding = kernel_size // 2
if small_kernel_merged:
self.lkb_reparam = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=1,
groups=groups,
bias=True)
else:
self.lkb_origin = conv_bn(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=1,
groups=groups)
if small_kernel is not None:
assert small_kernel <= kernel_size
self.small_conv = conv_bn(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=small_kernel,
stride=stride,
padding=small_kernel // 2,
groups=groups,
dilation=1)
def forward(self, inputs):
if hasattr(self, 'lkb_reparam'):
out = self.lkb_reparam(inputs)
else:
out = self.lkb_origin(inputs)
if hasattr(self, 'small_conv'):
out += self.small_conv(inputs)
return out
def get_equivalent_kernel_bias(self):
eq_k, eq_b = fuse_bn(self.lkb_origin.conv, self.lkb_origin.bn)
if hasattr(self, 'small_conv'):
small_k, small_b = fuse_bn(self.small_conv.conv,
self.small_conv.bn)
eq_b += small_b
# add to the central part
eq_k += nn.functional.pad(
small_k, [(self.kernel_size - self.small_kernel) // 2] * 4)
return eq_k, eq_b
def merge_kernel(self):
"""Switch the model structure from training mode to deployment mode."""
if self.small_kernel_merged:
return
eq_k, eq_b = self.get_equivalent_kernel_bias()
self.lkb_reparam = nn.Conv2d(
in_channels=self.lkb_origin.conv.in_channels,
out_channels=self.lkb_origin.conv.out_channels,
kernel_size=self.lkb_origin.conv.kernel_size,
stride=self.lkb_origin.conv.stride,
padding=self.lkb_origin.conv.padding,
dilation=self.lkb_origin.conv.dilation,
groups=self.lkb_origin.conv.groups,
bias=True)
self.lkb_reparam.weight.data = eq_k
self.lkb_reparam.bias.data = eq_b
self.__delattr__('lkb_origin')
if hasattr(self, 'small_conv'):
self.__delattr__('small_conv')
self.small_kernel_merged = True
class ConvFFN(BaseModule):
"""Mlp implemented by with 1*1 convolutions.
Input: Tensor with shape [B, C, H, W].
Output: Tensor with shape [B, C, H, W].
Args:
in_channels (int): Dimension of input features.
internal_channels (int): Dimension of hidden features.
out_channels (int): Dimension of output features.
drop_path (float): Stochastic depth rate. Defaults to 0.
norm_cfg (dict): dictionary to construct and config norm layer.
Default to ``dict(type='BN', requires_grad=True)``.
act_cfg (dict): The config dict for activation between pointwise
convolution. Defaults to ``dict(type='GELU')``.
init_cfg (dict or list[dict], optional): Initialization config dict.
Defaults to None.
"""
def __init__(self,
in_channels,
internal_channels,
out_channels,
drop_path,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='GELU'),
init_cfg=None):
super(ConvFFN, self).__init__(init_cfg)
self.drop_path = DropPath(
drop_prob=drop_path) if drop_path > 0. else nn.Identity()
self.preffn_bn = build_norm_layer(norm_cfg, in_channels)[1]
self.pw1 = conv_bn(
in_channels=in_channels,
out_channels=internal_channels,
kernel_size=1,
stride=1,
padding=0,
groups=1)
self.pw2 = conv_bn(
in_channels=internal_channels,
out_channels=out_channels,
kernel_size=1,
stride=1,
padding=0,
groups=1)
self.nonlinear = build_activation_layer(act_cfg)
def forward(self, x):
out = self.preffn_bn(x)
out = self.pw1(out)
out = self.nonlinear(out)
out = self.pw2(out)
return x + self.drop_path(out)
class RepLKBlock(BaseModule):
"""RepLKBlock for RepLKNet backbone.
Args:
in_channels (int): The input channels of the block.
dw_channels (int): The intermediate channels of the block,
i.e., input channels of the large kernel convolution.
block_lk_size (int): size of the super large kernel. Defaults: 31.
small_kernel (int): size of the parallel small kernel. Defaults: 5.
drop_path (float): Stochastic depth rate. Defaults: 0.
small_kernel_merged (bool): Whether to switch the model structure to
deployment mode (merge the small kernel to the large kernel).
Default to False.
norm_cfg (dict): dictionary to construct and config norm layer.
Default to ``dict(type='BN', requires_grad=True)``.
act_cfg (dict): Config dict for activation layer.
Default to ``dict(type='ReLU')``.
init_cfg (dict or list[dict], optional): Initialization config dict.
Default to None
"""
def __init__(self,
in_channels,
dw_channels,
block_lk_size,
small_kernel,
drop_path,
small_kernel_merged=False,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
init_cfg=None):
super(RepLKBlock, self).__init__(init_cfg)
self.pw1 = conv_bn_relu(in_channels, dw_channels, 1, 1, 0, groups=1)
self.pw2 = conv_bn(dw_channels, in_channels, 1, 1, 0, groups=1)
self.large_kernel = ReparamLargeKernelConv(
in_channels=dw_channels,
out_channels=dw_channels,
kernel_size=block_lk_size,
stride=1,
groups=dw_channels,
small_kernel=small_kernel,
small_kernel_merged=small_kernel_merged)
self.lk_nonlinear = build_activation_layer(act_cfg)
self.prelkb_bn = build_norm_layer(norm_cfg, in_channels)[1]
self.drop_path = DropPath(
drop_prob=drop_path) if drop_path > 0. else nn.Identity()
# print('drop path:', self.drop_path)
def forward(self, x):
out = self.prelkb_bn(x)
out = self.pw1(out)
out = self.large_kernel(out)
out = self.lk_nonlinear(out)
out = self.pw2(out)
return x + self.drop_path(out)
class RepLKNetStage(BaseModule):
"""
generate RepLKNet blocks for a stage
return: RepLKNet blocks
Args:
channels (int): The input channels of the stage.
num_blocks (int): The number of blocks of the stage.
stage_lk_size (int): size of the super large kernel. Defaults: 31.
drop_path (float): Stochastic depth rate. Defaults: 0.
small_kernel (int): size of the parallel small kernel. Defaults: 5.
dw_ratio (float): The intermediate channels
expansion ratio of the block. Defaults: 1.
ffn_ratio (float): Mlp expansion ratio. Defaults to 4.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default to False.
small_kernel_merged (bool): Whether to switch the model structure to
deployment mode (merge the small kernel to the large kernel).
Default to False.
norm_intermediate_features (bool): Construct and config norm layer
or not.
Using True will normalize the intermediate features for
downstream dense prediction tasks.
norm_cfg (dict): dictionary to construct and config norm layer.
Default to ``dict(type='BN', requires_grad=True)``.
init_cfg (dict or list[dict], optional): Initialization config dict.
Default to None
"""
def __init__(
self,
channels,
num_blocks,
stage_lk_size,
drop_path,
small_kernel,
dw_ratio=1,
ffn_ratio=4,
with_cp=False, # train with torch.utils.checkpoint to save memory
small_kernel_merged=False,
norm_intermediate_features=False,
norm_cfg=dict(type='BN'),
init_cfg=None):
super(RepLKNetStage, self).__init__(init_cfg)
self.with_cp = with_cp
blks = []
for i in range(num_blocks):
block_drop_path = drop_path[i] if isinstance(drop_path,
list) else drop_path
# Assume all RepLK Blocks within a stage share the same lk_size.
# You may tune it on your own model.
replk_block = RepLKBlock(
in_channels=channels,
dw_channels=int(channels * dw_ratio),
block_lk_size=stage_lk_size,
small_kernel=small_kernel,
drop_path=block_drop_path,
small_kernel_merged=small_kernel_merged)
convffn_block = ConvFFN(
in_channels=channels,
internal_channels=int(channels * ffn_ratio),
out_channels=channels,
drop_path=block_drop_path)
blks.append(replk_block)
blks.append(convffn_block)
self.blocks = nn.ModuleList(blks)
if norm_intermediate_features:
self.norm = build_norm_layer(norm_cfg, channels)[1]
else:
self.norm = nn.Identity()
def forward(self, x):
for blk in self.blocks:
if self.with_cp:
x = checkpoint.checkpoint(blk, x) # Save training memory
else:
x = blk(x)
return x
@MODELS.register_module()
class RepLKNet(BaseBackbone):
"""RepLKNet backbone.
A PyTorch impl of :
`Scaling Up Your Kernels to 31x31: Revisiting Large Kernel Design in CNNs
<https://arxiv.org/abs/2203.06717>`_
Args:
arch (str | dict): The parameter of RepLKNet.
If it's a dict, it should contain the following keys:
- large_kernel_sizes (Sequence[int]):
Large kernel size in each stage.
- layers (Sequence[int]): Number of blocks in each stage.
- channels (Sequence[int]): Number of channels in each stage.
- small_kernel (int): size of the parallel small kernel.
- dw_ratio (float): The intermediate channels
expansion ratio of the block.
in_channels (int): Number of input image channels. Default to 3.
ffn_ratio (float): Mlp expansion ratio. Defaults to 4.
out_indices (Sequence[int]): Output from which stages.
Default to (3, ).
strides (Sequence[int]): Strides of the first block of each stage.
Default to (2, 2, 2, 2).
dilations (Sequence[int]): Dilation of each stage.
Default to (1, 1, 1, 1).
frozen_stages (int): Stages to be frozen
(all param fixed). -1 means not freezing any parameters.
Default to -1.
conv_cfg (dict | None): The config dict for conv layers.
Default to None.
norm_cfg (dict): The config dict for norm layers.
Default to ``dict(type='BN')``.
act_cfg (dict): Config dict for activation layer.
Default to ``dict(type='ReLU')``.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default to False.
deploy (bool): Whether to switch the model structure to deployment
mode. Default to False.
norm_intermediate_features (bool): Construct and
config norm layer or not.
Using True will normalize the intermediate features
for downstream dense prediction tasks.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Default to False.
init_cfg (dict or list[dict], optional): Initialization config dict.
"""
arch_settings = {
'31B':
dict(
large_kernel_sizes=[31, 29, 27, 13],
layers=[2, 2, 18, 2],
channels=[128, 256, 512, 1024],
small_kernel=5,
dw_ratio=1),
'31L':
dict(
large_kernel_sizes=[31, 29, 27, 13],
layers=[2, 2, 18, 2],
channels=[192, 384, 768, 1536],
small_kernel=5,
dw_ratio=1),
'XL':
dict(
large_kernel_sizes=[27, 27, 27, 13],
layers=[2, 2, 18, 2],
channels=[256, 512, 1024, 2048],
small_kernel=None,
dw_ratio=1.5),
}
def __init__(self,
arch,
in_channels=3,
ffn_ratio=4,
out_indices=(3, ),
strides=(2, 2, 2, 2),
dilations=(1, 1, 1, 1),
frozen_stages=-1,
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
with_cp=False,
drop_path_rate=0.3,
small_kernel_merged=False,
norm_intermediate_features=False,
norm_eval=False,
init_cfg=[
dict(type='Kaiming', layer=['Conv2d']),
dict(
type='Constant',
val=1,
layer=['_BatchNorm', 'GroupNorm'])
]):
super(RepLKNet, self).__init__(init_cfg)
if isinstance(arch, str):
assert arch in self.arch_settings, \
f'"arch": "{arch}" is not one of the arch_settings'
arch = self.arch_settings[arch]
elif not isinstance(arch, dict):
raise TypeError('Expect "arch" to be either a string '
f'or a dict, got {type(arch)}')
assert len(arch['layers']) == len(
arch['channels']) == len(strides) == len(dilations)
assert max(out_indices) < len(arch['layers'])
self.arch = arch
self.in_channels = in_channels
self.out_indices = out_indices
self.strides = strides
self.dilations = dilations
self.frozen_stages = frozen_stages
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.with_cp = with_cp
self.drop_path_rate = drop_path_rate
self.small_kernel_merged = small_kernel_merged
self.norm_eval = norm_eval
self.norm_intermediate_features = norm_intermediate_features
self.out_indices = out_indices
base_width = self.arch['channels'][0]
self.norm_intermediate_features = norm_intermediate_features
self.num_stages = len(self.arch['layers'])
self.stem = nn.ModuleList([
conv_bn_relu(
in_channels=in_channels,
out_channels=base_width,
kernel_size=3,
stride=2,
padding=1,
groups=1),
conv_bn_relu(
in_channels=base_width,
out_channels=base_width,
kernel_size=3,
stride=1,
padding=1,
groups=base_width),
conv_bn_relu(
in_channels=base_width,
out_channels=base_width,
kernel_size=1,
stride=1,
padding=0,
groups=1),
conv_bn_relu(
in_channels=base_width,
out_channels=base_width,
kernel_size=3,
stride=2,
padding=1,
groups=base_width)
])
# stochastic depth. We set block-wise drop-path rate.
# The higher level blocks are more likely to be dropped.
# This implementation follows Swin.
dpr = [
x.item() for x in torch.linspace(0, drop_path_rate,
sum(self.arch['layers']))
]
self.stages = nn.ModuleList()
self.transitions = nn.ModuleList()
for stage_idx in range(self.num_stages):
layer = RepLKNetStage(
channels=self.arch['channels'][stage_idx],
num_blocks=self.arch['layers'][stage_idx],
stage_lk_size=self.arch['large_kernel_sizes'][stage_idx],
drop_path=dpr[sum(self.arch['layers'][:stage_idx]
):sum(self.arch['layers'][:stage_idx + 1])],
small_kernel=self.arch['small_kernel'],
dw_ratio=self.arch['dw_ratio'],
ffn_ratio=ffn_ratio,
with_cp=with_cp,
small_kernel_merged=small_kernel_merged,
norm_intermediate_features=(stage_idx in out_indices))
self.stages.append(layer)
if stage_idx < len(self.arch['layers']) - 1:
transition = nn.Sequential(
conv_bn_relu(
self.arch['channels'][stage_idx],
self.arch['channels'][stage_idx + 1],
1,
1,
0,
groups=1),
conv_bn_relu(
self.arch['channels'][stage_idx + 1],
self.arch['channels'][stage_idx + 1],
3,
stride=2,
padding=1,
groups=self.arch['channels'][stage_idx + 1]))
self.transitions.append(transition)
def forward_features(self, x):
x = self.stem[0](x)
for stem_layer in self.stem[1:]:
if self.with_cp:
x = checkpoint.checkpoint(stem_layer, x) # save memory
else:
x = stem_layer(x)
# Need the intermediate feature maps
outs = []
for stage_idx in range(self.num_stages):
x = self.stages[stage_idx](x)
if stage_idx in self.out_indices:
outs.append(self.stages[stage_idx].norm(x))
# For RepLKNet-XL normalize the features
# before feeding them into the heads
if stage_idx < self.num_stages - 1:
x = self.transitions[stage_idx](x)
return outs
def forward(self, x):
x = self.forward_features(x)
return tuple(x)
def _freeze_stages(self):
if self.frozen_stages >= 0:
self.stem.eval()
for param in self.stem.parameters():
param.requires_grad = False
for i in range(self.frozen_stages):
stage = self.stages[i]
stage.eval()
for param in stage.parameters():
param.requires_grad = False
def train(self, mode=True):
super(RepLKNet, self).train(mode)
self._freeze_stages()
if mode and self.norm_eval:
for m in self.modules():
if isinstance(m, _BatchNorm):
m.eval()
def switch_to_deploy(self):
for m in self.modules():
if hasattr(m, 'merge_kernel'):
m.merge_kernel()
self.small_kernel_merged = True