Spaces:
Sleeping
Sleeping
# Copyright (c) OpenMMLab. All rights reserved. | |
import copy | |
import pytest | |
import torch | |
from mmcv.cnn.bricks.drop import DropPath | |
from mmcv.cnn.bricks.transformer import (FFN, AdaptivePadding, | |
BaseTransformerLayer, | |
MultiheadAttention, PatchEmbed, | |
PatchMerging, | |
TransformerLayerSequence) | |
from mmcv.runner import ModuleList | |
def test_adaptive_padding(): | |
for padding in ('same', 'corner'): | |
kernel_size = 16 | |
stride = 16 | |
dilation = 1 | |
input = torch.rand(1, 1, 15, 17) | |
adap_pad = AdaptivePadding( | |
kernel_size=kernel_size, | |
stride=stride, | |
dilation=dilation, | |
padding=padding) | |
out = adap_pad(input) | |
# padding to divisible by 16 | |
assert (out.shape[2], out.shape[3]) == (16, 32) | |
input = torch.rand(1, 1, 16, 17) | |
out = adap_pad(input) | |
# padding to divisible by 16 | |
assert (out.shape[2], out.shape[3]) == (16, 32) | |
kernel_size = (2, 2) | |
stride = (2, 2) | |
dilation = (1, 1) | |
adap_pad = AdaptivePadding( | |
kernel_size=kernel_size, | |
stride=stride, | |
dilation=dilation, | |
padding=padding) | |
input = torch.rand(1, 1, 11, 13) | |
out = adap_pad(input) | |
# padding to divisible by 2 | |
assert (out.shape[2], out.shape[3]) == (12, 14) | |
kernel_size = (2, 2) | |
stride = (10, 10) | |
dilation = (1, 1) | |
adap_pad = AdaptivePadding( | |
kernel_size=kernel_size, | |
stride=stride, | |
dilation=dilation, | |
padding=padding) | |
input = torch.rand(1, 1, 10, 13) | |
out = adap_pad(input) | |
# no padding | |
assert (out.shape[2], out.shape[3]) == (10, 13) | |
kernel_size = (11, 11) | |
adap_pad = AdaptivePadding( | |
kernel_size=kernel_size, | |
stride=stride, | |
dilation=dilation, | |
padding=padding) | |
input = torch.rand(1, 1, 11, 13) | |
out = adap_pad(input) | |
# all padding | |
assert (out.shape[2], out.shape[3]) == (21, 21) | |
# test padding as kernel is (7,9) | |
input = torch.rand(1, 1, 11, 13) | |
stride = (3, 4) | |
kernel_size = (4, 5) | |
dilation = (2, 2) | |
# actually (7, 9) | |
adap_pad = AdaptivePadding( | |
kernel_size=kernel_size, | |
stride=stride, | |
dilation=dilation, | |
padding=padding) | |
dilation_out = adap_pad(input) | |
assert (dilation_out.shape[2], dilation_out.shape[3]) == (16, 21) | |
kernel_size = (7, 9) | |
dilation = (1, 1) | |
adap_pad = AdaptivePadding( | |
kernel_size=kernel_size, | |
stride=stride, | |
dilation=dilation, | |
padding=padding) | |
kernel79_out = adap_pad(input) | |
assert (kernel79_out.shape[2], kernel79_out.shape[3]) == (16, 21) | |
assert kernel79_out.shape == dilation_out.shape | |
# assert only support "same" "corner" | |
with pytest.raises(AssertionError): | |
AdaptivePadding( | |
kernel_size=kernel_size, | |
stride=stride, | |
dilation=dilation, | |
padding=1) | |
def test_patch_embed(): | |
B = 2 | |
H = 3 | |
W = 4 | |
C = 3 | |
embed_dims = 10 | |
kernel_size = 3 | |
stride = 1 | |
dummy_input = torch.rand(B, C, H, W) | |
patch_merge_1 = PatchEmbed( | |
in_channels=C, | |
embed_dims=embed_dims, | |
kernel_size=kernel_size, | |
stride=stride, | |
padding=0, | |
dilation=1, | |
norm_cfg=None) | |
x1, shape = patch_merge_1(dummy_input) | |
# test out shape | |
assert x1.shape == (2, 2, 10) | |
# test outsize is correct | |
assert shape == (1, 2) | |
# test L = out_h * out_w | |
assert shape[0] * shape[1] == x1.shape[1] | |
B = 2 | |
H = 10 | |
W = 10 | |
C = 3 | |
embed_dims = 10 | |
kernel_size = 5 | |
stride = 2 | |
dummy_input = torch.rand(B, C, H, W) | |
# test dilation | |
patch_merge_2 = PatchEmbed( | |
in_channels=C, | |
embed_dims=embed_dims, | |
kernel_size=kernel_size, | |
stride=stride, | |
padding=0, | |
dilation=2, | |
norm_cfg=None, | |
) | |
x2, shape = patch_merge_2(dummy_input) | |
# test out shape | |
assert x2.shape == (2, 1, 10) | |
# test outsize is correct | |
assert shape == (1, 1) | |
# test L = out_h * out_w | |
assert shape[0] * shape[1] == x2.shape[1] | |
stride = 2 | |
input_size = (10, 10) | |
dummy_input = torch.rand(B, C, H, W) | |
# test stride and norm | |
patch_merge_3 = PatchEmbed( | |
in_channels=C, | |
embed_dims=embed_dims, | |
kernel_size=kernel_size, | |
stride=stride, | |
padding=0, | |
dilation=2, | |
norm_cfg=dict(type='LN'), | |
input_size=input_size) | |
x3, shape = patch_merge_3(dummy_input) | |
# test out shape | |
assert x3.shape == (2, 1, 10) | |
# test outsize is correct | |
assert shape == (1, 1) | |
# test L = out_h * out_w | |
assert shape[0] * shape[1] == x3.shape[1] | |
# test the init_out_size with nn.Unfold | |
assert patch_merge_3.init_out_size[1] == (input_size[0] - 2 * 4 - | |
1) // 2 + 1 | |
assert patch_merge_3.init_out_size[0] == (input_size[0] - 2 * 4 - | |
1) // 2 + 1 | |
H = 11 | |
W = 12 | |
input_size = (H, W) | |
dummy_input = torch.rand(B, C, H, W) | |
# test stride and norm | |
patch_merge_3 = PatchEmbed( | |
in_channels=C, | |
embed_dims=embed_dims, | |
kernel_size=kernel_size, | |
stride=stride, | |
padding=0, | |
dilation=2, | |
norm_cfg=dict(type='LN'), | |
input_size=input_size) | |
_, shape = patch_merge_3(dummy_input) | |
# when input_size equal to real input | |
# the out_size should be equal to `init_out_size` | |
assert shape == patch_merge_3.init_out_size | |
input_size = (H, W) | |
dummy_input = torch.rand(B, C, H, W) | |
# test stride and norm | |
patch_merge_3 = PatchEmbed( | |
in_channels=C, | |
embed_dims=embed_dims, | |
kernel_size=kernel_size, | |
stride=stride, | |
padding=0, | |
dilation=2, | |
norm_cfg=dict(type='LN'), | |
input_size=input_size) | |
_, shape = patch_merge_3(dummy_input) | |
# when input_size equal to real input | |
# the out_size should be equal to `init_out_size` | |
assert shape == patch_merge_3.init_out_size | |
# test adap padding | |
for padding in ('same', 'corner'): | |
in_c = 2 | |
embed_dims = 3 | |
B = 2 | |
# test stride is 1 | |
input_size = (5, 5) | |
kernel_size = (5, 5) | |
stride = (1, 1) | |
dilation = 1 | |
bias = False | |
x = torch.rand(B, in_c, *input_size) | |
patch_embed = PatchEmbed( | |
in_channels=in_c, | |
embed_dims=embed_dims, | |
kernel_size=kernel_size, | |
stride=stride, | |
padding=padding, | |
dilation=dilation, | |
bias=bias) | |
x_out, out_size = patch_embed(x) | |
assert x_out.size() == (B, 25, 3) | |
assert out_size == (5, 5) | |
assert x_out.size(1) == out_size[0] * out_size[1] | |
# test kernel_size == stride | |
input_size = (5, 5) | |
kernel_size = (5, 5) | |
stride = (5, 5) | |
dilation = 1 | |
bias = False | |
x = torch.rand(B, in_c, *input_size) | |
patch_embed = PatchEmbed( | |
in_channels=in_c, | |
embed_dims=embed_dims, | |
kernel_size=kernel_size, | |
stride=stride, | |
padding=padding, | |
dilation=dilation, | |
bias=bias) | |
x_out, out_size = patch_embed(x) | |
assert x_out.size() == (B, 1, 3) | |
assert out_size == (1, 1) | |
assert x_out.size(1) == out_size[0] * out_size[1] | |
# test kernel_size == stride | |
input_size = (6, 5) | |
kernel_size = (5, 5) | |
stride = (5, 5) | |
dilation = 1 | |
bias = False | |
x = torch.rand(B, in_c, *input_size) | |
patch_embed = PatchEmbed( | |
in_channels=in_c, | |
embed_dims=embed_dims, | |
kernel_size=kernel_size, | |
stride=stride, | |
padding=padding, | |
dilation=dilation, | |
bias=bias) | |
x_out, out_size = patch_embed(x) | |
assert x_out.size() == (B, 2, 3) | |
assert out_size == (2, 1) | |
assert x_out.size(1) == out_size[0] * out_size[1] | |
# test different kernel_size with different stride | |
input_size = (6, 5) | |
kernel_size = (6, 2) | |
stride = (6, 2) | |
dilation = 1 | |
bias = False | |
x = torch.rand(B, in_c, *input_size) | |
patch_embed = PatchEmbed( | |
in_channels=in_c, | |
embed_dims=embed_dims, | |
kernel_size=kernel_size, | |
stride=stride, | |
padding=padding, | |
dilation=dilation, | |
bias=bias) | |
x_out, out_size = patch_embed(x) | |
assert x_out.size() == (B, 3, 3) | |
assert out_size == (1, 3) | |
assert x_out.size(1) == out_size[0] * out_size[1] | |
def test_patch_merging(): | |
# Test the model with int padding | |
in_c = 3 | |
out_c = 4 | |
kernel_size = 3 | |
stride = 3 | |
padding = 1 | |
dilation = 1 | |
bias = False | |
# test the case `pad_to_stride` is False | |
patch_merge = PatchMerging( | |
in_channels=in_c, | |
out_channels=out_c, | |
kernel_size=kernel_size, | |
stride=stride, | |
padding=padding, | |
dilation=dilation, | |
bias=bias) | |
B, L, C = 1, 100, 3 | |
input_size = (10, 10) | |
x = torch.rand(B, L, C) | |
x_out, out_size = patch_merge(x, input_size) | |
assert x_out.size() == (1, 16, 4) | |
assert out_size == (4, 4) | |
# assert out size is consistent with real output | |
assert x_out.size(1) == out_size[0] * out_size[1] | |
in_c = 4 | |
out_c = 5 | |
kernel_size = 6 | |
stride = 3 | |
padding = 2 | |
dilation = 2 | |
bias = False | |
patch_merge = PatchMerging( | |
in_channels=in_c, | |
out_channels=out_c, | |
kernel_size=kernel_size, | |
stride=stride, | |
padding=padding, | |
dilation=dilation, | |
bias=bias) | |
B, L, C = 1, 100, 4 | |
input_size = (10, 10) | |
x = torch.rand(B, L, C) | |
x_out, out_size = patch_merge(x, input_size) | |
assert x_out.size() == (1, 4, 5) | |
assert out_size == (2, 2) | |
# assert out size is consistent with real output | |
assert x_out.size(1) == out_size[0] * out_size[1] | |
# Test with adaptive padding | |
for padding in ('same', 'corner'): | |
in_c = 2 | |
out_c = 3 | |
B = 2 | |
# test stride is 1 | |
input_size = (5, 5) | |
kernel_size = (5, 5) | |
stride = (1, 1) | |
dilation = 1 | |
bias = False | |
L = input_size[0] * input_size[1] | |
x = torch.rand(B, L, in_c) | |
patch_merge = PatchMerging( | |
in_channels=in_c, | |
out_channels=out_c, | |
kernel_size=kernel_size, | |
stride=stride, | |
padding=padding, | |
dilation=dilation, | |
bias=bias) | |
x_out, out_size = patch_merge(x, input_size) | |
assert x_out.size() == (B, 25, 3) | |
assert out_size == (5, 5) | |
assert x_out.size(1) == out_size[0] * out_size[1] | |
# test kernel_size == stride | |
input_size = (5, 5) | |
kernel_size = (5, 5) | |
stride = (5, 5) | |
dilation = 1 | |
bias = False | |
L = input_size[0] * input_size[1] | |
x = torch.rand(B, L, in_c) | |
patch_merge = PatchMerging( | |
in_channels=in_c, | |
out_channels=out_c, | |
kernel_size=kernel_size, | |
stride=stride, | |
padding=padding, | |
dilation=dilation, | |
bias=bias) | |
x_out, out_size = patch_merge(x, input_size) | |
assert x_out.size() == (B, 1, 3) | |
assert out_size == (1, 1) | |
assert x_out.size(1) == out_size[0] * out_size[1] | |
# test kernel_size == stride | |
input_size = (6, 5) | |
kernel_size = (5, 5) | |
stride = (5, 5) | |
dilation = 1 | |
bias = False | |
L = input_size[0] * input_size[1] | |
x = torch.rand(B, L, in_c) | |
patch_merge = PatchMerging( | |
in_channels=in_c, | |
out_channels=out_c, | |
kernel_size=kernel_size, | |
stride=stride, | |
padding=padding, | |
dilation=dilation, | |
bias=bias) | |
x_out, out_size = patch_merge(x, input_size) | |
assert x_out.size() == (B, 2, 3) | |
assert out_size == (2, 1) | |
assert x_out.size(1) == out_size[0] * out_size[1] | |
# test different kernel_size with different stride | |
input_size = (6, 5) | |
kernel_size = (6, 2) | |
stride = (6, 2) | |
dilation = 1 | |
bias = False | |
L = input_size[0] * input_size[1] | |
x = torch.rand(B, L, in_c) | |
patch_merge = PatchMerging( | |
in_channels=in_c, | |
out_channels=out_c, | |
kernel_size=kernel_size, | |
stride=stride, | |
padding=padding, | |
dilation=dilation, | |
bias=bias) | |
x_out, out_size = patch_merge(x, input_size) | |
assert x_out.size() == (B, 3, 3) | |
assert out_size == (1, 3) | |
assert x_out.size(1) == out_size[0] * out_size[1] | |
def test_multiheadattention(): | |
MultiheadAttention( | |
embed_dims=5, | |
num_heads=5, | |
attn_drop=0, | |
proj_drop=0, | |
dropout_layer=dict(type='Dropout', drop_prob=0.), | |
batch_first=True) | |
batch_dim = 2 | |
embed_dim = 5 | |
num_query = 100 | |
attn_batch_first = MultiheadAttention( | |
embed_dims=5, | |
num_heads=5, | |
attn_drop=0, | |
proj_drop=0, | |
dropout_layer=dict(type='DropPath', drop_prob=0.), | |
batch_first=True) | |
attn_query_first = MultiheadAttention( | |
embed_dims=5, | |
num_heads=5, | |
attn_drop=0, | |
proj_drop=0, | |
dropout_layer=dict(type='DropPath', drop_prob=0.), | |
batch_first=False) | |
param_dict = dict(attn_query_first.named_parameters()) | |
for n, v in attn_batch_first.named_parameters(): | |
param_dict[n].data = v.data | |
input_batch_first = torch.rand(batch_dim, num_query, embed_dim) | |
input_query_first = input_batch_first.transpose(0, 1) | |
assert torch.allclose( | |
attn_query_first(input_query_first).sum(), | |
attn_batch_first(input_batch_first).sum()) | |
key_batch_first = torch.rand(batch_dim, num_query, embed_dim) | |
key_query_first = key_batch_first.transpose(0, 1) | |
assert torch.allclose( | |
attn_query_first(input_query_first, key_query_first).sum(), | |
attn_batch_first(input_batch_first, key_batch_first).sum()) | |
identity = torch.ones_like(input_query_first) | |
# check deprecated arguments can be used normally | |
assert torch.allclose( | |
attn_query_first( | |
input_query_first, key_query_first, residual=identity).sum(), | |
attn_batch_first(input_batch_first, key_batch_first).sum() + | |
identity.sum() - input_batch_first.sum()) | |
assert torch.allclose( | |
attn_query_first( | |
input_query_first, key_query_first, identity=identity).sum(), | |
attn_batch_first(input_batch_first, key_batch_first).sum() + | |
identity.sum() - input_batch_first.sum()) | |
attn_query_first( | |
input_query_first, key_query_first, identity=identity).sum(), | |
def test_ffn(): | |
with pytest.raises(AssertionError): | |
# num_fcs should be no less than 2 | |
FFN(num_fcs=1) | |
FFN(dropout=0, add_residual=True) | |
ffn = FFN(dropout=0, add_identity=True) | |
input_tensor = torch.rand(2, 20, 256) | |
input_tensor_nbc = input_tensor.transpose(0, 1) | |
assert torch.allclose(ffn(input_tensor).sum(), ffn(input_tensor_nbc).sum()) | |
residual = torch.rand_like(input_tensor) | |
torch.allclose( | |
ffn(input_tensor, residual=residual).sum(), | |
ffn(input_tensor).sum() + residual.sum() - input_tensor.sum()) | |
torch.allclose( | |
ffn(input_tensor, identity=residual).sum(), | |
ffn(input_tensor).sum() + residual.sum() - input_tensor.sum()) | |
def test_basetransformerlayer_cuda(): | |
# To test if the BaseTransformerLayer's behaviour remains | |
# consistent after being deepcopied | |
operation_order = ('self_attn', 'ffn') | |
baselayer = BaseTransformerLayer( | |
operation_order=operation_order, | |
batch_first=True, | |
attn_cfgs=dict( | |
type='MultiheadAttention', | |
embed_dims=256, | |
num_heads=8, | |
), | |
) | |
baselayers = ModuleList([copy.deepcopy(baselayer) for _ in range(2)]) | |
baselayers.to('cuda') | |
x = torch.rand(2, 10, 256).cuda() | |
for m in baselayers: | |
x = m(x) | |
assert x.shape == torch.Size([2, 10, 256]) | |
def test_basetransformerlayer(embed_dims): | |
attn_cfgs = dict(type='MultiheadAttention', embed_dims=256, num_heads=8), | |
if embed_dims: | |
ffn_cfgs = dict( | |
type='FFN', | |
embed_dims=embed_dims, | |
feedforward_channels=1024, | |
num_fcs=2, | |
ffn_drop=0., | |
act_cfg=dict(type='ReLU', inplace=True), | |
) | |
else: | |
ffn_cfgs = dict( | |
type='FFN', | |
feedforward_channels=1024, | |
num_fcs=2, | |
ffn_drop=0., | |
act_cfg=dict(type='ReLU', inplace=True), | |
) | |
feedforward_channels = 2048 | |
ffn_dropout = 0.1 | |
operation_order = ('self_attn', 'norm', 'ffn', 'norm') | |
# test deprecated_args | |
baselayer = BaseTransformerLayer( | |
attn_cfgs=attn_cfgs, | |
ffn_cfgs=ffn_cfgs, | |
feedforward_channels=feedforward_channels, | |
ffn_dropout=ffn_dropout, | |
operation_order=operation_order) | |
assert baselayer.batch_first is False | |
assert baselayer.ffns[0].feedforward_channels == feedforward_channels | |
attn_cfgs = dict(type='MultiheadAttention', num_heads=8, embed_dims=256), | |
feedforward_channels = 2048 | |
ffn_dropout = 0.1 | |
operation_order = ('self_attn', 'norm', 'ffn', 'norm') | |
baselayer = BaseTransformerLayer( | |
attn_cfgs=attn_cfgs, | |
feedforward_channels=feedforward_channels, | |
ffn_dropout=ffn_dropout, | |
operation_order=operation_order, | |
batch_first=True) | |
assert baselayer.attentions[0].batch_first | |
in_tensor = torch.rand(2, 10, 256) | |
baselayer(in_tensor) | |
def test_transformerlayersequence(): | |
squeue = TransformerLayerSequence( | |
num_layers=6, | |
transformerlayers=dict( | |
type='BaseTransformerLayer', | |
attn_cfgs=[ | |
dict( | |
type='MultiheadAttention', | |
embed_dims=256, | |
num_heads=8, | |
dropout=0.1), | |
dict(type='MultiheadAttention', embed_dims=256, num_heads=4) | |
], | |
feedforward_channels=1024, | |
ffn_dropout=0.1, | |
operation_order=('self_attn', 'norm', 'cross_attn', 'norm', 'ffn', | |
'norm'))) | |
assert len(squeue.layers) == 6 | |
assert squeue.pre_norm is False | |
with pytest.raises(AssertionError): | |
# if transformerlayers is a list, len(transformerlayers) | |
# should be equal to num_layers | |
TransformerLayerSequence( | |
num_layers=6, | |
transformerlayers=[ | |
dict( | |
type='BaseTransformerLayer', | |
attn_cfgs=[ | |
dict( | |
type='MultiheadAttention', | |
embed_dims=256, | |
num_heads=8, | |
dropout=0.1), | |
dict(type='MultiheadAttention', embed_dims=256) | |
], | |
feedforward_channels=1024, | |
ffn_dropout=0.1, | |
operation_order=('self_attn', 'norm', 'cross_attn', 'norm', | |
'ffn', 'norm')) | |
]) | |
def test_drop_path(): | |
drop_path = DropPath(drop_prob=0) | |
test_in = torch.rand(2, 3, 4, 5) | |
assert test_in is drop_path(test_in) | |
drop_path = DropPath(drop_prob=0.1) | |
drop_path.training = False | |
test_in = torch.rand(2, 3, 4, 5) | |
assert test_in is drop_path(test_in) | |
drop_path.training = True | |
assert test_in is not drop_path(test_in) | |