Spaces:
Sleeping
Sleeping
# Copyright (c) OpenMMLab. All rights reserved. | |
import pytest | |
import torch | |
import mmcv | |
from mmcv.utils import TORCH_VERSION | |
pytest.skip('this test not ready now', allow_module_level=True) | |
skip_no_parrots = pytest.mark.skipif( | |
TORCH_VERSION != 'parrots', reason='test case under parrots environment') | |
class TestJit: | |
def test_add_dict(self): | |
def add_dict(oper): | |
rets = oper['x'] + oper['y'] | |
return {'result': rets} | |
def add_dict_pyfunc(oper): | |
rets = oper['x'] + oper['y'] | |
return {'result': rets} | |
a = torch.rand((3, 4)) | |
b = torch.rand((3, 4)) | |
oper = {'x': a, 'y': b} | |
rets_t = add_dict(oper) | |
rets = add_dict_pyfunc(oper) | |
assert 'result' in rets | |
assert (rets_t['result'] == rets['result']).all() | |
def test_add_list(self): | |
def add_list(oper, x, y): | |
rets = {} | |
for idx, pair in enumerate(oper): | |
rets[f'k{idx}'] = pair['x'] + pair['y'] | |
rets[f'k{len(oper)}'] = x + y | |
return rets | |
def add_list_pyfunc(oper, x, y): | |
rets = {} | |
for idx, pair in enumerate(oper): | |
rets[f'k{idx}'] = pair['x'] + pair['y'] | |
rets[f'k{len(oper)}'] = x + y | |
return rets | |
pair_num = 3 | |
oper = [] | |
for _ in range(pair_num): | |
oper.append({'x': torch.rand((3, 4)), 'y': torch.rand((3, 4))}) | |
a = torch.rand((3, 4)) | |
b = torch.rand((3, 4)) | |
rets = add_list_pyfunc(oper, x=a, y=b) | |
rets_t = add_list(oper, x=a, y=b) | |
for idx in range(pair_num + 1): | |
assert f'k{idx}' in rets_t | |
assert (rets[f'k{idx}'] == rets_t[f'k{idx}']).all() | |
def test_jit_cache(self): | |
def func(oper): | |
if oper['const'] > 1: | |
return oper['x'] * 2 + oper['y'] | |
else: | |
return oper['x'] * 2 - oper['y'] | |
def pyfunc(oper): | |
if oper['const'] > 1: | |
return oper['x'] * 2 + oper['y'] | |
else: | |
return oper['x'] * 2 - oper['y'] | |
assert len(func._cache._cache) == 0 | |
oper = {'const': 2, 'x': torch.rand((3, 4)), 'y': torch.rand((3, 4))} | |
rets_plus = pyfunc(oper) | |
rets_plus_t = func(oper) | |
assert (rets_plus == rets_plus_t).all() | |
assert len(func._cache._cache) == 1 | |
oper['const'] = 0.5 | |
rets_minus = pyfunc(oper) | |
rets_minus_t = func(oper) | |
assert (rets_minus == rets_minus_t).all() | |
assert len(func._cache._cache) == 2 | |
rets_a = (rets_minus_t + rets_plus_t) / 4 | |
assert torch.allclose(oper['x'], rets_a) | |
def test_jit_shape(self): | |
def func(a): | |
return a + 1 | |
assert len(func._cache._cache) == 0 | |
a = torch.ones((3, 4)) | |
r = func(a) | |
assert r.shape == (3, 4) | |
assert (r == 2).all() | |
assert len(func._cache._cache) == 1 | |
a = torch.ones((2, 3, 4)) | |
r = func(a) | |
assert r.shape == (2, 3, 4) | |
assert (r == 2).all() | |
assert len(func._cache._cache) == 2 | |
def test_jit_kwargs(self): | |
def func(a, b): | |
return torch.mean((a - b) * (a - b)) | |
assert len(func._cache._cache) == 0 | |
x = torch.rand((16, 32)) | |
y = torch.rand((16, 32)) | |
func(x, y) | |
assert len(func._cache._cache) == 1 | |
func(x, b=y) | |
assert len(func._cache._cache) == 1 | |
func(b=y, a=x) | |
assert len(func._cache._cache) == 1 | |
def test_jit_derivate(self): | |
def func(x, y): | |
return (x + 2) * (y - 2) | |
a = torch.rand((3, 4)) | |
b = torch.rand((3, 4)) | |
a.requires_grad = True | |
c = func(a, b) | |
assert c.requires_grad | |
d = torch.empty_like(c) | |
d.fill_(1.0) | |
c.backward(d) | |
assert torch.allclose(a.grad, (b - 2)) | |
assert b.grad is None | |
a.grad = None | |
c = func(a, b) | |
assert c.requires_grad | |
d = torch.empty_like(c) | |
d.fill_(2.7) | |
c.backward(d) | |
assert torch.allclose(a.grad, 2.7 * (b - 2)) | |
assert b.grad is None | |
def test_jit_optimize(self): | |
def func(a, b): | |
return torch.mean((a - b) * (a - b)) | |
def pyfunc(a, b): | |
return torch.mean((a - b) * (a - b)) | |
a = torch.rand((16, 32)) | |
b = torch.rand((16, 32)) | |
c = func(a, b) | |
d = pyfunc(a, b) | |
assert torch.allclose(c, d) | |
def test_jit_coderize(self): | |
if not torch.cuda.is_available(): | |
return | |
def func(a, b): | |
return (a + b) * (a - b) | |
def pyfunc(a, b): | |
return (a + b) * (a - b) | |
a = torch.rand((16, 32), device='cuda') | |
b = torch.rand((16, 32), device='cuda') | |
c = func(a, b) | |
d = pyfunc(a, b) | |
assert torch.allclose(c, d) | |
def test_jit_value_dependent(self): | |
def func(a, b): | |
torch.nonzero(a) | |
return torch.mean((a - b) * (a - b)) | |
def pyfunc(a, b): | |
torch.nonzero(a) | |
return torch.mean((a - b) * (a - b)) | |
a = torch.rand((16, 32)) | |
b = torch.rand((16, 32)) | |
c = func(a, b) | |
d = pyfunc(a, b) | |
assert torch.allclose(c, d) | |
def test_jit_check_input(self): | |
def func(x): | |
y = torch.rand_like(x) | |
return x + y | |
a = torch.ones((3, 4)) | |
with pytest.raises(AssertionError): | |
func = mmcv.jit(func, check_input=(a, )) | |
def test_jit_partial_shape(self): | |
def func(a, b): | |
return torch.mean((a - b) * (a - b)) | |
def pyfunc(a, b): | |
return torch.mean((a - b) * (a - b)) | |
a = torch.rand((3, 4)) | |
b = torch.rand((3, 4)) | |
assert torch.allclose(func(a, b), pyfunc(a, b)) | |
assert len(func._cache._cache) == 1 | |
a = torch.rand((6, 5)) | |
b = torch.rand((6, 5)) | |
assert torch.allclose(func(a, b), pyfunc(a, b)) | |
assert len(func._cache._cache) == 1 | |
a = torch.rand((3, 4, 5)) | |
b = torch.rand((3, 4, 5)) | |
assert torch.allclose(func(a, b), pyfunc(a, b)) | |
assert len(func._cache._cache) == 2 | |
a = torch.rand((1, 9, 8)) | |
b = torch.rand((1, 9, 8)) | |
assert torch.allclose(func(a, b), pyfunc(a, b)) | |
assert len(func._cache._cache) == 2 | |
def test_instance_method(self): | |
class T: | |
def __init__(self, shape): | |
self._c = torch.rand(shape) | |
def test_method(self, x, y): | |
return (x * self._c) + y | |
shape = (16, 32) | |
t = T(shape) | |
a = torch.rand(shape) | |
b = torch.rand(shape) | |
res = (a * t._c) + b | |
jit_res = t.test_method(a, b) | |
assert torch.allclose(res, jit_res) | |
t = T(shape) | |
res = (a * t._c) + b | |
jit_res = t.test_method(a, b) | |
assert torch.allclose(res, jit_res) | |