|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from annotator.mmpkg.mmcv.cnn import CONV_LAYERS, ConvAWS2d, constant_init |
|
from annotator.mmpkg.mmcv.ops.deform_conv import deform_conv2d |
|
from annotator.mmpkg.mmcv.utils import TORCH_VERSION, digit_version |
|
|
|
|
|
@CONV_LAYERS.register_module(name='SAC') |
|
class SAConv2d(ConvAWS2d): |
|
"""SAC (Switchable Atrous Convolution) |
|
|
|
This is an implementation of SAC in DetectoRS |
|
(https://arxiv.org/pdf/2006.02334.pdf). |
|
|
|
Args: |
|
in_channels (int): Number of channels in the input image |
|
out_channels (int): Number of channels produced by the convolution |
|
kernel_size (int or tuple): Size of the convolving kernel |
|
stride (int or tuple, optional): Stride of the convolution. Default: 1 |
|
padding (int or tuple, optional): Zero-padding added to both sides of |
|
the input. Default: 0 |
|
padding_mode (string, optional): ``'zeros'``, ``'reflect'``, |
|
``'replicate'`` or ``'circular'``. Default: ``'zeros'`` |
|
dilation (int or tuple, optional): Spacing between kernel elements. |
|
Default: 1 |
|
groups (int, optional): Number of blocked connections from input |
|
channels to output channels. Default: 1 |
|
bias (bool, optional): If ``True``, adds a learnable bias to the |
|
output. Default: ``True`` |
|
use_deform: If ``True``, replace convolution with deformable |
|
convolution. Default: ``False``. |
|
""" |
|
|
|
def __init__(self, |
|
in_channels, |
|
out_channels, |
|
kernel_size, |
|
stride=1, |
|
padding=0, |
|
dilation=1, |
|
groups=1, |
|
bias=True, |
|
use_deform=False): |
|
super().__init__( |
|
in_channels, |
|
out_channels, |
|
kernel_size, |
|
stride=stride, |
|
padding=padding, |
|
dilation=dilation, |
|
groups=groups, |
|
bias=bias) |
|
self.use_deform = use_deform |
|
self.switch = nn.Conv2d( |
|
self.in_channels, 1, kernel_size=1, stride=stride, bias=True) |
|
self.weight_diff = nn.Parameter(torch.Tensor(self.weight.size())) |
|
self.pre_context = nn.Conv2d( |
|
self.in_channels, self.in_channels, kernel_size=1, bias=True) |
|
self.post_context = nn.Conv2d( |
|
self.out_channels, self.out_channels, kernel_size=1, bias=True) |
|
if self.use_deform: |
|
self.offset_s = nn.Conv2d( |
|
self.in_channels, |
|
18, |
|
kernel_size=3, |
|
padding=1, |
|
stride=stride, |
|
bias=True) |
|
self.offset_l = nn.Conv2d( |
|
self.in_channels, |
|
18, |
|
kernel_size=3, |
|
padding=1, |
|
stride=stride, |
|
bias=True) |
|
self.init_weights() |
|
|
|
def init_weights(self): |
|
constant_init(self.switch, 0, bias=1) |
|
self.weight_diff.data.zero_() |
|
constant_init(self.pre_context, 0) |
|
constant_init(self.post_context, 0) |
|
if self.use_deform: |
|
constant_init(self.offset_s, 0) |
|
constant_init(self.offset_l, 0) |
|
|
|
def forward(self, x): |
|
|
|
avg_x = F.adaptive_avg_pool2d(x, output_size=1) |
|
avg_x = self.pre_context(avg_x) |
|
avg_x = avg_x.expand_as(x) |
|
x = x + avg_x |
|
|
|
avg_x = F.pad(x, pad=(2, 2, 2, 2), mode='reflect') |
|
avg_x = F.avg_pool2d(avg_x, kernel_size=5, stride=1, padding=0) |
|
switch = self.switch(avg_x) |
|
|
|
weight = self._get_weight(self.weight) |
|
zero_bias = torch.zeros( |
|
self.out_channels, device=weight.device, dtype=weight.dtype) |
|
|
|
if self.use_deform: |
|
offset = self.offset_s(avg_x) |
|
out_s = deform_conv2d(x, offset, weight, self.stride, self.padding, |
|
self.dilation, self.groups, 1) |
|
else: |
|
if (TORCH_VERSION == 'parrots' |
|
or digit_version(TORCH_VERSION) < digit_version('1.5.0')): |
|
out_s = super().conv2d_forward(x, weight) |
|
elif digit_version(TORCH_VERSION) >= digit_version('1.8.0'): |
|
|
|
out_s = super()._conv_forward(x, weight, zero_bias) |
|
else: |
|
out_s = super()._conv_forward(x, weight) |
|
ori_p = self.padding |
|
ori_d = self.dilation |
|
self.padding = tuple(3 * p for p in self.padding) |
|
self.dilation = tuple(3 * d for d in self.dilation) |
|
weight = weight + self.weight_diff |
|
if self.use_deform: |
|
offset = self.offset_l(avg_x) |
|
out_l = deform_conv2d(x, offset, weight, self.stride, self.padding, |
|
self.dilation, self.groups, 1) |
|
else: |
|
if (TORCH_VERSION == 'parrots' |
|
or digit_version(TORCH_VERSION) < digit_version('1.5.0')): |
|
out_l = super().conv2d_forward(x, weight) |
|
elif digit_version(TORCH_VERSION) >= digit_version('1.8.0'): |
|
|
|
out_l = super()._conv_forward(x, weight, zero_bias) |
|
else: |
|
out_l = super()._conv_forward(x, weight) |
|
|
|
out = switch * out_s + (1 - switch) * out_l |
|
self.padding = ori_p |
|
self.dilation = ori_d |
|
|
|
avg_x = F.adaptive_avg_pool2d(out, output_size=1) |
|
avg_x = self.post_context(avg_x) |
|
avg_x = avg_x.expand_as(out) |
|
out = out + avg_x |
|
return out |
|
|