Spaces:
Sleeping
Sleeping
File size: 1,724 Bytes
d7e58f0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 |
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmcv.ops import SAConv2d
def test_sacconv():
# test with normal cast
x = torch.rand(1, 3, 256, 256)
saconv = SAConv2d(3, 5, kernel_size=3, padding=1)
sac_out = saconv(x)
refer_conv = nn.Conv2d(3, 5, kernel_size=3, padding=1)
refer_out = refer_conv(x)
assert sac_out.shape == refer_out.shape
# test with dilation >= 2
dalited_saconv = SAConv2d(3, 5, kernel_size=3, padding=2, dilation=2)
dalited_sac_out = dalited_saconv(x)
refer_conv = nn.Conv2d(3, 5, kernel_size=3, padding=2, dilation=2)
refer_out = refer_conv(x)
assert dalited_sac_out.shape == refer_out.shape
# test with deform
deform_saconv = SAConv2d(3, 5, kernel_size=3, padding=1, use_deform=True)
if torch.cuda.is_available():
x = torch.rand(1, 3, 256, 256).cuda()
deform_saconv = SAConv2d(
3, 5, kernel_size=3, padding=1, use_deform=True).cuda()
deform_sac_out = deform_saconv(x).cuda()
refer_conv = nn.Conv2d(3, 5, kernel_size=3, padding=1).cuda()
refer_out = refer_conv(x)
assert deform_sac_out.shape == refer_out.shape
else:
deform_sac_out = deform_saconv(x)
refer_conv = nn.Conv2d(3, 5, kernel_size=3, padding=1)
refer_out = refer_conv(x)
assert deform_sac_out.shape == refer_out.shape
# test with groups >= 2
x = torch.rand(1, 4, 256, 256)
group_saconv = SAConv2d(4, 4, kernel_size=3, padding=1, groups=2)
group_sac_out = group_saconv(x)
refer_conv = nn.Conv2d(4, 4, kernel_size=3, padding=1, groups=2)
refer_out = refer_conv(x)
assert group_sac_out.shape == refer_out.shape
|