# Copyright (c) OpenMMLab. All rights reserved. import torch from mmcv.cnn.bricks import GeneralizedAttention def test_context_block(): # test attention_type='1000' imgs = torch.randn(2, 16, 20, 20) gen_attention_block = GeneralizedAttention(16, attention_type='1000') assert gen_attention_block.query_conv.in_channels == 16 assert gen_attention_block.key_conv.in_channels == 16 assert gen_attention_block.key_conv.in_channels == 16 out = gen_attention_block(imgs) assert out.shape == imgs.shape # test attention_type='0100' imgs = torch.randn(2, 16, 20, 20) gen_attention_block = GeneralizedAttention(16, attention_type='0100') assert gen_attention_block.query_conv.in_channels == 16 assert gen_attention_block.appr_geom_fc_x.in_features == 8 assert gen_attention_block.appr_geom_fc_y.in_features == 8 out = gen_attention_block(imgs) assert out.shape == imgs.shape # test attention_type='0010' imgs = torch.randn(2, 16, 20, 20) gen_attention_block = GeneralizedAttention(16, attention_type='0010') assert gen_attention_block.key_conv.in_channels == 16 assert hasattr(gen_attention_block, 'appr_bias') out = gen_attention_block(imgs) assert out.shape == imgs.shape # test attention_type='0001' imgs = torch.randn(2, 16, 20, 20) gen_attention_block = GeneralizedAttention(16, attention_type='0001') assert gen_attention_block.appr_geom_fc_x.in_features == 8 assert gen_attention_block.appr_geom_fc_y.in_features == 8 assert hasattr(gen_attention_block, 'geom_bias') out = gen_attention_block(imgs) assert out.shape == imgs.shape # test spatial_range >= 0 imgs = torch.randn(2, 256, 20, 20) gen_attention_block = GeneralizedAttention(256, spatial_range=10) assert hasattr(gen_attention_block, 'local_constraint_map') out = gen_attention_block(imgs) assert out.shape == imgs.shape # test q_stride > 1 imgs = torch.randn(2, 16, 20, 20) gen_attention_block = GeneralizedAttention(16, q_stride=2) assert gen_attention_block.q_downsample is not None out = gen_attention_block(imgs) assert out.shape == imgs.shape # test kv_stride > 1 imgs = torch.randn(2, 16, 20, 20) gen_attention_block = GeneralizedAttention(16, kv_stride=2) assert gen_attention_block.kv_downsample is not None out = gen_attention_block(imgs) assert out.shape == imgs.shape # test fp16 with attention_type='1111' if torch.cuda.is_available(): imgs = torch.randn(2, 16, 20, 20).cuda().to(torch.half) gen_attention_block = GeneralizedAttention( 16, spatial_range=-1, num_heads=8, attention_type='1111', kv_stride=2) gen_attention_block.cuda().type(torch.half) out = gen_attention_block(imgs) assert out.shape == imgs.shape