Spaces:
Running
on
L40S
Running
on
L40S
File size: 3,472 Bytes
d7e58f0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
import torch.nn as nn
from mmcv.cnn.bricks import DepthwiseSeparableConvModule
def test_depthwise_separable_conv():
with pytest.raises(AssertionError):
# conv_cfg must be a dict or None
DepthwiseSeparableConvModule(4, 8, 2, groups=2)
# test default config
conv = DepthwiseSeparableConvModule(3, 8, 2)
assert conv.depthwise_conv.conv.groups == 3
assert conv.pointwise_conv.conv.kernel_size == (1, 1)
assert not conv.depthwise_conv.with_norm
assert not conv.pointwise_conv.with_norm
assert conv.depthwise_conv.activate.__class__.__name__ == 'ReLU'
assert conv.pointwise_conv.activate.__class__.__name__ == 'ReLU'
x = torch.rand(1, 3, 256, 256)
output = conv(x)
assert output.shape == (1, 8, 255, 255)
# test dw_norm_cfg
conv = DepthwiseSeparableConvModule(3, 8, 2, dw_norm_cfg=dict(type='BN'))
assert conv.depthwise_conv.norm_name == 'bn'
assert not conv.pointwise_conv.with_norm
x = torch.rand(1, 3, 256, 256)
output = conv(x)
assert output.shape == (1, 8, 255, 255)
# test pw_norm_cfg
conv = DepthwiseSeparableConvModule(3, 8, 2, pw_norm_cfg=dict(type='BN'))
assert not conv.depthwise_conv.with_norm
assert conv.pointwise_conv.norm_name == 'bn'
x = torch.rand(1, 3, 256, 256)
output = conv(x)
assert output.shape == (1, 8, 255, 255)
# test norm_cfg
conv = DepthwiseSeparableConvModule(3, 8, 2, norm_cfg=dict(type='BN'))
assert conv.depthwise_conv.norm_name == 'bn'
assert conv.pointwise_conv.norm_name == 'bn'
x = torch.rand(1, 3, 256, 256)
output = conv(x)
assert output.shape == (1, 8, 255, 255)
# add test for ['norm', 'conv', 'act']
conv = DepthwiseSeparableConvModule(3, 8, 2, order=('norm', 'conv', 'act'))
x = torch.rand(1, 3, 256, 256)
output = conv(x)
assert output.shape == (1, 8, 255, 255)
conv = DepthwiseSeparableConvModule(
3, 8, 3, padding=1, with_spectral_norm=True)
assert hasattr(conv.depthwise_conv.conv, 'weight_orig')
assert hasattr(conv.pointwise_conv.conv, 'weight_orig')
output = conv(x)
assert output.shape == (1, 8, 256, 256)
conv = DepthwiseSeparableConvModule(
3, 8, 3, padding=1, padding_mode='reflect')
assert isinstance(conv.depthwise_conv.padding_layer, nn.ReflectionPad2d)
output = conv(x)
assert output.shape == (1, 8, 256, 256)
# test dw_act_cfg
conv = DepthwiseSeparableConvModule(
3, 8, 3, padding=1, dw_act_cfg=dict(type='LeakyReLU'))
assert conv.depthwise_conv.activate.__class__.__name__ == 'LeakyReLU'
assert conv.pointwise_conv.activate.__class__.__name__ == 'ReLU'
output = conv(x)
assert output.shape == (1, 8, 256, 256)
# test pw_act_cfg
conv = DepthwiseSeparableConvModule(
3, 8, 3, padding=1, pw_act_cfg=dict(type='LeakyReLU'))
assert conv.depthwise_conv.activate.__class__.__name__ == 'ReLU'
assert conv.pointwise_conv.activate.__class__.__name__ == 'LeakyReLU'
output = conv(x)
assert output.shape == (1, 8, 256, 256)
# test act_cfg
conv = DepthwiseSeparableConvModule(
3, 8, 3, padding=1, act_cfg=dict(type='LeakyReLU'))
assert conv.depthwise_conv.activate.__class__.__name__ == 'LeakyReLU'
assert conv.pointwise_conv.activate.__class__.__name__ == 'LeakyReLU'
output = conv(x)
assert output.shape == (1, 8, 256, 256)
|