File size: 2,899 Bytes
d7e58f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
# Copyright (c) OpenMMLab. All rights reserved.
import torch

from mmcv.cnn.bricks import GeneralizedAttention


def test_context_block():

    # test attention_type='1000'
    imgs = torch.randn(2, 16, 20, 20)
    gen_attention_block = GeneralizedAttention(16, attention_type='1000')
    assert gen_attention_block.query_conv.in_channels == 16
    assert gen_attention_block.key_conv.in_channels == 16
    assert gen_attention_block.key_conv.in_channels == 16
    out = gen_attention_block(imgs)
    assert out.shape == imgs.shape

    # test attention_type='0100'
    imgs = torch.randn(2, 16, 20, 20)
    gen_attention_block = GeneralizedAttention(16, attention_type='0100')
    assert gen_attention_block.query_conv.in_channels == 16
    assert gen_attention_block.appr_geom_fc_x.in_features == 8
    assert gen_attention_block.appr_geom_fc_y.in_features == 8
    out = gen_attention_block(imgs)
    assert out.shape == imgs.shape

    # test attention_type='0010'
    imgs = torch.randn(2, 16, 20, 20)
    gen_attention_block = GeneralizedAttention(16, attention_type='0010')
    assert gen_attention_block.key_conv.in_channels == 16
    assert hasattr(gen_attention_block, 'appr_bias')
    out = gen_attention_block(imgs)
    assert out.shape == imgs.shape

    # test attention_type='0001'
    imgs = torch.randn(2, 16, 20, 20)
    gen_attention_block = GeneralizedAttention(16, attention_type='0001')
    assert gen_attention_block.appr_geom_fc_x.in_features == 8
    assert gen_attention_block.appr_geom_fc_y.in_features == 8
    assert hasattr(gen_attention_block, 'geom_bias')
    out = gen_attention_block(imgs)
    assert out.shape == imgs.shape

    # test spatial_range >= 0
    imgs = torch.randn(2, 256, 20, 20)
    gen_attention_block = GeneralizedAttention(256, spatial_range=10)
    assert hasattr(gen_attention_block, 'local_constraint_map')
    out = gen_attention_block(imgs)
    assert out.shape == imgs.shape

    # test q_stride > 1
    imgs = torch.randn(2, 16, 20, 20)
    gen_attention_block = GeneralizedAttention(16, q_stride=2)
    assert gen_attention_block.q_downsample is not None
    out = gen_attention_block(imgs)
    assert out.shape == imgs.shape

    # test kv_stride > 1
    imgs = torch.randn(2, 16, 20, 20)
    gen_attention_block = GeneralizedAttention(16, kv_stride=2)
    assert gen_attention_block.kv_downsample is not None
    out = gen_attention_block(imgs)
    assert out.shape == imgs.shape

    # test fp16 with attention_type='1111'
    if torch.cuda.is_available():
        imgs = torch.randn(2, 16, 20, 20).cuda().to(torch.half)
        gen_attention_block = GeneralizedAttention(
            16,
            spatial_range=-1,
            num_heads=8,
            attention_type='1111',
            kv_stride=2)
        gen_attention_block.cuda().type(torch.half)
        out = gen_attention_block(imgs)
        assert out.shape == imgs.shape