Spaces:
Sleeping
Sleeping
# Copyright (c) OpenMMLab. All rights reserved. | |
import torch | |
import torch.nn as nn | |
from mmcv.ops import SAConv2d | |
def test_sacconv(): | |
# test with normal cast | |
x = torch.rand(1, 3, 256, 256) | |
saconv = SAConv2d(3, 5, kernel_size=3, padding=1) | |
sac_out = saconv(x) | |
refer_conv = nn.Conv2d(3, 5, kernel_size=3, padding=1) | |
refer_out = refer_conv(x) | |
assert sac_out.shape == refer_out.shape | |
# test with dilation >= 2 | |
dalited_saconv = SAConv2d(3, 5, kernel_size=3, padding=2, dilation=2) | |
dalited_sac_out = dalited_saconv(x) | |
refer_conv = nn.Conv2d(3, 5, kernel_size=3, padding=2, dilation=2) | |
refer_out = refer_conv(x) | |
assert dalited_sac_out.shape == refer_out.shape | |
# test with deform | |
deform_saconv = SAConv2d(3, 5, kernel_size=3, padding=1, use_deform=True) | |
if torch.cuda.is_available(): | |
x = torch.rand(1, 3, 256, 256).cuda() | |
deform_saconv = SAConv2d( | |
3, 5, kernel_size=3, padding=1, use_deform=True).cuda() | |
deform_sac_out = deform_saconv(x).cuda() | |
refer_conv = nn.Conv2d(3, 5, kernel_size=3, padding=1).cuda() | |
refer_out = refer_conv(x) | |
assert deform_sac_out.shape == refer_out.shape | |
else: | |
deform_sac_out = deform_saconv(x) | |
refer_conv = nn.Conv2d(3, 5, kernel_size=3, padding=1) | |
refer_out = refer_conv(x) | |
assert deform_sac_out.shape == refer_out.shape | |
# test with groups >= 2 | |
x = torch.rand(1, 4, 256, 256) | |
group_saconv = SAConv2d(4, 4, kernel_size=3, padding=1, groups=2) | |
group_sac_out = group_saconv(x) | |
refer_conv = nn.Conv2d(4, 4, kernel_size=3, padding=1, groups=2) | |
refer_out = refer_conv(x) | |
assert group_sac_out.shape == refer_out.shape | |