|
|
|
from functools import partial |
|
|
|
import torch |
|
|
|
TORCH_VERSION = torch.__version__ |
|
|
|
|
|
def is_rocm_pytorch() -> bool: |
|
is_rocm = False |
|
if TORCH_VERSION != 'parrots': |
|
try: |
|
from torch.utils.cpp_extension import ROCM_HOME |
|
is_rocm = True if ((torch.version.hip is not None) and |
|
(ROCM_HOME is not None)) else False |
|
except ImportError: |
|
pass |
|
return is_rocm |
|
|
|
|
|
def _get_cuda_home(): |
|
if TORCH_VERSION == 'parrots': |
|
from parrots.utils.build_extension import CUDA_HOME |
|
else: |
|
if is_rocm_pytorch(): |
|
from torch.utils.cpp_extension import ROCM_HOME |
|
CUDA_HOME = ROCM_HOME |
|
else: |
|
from torch.utils.cpp_extension import CUDA_HOME |
|
return CUDA_HOME |
|
|
|
|
|
def get_build_config(): |
|
if TORCH_VERSION == 'parrots': |
|
from parrots.config import get_build_info |
|
return get_build_info() |
|
else: |
|
return torch.__config__.show() |
|
|
|
|
|
def _get_conv(): |
|
if TORCH_VERSION == 'parrots': |
|
from parrots.nn.modules.conv import _ConvNd, _ConvTransposeMixin |
|
else: |
|
from torch.nn.modules.conv import _ConvNd, _ConvTransposeMixin |
|
return _ConvNd, _ConvTransposeMixin |
|
|
|
|
|
def _get_dataloader(): |
|
if TORCH_VERSION == 'parrots': |
|
from torch.utils.data import DataLoader, PoolDataLoader |
|
else: |
|
from torch.utils.data import DataLoader |
|
PoolDataLoader = DataLoader |
|
return DataLoader, PoolDataLoader |
|
|
|
|
|
def _get_extension(): |
|
if TORCH_VERSION == 'parrots': |
|
from parrots.utils.build_extension import BuildExtension, Extension |
|
CppExtension = partial(Extension, cuda=False) |
|
CUDAExtension = partial(Extension, cuda=True) |
|
else: |
|
from torch.utils.cpp_extension import (BuildExtension, CppExtension, |
|
CUDAExtension) |
|
return BuildExtension, CppExtension, CUDAExtension |
|
|
|
|
|
def _get_pool(): |
|
if TORCH_VERSION == 'parrots': |
|
from parrots.nn.modules.pool import (_AdaptiveAvgPoolNd, |
|
_AdaptiveMaxPoolNd, _AvgPoolNd, |
|
_MaxPoolNd) |
|
else: |
|
from torch.nn.modules.pooling import (_AdaptiveAvgPoolNd, |
|
_AdaptiveMaxPoolNd, _AvgPoolNd, |
|
_MaxPoolNd) |
|
return _AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, _AvgPoolNd, _MaxPoolNd |
|
|
|
|
|
def _get_norm(): |
|
if TORCH_VERSION == 'parrots': |
|
from parrots.nn.modules.batchnorm import _BatchNorm, _InstanceNorm |
|
SyncBatchNorm_ = torch.nn.SyncBatchNorm2d |
|
else: |
|
from torch.nn.modules.instancenorm import _InstanceNorm |
|
from torch.nn.modules.batchnorm import _BatchNorm |
|
SyncBatchNorm_ = torch.nn.SyncBatchNorm |
|
return _BatchNorm, _InstanceNorm, SyncBatchNorm_ |
|
|
|
|
|
_ConvNd, _ConvTransposeMixin = _get_conv() |
|
DataLoader, PoolDataLoader = _get_dataloader() |
|
BuildExtension, CppExtension, CUDAExtension = _get_extension() |
|
_BatchNorm, _InstanceNorm, SyncBatchNorm_ = _get_norm() |
|
_AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, _AvgPoolNd, _MaxPoolNd = _get_pool() |
|
|
|
|
|
class SyncBatchNorm(SyncBatchNorm_): |
|
|
|
def _check_input_dim(self, input): |
|
if TORCH_VERSION == 'parrots': |
|
if input.dim() < 2: |
|
raise ValueError( |
|
f'expected at least 2D input (got {input.dim()}D input)') |
|
else: |
|
super()._check_input_dim(input) |
|
|