Spaces:
Sleeping
Sleeping
# Copyright (c) OpenMMLab. All rights reserved. | |
from unittest.mock import patch | |
import pytest | |
import torch | |
import torch.nn as nn | |
from mmcv.cnn.bricks import (Conv2d, Conv3d, ConvTranspose2d, ConvTranspose3d, | |
Linear, MaxPool2d, MaxPool3d) | |
if torch.__version__ != 'parrots': | |
torch_version = '1.1' | |
else: | |
torch_version = 'parrots' | |
def test_conv2d(in_w, in_h, in_channel, out_channel, kernel_size, stride, | |
padding, dilation): | |
""" | |
CommandLine: | |
xdoctest -m tests/test_wrappers.py test_conv2d | |
""" | |
# train mode | |
# wrapper op with 0-dim input | |
x_empty = torch.randn(0, in_channel, in_h, in_w) | |
torch.manual_seed(0) | |
wrapper = Conv2d( | |
in_channel, | |
out_channel, | |
kernel_size, | |
stride=stride, | |
padding=padding, | |
dilation=dilation) | |
wrapper_out = wrapper(x_empty) | |
# torch op with 3-dim input as shape reference | |
x_normal = torch.randn(3, in_channel, in_h, in_w).requires_grad_(True) | |
torch.manual_seed(0) | |
ref = nn.Conv2d( | |
in_channel, | |
out_channel, | |
kernel_size, | |
stride=stride, | |
padding=padding, | |
dilation=dilation) | |
ref_out = ref(x_normal) | |
assert wrapper_out.shape[0] == 0 | |
assert wrapper_out.shape[1:] == ref_out.shape[1:] | |
wrapper_out.sum().backward() | |
assert wrapper.weight.grad is not None | |
assert wrapper.weight.grad.shape == wrapper.weight.shape | |
assert torch.equal(wrapper(x_normal), ref_out) | |
# eval mode | |
x_empty = torch.randn(0, in_channel, in_h, in_w) | |
wrapper = Conv2d( | |
in_channel, | |
out_channel, | |
kernel_size, | |
stride=stride, | |
padding=padding, | |
dilation=dilation) | |
wrapper.eval() | |
wrapper(x_empty) | |
def test_conv3d(in_w, in_h, in_t, in_channel, out_channel, kernel_size, stride, | |
padding, dilation): | |
""" | |
CommandLine: | |
xdoctest -m tests/test_wrappers.py test_conv3d | |
""" | |
# train mode | |
# wrapper op with 0-dim input | |
x_empty = torch.randn(0, in_channel, in_t, in_h, in_w) | |
torch.manual_seed(0) | |
wrapper = Conv3d( | |
in_channel, | |
out_channel, | |
kernel_size, | |
stride=stride, | |
padding=padding, | |
dilation=dilation) | |
wrapper_out = wrapper(x_empty) | |
# torch op with 3-dim input as shape reference | |
x_normal = torch.randn(3, in_channel, in_t, in_h, | |
in_w).requires_grad_(True) | |
torch.manual_seed(0) | |
ref = nn.Conv3d( | |
in_channel, | |
out_channel, | |
kernel_size, | |
stride=stride, | |
padding=padding, | |
dilation=dilation) | |
ref_out = ref(x_normal) | |
assert wrapper_out.shape[0] == 0 | |
assert wrapper_out.shape[1:] == ref_out.shape[1:] | |
wrapper_out.sum().backward() | |
assert wrapper.weight.grad is not None | |
assert wrapper.weight.grad.shape == wrapper.weight.shape | |
assert torch.equal(wrapper(x_normal), ref_out) | |
# eval mode | |
x_empty = torch.randn(0, in_channel, in_t, in_h, in_w) | |
wrapper = Conv3d( | |
in_channel, | |
out_channel, | |
kernel_size, | |
stride=stride, | |
padding=padding, | |
dilation=dilation) | |
wrapper.eval() | |
wrapper(x_empty) | |
def test_conv_transposed_2d(in_w, in_h, in_channel, out_channel, kernel_size, | |
stride, padding, dilation): | |
# wrapper op with 0-dim input | |
x_empty = torch.randn(0, in_channel, in_h, in_w, requires_grad=True) | |
# out padding must be smaller than either stride or dilation | |
op = min(stride, dilation) - 1 | |
if torch.__version__ == 'parrots': | |
op = 0 | |
torch.manual_seed(0) | |
wrapper = ConvTranspose2d( | |
in_channel, | |
out_channel, | |
kernel_size, | |
stride=stride, | |
padding=padding, | |
dilation=dilation, | |
output_padding=op) | |
wrapper_out = wrapper(x_empty) | |
# torch op with 3-dim input as shape reference | |
x_normal = torch.randn(3, in_channel, in_h, in_w) | |
torch.manual_seed(0) | |
ref = nn.ConvTranspose2d( | |
in_channel, | |
out_channel, | |
kernel_size, | |
stride=stride, | |
padding=padding, | |
dilation=dilation, | |
output_padding=op) | |
ref_out = ref(x_normal) | |
assert wrapper_out.shape[0] == 0 | |
assert wrapper_out.shape[1:] == ref_out.shape[1:] | |
wrapper_out.sum().backward() | |
assert wrapper.weight.grad is not None | |
assert wrapper.weight.grad.shape == wrapper.weight.shape | |
assert torch.equal(wrapper(x_normal), ref_out) | |
# eval mode | |
x_empty = torch.randn(0, in_channel, in_h, in_w) | |
wrapper = ConvTranspose2d( | |
in_channel, | |
out_channel, | |
kernel_size, | |
stride=stride, | |
padding=padding, | |
dilation=dilation, | |
output_padding=op) | |
wrapper.eval() | |
wrapper(x_empty) | |
def test_conv_transposed_3d(in_w, in_h, in_t, in_channel, out_channel, | |
kernel_size, stride, padding, dilation): | |
# wrapper op with 0-dim input | |
x_empty = torch.randn(0, in_channel, in_t, in_h, in_w, requires_grad=True) | |
# out padding must be smaller than either stride or dilation | |
op = min(stride, dilation) - 1 | |
torch.manual_seed(0) | |
wrapper = ConvTranspose3d( | |
in_channel, | |
out_channel, | |
kernel_size, | |
stride=stride, | |
padding=padding, | |
dilation=dilation, | |
output_padding=op) | |
wrapper_out = wrapper(x_empty) | |
# torch op with 3-dim input as shape reference | |
x_normal = torch.randn(3, in_channel, in_t, in_h, in_w) | |
torch.manual_seed(0) | |
ref = nn.ConvTranspose3d( | |
in_channel, | |
out_channel, | |
kernel_size, | |
stride=stride, | |
padding=padding, | |
dilation=dilation, | |
output_padding=op) | |
ref_out = ref(x_normal) | |
assert wrapper_out.shape[0] == 0 | |
assert wrapper_out.shape[1:] == ref_out.shape[1:] | |
wrapper_out.sum().backward() | |
assert wrapper.weight.grad is not None | |
assert wrapper.weight.grad.shape == wrapper.weight.shape | |
assert torch.equal(wrapper(x_normal), ref_out) | |
# eval mode | |
x_empty = torch.randn(0, in_channel, in_t, in_h, in_w) | |
wrapper = ConvTranspose3d( | |
in_channel, | |
out_channel, | |
kernel_size, | |
stride=stride, | |
padding=padding, | |
dilation=dilation, | |
output_padding=op) | |
wrapper.eval() | |
wrapper(x_empty) | |
def test_max_pool_2d(in_w, in_h, in_channel, out_channel, kernel_size, stride, | |
padding, dilation): | |
# wrapper op with 0-dim input | |
x_empty = torch.randn(0, in_channel, in_h, in_w, requires_grad=True) | |
wrapper = MaxPool2d( | |
kernel_size, stride=stride, padding=padding, dilation=dilation) | |
wrapper_out = wrapper(x_empty) | |
# torch op with 3-dim input as shape reference | |
x_normal = torch.randn(3, in_channel, in_h, in_w) | |
ref = nn.MaxPool2d( | |
kernel_size, stride=stride, padding=padding, dilation=dilation) | |
ref_out = ref(x_normal) | |
assert wrapper_out.shape[0] == 0 | |
assert wrapper_out.shape[1:] == ref_out.shape[1:] | |
assert torch.equal(wrapper(x_normal), ref_out) | |
def test_max_pool_3d(in_w, in_h, in_t, in_channel, out_channel, kernel_size, | |
stride, padding, dilation): | |
# wrapper op with 0-dim input | |
x_empty = torch.randn(0, in_channel, in_t, in_h, in_w, requires_grad=True) | |
wrapper = MaxPool3d( | |
kernel_size, stride=stride, padding=padding, dilation=dilation) | |
if torch.__version__ == 'parrots': | |
x_empty = x_empty.cuda() | |
wrapper_out = wrapper(x_empty) | |
# torch op with 3-dim input as shape reference | |
x_normal = torch.randn(3, in_channel, in_t, in_h, in_w) | |
ref = nn.MaxPool3d( | |
kernel_size, stride=stride, padding=padding, dilation=dilation) | |
if torch.__version__ == 'parrots': | |
x_normal = x_normal.cuda() | |
ref_out = ref(x_normal) | |
assert wrapper_out.shape[0] == 0 | |
assert wrapper_out.shape[1:] == ref_out.shape[1:] | |
assert torch.equal(wrapper(x_normal), ref_out) | |
def test_linear(in_w, in_h, in_feature, out_feature): | |
# wrapper op with 0-dim input | |
x_empty = torch.randn(0, in_feature, requires_grad=True) | |
torch.manual_seed(0) | |
wrapper = Linear(in_feature, out_feature) | |
wrapper_out = wrapper(x_empty) | |
# torch op with 3-dim input as shape reference | |
x_normal = torch.randn(3, in_feature) | |
torch.manual_seed(0) | |
ref = nn.Linear(in_feature, out_feature) | |
ref_out = ref(x_normal) | |
assert wrapper_out.shape[0] == 0 | |
assert wrapper_out.shape[1:] == ref_out.shape[1:] | |
wrapper_out.sum().backward() | |
assert wrapper.weight.grad is not None | |
assert wrapper.weight.grad.shape == wrapper.weight.shape | |
assert torch.equal(wrapper(x_normal), ref_out) | |
# eval mode | |
x_empty = torch.randn(0, in_feature) | |
wrapper = Linear(in_feature, out_feature) | |
wrapper.eval() | |
wrapper(x_empty) | |
def test_nn_op_forward_called(): | |
for m in ['Conv2d', 'ConvTranspose2d', 'MaxPool2d']: | |
with patch(f'torch.nn.{m}.forward') as nn_module_forward: | |
# randn input | |
x_empty = torch.randn(0, 3, 10, 10) | |
wrapper = eval(m)(3, 2, 1) | |
wrapper(x_empty) | |
nn_module_forward.assert_called_with(x_empty) | |
# non-randn input | |
x_normal = torch.randn(1, 3, 10, 10) | |
wrapper = eval(m)(3, 2, 1) | |
wrapper(x_normal) | |
nn_module_forward.assert_called_with(x_normal) | |
for m in ['Conv3d', 'ConvTranspose3d', 'MaxPool3d']: | |
with patch(f'torch.nn.{m}.forward') as nn_module_forward: | |
# randn input | |
x_empty = torch.randn(0, 3, 10, 10, 10) | |
wrapper = eval(m)(3, 2, 1) | |
wrapper(x_empty) | |
nn_module_forward.assert_called_with(x_empty) | |
# non-randn input | |
x_normal = torch.randn(1, 3, 10, 10, 10) | |
wrapper = eval(m)(3, 2, 1) | |
wrapper(x_normal) | |
nn_module_forward.assert_called_with(x_normal) | |
with patch('torch.nn.Linear.forward') as nn_module_forward: | |
# randn input | |
x_empty = torch.randn(0, 3) | |
wrapper = Linear(3, 3) | |
wrapper(x_empty) | |
nn_module_forward.assert_called_with(x_empty) | |
# non-randn input | |
x_normal = torch.randn(1, 3) | |
wrapper = Linear(3, 3) | |
wrapper(x_normal) | |
nn_module_forward.assert_called_with(x_normal) | |