# Copyright (c) OpenMMLab. All rights reserved. import torch import torch.nn as nn from mmcv.ops import SAConv2d def test_sacconv(): # test with normal cast x = torch.rand(1, 3, 256, 256) saconv = SAConv2d(3, 5, kernel_size=3, padding=1) sac_out = saconv(x) refer_conv = nn.Conv2d(3, 5, kernel_size=3, padding=1) refer_out = refer_conv(x) assert sac_out.shape == refer_out.shape # test with dilation >= 2 dalited_saconv = SAConv2d(3, 5, kernel_size=3, padding=2, dilation=2) dalited_sac_out = dalited_saconv(x) refer_conv = nn.Conv2d(3, 5, kernel_size=3, padding=2, dilation=2) refer_out = refer_conv(x) assert dalited_sac_out.shape == refer_out.shape # test with deform deform_saconv = SAConv2d(3, 5, kernel_size=3, padding=1, use_deform=True) if torch.cuda.is_available(): x = torch.rand(1, 3, 256, 256).cuda() deform_saconv = SAConv2d( 3, 5, kernel_size=3, padding=1, use_deform=True).cuda() deform_sac_out = deform_saconv(x).cuda() refer_conv = nn.Conv2d(3, 5, kernel_size=3, padding=1).cuda() refer_out = refer_conv(x) assert deform_sac_out.shape == refer_out.shape else: deform_sac_out = deform_saconv(x) refer_conv = nn.Conv2d(3, 5, kernel_size=3, padding=1) refer_out = refer_conv(x) assert deform_sac_out.shape == refer_out.shape # test with groups >= 2 x = torch.rand(1, 4, 256, 256) group_saconv = SAConv2d(4, 4, kernel_size=3, padding=1, groups=2) group_sac_out = group_saconv(x) refer_conv = nn.Conv2d(4, 4, kernel_size=3, padding=1, groups=2) refer_out = refer_conv(x) assert group_sac_out.shape == refer_out.shape