AiOS / mmcv /tests /test_ops /test_masked_conv2d.py
ttxskk
update
d7e58f0
raw
history blame
488 Bytes
# Copyright (c) OpenMMLab. All rights reserved.
import torch
class TestMaskedConv2d:
def test_masked_conv2d(self):
if not torch.cuda.is_available():
return
from mmcv.ops import MaskedConv2d
input = torch.randn(1, 3, 16, 16, requires_grad=True, device='cuda')
mask = torch.randn(1, 16, 16, requires_grad=True, device='cuda')
conv = MaskedConv2d(3, 3, 3).cuda()
output = conv(input, mask)
assert output is not None