Spaces:
Sleeping
Sleeping
# Copyright (c) OpenMMLab. All rights reserved. | |
import pytest | |
import torch | |
import torch.nn as nn | |
from mmcv.cnn.bricks import DepthwiseSeparableConvModule | |
def test_depthwise_separable_conv(): | |
with pytest.raises(AssertionError): | |
# conv_cfg must be a dict or None | |
DepthwiseSeparableConvModule(4, 8, 2, groups=2) | |
# test default config | |
conv = DepthwiseSeparableConvModule(3, 8, 2) | |
assert conv.depthwise_conv.conv.groups == 3 | |
assert conv.pointwise_conv.conv.kernel_size == (1, 1) | |
assert not conv.depthwise_conv.with_norm | |
assert not conv.pointwise_conv.with_norm | |
assert conv.depthwise_conv.activate.__class__.__name__ == 'ReLU' | |
assert conv.pointwise_conv.activate.__class__.__name__ == 'ReLU' | |
x = torch.rand(1, 3, 256, 256) | |
output = conv(x) | |
assert output.shape == (1, 8, 255, 255) | |
# test dw_norm_cfg | |
conv = DepthwiseSeparableConvModule(3, 8, 2, dw_norm_cfg=dict(type='BN')) | |
assert conv.depthwise_conv.norm_name == 'bn' | |
assert not conv.pointwise_conv.with_norm | |
x = torch.rand(1, 3, 256, 256) | |
output = conv(x) | |
assert output.shape == (1, 8, 255, 255) | |
# test pw_norm_cfg | |
conv = DepthwiseSeparableConvModule(3, 8, 2, pw_norm_cfg=dict(type='BN')) | |
assert not conv.depthwise_conv.with_norm | |
assert conv.pointwise_conv.norm_name == 'bn' | |
x = torch.rand(1, 3, 256, 256) | |
output = conv(x) | |
assert output.shape == (1, 8, 255, 255) | |
# test norm_cfg | |
conv = DepthwiseSeparableConvModule(3, 8, 2, norm_cfg=dict(type='BN')) | |
assert conv.depthwise_conv.norm_name == 'bn' | |
assert conv.pointwise_conv.norm_name == 'bn' | |
x = torch.rand(1, 3, 256, 256) | |
output = conv(x) | |
assert output.shape == (1, 8, 255, 255) | |
# add test for ['norm', 'conv', 'act'] | |
conv = DepthwiseSeparableConvModule(3, 8, 2, order=('norm', 'conv', 'act')) | |
x = torch.rand(1, 3, 256, 256) | |
output = conv(x) | |
assert output.shape == (1, 8, 255, 255) | |
conv = DepthwiseSeparableConvModule( | |
3, 8, 3, padding=1, with_spectral_norm=True) | |
assert hasattr(conv.depthwise_conv.conv, 'weight_orig') | |
assert hasattr(conv.pointwise_conv.conv, 'weight_orig') | |
output = conv(x) | |
assert output.shape == (1, 8, 256, 256) | |
conv = DepthwiseSeparableConvModule( | |
3, 8, 3, padding=1, padding_mode='reflect') | |
assert isinstance(conv.depthwise_conv.padding_layer, nn.ReflectionPad2d) | |
output = conv(x) | |
assert output.shape == (1, 8, 256, 256) | |
# test dw_act_cfg | |
conv = DepthwiseSeparableConvModule( | |
3, 8, 3, padding=1, dw_act_cfg=dict(type='LeakyReLU')) | |
assert conv.depthwise_conv.activate.__class__.__name__ == 'LeakyReLU' | |
assert conv.pointwise_conv.activate.__class__.__name__ == 'ReLU' | |
output = conv(x) | |
assert output.shape == (1, 8, 256, 256) | |
# test pw_act_cfg | |
conv = DepthwiseSeparableConvModule( | |
3, 8, 3, padding=1, pw_act_cfg=dict(type='LeakyReLU')) | |
assert conv.depthwise_conv.activate.__class__.__name__ == 'ReLU' | |
assert conv.pointwise_conv.activate.__class__.__name__ == 'LeakyReLU' | |
output = conv(x) | |
assert output.shape == (1, 8, 256, 256) | |
# test act_cfg | |
conv = DepthwiseSeparableConvModule( | |
3, 8, 3, padding=1, act_cfg=dict(type='LeakyReLU')) | |
assert conv.depthwise_conv.activate.__class__.__name__ == 'LeakyReLU' | |
assert conv.pointwise_conv.activate.__class__.__name__ == 'LeakyReLU' | |
output = conv(x) | |
assert output.shape == (1, 8, 256, 256) | |