# Copyright (c) OpenMMLab. All rights reserved. import pytest import torch import mmcv from mmcv.utils import TORCH_VERSION pytest.skip('this test not ready now', allow_module_level=True) skip_no_parrots = pytest.mark.skipif( TORCH_VERSION != 'parrots', reason='test case under parrots environment') class TestJit: def test_add_dict(self): @mmcv.jit def add_dict(oper): rets = oper['x'] + oper['y'] return {'result': rets} def add_dict_pyfunc(oper): rets = oper['x'] + oper['y'] return {'result': rets} a = torch.rand((3, 4)) b = torch.rand((3, 4)) oper = {'x': a, 'y': b} rets_t = add_dict(oper) rets = add_dict_pyfunc(oper) assert 'result' in rets assert (rets_t['result'] == rets['result']).all() def test_add_list(self): @mmcv.jit def add_list(oper, x, y): rets = {} for idx, pair in enumerate(oper): rets[f'k{idx}'] = pair['x'] + pair['y'] rets[f'k{len(oper)}'] = x + y return rets def add_list_pyfunc(oper, x, y): rets = {} for idx, pair in enumerate(oper): rets[f'k{idx}'] = pair['x'] + pair['y'] rets[f'k{len(oper)}'] = x + y return rets pair_num = 3 oper = [] for _ in range(pair_num): oper.append({'x': torch.rand((3, 4)), 'y': torch.rand((3, 4))}) a = torch.rand((3, 4)) b = torch.rand((3, 4)) rets = add_list_pyfunc(oper, x=a, y=b) rets_t = add_list(oper, x=a, y=b) for idx in range(pair_num + 1): assert f'k{idx}' in rets_t assert (rets[f'k{idx}'] == rets_t[f'k{idx}']).all() @skip_no_parrots def test_jit_cache(self): @mmcv.jit def func(oper): if oper['const'] > 1: return oper['x'] * 2 + oper['y'] else: return oper['x'] * 2 - oper['y'] def pyfunc(oper): if oper['const'] > 1: return oper['x'] * 2 + oper['y'] else: return oper['x'] * 2 - oper['y'] assert len(func._cache._cache) == 0 oper = {'const': 2, 'x': torch.rand((3, 4)), 'y': torch.rand((3, 4))} rets_plus = pyfunc(oper) rets_plus_t = func(oper) assert (rets_plus == rets_plus_t).all() assert len(func._cache._cache) == 1 oper['const'] = 0.5 rets_minus = pyfunc(oper) rets_minus_t = func(oper) assert (rets_minus == rets_minus_t).all() assert len(func._cache._cache) == 2 rets_a = (rets_minus_t + rets_plus_t) / 4 assert torch.allclose(oper['x'], rets_a) @skip_no_parrots def test_jit_shape(self): @mmcv.jit def func(a): return a + 1 assert len(func._cache._cache) == 0 a = torch.ones((3, 4)) r = func(a) assert r.shape == (3, 4) assert (r == 2).all() assert len(func._cache._cache) == 1 a = torch.ones((2, 3, 4)) r = func(a) assert r.shape == (2, 3, 4) assert (r == 2).all() assert len(func._cache._cache) == 2 @skip_no_parrots def test_jit_kwargs(self): @mmcv.jit def func(a, b): return torch.mean((a - b) * (a - b)) assert len(func._cache._cache) == 0 x = torch.rand((16, 32)) y = torch.rand((16, 32)) func(x, y) assert len(func._cache._cache) == 1 func(x, b=y) assert len(func._cache._cache) == 1 func(b=y, a=x) assert len(func._cache._cache) == 1 def test_jit_derivate(self): @mmcv.jit(derivate=True) def func(x, y): return (x + 2) * (y - 2) a = torch.rand((3, 4)) b = torch.rand((3, 4)) a.requires_grad = True c = func(a, b) assert c.requires_grad d = torch.empty_like(c) d.fill_(1.0) c.backward(d) assert torch.allclose(a.grad, (b - 2)) assert b.grad is None a.grad = None c = func(a, b) assert c.requires_grad d = torch.empty_like(c) d.fill_(2.7) c.backward(d) assert torch.allclose(a.grad, 2.7 * (b - 2)) assert b.grad is None def test_jit_optimize(self): @mmcv.jit(optimize=True) def func(a, b): return torch.mean((a - b) * (a - b)) def pyfunc(a, b): return torch.mean((a - b) * (a - b)) a = torch.rand((16, 32)) b = torch.rand((16, 32)) c = func(a, b) d = pyfunc(a, b) assert torch.allclose(c, d) @mmcv.skip_no_elena def test_jit_coderize(self): if not torch.cuda.is_available(): return @mmcv.jit(coderize=True) def func(a, b): return (a + b) * (a - b) def pyfunc(a, b): return (a + b) * (a - b) a = torch.rand((16, 32), device='cuda') b = torch.rand((16, 32), device='cuda') c = func(a, b) d = pyfunc(a, b) assert torch.allclose(c, d) def test_jit_value_dependent(self): @mmcv.jit def func(a, b): torch.nonzero(a) return torch.mean((a - b) * (a - b)) def pyfunc(a, b): torch.nonzero(a) return torch.mean((a - b) * (a - b)) a = torch.rand((16, 32)) b = torch.rand((16, 32)) c = func(a, b) d = pyfunc(a, b) assert torch.allclose(c, d) @skip_no_parrots def test_jit_check_input(self): def func(x): y = torch.rand_like(x) return x + y a = torch.ones((3, 4)) with pytest.raises(AssertionError): func = mmcv.jit(func, check_input=(a, )) @skip_no_parrots def test_jit_partial_shape(self): @mmcv.jit(full_shape=False) def func(a, b): return torch.mean((a - b) * (a - b)) def pyfunc(a, b): return torch.mean((a - b) * (a - b)) a = torch.rand((3, 4)) b = torch.rand((3, 4)) assert torch.allclose(func(a, b), pyfunc(a, b)) assert len(func._cache._cache) == 1 a = torch.rand((6, 5)) b = torch.rand((6, 5)) assert torch.allclose(func(a, b), pyfunc(a, b)) assert len(func._cache._cache) == 1 a = torch.rand((3, 4, 5)) b = torch.rand((3, 4, 5)) assert torch.allclose(func(a, b), pyfunc(a, b)) assert len(func._cache._cache) == 2 a = torch.rand((1, 9, 8)) b = torch.rand((1, 9, 8)) assert torch.allclose(func(a, b), pyfunc(a, b)) assert len(func._cache._cache) == 2 def test_instance_method(self): class T: def __init__(self, shape): self._c = torch.rand(shape) @mmcv.jit def test_method(self, x, y): return (x * self._c) + y shape = (16, 32) t = T(shape) a = torch.rand(shape) b = torch.rand(shape) res = (a * t._c) + b jit_res = t.test_method(a, b) assert torch.allclose(res, jit_res) t = T(shape) res = (a * t._c) + b jit_res = t.test_method(a, b) assert torch.allclose(res, jit_res)