Spaces:
Running
on
L40S
Running
on
L40S
# Copyright (c) OpenMMLab. All rights reserved. | |
import pytest | |
import torch | |
from mmcv.utils import digit_version, is_jit_tracing | |
def test_is_jit_tracing(): | |
def foo(x): | |
if is_jit_tracing(): | |
return x | |
else: | |
return x.tolist() | |
x = torch.rand(3) | |
# test without trace | |
assert isinstance(foo(x), list) | |
# test with trace | |
traced_foo = torch.jit.trace(foo, (torch.rand(1), )) | |
assert isinstance(traced_foo(x), torch.Tensor) | |