# Copyright (c) OpenMMLab. All rights reserved. import pytest import torch import torch.nn as nn from mmcv.cnn.bricks import DepthwiseSeparableConvModule def test_depthwise_separable_conv(): with pytest.raises(AssertionError): # conv_cfg must be a dict or None DepthwiseSeparableConvModule(4, 8, 2, groups=2) # test default config conv = DepthwiseSeparableConvModule(3, 8, 2) assert conv.depthwise_conv.conv.groups == 3 assert conv.pointwise_conv.conv.kernel_size == (1, 1) assert not conv.depthwise_conv.with_norm assert not conv.pointwise_conv.with_norm assert conv.depthwise_conv.activate.__class__.__name__ == 'ReLU' assert conv.pointwise_conv.activate.__class__.__name__ == 'ReLU' x = torch.rand(1, 3, 256, 256) output = conv(x) assert output.shape == (1, 8, 255, 255) # test dw_norm_cfg conv = DepthwiseSeparableConvModule(3, 8, 2, dw_norm_cfg=dict(type='BN')) assert conv.depthwise_conv.norm_name == 'bn' assert not conv.pointwise_conv.with_norm x = torch.rand(1, 3, 256, 256) output = conv(x) assert output.shape == (1, 8, 255, 255) # test pw_norm_cfg conv = DepthwiseSeparableConvModule(3, 8, 2, pw_norm_cfg=dict(type='BN')) assert not conv.depthwise_conv.with_norm assert conv.pointwise_conv.norm_name == 'bn' x = torch.rand(1, 3, 256, 256) output = conv(x) assert output.shape == (1, 8, 255, 255) # test norm_cfg conv = DepthwiseSeparableConvModule(3, 8, 2, norm_cfg=dict(type='BN')) assert conv.depthwise_conv.norm_name == 'bn' assert conv.pointwise_conv.norm_name == 'bn' x = torch.rand(1, 3, 256, 256) output = conv(x) assert output.shape == (1, 8, 255, 255) # add test for ['norm', 'conv', 'act'] conv = DepthwiseSeparableConvModule(3, 8, 2, order=('norm', 'conv', 'act')) x = torch.rand(1, 3, 256, 256) output = conv(x) assert output.shape == (1, 8, 255, 255) conv = DepthwiseSeparableConvModule( 3, 8, 3, padding=1, with_spectral_norm=True) assert hasattr(conv.depthwise_conv.conv, 'weight_orig') assert hasattr(conv.pointwise_conv.conv, 'weight_orig') output = conv(x) assert output.shape == (1, 8, 256, 256) conv = DepthwiseSeparableConvModule( 3, 8, 3, padding=1, padding_mode='reflect') assert isinstance(conv.depthwise_conv.padding_layer, nn.ReflectionPad2d) output = conv(x) assert output.shape == (1, 8, 256, 256) # test dw_act_cfg conv = DepthwiseSeparableConvModule( 3, 8, 3, padding=1, dw_act_cfg=dict(type='LeakyReLU')) assert conv.depthwise_conv.activate.__class__.__name__ == 'LeakyReLU' assert conv.pointwise_conv.activate.__class__.__name__ == 'ReLU' output = conv(x) assert output.shape == (1, 8, 256, 256) # test pw_act_cfg conv = DepthwiseSeparableConvModule( 3, 8, 3, padding=1, pw_act_cfg=dict(type='LeakyReLU')) assert conv.depthwise_conv.activate.__class__.__name__ == 'ReLU' assert conv.pointwise_conv.activate.__class__.__name__ == 'LeakyReLU' output = conv(x) assert output.shape == (1, 8, 256, 256) # test act_cfg conv = DepthwiseSeparableConvModule( 3, 8, 3, padding=1, act_cfg=dict(type='LeakyReLU')) assert conv.depthwise_conv.activate.__class__.__name__ == 'LeakyReLU' assert conv.pointwise_conv.activate.__class__.__name__ == 'LeakyReLU' output = conv(x) assert output.shape == (1, 8, 256, 256)