AiOS / mmcv /tests /test_cnn /test_conv_module.py
ttxskk
update
d7e58f0
raw
history blame
7.69 kB
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from unittest.mock import patch
import pytest
import torch
import torch.nn as nn
from mmcv.cnn.bricks import CONV_LAYERS, ConvModule, HSigmoid, HSwish
from mmcv.utils import TORCH_VERSION, digit_version
@CONV_LAYERS.register_module()
class ExampleConv(nn.Module):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
norm_cfg=None):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.dilation = dilation
self.groups = groups
self.bias = bias
self.norm_cfg = norm_cfg
self.output_padding = (0, 0, 0)
self.transposed = False
self.conv0 = nn.Conv2d(in_channels, out_channels, kernel_size)
self.init_weights()
def forward(self, x):
x = self.conv0(x)
return x
def init_weights(self):
nn.init.constant_(self.conv0.weight, 0)
def test_conv_module():
with pytest.raises(AssertionError):
# conv_cfg must be a dict or None
conv_cfg = 'conv'
ConvModule(3, 8, 2, conv_cfg=conv_cfg)
with pytest.raises(AssertionError):
# norm_cfg must be a dict or None
norm_cfg = 'norm'
ConvModule(3, 8, 2, norm_cfg=norm_cfg)
with pytest.raises(KeyError):
# softmax is not supported
act_cfg = dict(type='softmax')
ConvModule(3, 8, 2, act_cfg=act_cfg)
# conv + norm + act
conv = ConvModule(3, 8, 2, norm_cfg=dict(type='BN'))
assert conv.with_activation
assert hasattr(conv, 'activate')
assert conv.with_norm
assert hasattr(conv, 'norm')
x = torch.rand(1, 3, 256, 256)
output = conv(x)
assert output.shape == (1, 8, 255, 255)
# conv + act
conv = ConvModule(3, 8, 2)
assert conv.with_activation
assert hasattr(conv, 'activate')
assert not conv.with_norm
assert conv.norm is None
x = torch.rand(1, 3, 256, 256)
output = conv(x)
assert output.shape == (1, 8, 255, 255)
# conv
conv = ConvModule(3, 8, 2, act_cfg=None)
assert not conv.with_norm
assert conv.norm is None
assert not conv.with_activation
assert not hasattr(conv, 'activate')
x = torch.rand(1, 3, 256, 256)
output = conv(x)
assert output.shape == (1, 8, 255, 255)
# conv with its own `init_weights` method
conv_module = ConvModule(
3, 8, 2, conv_cfg=dict(type='ExampleConv'), act_cfg=None)
assert torch.equal(conv_module.conv.conv0.weight, torch.zeros(8, 3, 2, 2))
# with_spectral_norm=True
conv = ConvModule(3, 8, 3, padding=1, with_spectral_norm=True)
assert hasattr(conv.conv, 'weight_orig')
output = conv(x)
assert output.shape == (1, 8, 256, 256)
# padding_mode='reflect'
conv = ConvModule(3, 8, 3, padding=1, padding_mode='reflect')
assert isinstance(conv.padding_layer, nn.ReflectionPad2d)
output = conv(x)
assert output.shape == (1, 8, 256, 256)
# non-existing padding mode
with pytest.raises(KeyError):
conv = ConvModule(3, 8, 3, padding=1, padding_mode='non_exists')
# leaky relu
conv = ConvModule(3, 8, 3, padding=1, act_cfg=dict(type='LeakyReLU'))
assert isinstance(conv.activate, nn.LeakyReLU)
output = conv(x)
assert output.shape == (1, 8, 256, 256)
# tanh
conv = ConvModule(3, 8, 3, padding=1, act_cfg=dict(type='Tanh'))
assert isinstance(conv.activate, nn.Tanh)
output = conv(x)
assert output.shape == (1, 8, 256, 256)
# Sigmoid
conv = ConvModule(3, 8, 3, padding=1, act_cfg=dict(type='Sigmoid'))
assert isinstance(conv.activate, nn.Sigmoid)
output = conv(x)
assert output.shape == (1, 8, 256, 256)
# PReLU
conv = ConvModule(3, 8, 3, padding=1, act_cfg=dict(type='PReLU'))
assert isinstance(conv.activate, nn.PReLU)
output = conv(x)
assert output.shape == (1, 8, 256, 256)
# HSwish
conv = ConvModule(3, 8, 3, padding=1, act_cfg=dict(type='HSwish'))
if (TORCH_VERSION == 'parrots'
or digit_version(TORCH_VERSION) < digit_version('1.7')):
assert isinstance(conv.activate, HSwish)
else:
assert isinstance(conv.activate, nn.Hardswish)
output = conv(x)
assert output.shape == (1, 8, 256, 256)
# HSigmoid
conv = ConvModule(3, 8, 3, padding=1, act_cfg=dict(type='HSigmoid'))
assert isinstance(conv.activate, HSigmoid)
output = conv(x)
assert output.shape == (1, 8, 256, 256)
def test_bias():
# bias: auto, without norm
conv = ConvModule(3, 8, 2)
assert conv.conv.bias is not None
# bias: auto, with norm
conv = ConvModule(3, 8, 2, norm_cfg=dict(type='BN'))
assert conv.conv.bias is None
# bias: False, without norm
conv = ConvModule(3, 8, 2, bias=False)
assert conv.conv.bias is None
# bias: True, with batch norm
with pytest.warns(UserWarning) as record:
ConvModule(3, 8, 2, bias=True, norm_cfg=dict(type='BN'))
assert len(record) == 1
assert record[0].message.args[
0] == 'Unnecessary conv bias before batch/instance norm'
# bias: True, with instance norm
with pytest.warns(UserWarning) as record:
ConvModule(3, 8, 2, bias=True, norm_cfg=dict(type='IN'))
assert len(record) == 1
assert record[0].message.args[
0] == 'Unnecessary conv bias before batch/instance norm'
# bias: True, with other norm
with pytest.warns(UserWarning) as record:
norm_cfg = dict(type='GN', num_groups=1)
ConvModule(3, 8, 2, bias=True, norm_cfg=norm_cfg)
warnings.warn('No warnings')
assert len(record) == 1
assert record[0].message.args[0] == 'No warnings'
def conv_forward(self, x):
return x + '_conv'
def bn_forward(self, x):
return x + '_bn'
def relu_forward(self, x):
return x + '_relu'
@patch('torch.nn.ReLU.forward', relu_forward)
@patch('torch.nn.BatchNorm2d.forward', bn_forward)
@patch('torch.nn.Conv2d.forward', conv_forward)
def test_order():
with pytest.raises(AssertionError):
# order must be a tuple
order = ['conv', 'norm', 'act']
ConvModule(3, 8, 2, order=order)
with pytest.raises(AssertionError):
# length of order must be 3
order = ('conv', 'norm')
ConvModule(3, 8, 2, order=order)
with pytest.raises(AssertionError):
# order must be an order of 'conv', 'norm', 'act'
order = ('conv', 'norm', 'norm')
ConvModule(3, 8, 2, order=order)
with pytest.raises(AssertionError):
# order must be an order of 'conv', 'norm', 'act'
order = ('conv', 'norm', 'something')
ConvModule(3, 8, 2, order=order)
# ('conv', 'norm', 'act')
conv = ConvModule(3, 8, 2, norm_cfg=dict(type='BN'))
out = conv('input')
assert out == 'input_conv_bn_relu'
# ('norm', 'conv', 'act')
conv = ConvModule(
3, 8, 2, norm_cfg=dict(type='BN'), order=('norm', 'conv', 'act'))
out = conv('input')
assert out == 'input_bn_conv_relu'
# ('conv', 'norm', 'act'), activate=False
conv = ConvModule(3, 8, 2, norm_cfg=dict(type='BN'))
out = conv('input', activate=False)
assert out == 'input_conv_bn'
# ('conv', 'norm', 'act'), activate=False
conv = ConvModule(3, 8, 2, norm_cfg=dict(type='BN'))
out = conv('input', norm=False)
assert out == 'input_conv_relu'