|
|
|
|
|
|
|
|
|
|
|
import torch.nn as nn |
|
import torch.utils.checkpoint as checkpoint |
|
from se_block import SEBlock |
|
import torch |
|
import numpy as np |
|
|
|
|
|
def conv_bn_relu(in_channels, out_channels, kernel_size, stride, padding, groups=1): |
|
result = nn.Sequential() |
|
result.add_module('conv', nn.Conv2d(in_channels=in_channels, out_channels=out_channels, |
|
kernel_size=kernel_size, stride=stride, padding=padding, groups=groups, bias=False)) |
|
result.add_module('bn', nn.BatchNorm2d(num_features=out_channels)) |
|
result.add_module('relu', nn.ReLU()) |
|
return result |
|
|
|
def conv_bn(in_channels, out_channels, kernel_size, stride, padding, groups=1): |
|
result = nn.Sequential() |
|
result.add_module('conv', nn.Conv2d(in_channels=in_channels, out_channels=out_channels, |
|
kernel_size=kernel_size, stride=stride, padding=padding, groups=groups, bias=False)) |
|
result.add_module('bn', nn.BatchNorm2d(num_features=out_channels)) |
|
return result |
|
|
|
class RepVGGplusBlock(nn.Module): |
|
|
|
def __init__(self, in_channels, out_channels, kernel_size, |
|
stride=1, padding=0, dilation=1, groups=1, padding_mode='zeros', |
|
deploy=False, |
|
use_post_se=False): |
|
super(RepVGGplusBlock, self).__init__() |
|
self.deploy = deploy |
|
self.groups = groups |
|
self.in_channels = in_channels |
|
|
|
assert kernel_size == 3 |
|
assert padding == 1 |
|
|
|
self.nonlinearity = nn.ReLU() |
|
|
|
if use_post_se: |
|
self.post_se = SEBlock(out_channels, internal_neurons=out_channels // 4) |
|
else: |
|
self.post_se = nn.Identity() |
|
|
|
if deploy: |
|
self.rbr_reparam = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, |
|
padding=padding, dilation=dilation, groups=groups, bias=True, padding_mode=padding_mode) |
|
else: |
|
if out_channels == in_channels and stride == 1: |
|
self.rbr_identity = nn.BatchNorm2d(num_features=out_channels) |
|
else: |
|
self.rbr_identity = None |
|
self.rbr_dense = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, groups=groups) |
|
padding_11 = padding - kernel_size // 2 |
|
self.rbr_1x1 = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=stride, padding=padding_11, groups=groups) |
|
|
|
def forward(self, x, L2): |
|
|
|
if self.deploy: |
|
return self.post_se(self.nonlinearity(self.rbr_reparam(x))), None |
|
|
|
if self.rbr_identity is None: |
|
id_out = 0 |
|
else: |
|
id_out = self.rbr_identity(x) |
|
out = self.rbr_dense(x) + self.rbr_1x1(x) + id_out |
|
out = self.post_se(self.nonlinearity(out)) |
|
|
|
|
|
t3 = (self.rbr_dense.bn.weight / ((self.rbr_dense.bn.running_var + self.rbr_dense.bn.eps).sqrt())).reshape(-1, 1, 1, 1).detach() |
|
t1 = (self.rbr_1x1.bn.weight / ((self.rbr_1x1.bn.running_var + self.rbr_1x1.bn.eps).sqrt())).reshape(-1, 1, 1, 1).detach() |
|
K3 = self.rbr_dense.conv.weight |
|
K1 = self.rbr_1x1.conv.weight |
|
|
|
l2_loss_circle = (K3 ** 2).sum() - (K3[:, :, 1:2, 1:2] ** 2).sum() |
|
eq_kernel = K3[:,:,1:2,1:2] * t3 + K1 * t1 |
|
l2_loss_eq_kernel = (eq_kernel ** 2 / (t3 ** 2 + t1 ** 2)).sum() |
|
|
|
return out, L2 + l2_loss_circle + l2_loss_eq_kernel |
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_equivalent_kernel_bias(self): |
|
kernel3x3, bias3x3 = self._fuse_bn_tensor(self.rbr_dense) |
|
kernel1x1, bias1x1 = self._fuse_bn_tensor(self.rbr_1x1) |
|
kernelid, biasid = self._fuse_bn_tensor(self.rbr_identity) |
|
return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid |
|
|
|
def _pad_1x1_to_3x3_tensor(self, kernel1x1): |
|
if kernel1x1 is None: |
|
return 0 |
|
else: |
|
return torch.nn.functional.pad(kernel1x1, [1, 1, 1, 1]) |
|
|
|
def _fuse_bn_tensor(self, branch): |
|
if branch is None: |
|
return 0, 0 |
|
if isinstance(branch, nn.Sequential): |
|
|
|
kernel, running_mean, running_var, gamma, beta, eps = branch.conv.weight, branch.bn.running_mean, branch.bn.running_var, branch.bn.weight, branch.bn.bias, branch.bn.eps |
|
else: |
|
|
|
assert isinstance(branch, nn.BatchNorm2d) |
|
if not hasattr(self, 'id_tensor'): |
|
|
|
input_dim = self.in_channels // self.groups |
|
kernel_value = np.zeros((self.in_channels, input_dim, 3, 3), dtype=np.float32) |
|
for i in range(self.in_channels): |
|
kernel_value[i, i % input_dim, 1, 1] = 1 |
|
self.id_tensor = torch.from_numpy(kernel_value).to(branch.weight.device) |
|
kernel, running_mean, running_var, gamma, beta, eps = self.id_tensor, branch.running_mean, branch.running_var, branch.weight, branch.bias, branch.eps |
|
std = (running_var + eps).sqrt() |
|
t = (gamma / std).reshape(-1, 1, 1, 1) |
|
return kernel * t, beta - running_mean * gamma / std |
|
|
|
def switch_to_deploy(self): |
|
if hasattr(self, 'rbr_reparam'): |
|
return |
|
kernel, bias = self.get_equivalent_kernel_bias() |
|
self.rbr_reparam = nn.Conv2d(in_channels=self.rbr_dense.conv.in_channels, |
|
out_channels=self.rbr_dense.conv.out_channels, |
|
kernel_size=self.rbr_dense.conv.kernel_size, stride=self.rbr_dense.conv.stride, |
|
padding=self.rbr_dense.conv.padding, dilation=self.rbr_dense.conv.dilation, |
|
groups=self.rbr_dense.conv.groups, bias=True) |
|
self.rbr_reparam.weight.data = kernel |
|
self.rbr_reparam.bias.data = bias |
|
self.__delattr__('rbr_dense') |
|
self.__delattr__('rbr_1x1') |
|
if hasattr(self, 'rbr_identity'): |
|
self.__delattr__('rbr_identity') |
|
if hasattr(self, 'id_tensor'): |
|
self.__delattr__('id_tensor') |
|
self.deploy = True |
|
|
|
|
|
|
|
class RepVGGplusStage(nn.Module): |
|
|
|
def __init__(self, in_planes, planes, num_blocks, stride, use_checkpoint, use_post_se=False, deploy=False): |
|
super().__init__() |
|
strides = [stride] + [1] * (num_blocks - 1) |
|
blocks = [] |
|
self.in_planes = in_planes |
|
for stride in strides: |
|
cur_groups = 1 |
|
blocks.append(RepVGGplusBlock(in_channels=self.in_planes, out_channels=planes, kernel_size=3, |
|
stride=stride, padding=1, groups=cur_groups, deploy=deploy, use_post_se=use_post_se)) |
|
self.in_planes = planes |
|
self.blocks = nn.ModuleList(blocks) |
|
self.use_checkpoint = use_checkpoint |
|
|
|
def forward(self, x, L2): |
|
for block in self.blocks: |
|
if self.use_checkpoint: |
|
x, L2 = checkpoint.checkpoint(block, x, L2) |
|
else: |
|
x, L2 = block(x, L2) |
|
return x, L2 |
|
|
|
|
|
class RepVGGplus(nn.Module): |
|
|
|
def __init__(self, num_blocks, num_classes, |
|
width_multiplier, override_groups_map=None, |
|
deploy=False, |
|
use_post_se=False, |
|
use_checkpoint=False): |
|
super().__init__() |
|
|
|
self.deploy = deploy |
|
self.override_groups_map = override_groups_map or dict() |
|
self.use_post_se = use_post_se |
|
self.use_checkpoint = use_checkpoint |
|
self.num_classes = num_classes |
|
self.nonlinear = 'relu' |
|
|
|
self.in_planes = min(64, int(64 * width_multiplier[0])) |
|
self.stage0 = RepVGGplusBlock(in_channels=3, out_channels=self.in_planes, kernel_size=3, stride=2, padding=1, deploy=self.deploy, use_post_se=use_post_se) |
|
self.cur_layer_idx = 1 |
|
self.stage1 = RepVGGplusStage(self.in_planes, int(64 * width_multiplier[0]), num_blocks[0], stride=2, use_checkpoint=use_checkpoint, use_post_se=use_post_se, deploy=deploy) |
|
self.stage2 = RepVGGplusStage(int(64 * width_multiplier[0]), int(128 * width_multiplier[1]), num_blocks[1], stride=2, use_checkpoint=use_checkpoint, use_post_se=use_post_se, deploy=deploy) |
|
|
|
self.stage3_first = RepVGGplusStage(int(128 * width_multiplier[1]), int(256 * width_multiplier[2]), num_blocks[2] // 2, stride=2, use_checkpoint=use_checkpoint, use_post_se=use_post_se, deploy=deploy) |
|
self.stage3_second = RepVGGplusStage(int(256 * width_multiplier[2]), int(256 * width_multiplier[2]), num_blocks[2] // 2, stride=1, use_checkpoint=use_checkpoint, use_post_se=use_post_se, deploy=deploy) |
|
self.stage4 = RepVGGplusStage(int(256 * width_multiplier[2]), int(512 * width_multiplier[3]), num_blocks[3], stride=2, use_checkpoint=use_checkpoint, use_post_se=use_post_se, deploy=deploy) |
|
self.gap = nn.AdaptiveAvgPool2d(output_size=1) |
|
self.linear = nn.Linear(int(512 * width_multiplier[3]), num_classes) |
|
|
|
if not self.deploy: |
|
self.stage1_aux = self._build_aux_for_stage(self.stage1) |
|
self.stage2_aux = self._build_aux_for_stage(self.stage2) |
|
self.stage3_first_aux = self._build_aux_for_stage(self.stage3_first) |
|
|
|
def _build_aux_for_stage(self, stage): |
|
stage_out_channels = list(stage.blocks.children())[-1].rbr_dense.conv.out_channels |
|
downsample = conv_bn_relu(in_channels=stage_out_channels, out_channels=stage_out_channels, kernel_size=3, stride=2, padding=1) |
|
fc = nn.Linear(stage_out_channels, self.num_classes, bias=True) |
|
return nn.Sequential(downsample, nn.AdaptiveAvgPool2d(1), nn.Flatten(), fc) |
|
|
|
def forward(self, x): |
|
if self.deploy: |
|
out, _ = self.stage0(x, L2=None) |
|
out, _ = self.stage1(out, L2=None) |
|
out, _ = self.stage2(out, L2=None) |
|
out, _ = self.stage3_first(out, L2=None) |
|
out, _ = self.stage3_second(out, L2=None) |
|
out, _ = self.stage4(out, L2=None) |
|
y = self.gap(out) |
|
y = y.view(y.size(0), -1) |
|
y = self.linear(y) |
|
return y |
|
|
|
else: |
|
out, L2 = self.stage0(x, L2=0.0) |
|
out, L2 = self.stage1(out, L2=L2) |
|
stage1_aux = self.stage1_aux(out) |
|
out, L2 = self.stage2(out, L2=L2) |
|
stage2_aux = self.stage2_aux(out) |
|
out, L2 = self.stage3_first(out, L2=L2) |
|
stage3_first_aux = self.stage3_first_aux(out) |
|
out, L2 = self.stage3_second(out, L2=L2) |
|
out, L2 = self.stage4(out, L2=L2) |
|
y = self.gap(out) |
|
y = y.view(y.size(0), -1) |
|
y = self.linear(y) |
|
return { |
|
'main': y, |
|
'stage1_aux': stage1_aux, |
|
'stage2_aux': stage2_aux, |
|
'stage3_first_aux': stage3_first_aux, |
|
'L2': L2 |
|
} |
|
|
|
def switch_repvggplus_to_deploy(self): |
|
for m in self.modules(): |
|
if hasattr(m, 'switch_to_deploy'): |
|
m.switch_to_deploy() |
|
if hasattr(m, 'use_checkpoint'): |
|
m.use_checkpoint = False |
|
if hasattr(self, 'stage1_aux'): |
|
self.__delattr__('stage1_aux') |
|
if hasattr(self, 'stage2_aux'): |
|
self.__delattr__('stage2_aux') |
|
if hasattr(self, 'stage3_first_aux'): |
|
self.__delattr__('stage3_first_aux') |
|
self.deploy = True |
|
|
|
|
|
|
|
|
|
|
|
def create_RepVGGplus_L2pse(deploy=False, use_checkpoint=False): |
|
return RepVGGplus(num_blocks=[8, 14, 24, 1], num_classes=1000, |
|
width_multiplier=[2.5, 2.5, 2.5, 5], override_groups_map=None, deploy=deploy, use_post_se=True, |
|
use_checkpoint=use_checkpoint) |
|
|
|
repvggplus_func_dict = { |
|
'RepVGGplus-L2pse': create_RepVGGplus_L2pse, |
|
} |
|
def get_RepVGGplus_func_by_name(name): |
|
return repvggplus_func_dict[name] |