AiOS / mmcv /tests /test_ops /test_saconv.py
ttxskk
update
d7e58f0
raw
history blame
1.72 kB
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmcv.ops import SAConv2d
def test_sacconv():
# test with normal cast
x = torch.rand(1, 3, 256, 256)
saconv = SAConv2d(3, 5, kernel_size=3, padding=1)
sac_out = saconv(x)
refer_conv = nn.Conv2d(3, 5, kernel_size=3, padding=1)
refer_out = refer_conv(x)
assert sac_out.shape == refer_out.shape
# test with dilation >= 2
dalited_saconv = SAConv2d(3, 5, kernel_size=3, padding=2, dilation=2)
dalited_sac_out = dalited_saconv(x)
refer_conv = nn.Conv2d(3, 5, kernel_size=3, padding=2, dilation=2)
refer_out = refer_conv(x)
assert dalited_sac_out.shape == refer_out.shape
# test with deform
deform_saconv = SAConv2d(3, 5, kernel_size=3, padding=1, use_deform=True)
if torch.cuda.is_available():
x = torch.rand(1, 3, 256, 256).cuda()
deform_saconv = SAConv2d(
3, 5, kernel_size=3, padding=1, use_deform=True).cuda()
deform_sac_out = deform_saconv(x).cuda()
refer_conv = nn.Conv2d(3, 5, kernel_size=3, padding=1).cuda()
refer_out = refer_conv(x)
assert deform_sac_out.shape == refer_out.shape
else:
deform_sac_out = deform_saconv(x)
refer_conv = nn.Conv2d(3, 5, kernel_size=3, padding=1)
refer_out = refer_conv(x)
assert deform_sac_out.shape == refer_out.shape
# test with groups >= 2
x = torch.rand(1, 4, 256, 256)
group_saconv = SAConv2d(4, 4, kernel_size=3, padding=1, groups=2)
group_sac_out = group_saconv(x)
refer_conv = nn.Conv2d(4, 4, kernel_size=3, padding=1, groups=2)
refer_out = refer_conv(x)
assert group_sac_out.shape == refer_out.shape