|
|
|
from torch import nn |
|
from torch.autograd import Function |
|
from torch.nn.modules.utils import _pair |
|
|
|
from ..utils import ext_loader |
|
|
|
ext_module = ext_loader.load_ext('_ext', |
|
['psamask_forward', 'psamask_backward']) |
|
|
|
|
|
class PSAMaskFunction(Function): |
|
|
|
@staticmethod |
|
def symbolic(g, input, psa_type, mask_size): |
|
return g.op( |
|
'mmcv::MMCVPSAMask', |
|
input, |
|
psa_type_i=psa_type, |
|
mask_size_i=mask_size) |
|
|
|
@staticmethod |
|
def forward(ctx, input, psa_type, mask_size): |
|
ctx.psa_type = psa_type |
|
ctx.mask_size = _pair(mask_size) |
|
ctx.save_for_backward(input) |
|
|
|
h_mask, w_mask = ctx.mask_size |
|
batch_size, channels, h_feature, w_feature = input.size() |
|
assert channels == h_mask * w_mask |
|
output = input.new_zeros( |
|
(batch_size, h_feature * w_feature, h_feature, w_feature)) |
|
|
|
ext_module.psamask_forward( |
|
input, |
|
output, |
|
psa_type=psa_type, |
|
num_=batch_size, |
|
h_feature=h_feature, |
|
w_feature=w_feature, |
|
h_mask=h_mask, |
|
w_mask=w_mask, |
|
half_h_mask=(h_mask - 1) // 2, |
|
half_w_mask=(w_mask - 1) // 2) |
|
return output |
|
|
|
@staticmethod |
|
def backward(ctx, grad_output): |
|
input = ctx.saved_tensors[0] |
|
psa_type = ctx.psa_type |
|
h_mask, w_mask = ctx.mask_size |
|
batch_size, channels, h_feature, w_feature = input.size() |
|
grad_input = grad_output.new_zeros( |
|
(batch_size, channels, h_feature, w_feature)) |
|
ext_module.psamask_backward( |
|
grad_output, |
|
grad_input, |
|
psa_type=psa_type, |
|
num_=batch_size, |
|
h_feature=h_feature, |
|
w_feature=w_feature, |
|
h_mask=h_mask, |
|
w_mask=w_mask, |
|
half_h_mask=(h_mask - 1) // 2, |
|
half_w_mask=(w_mask - 1) // 2) |
|
return grad_input, None, None, None |
|
|
|
|
|
psa_mask = PSAMaskFunction.apply |
|
|
|
|
|
class PSAMask(nn.Module): |
|
|
|
def __init__(self, psa_type, mask_size=None): |
|
super(PSAMask, self).__init__() |
|
assert psa_type in ['collect', 'distribute'] |
|
if psa_type == 'collect': |
|
psa_type_enum = 0 |
|
else: |
|
psa_type_enum = 1 |
|
self.psa_type_enum = psa_type_enum |
|
self.mask_size = mask_size |
|
self.psa_type = psa_type |
|
|
|
def forward(self, input): |
|
return psa_mask(input, self.psa_type_enum, self.mask_size) |
|
|
|
def __repr__(self): |
|
s = self.__class__.__name__ |
|
s += f'(psa_type={self.psa_type}, ' |
|
s += f'mask_size={self.mask_size})' |
|
return s |
|
|