File size: 899 Bytes
18dd6ad |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 |
# Copyright (c) OpenMMLab. All rights reserved.
import os
from .parrots_wrapper import TORCH_VERSION
parrots_jit_option = os.getenv('PARROTS_JIT_OPTION')
if TORCH_VERSION == 'parrots' and parrots_jit_option == 'ON':
from parrots.jit import pat as jit
else:
def jit(func=None,
check_input=None,
full_shape=True,
derivate=False,
coderize=False,
optimize=False):
def wrapper(func):
def wrapper_inner(*args, **kargs):
return func(*args, **kargs)
return wrapper_inner
if func is None:
return wrapper
else:
return func
if TORCH_VERSION == 'parrots':
from parrots.utils.tester import skip_no_elena
else:
def skip_no_elena(func):
def wrapper(*args, **kargs):
return func(*args, **kargs)
return wrapper
|