AiOS / mmcv /tests /test_ops /test_fused_bias_leakyrelu.py
ttxskk
update
d7e58f0
raw
history blame
1.38 kB
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
_USING_PARROTS = True
try:
from parrots.autograd import gradcheck
except ImportError:
from torch.autograd import gradcheck, gradgradcheck
_USING_PARROTS = False
class TestFusedBiasLeakyReLU:
@classmethod
def setup_class(cls):
if not torch.cuda.is_available():
return
cls.input_tensor = torch.randn((2, 2, 2, 2), requires_grad=True).cuda()
cls.bias = torch.zeros(2, requires_grad=True).cuda()
@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
def test_gradient(self):
from mmcv.ops import FusedBiasLeakyReLU
if _USING_PARROTS:
gradcheck(
FusedBiasLeakyReLU(2).cuda(),
self.input_tensor,
delta=1e-4,
pt_atol=1e-3)
else:
gradcheck(
FusedBiasLeakyReLU(2).cuda(),
self.input_tensor,
eps=1e-4,
atol=1e-3)
@pytest.mark.skipif(
not torch.cuda.is_available() or _USING_PARROTS,
reason='requires cuda')
def test_gradgradient(self):
from mmcv.ops import FusedBiasLeakyReLU
gradgradcheck(
FusedBiasLeakyReLU(2).cuda(),
self.input_tensor,
eps=1e-4,
atol=1e-3)