Spaces:
Starting
on
L40S
Starting
on
L40S
File size: 633 Bytes
d7e58f0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 |
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
from mmcv.utils import digit_version, is_jit_tracing
@pytest.mark.skipif(
digit_version(torch.__version__) < digit_version('1.6.0'),
reason='torch.jit.is_tracing is not available before 1.6.0')
def test_is_jit_tracing():
def foo(x):
if is_jit_tracing():
return x
else:
return x.tolist()
x = torch.rand(3)
# test without trace
assert isinstance(foo(x), list)
# test with trace
traced_foo = torch.jit.trace(foo, (torch.rand(1), ))
assert isinstance(traced_foo(x), torch.Tensor)
|