Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import collections.abc | |
import warnings | |
from itertools import repeat | |
import torch | |
from mmengine.utils import digit_version | |
def is_tracing() -> bool: | |
"""Determine whether the model is called during the tracing of code with | |
``torch.jit.trace``.""" | |
if digit_version(torch.__version__) >= digit_version('1.6.0'): | |
on_trace = torch.jit.is_tracing() | |
# In PyTorch 1.6, torch.jit.is_tracing has a bug. | |
# Refers to https://github.com/pytorch/pytorch/issues/42448 | |
if isinstance(on_trace, bool): | |
return on_trace | |
else: | |
return torch._C._is_tracing() | |
else: | |
warnings.warn( | |
'torch.jit.is_tracing is only supported after v1.6.0. ' | |
'Therefore is_tracing returns False automatically. Please ' | |
'set on_trace manually if you are using trace.', UserWarning) | |
return False | |
# From PyTorch internals | |
def _ntuple(n): | |
"""A `to_tuple` function generator. | |
It returns a function, this function will repeat the input to a tuple of | |
length ``n`` if the input is not an Iterable object, otherwise, return the | |
input directly. | |
Args: | |
n (int): The number of the target length. | |
""" | |
def parse(x): | |
if isinstance(x, collections.abc.Iterable): | |
return x | |
return tuple(repeat(x, n)) | |
return parse | |
to_1tuple = _ntuple(1) | |
to_2tuple = _ntuple(2) | |
to_3tuple = _ntuple(3) | |
to_4tuple = _ntuple(4) | |
to_ntuple = _ntuple | |