Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import warnings | |
from collections import abc | |
from contextlib import contextmanager | |
from functools import wraps | |
import torch | |
from mmengine.logging import MMLogger | |
def cast_tensor_type(inputs, src_type=None, dst_type=None): | |
"""Recursively convert Tensor in inputs from ``src_type`` to ``dst_type``. | |
Args: | |
inputs: Inputs that to be casted. | |
src_type (torch.dtype | torch.device): Source type. | |
src_type (torch.dtype | torch.device): Destination type. | |
Returns: | |
The same type with inputs, but all contained Tensors have been cast. | |
""" | |
assert dst_type is not None | |
if isinstance(inputs, torch.Tensor): | |
if isinstance(dst_type, torch.device): | |
# convert Tensor to dst_device | |
if hasattr(inputs, 'to') and \ | |
hasattr(inputs, 'device') and \ | |
(inputs.device == src_type or src_type is None): | |
return inputs.to(dst_type) | |
else: | |
return inputs | |
else: | |
# convert Tensor to dst_dtype | |
if hasattr(inputs, 'to') and \ | |
hasattr(inputs, 'dtype') and \ | |
(inputs.dtype == src_type or src_type is None): | |
return inputs.to(dst_type) | |
else: | |
return inputs | |
# we need to ensure that the type of inputs to be casted are the same | |
# as the argument `src_type`. | |
elif isinstance(inputs, abc.Mapping): | |
return type(inputs)({ | |
k: cast_tensor_type(v, src_type=src_type, dst_type=dst_type) | |
for k, v in inputs.items() | |
}) | |
elif isinstance(inputs, abc.Iterable): | |
return type(inputs)( | |
cast_tensor_type(item, src_type=src_type, dst_type=dst_type) | |
for item in inputs) | |
# TODO: Currently not supported | |
# elif isinstance(inputs, InstanceData): | |
# for key, value in inputs.items(): | |
# inputs[key] = cast_tensor_type( | |
# value, src_type=src_type, dst_type=dst_type) | |
# return inputs | |
else: | |
return inputs | |
def _ignore_torch_cuda_oom(): | |
"""A context which ignores CUDA OOM exception from pytorch. | |
Code is modified from | |
<https://github.com/facebookresearch/detectron2/blob/main/detectron2/utils/memory.py> # noqa: E501 | |
""" | |
try: | |
yield | |
except RuntimeError as e: | |
# NOTE: the string may change? | |
if 'CUDA out of memory. ' in str(e): | |
pass | |
else: | |
raise | |
class AvoidOOM: | |
"""Try to convert inputs to FP16 and CPU if got a PyTorch's CUDA Out of | |
Memory error. It will do the following steps: | |
1. First retry after calling `torch.cuda.empty_cache()`. | |
2. If that still fails, it will then retry by converting inputs | |
to FP16. | |
3. If that still fails trying to convert inputs to CPUs. | |
In this case, it expects the function to dispatch to | |
CPU implementation. | |
Args: | |
to_cpu (bool): Whether to convert outputs to CPU if get an OOM | |
error. This will slow down the code significantly. | |
Defaults to True. | |
test (bool): Skip `_ignore_torch_cuda_oom` operate that can use | |
lightweight data in unit test, only used in | |
test unit. Defaults to False. | |
Examples: | |
>>> from mmdet.utils.memory import AvoidOOM | |
>>> AvoidCUDAOOM = AvoidOOM() | |
>>> output = AvoidOOM.retry_if_cuda_oom( | |
>>> some_torch_function)(input1, input2) | |
>>> # To use as a decorator | |
>>> # from mmdet.utils import AvoidCUDAOOM | |
>>> @AvoidCUDAOOM.retry_if_cuda_oom | |
>>> def function(*args, **kwargs): | |
>>> return None | |
``` | |
Note: | |
1. The output may be on CPU even if inputs are on GPU. Processing | |
on CPU will slow down the code significantly. | |
2. When converting inputs to CPU, it will only look at each argument | |
and check if it has `.device` and `.to` for conversion. Nested | |
structures of tensors are not supported. | |
3. Since the function might be called more than once, it has to be | |
stateless. | |
""" | |
def __init__(self, to_cpu=True, test=False): | |
self.to_cpu = to_cpu | |
self.test = test | |
def retry_if_cuda_oom(self, func): | |
"""Makes a function retry itself after encountering pytorch's CUDA OOM | |
error. | |
The implementation logic is referred to | |
https://github.com/facebookresearch/detectron2/blob/main/detectron2/utils/memory.py | |
Args: | |
func: a stateless callable that takes tensor-like objects | |
as arguments. | |
Returns: | |
func: a callable which retries `func` if OOM is encountered. | |
""" # noqa: W605 | |
def wrapped(*args, **kwargs): | |
# raw function | |
if not self.test: | |
with _ignore_torch_cuda_oom(): | |
return func(*args, **kwargs) | |
# Clear cache and retry | |
torch.cuda.empty_cache() | |
with _ignore_torch_cuda_oom(): | |
return func(*args, **kwargs) | |
# get the type and device of first tensor | |
dtype, device = None, None | |
values = args + tuple(kwargs.values()) | |
for value in values: | |
if isinstance(value, torch.Tensor): | |
dtype = value.dtype | |
device = value.device | |
break | |
if dtype is None or device is None: | |
raise ValueError('There is no tensor in the inputs, ' | |
'cannot get dtype and device.') | |
# Convert to FP16 | |
fp16_args = cast_tensor_type(args, dst_type=torch.half) | |
fp16_kwargs = cast_tensor_type(kwargs, dst_type=torch.half) | |
logger = MMLogger.get_current_instance() | |
logger.warning(f'Attempting to copy inputs of {str(func)} ' | |
'to FP16 due to CUDA OOM') | |
# get input tensor type, the output type will same as | |
# the first parameter type. | |
with _ignore_torch_cuda_oom(): | |
output = func(*fp16_args, **fp16_kwargs) | |
output = cast_tensor_type( | |
output, src_type=torch.half, dst_type=dtype) | |
if not self.test: | |
return output | |
logger.warning('Using FP16 still meet CUDA OOM') | |
# Try on CPU. This will slow down the code significantly, | |
# therefore print a notice. | |
if self.to_cpu: | |
logger.warning(f'Attempting to copy inputs of {str(func)} ' | |
'to CPU due to CUDA OOM') | |
cpu_device = torch.empty(0).device | |
cpu_args = cast_tensor_type(args, dst_type=cpu_device) | |
cpu_kwargs = cast_tensor_type(kwargs, dst_type=cpu_device) | |
# convert outputs to GPU | |
with _ignore_torch_cuda_oom(): | |
logger.warning(f'Convert outputs to GPU (device={device})') | |
output = func(*cpu_args, **cpu_kwargs) | |
output = cast_tensor_type( | |
output, src_type=cpu_device, dst_type=device) | |
return output | |
warnings.warn('Cannot convert output to GPU due to CUDA OOM, ' | |
'the output is now on CPU, which might cause ' | |
'errors if the output need to interact with GPU ' | |
'data in subsequent operations') | |
logger.warning('Cannot convert output to GPU due to ' | |
'CUDA OOM, the output is on CPU now.') | |
return func(*cpu_args, **cpu_kwargs) | |
else: | |
# may still get CUDA OOM error | |
return func(*args, **kwargs) | |
return wrapped | |
# To use AvoidOOM as a decorator | |
AvoidCUDAOOM = AvoidOOM() | |