Spaces:
Sleeping
Sleeping
# Copyright (c) OpenMMLab. All rights reserved. | |
import torch | |
from mmcv.cnn.bricks import GeneralizedAttention | |
def test_context_block(): | |
# test attention_type='1000' | |
imgs = torch.randn(2, 16, 20, 20) | |
gen_attention_block = GeneralizedAttention(16, attention_type='1000') | |
assert gen_attention_block.query_conv.in_channels == 16 | |
assert gen_attention_block.key_conv.in_channels == 16 | |
assert gen_attention_block.key_conv.in_channels == 16 | |
out = gen_attention_block(imgs) | |
assert out.shape == imgs.shape | |
# test attention_type='0100' | |
imgs = torch.randn(2, 16, 20, 20) | |
gen_attention_block = GeneralizedAttention(16, attention_type='0100') | |
assert gen_attention_block.query_conv.in_channels == 16 | |
assert gen_attention_block.appr_geom_fc_x.in_features == 8 | |
assert gen_attention_block.appr_geom_fc_y.in_features == 8 | |
out = gen_attention_block(imgs) | |
assert out.shape == imgs.shape | |
# test attention_type='0010' | |
imgs = torch.randn(2, 16, 20, 20) | |
gen_attention_block = GeneralizedAttention(16, attention_type='0010') | |
assert gen_attention_block.key_conv.in_channels == 16 | |
assert hasattr(gen_attention_block, 'appr_bias') | |
out = gen_attention_block(imgs) | |
assert out.shape == imgs.shape | |
# test attention_type='0001' | |
imgs = torch.randn(2, 16, 20, 20) | |
gen_attention_block = GeneralizedAttention(16, attention_type='0001') | |
assert gen_attention_block.appr_geom_fc_x.in_features == 8 | |
assert gen_attention_block.appr_geom_fc_y.in_features == 8 | |
assert hasattr(gen_attention_block, 'geom_bias') | |
out = gen_attention_block(imgs) | |
assert out.shape == imgs.shape | |
# test spatial_range >= 0 | |
imgs = torch.randn(2, 256, 20, 20) | |
gen_attention_block = GeneralizedAttention(256, spatial_range=10) | |
assert hasattr(gen_attention_block, 'local_constraint_map') | |
out = gen_attention_block(imgs) | |
assert out.shape == imgs.shape | |
# test q_stride > 1 | |
imgs = torch.randn(2, 16, 20, 20) | |
gen_attention_block = GeneralizedAttention(16, q_stride=2) | |
assert gen_attention_block.q_downsample is not None | |
out = gen_attention_block(imgs) | |
assert out.shape == imgs.shape | |
# test kv_stride > 1 | |
imgs = torch.randn(2, 16, 20, 20) | |
gen_attention_block = GeneralizedAttention(16, kv_stride=2) | |
assert gen_attention_block.kv_downsample is not None | |
out = gen_attention_block(imgs) | |
assert out.shape == imgs.shape | |
# test fp16 with attention_type='1111' | |
if torch.cuda.is_available(): | |
imgs = torch.randn(2, 16, 20, 20).cuda().to(torch.half) | |
gen_attention_block = GeneralizedAttention( | |
16, | |
spatial_range=-1, | |
num_heads=8, | |
attention_type='1111', | |
kv_stride=2) | |
gen_attention_block.cuda().type(torch.half) | |
out = gen_attention_block(imgs) | |
assert out.shape == imgs.shape | |