AiOS / mmcv /tests /test_cnn /test_wrappers.py
ttxskk
update
d7e58f0
raw
history blame
12.1 kB
# Copyright (c) OpenMMLab. All rights reserved.
from unittest.mock import patch
import pytest
import torch
import torch.nn as nn
from mmcv.cnn.bricks import (Conv2d, Conv3d, ConvTranspose2d, ConvTranspose3d,
Linear, MaxPool2d, MaxPool3d)
if torch.__version__ != 'parrots':
torch_version = '1.1'
else:
torch_version = 'parrots'
@patch('torch.__version__', torch_version)
@pytest.mark.parametrize(
'in_w,in_h,in_channel,out_channel,kernel_size,stride,padding,dilation',
[(10, 10, 1, 1, 3, 1, 0, 1), (20, 20, 3, 3, 5, 2, 1, 2)])
def test_conv2d(in_w, in_h, in_channel, out_channel, kernel_size, stride,
padding, dilation):
"""
CommandLine:
xdoctest -m tests/test_wrappers.py test_conv2d
"""
# train mode
# wrapper op with 0-dim input
x_empty = torch.randn(0, in_channel, in_h, in_w)
torch.manual_seed(0)
wrapper = Conv2d(
in_channel,
out_channel,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation)
wrapper_out = wrapper(x_empty)
# torch op with 3-dim input as shape reference
x_normal = torch.randn(3, in_channel, in_h, in_w).requires_grad_(True)
torch.manual_seed(0)
ref = nn.Conv2d(
in_channel,
out_channel,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation)
ref_out = ref(x_normal)
assert wrapper_out.shape[0] == 0
assert wrapper_out.shape[1:] == ref_out.shape[1:]
wrapper_out.sum().backward()
assert wrapper.weight.grad is not None
assert wrapper.weight.grad.shape == wrapper.weight.shape
assert torch.equal(wrapper(x_normal), ref_out)
# eval mode
x_empty = torch.randn(0, in_channel, in_h, in_w)
wrapper = Conv2d(
in_channel,
out_channel,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation)
wrapper.eval()
wrapper(x_empty)
@patch('torch.__version__', torch_version)
@pytest.mark.parametrize(
'in_w,in_h,in_t,in_channel,out_channel,kernel_size,stride,padding,dilation', # noqa: E501
[(10, 10, 10, 1, 1, 3, 1, 0, 1), (20, 20, 20, 3, 3, 5, 2, 1, 2)])
def test_conv3d(in_w, in_h, in_t, in_channel, out_channel, kernel_size, stride,
padding, dilation):
"""
CommandLine:
xdoctest -m tests/test_wrappers.py test_conv3d
"""
# train mode
# wrapper op with 0-dim input
x_empty = torch.randn(0, in_channel, in_t, in_h, in_w)
torch.manual_seed(0)
wrapper = Conv3d(
in_channel,
out_channel,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation)
wrapper_out = wrapper(x_empty)
# torch op with 3-dim input as shape reference
x_normal = torch.randn(3, in_channel, in_t, in_h,
in_w).requires_grad_(True)
torch.manual_seed(0)
ref = nn.Conv3d(
in_channel,
out_channel,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation)
ref_out = ref(x_normal)
assert wrapper_out.shape[0] == 0
assert wrapper_out.shape[1:] == ref_out.shape[1:]
wrapper_out.sum().backward()
assert wrapper.weight.grad is not None
assert wrapper.weight.grad.shape == wrapper.weight.shape
assert torch.equal(wrapper(x_normal), ref_out)
# eval mode
x_empty = torch.randn(0, in_channel, in_t, in_h, in_w)
wrapper = Conv3d(
in_channel,
out_channel,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation)
wrapper.eval()
wrapper(x_empty)
@patch('torch.__version__', torch_version)
@pytest.mark.parametrize(
'in_w,in_h,in_channel,out_channel,kernel_size,stride,padding,dilation',
[(10, 10, 1, 1, 3, 1, 0, 1), (20, 20, 3, 3, 5, 2, 1, 2)])
def test_conv_transposed_2d(in_w, in_h, in_channel, out_channel, kernel_size,
stride, padding, dilation):
# wrapper op with 0-dim input
x_empty = torch.randn(0, in_channel, in_h, in_w, requires_grad=True)
# out padding must be smaller than either stride or dilation
op = min(stride, dilation) - 1
if torch.__version__ == 'parrots':
op = 0
torch.manual_seed(0)
wrapper = ConvTranspose2d(
in_channel,
out_channel,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
output_padding=op)
wrapper_out = wrapper(x_empty)
# torch op with 3-dim input as shape reference
x_normal = torch.randn(3, in_channel, in_h, in_w)
torch.manual_seed(0)
ref = nn.ConvTranspose2d(
in_channel,
out_channel,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
output_padding=op)
ref_out = ref(x_normal)
assert wrapper_out.shape[0] == 0
assert wrapper_out.shape[1:] == ref_out.shape[1:]
wrapper_out.sum().backward()
assert wrapper.weight.grad is not None
assert wrapper.weight.grad.shape == wrapper.weight.shape
assert torch.equal(wrapper(x_normal), ref_out)
# eval mode
x_empty = torch.randn(0, in_channel, in_h, in_w)
wrapper = ConvTranspose2d(
in_channel,
out_channel,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
output_padding=op)
wrapper.eval()
wrapper(x_empty)
@patch('torch.__version__', torch_version)
@pytest.mark.parametrize(
'in_w,in_h,in_t,in_channel,out_channel,kernel_size,stride,padding,dilation', # noqa: E501
[(10, 10, 10, 1, 1, 3, 1, 0, 1), (20, 20, 20, 3, 3, 5, 2, 1, 2)])
def test_conv_transposed_3d(in_w, in_h, in_t, in_channel, out_channel,
kernel_size, stride, padding, dilation):
# wrapper op with 0-dim input
x_empty = torch.randn(0, in_channel, in_t, in_h, in_w, requires_grad=True)
# out padding must be smaller than either stride or dilation
op = min(stride, dilation) - 1
torch.manual_seed(0)
wrapper = ConvTranspose3d(
in_channel,
out_channel,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
output_padding=op)
wrapper_out = wrapper(x_empty)
# torch op with 3-dim input as shape reference
x_normal = torch.randn(3, in_channel, in_t, in_h, in_w)
torch.manual_seed(0)
ref = nn.ConvTranspose3d(
in_channel,
out_channel,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
output_padding=op)
ref_out = ref(x_normal)
assert wrapper_out.shape[0] == 0
assert wrapper_out.shape[1:] == ref_out.shape[1:]
wrapper_out.sum().backward()
assert wrapper.weight.grad is not None
assert wrapper.weight.grad.shape == wrapper.weight.shape
assert torch.equal(wrapper(x_normal), ref_out)
# eval mode
x_empty = torch.randn(0, in_channel, in_t, in_h, in_w)
wrapper = ConvTranspose3d(
in_channel,
out_channel,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
output_padding=op)
wrapper.eval()
wrapper(x_empty)
@patch('torch.__version__', torch_version)
@pytest.mark.parametrize(
'in_w,in_h,in_channel,out_channel,kernel_size,stride,padding,dilation',
[(10, 10, 1, 1, 3, 1, 0, 1), (20, 20, 3, 3, 5, 2, 1, 2)])
def test_max_pool_2d(in_w, in_h, in_channel, out_channel, kernel_size, stride,
padding, dilation):
# wrapper op with 0-dim input
x_empty = torch.randn(0, in_channel, in_h, in_w, requires_grad=True)
wrapper = MaxPool2d(
kernel_size, stride=stride, padding=padding, dilation=dilation)
wrapper_out = wrapper(x_empty)
# torch op with 3-dim input as shape reference
x_normal = torch.randn(3, in_channel, in_h, in_w)
ref = nn.MaxPool2d(
kernel_size, stride=stride, padding=padding, dilation=dilation)
ref_out = ref(x_normal)
assert wrapper_out.shape[0] == 0
assert wrapper_out.shape[1:] == ref_out.shape[1:]
assert torch.equal(wrapper(x_normal), ref_out)
@patch('torch.__version__', torch_version)
@pytest.mark.parametrize(
'in_w,in_h,in_t,in_channel,out_channel,kernel_size,stride,padding,dilation', # noqa: E501
[(10, 10, 10, 1, 1, 3, 1, 0, 1), (20, 20, 20, 3, 3, 5, 2, 1, 2)])
@pytest.mark.skipif(
torch.__version__ == 'parrots' and not torch.cuda.is_available(),
reason='parrots requires CUDA support')
def test_max_pool_3d(in_w, in_h, in_t, in_channel, out_channel, kernel_size,
stride, padding, dilation):
# wrapper op with 0-dim input
x_empty = torch.randn(0, in_channel, in_t, in_h, in_w, requires_grad=True)
wrapper = MaxPool3d(
kernel_size, stride=stride, padding=padding, dilation=dilation)
if torch.__version__ == 'parrots':
x_empty = x_empty.cuda()
wrapper_out = wrapper(x_empty)
# torch op with 3-dim input as shape reference
x_normal = torch.randn(3, in_channel, in_t, in_h, in_w)
ref = nn.MaxPool3d(
kernel_size, stride=stride, padding=padding, dilation=dilation)
if torch.__version__ == 'parrots':
x_normal = x_normal.cuda()
ref_out = ref(x_normal)
assert wrapper_out.shape[0] == 0
assert wrapper_out.shape[1:] == ref_out.shape[1:]
assert torch.equal(wrapper(x_normal), ref_out)
@patch('torch.__version__', torch_version)
@pytest.mark.parametrize('in_w,in_h,in_feature,out_feature', [(10, 10, 1, 1),
(20, 20, 3, 3)])
def test_linear(in_w, in_h, in_feature, out_feature):
# wrapper op with 0-dim input
x_empty = torch.randn(0, in_feature, requires_grad=True)
torch.manual_seed(0)
wrapper = Linear(in_feature, out_feature)
wrapper_out = wrapper(x_empty)
# torch op with 3-dim input as shape reference
x_normal = torch.randn(3, in_feature)
torch.manual_seed(0)
ref = nn.Linear(in_feature, out_feature)
ref_out = ref(x_normal)
assert wrapper_out.shape[0] == 0
assert wrapper_out.shape[1:] == ref_out.shape[1:]
wrapper_out.sum().backward()
assert wrapper.weight.grad is not None
assert wrapper.weight.grad.shape == wrapper.weight.shape
assert torch.equal(wrapper(x_normal), ref_out)
# eval mode
x_empty = torch.randn(0, in_feature)
wrapper = Linear(in_feature, out_feature)
wrapper.eval()
wrapper(x_empty)
@patch('mmcv.cnn.bricks.wrappers.TORCH_VERSION', (1, 10))
def test_nn_op_forward_called():
for m in ['Conv2d', 'ConvTranspose2d', 'MaxPool2d']:
with patch(f'torch.nn.{m}.forward') as nn_module_forward:
# randn input
x_empty = torch.randn(0, 3, 10, 10)
wrapper = eval(m)(3, 2, 1)
wrapper(x_empty)
nn_module_forward.assert_called_with(x_empty)
# non-randn input
x_normal = torch.randn(1, 3, 10, 10)
wrapper = eval(m)(3, 2, 1)
wrapper(x_normal)
nn_module_forward.assert_called_with(x_normal)
for m in ['Conv3d', 'ConvTranspose3d', 'MaxPool3d']:
with patch(f'torch.nn.{m}.forward') as nn_module_forward:
# randn input
x_empty = torch.randn(0, 3, 10, 10, 10)
wrapper = eval(m)(3, 2, 1)
wrapper(x_empty)
nn_module_forward.assert_called_with(x_empty)
# non-randn input
x_normal = torch.randn(1, 3, 10, 10, 10)
wrapper = eval(m)(3, 2, 1)
wrapper(x_normal)
nn_module_forward.assert_called_with(x_normal)
with patch('torch.nn.Linear.forward') as nn_module_forward:
# randn input
x_empty = torch.randn(0, 3)
wrapper = Linear(3, 3)
wrapper(x_empty)
nn_module_forward.assert_called_with(x_empty)
# non-randn input
x_normal = torch.randn(1, 3)
wrapper = Linear(3, 3)
wrapper(x_normal)
nn_module_forward.assert_called_with(x_normal)