# Copyright (c) OpenMMLab. All rights reserved. import copy import pytest import torch from mmcv.cnn.bricks.drop import DropPath from mmcv.cnn.bricks.transformer import (FFN, AdaptivePadding, BaseTransformerLayer, MultiheadAttention, PatchEmbed, PatchMerging, TransformerLayerSequence) from mmcv.runner import ModuleList def test_adaptive_padding(): for padding in ('same', 'corner'): kernel_size = 16 stride = 16 dilation = 1 input = torch.rand(1, 1, 15, 17) adap_pad = AdaptivePadding( kernel_size=kernel_size, stride=stride, dilation=dilation, padding=padding) out = adap_pad(input) # padding to divisible by 16 assert (out.shape[2], out.shape[3]) == (16, 32) input = torch.rand(1, 1, 16, 17) out = adap_pad(input) # padding to divisible by 16 assert (out.shape[2], out.shape[3]) == (16, 32) kernel_size = (2, 2) stride = (2, 2) dilation = (1, 1) adap_pad = AdaptivePadding( kernel_size=kernel_size, stride=stride, dilation=dilation, padding=padding) input = torch.rand(1, 1, 11, 13) out = adap_pad(input) # padding to divisible by 2 assert (out.shape[2], out.shape[3]) == (12, 14) kernel_size = (2, 2) stride = (10, 10) dilation = (1, 1) adap_pad = AdaptivePadding( kernel_size=kernel_size, stride=stride, dilation=dilation, padding=padding) input = torch.rand(1, 1, 10, 13) out = adap_pad(input) # no padding assert (out.shape[2], out.shape[3]) == (10, 13) kernel_size = (11, 11) adap_pad = AdaptivePadding( kernel_size=kernel_size, stride=stride, dilation=dilation, padding=padding) input = torch.rand(1, 1, 11, 13) out = adap_pad(input) # all padding assert (out.shape[2], out.shape[3]) == (21, 21) # test padding as kernel is (7,9) input = torch.rand(1, 1, 11, 13) stride = (3, 4) kernel_size = (4, 5) dilation = (2, 2) # actually (7, 9) adap_pad = AdaptivePadding( kernel_size=kernel_size, stride=stride, dilation=dilation, padding=padding) dilation_out = adap_pad(input) assert (dilation_out.shape[2], dilation_out.shape[3]) == (16, 21) kernel_size = (7, 9) dilation = (1, 1) adap_pad = AdaptivePadding( kernel_size=kernel_size, stride=stride, dilation=dilation, padding=padding) kernel79_out = adap_pad(input) assert (kernel79_out.shape[2], kernel79_out.shape[3]) == (16, 21) assert kernel79_out.shape == dilation_out.shape # assert only support "same" "corner" with pytest.raises(AssertionError): AdaptivePadding( kernel_size=kernel_size, stride=stride, dilation=dilation, padding=1) def test_patch_embed(): B = 2 H = 3 W = 4 C = 3 embed_dims = 10 kernel_size = 3 stride = 1 dummy_input = torch.rand(B, C, H, W) patch_merge_1 = PatchEmbed( in_channels=C, embed_dims=embed_dims, kernel_size=kernel_size, stride=stride, padding=0, dilation=1, norm_cfg=None) x1, shape = patch_merge_1(dummy_input) # test out shape assert x1.shape == (2, 2, 10) # test outsize is correct assert shape == (1, 2) # test L = out_h * out_w assert shape[0] * shape[1] == x1.shape[1] B = 2 H = 10 W = 10 C = 3 embed_dims = 10 kernel_size = 5 stride = 2 dummy_input = torch.rand(B, C, H, W) # test dilation patch_merge_2 = PatchEmbed( in_channels=C, embed_dims=embed_dims, kernel_size=kernel_size, stride=stride, padding=0, dilation=2, norm_cfg=None, ) x2, shape = patch_merge_2(dummy_input) # test out shape assert x2.shape == (2, 1, 10) # test outsize is correct assert shape == (1, 1) # test L = out_h * out_w assert shape[0] * shape[1] == x2.shape[1] stride = 2 input_size = (10, 10) dummy_input = torch.rand(B, C, H, W) # test stride and norm patch_merge_3 = PatchEmbed( in_channels=C, embed_dims=embed_dims, kernel_size=kernel_size, stride=stride, padding=0, dilation=2, norm_cfg=dict(type='LN'), input_size=input_size) x3, shape = patch_merge_3(dummy_input) # test out shape assert x3.shape == (2, 1, 10) # test outsize is correct assert shape == (1, 1) # test L = out_h * out_w assert shape[0] * shape[1] == x3.shape[1] # test the init_out_size with nn.Unfold assert patch_merge_3.init_out_size[1] == (input_size[0] - 2 * 4 - 1) // 2 + 1 assert patch_merge_3.init_out_size[0] == (input_size[0] - 2 * 4 - 1) // 2 + 1 H = 11 W = 12 input_size = (H, W) dummy_input = torch.rand(B, C, H, W) # test stride and norm patch_merge_3 = PatchEmbed( in_channels=C, embed_dims=embed_dims, kernel_size=kernel_size, stride=stride, padding=0, dilation=2, norm_cfg=dict(type='LN'), input_size=input_size) _, shape = patch_merge_3(dummy_input) # when input_size equal to real input # the out_size should be equal to `init_out_size` assert shape == patch_merge_3.init_out_size input_size = (H, W) dummy_input = torch.rand(B, C, H, W) # test stride and norm patch_merge_3 = PatchEmbed( in_channels=C, embed_dims=embed_dims, kernel_size=kernel_size, stride=stride, padding=0, dilation=2, norm_cfg=dict(type='LN'), input_size=input_size) _, shape = patch_merge_3(dummy_input) # when input_size equal to real input # the out_size should be equal to `init_out_size` assert shape == patch_merge_3.init_out_size # test adap padding for padding in ('same', 'corner'): in_c = 2 embed_dims = 3 B = 2 # test stride is 1 input_size = (5, 5) kernel_size = (5, 5) stride = (1, 1) dilation = 1 bias = False x = torch.rand(B, in_c, *input_size) patch_embed = PatchEmbed( in_channels=in_c, embed_dims=embed_dims, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias) x_out, out_size = patch_embed(x) assert x_out.size() == (B, 25, 3) assert out_size == (5, 5) assert x_out.size(1) == out_size[0] * out_size[1] # test kernel_size == stride input_size = (5, 5) kernel_size = (5, 5) stride = (5, 5) dilation = 1 bias = False x = torch.rand(B, in_c, *input_size) patch_embed = PatchEmbed( in_channels=in_c, embed_dims=embed_dims, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias) x_out, out_size = patch_embed(x) assert x_out.size() == (B, 1, 3) assert out_size == (1, 1) assert x_out.size(1) == out_size[0] * out_size[1] # test kernel_size == stride input_size = (6, 5) kernel_size = (5, 5) stride = (5, 5) dilation = 1 bias = False x = torch.rand(B, in_c, *input_size) patch_embed = PatchEmbed( in_channels=in_c, embed_dims=embed_dims, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias) x_out, out_size = patch_embed(x) assert x_out.size() == (B, 2, 3) assert out_size == (2, 1) assert x_out.size(1) == out_size[0] * out_size[1] # test different kernel_size with different stride input_size = (6, 5) kernel_size = (6, 2) stride = (6, 2) dilation = 1 bias = False x = torch.rand(B, in_c, *input_size) patch_embed = PatchEmbed( in_channels=in_c, embed_dims=embed_dims, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias) x_out, out_size = patch_embed(x) assert x_out.size() == (B, 3, 3) assert out_size == (1, 3) assert x_out.size(1) == out_size[0] * out_size[1] def test_patch_merging(): # Test the model with int padding in_c = 3 out_c = 4 kernel_size = 3 stride = 3 padding = 1 dilation = 1 bias = False # test the case `pad_to_stride` is False patch_merge = PatchMerging( in_channels=in_c, out_channels=out_c, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias) B, L, C = 1, 100, 3 input_size = (10, 10) x = torch.rand(B, L, C) x_out, out_size = patch_merge(x, input_size) assert x_out.size() == (1, 16, 4) assert out_size == (4, 4) # assert out size is consistent with real output assert x_out.size(1) == out_size[0] * out_size[1] in_c = 4 out_c = 5 kernel_size = 6 stride = 3 padding = 2 dilation = 2 bias = False patch_merge = PatchMerging( in_channels=in_c, out_channels=out_c, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias) B, L, C = 1, 100, 4 input_size = (10, 10) x = torch.rand(B, L, C) x_out, out_size = patch_merge(x, input_size) assert x_out.size() == (1, 4, 5) assert out_size == (2, 2) # assert out size is consistent with real output assert x_out.size(1) == out_size[0] * out_size[1] # Test with adaptive padding for padding in ('same', 'corner'): in_c = 2 out_c = 3 B = 2 # test stride is 1 input_size = (5, 5) kernel_size = (5, 5) stride = (1, 1) dilation = 1 bias = False L = input_size[0] * input_size[1] x = torch.rand(B, L, in_c) patch_merge = PatchMerging( in_channels=in_c, out_channels=out_c, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias) x_out, out_size = patch_merge(x, input_size) assert x_out.size() == (B, 25, 3) assert out_size == (5, 5) assert x_out.size(1) == out_size[0] * out_size[1] # test kernel_size == stride input_size = (5, 5) kernel_size = (5, 5) stride = (5, 5) dilation = 1 bias = False L = input_size[0] * input_size[1] x = torch.rand(B, L, in_c) patch_merge = PatchMerging( in_channels=in_c, out_channels=out_c, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias) x_out, out_size = patch_merge(x, input_size) assert x_out.size() == (B, 1, 3) assert out_size == (1, 1) assert x_out.size(1) == out_size[0] * out_size[1] # test kernel_size == stride input_size = (6, 5) kernel_size = (5, 5) stride = (5, 5) dilation = 1 bias = False L = input_size[0] * input_size[1] x = torch.rand(B, L, in_c) patch_merge = PatchMerging( in_channels=in_c, out_channels=out_c, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias) x_out, out_size = patch_merge(x, input_size) assert x_out.size() == (B, 2, 3) assert out_size == (2, 1) assert x_out.size(1) == out_size[0] * out_size[1] # test different kernel_size with different stride input_size = (6, 5) kernel_size = (6, 2) stride = (6, 2) dilation = 1 bias = False L = input_size[0] * input_size[1] x = torch.rand(B, L, in_c) patch_merge = PatchMerging( in_channels=in_c, out_channels=out_c, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias) x_out, out_size = patch_merge(x, input_size) assert x_out.size() == (B, 3, 3) assert out_size == (1, 3) assert x_out.size(1) == out_size[0] * out_size[1] def test_multiheadattention(): MultiheadAttention( embed_dims=5, num_heads=5, attn_drop=0, proj_drop=0, dropout_layer=dict(type='Dropout', drop_prob=0.), batch_first=True) batch_dim = 2 embed_dim = 5 num_query = 100 attn_batch_first = MultiheadAttention( embed_dims=5, num_heads=5, attn_drop=0, proj_drop=0, dropout_layer=dict(type='DropPath', drop_prob=0.), batch_first=True) attn_query_first = MultiheadAttention( embed_dims=5, num_heads=5, attn_drop=0, proj_drop=0, dropout_layer=dict(type='DropPath', drop_prob=0.), batch_first=False) param_dict = dict(attn_query_first.named_parameters()) for n, v in attn_batch_first.named_parameters(): param_dict[n].data = v.data input_batch_first = torch.rand(batch_dim, num_query, embed_dim) input_query_first = input_batch_first.transpose(0, 1) assert torch.allclose( attn_query_first(input_query_first).sum(), attn_batch_first(input_batch_first).sum()) key_batch_first = torch.rand(batch_dim, num_query, embed_dim) key_query_first = key_batch_first.transpose(0, 1) assert torch.allclose( attn_query_first(input_query_first, key_query_first).sum(), attn_batch_first(input_batch_first, key_batch_first).sum()) identity = torch.ones_like(input_query_first) # check deprecated arguments can be used normally assert torch.allclose( attn_query_first( input_query_first, key_query_first, residual=identity).sum(), attn_batch_first(input_batch_first, key_batch_first).sum() + identity.sum() - input_batch_first.sum()) assert torch.allclose( attn_query_first( input_query_first, key_query_first, identity=identity).sum(), attn_batch_first(input_batch_first, key_batch_first).sum() + identity.sum() - input_batch_first.sum()) attn_query_first( input_query_first, key_query_first, identity=identity).sum(), def test_ffn(): with pytest.raises(AssertionError): # num_fcs should be no less than 2 FFN(num_fcs=1) FFN(dropout=0, add_residual=True) ffn = FFN(dropout=0, add_identity=True) input_tensor = torch.rand(2, 20, 256) input_tensor_nbc = input_tensor.transpose(0, 1) assert torch.allclose(ffn(input_tensor).sum(), ffn(input_tensor_nbc).sum()) residual = torch.rand_like(input_tensor) torch.allclose( ffn(input_tensor, residual=residual).sum(), ffn(input_tensor).sum() + residual.sum() - input_tensor.sum()) torch.allclose( ffn(input_tensor, identity=residual).sum(), ffn(input_tensor).sum() + residual.sum() - input_tensor.sum()) @pytest.mark.skipif(not torch.cuda.is_available(), reason='Cuda not available') def test_basetransformerlayer_cuda(): # To test if the BaseTransformerLayer's behaviour remains # consistent after being deepcopied operation_order = ('self_attn', 'ffn') baselayer = BaseTransformerLayer( operation_order=operation_order, batch_first=True, attn_cfgs=dict( type='MultiheadAttention', embed_dims=256, num_heads=8, ), ) baselayers = ModuleList([copy.deepcopy(baselayer) for _ in range(2)]) baselayers.to('cuda') x = torch.rand(2, 10, 256).cuda() for m in baselayers: x = m(x) assert x.shape == torch.Size([2, 10, 256]) @pytest.mark.parametrize('embed_dims', [False, 256]) def test_basetransformerlayer(embed_dims): attn_cfgs = dict(type='MultiheadAttention', embed_dims=256, num_heads=8), if embed_dims: ffn_cfgs = dict( type='FFN', embed_dims=embed_dims, feedforward_channels=1024, num_fcs=2, ffn_drop=0., act_cfg=dict(type='ReLU', inplace=True), ) else: ffn_cfgs = dict( type='FFN', feedforward_channels=1024, num_fcs=2, ffn_drop=0., act_cfg=dict(type='ReLU', inplace=True), ) feedforward_channels = 2048 ffn_dropout = 0.1 operation_order = ('self_attn', 'norm', 'ffn', 'norm') # test deprecated_args baselayer = BaseTransformerLayer( attn_cfgs=attn_cfgs, ffn_cfgs=ffn_cfgs, feedforward_channels=feedforward_channels, ffn_dropout=ffn_dropout, operation_order=operation_order) assert baselayer.batch_first is False assert baselayer.ffns[0].feedforward_channels == feedforward_channels attn_cfgs = dict(type='MultiheadAttention', num_heads=8, embed_dims=256), feedforward_channels = 2048 ffn_dropout = 0.1 operation_order = ('self_attn', 'norm', 'ffn', 'norm') baselayer = BaseTransformerLayer( attn_cfgs=attn_cfgs, feedforward_channels=feedforward_channels, ffn_dropout=ffn_dropout, operation_order=operation_order, batch_first=True) assert baselayer.attentions[0].batch_first in_tensor = torch.rand(2, 10, 256) baselayer(in_tensor) def test_transformerlayersequence(): squeue = TransformerLayerSequence( num_layers=6, transformerlayers=dict( type='BaseTransformerLayer', attn_cfgs=[ dict( type='MultiheadAttention', embed_dims=256, num_heads=8, dropout=0.1), dict(type='MultiheadAttention', embed_dims=256, num_heads=4) ], feedforward_channels=1024, ffn_dropout=0.1, operation_order=('self_attn', 'norm', 'cross_attn', 'norm', 'ffn', 'norm'))) assert len(squeue.layers) == 6 assert squeue.pre_norm is False with pytest.raises(AssertionError): # if transformerlayers is a list, len(transformerlayers) # should be equal to num_layers TransformerLayerSequence( num_layers=6, transformerlayers=[ dict( type='BaseTransformerLayer', attn_cfgs=[ dict( type='MultiheadAttention', embed_dims=256, num_heads=8, dropout=0.1), dict(type='MultiheadAttention', embed_dims=256) ], feedforward_channels=1024, ffn_dropout=0.1, operation_order=('self_attn', 'norm', 'cross_attn', 'norm', 'ffn', 'norm')) ]) def test_drop_path(): drop_path = DropPath(drop_prob=0) test_in = torch.rand(2, 3, 4, 5) assert test_in is drop_path(test_in) drop_path = DropPath(drop_prob=0.1) drop_path.training = False test_in = torch.rand(2, 3, 4, 5) assert test_in is drop_path(test_in) drop_path.training = True assert test_in is not drop_path(test_in)