|
|
|
import warnings |
|
from collections import abc |
|
from contextlib import contextmanager |
|
from functools import wraps |
|
|
|
import torch |
|
from mmengine.logging import MMLogger |
|
|
|
|
|
def cast_tensor_type(inputs, src_type=None, dst_type=None): |
|
"""Recursively convert Tensor in inputs from ``src_type`` to ``dst_type``. |
|
|
|
Args: |
|
inputs: Inputs that to be casted. |
|
src_type (torch.dtype | torch.device): Source type. |
|
src_type (torch.dtype | torch.device): Destination type. |
|
|
|
Returns: |
|
The same type with inputs, but all contained Tensors have been cast. |
|
""" |
|
assert dst_type is not None |
|
if isinstance(inputs, torch.Tensor): |
|
if isinstance(dst_type, torch.device): |
|
|
|
if hasattr(inputs, 'to') and \ |
|
hasattr(inputs, 'device') and \ |
|
(inputs.device == src_type or src_type is None): |
|
return inputs.to(dst_type) |
|
else: |
|
return inputs |
|
else: |
|
|
|
if hasattr(inputs, 'to') and \ |
|
hasattr(inputs, 'dtype') and \ |
|
(inputs.dtype == src_type or src_type is None): |
|
return inputs.to(dst_type) |
|
else: |
|
return inputs |
|
|
|
|
|
elif isinstance(inputs, abc.Mapping): |
|
return type(inputs)({ |
|
k: cast_tensor_type(v, src_type=src_type, dst_type=dst_type) |
|
for k, v in inputs.items() |
|
}) |
|
elif isinstance(inputs, abc.Iterable): |
|
return type(inputs)( |
|
cast_tensor_type(item, src_type=src_type, dst_type=dst_type) |
|
for item in inputs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
else: |
|
return inputs |
|
|
|
|
|
@contextmanager |
|
def _ignore_torch_cuda_oom(): |
|
"""A context which ignores CUDA OOM exception from pytorch. |
|
|
|
Code is modified from |
|
<https://github.com/facebookresearch/detectron2/blob/main/detectron2/utils/memory.py> # noqa: E501 |
|
""" |
|
try: |
|
yield |
|
except RuntimeError as e: |
|
|
|
if 'CUDA out of memory. ' in str(e): |
|
pass |
|
else: |
|
raise |
|
|
|
|
|
class AvoidOOM: |
|
"""Try to convert inputs to FP16 and CPU if got a PyTorch's CUDA Out of |
|
Memory error. It will do the following steps: |
|
|
|
1. First retry after calling `torch.cuda.empty_cache()`. |
|
2. If that still fails, it will then retry by converting inputs |
|
to FP16. |
|
3. If that still fails trying to convert inputs to CPUs. |
|
In this case, it expects the function to dispatch to |
|
CPU implementation. |
|
|
|
Args: |
|
to_cpu (bool): Whether to convert outputs to CPU if get an OOM |
|
error. This will slow down the code significantly. |
|
Defaults to True. |
|
test (bool): Skip `_ignore_torch_cuda_oom` operate that can use |
|
lightweight data in unit test, only used in |
|
test unit. Defaults to False. |
|
|
|
Examples: |
|
>>> from mmdet.utils.memory import AvoidOOM |
|
>>> AvoidCUDAOOM = AvoidOOM() |
|
>>> output = AvoidOOM.retry_if_cuda_oom( |
|
>>> some_torch_function)(input1, input2) |
|
>>> # To use as a decorator |
|
>>> # from mmdet.utils import AvoidCUDAOOM |
|
>>> @AvoidCUDAOOM.retry_if_cuda_oom |
|
>>> def function(*args, **kwargs): |
|
>>> return None |
|
``` |
|
|
|
Note: |
|
1. The output may be on CPU even if inputs are on GPU. Processing |
|
on CPU will slow down the code significantly. |
|
2. When converting inputs to CPU, it will only look at each argument |
|
and check if it has `.device` and `.to` for conversion. Nested |
|
structures of tensors are not supported. |
|
3. Since the function might be called more than once, it has to be |
|
stateless. |
|
""" |
|
|
|
def __init__(self, to_cpu=True, test=False): |
|
self.to_cpu = to_cpu |
|
self.test = test |
|
|
|
def retry_if_cuda_oom(self, func): |
|
"""Makes a function retry itself after encountering pytorch's CUDA OOM |
|
error. |
|
|
|
The implementation logic is referred to |
|
https://github.com/facebookresearch/detectron2/blob/main/detectron2/utils/memory.py |
|
|
|
Args: |
|
func: a stateless callable that takes tensor-like objects |
|
as arguments. |
|
Returns: |
|
func: a callable which retries `func` if OOM is encountered. |
|
""" |
|
|
|
@wraps(func) |
|
def wrapped(*args, **kwargs): |
|
|
|
|
|
if not self.test: |
|
with _ignore_torch_cuda_oom(): |
|
return func(*args, **kwargs) |
|
|
|
|
|
torch.cuda.empty_cache() |
|
with _ignore_torch_cuda_oom(): |
|
return func(*args, **kwargs) |
|
|
|
|
|
dtype, device = None, None |
|
values = args + tuple(kwargs.values()) |
|
for value in values: |
|
if isinstance(value, torch.Tensor): |
|
dtype = value.dtype |
|
device = value.device |
|
break |
|
if dtype is None or device is None: |
|
raise ValueError('There is no tensor in the inputs, ' |
|
'cannot get dtype and device.') |
|
|
|
|
|
fp16_args = cast_tensor_type(args, dst_type=torch.half) |
|
fp16_kwargs = cast_tensor_type(kwargs, dst_type=torch.half) |
|
logger = MMLogger.get_current_instance() |
|
logger.warning(f'Attempting to copy inputs of {str(func)} ' |
|
'to FP16 due to CUDA OOM') |
|
|
|
|
|
|
|
with _ignore_torch_cuda_oom(): |
|
output = func(*fp16_args, **fp16_kwargs) |
|
output = cast_tensor_type( |
|
output, src_type=torch.half, dst_type=dtype) |
|
if not self.test: |
|
return output |
|
logger.warning('Using FP16 still meet CUDA OOM') |
|
|
|
|
|
|
|
if self.to_cpu: |
|
logger.warning(f'Attempting to copy inputs of {str(func)} ' |
|
'to CPU due to CUDA OOM') |
|
cpu_device = torch.empty(0).device |
|
cpu_args = cast_tensor_type(args, dst_type=cpu_device) |
|
cpu_kwargs = cast_tensor_type(kwargs, dst_type=cpu_device) |
|
|
|
|
|
with _ignore_torch_cuda_oom(): |
|
logger.warning(f'Convert outputs to GPU (device={device})') |
|
output = func(*cpu_args, **cpu_kwargs) |
|
output = cast_tensor_type( |
|
output, src_type=cpu_device, dst_type=device) |
|
return output |
|
|
|
warnings.warn('Cannot convert output to GPU due to CUDA OOM, ' |
|
'the output is now on CPU, which might cause ' |
|
'errors if the output need to interact with GPU ' |
|
'data in subsequent operations') |
|
logger.warning('Cannot convert output to GPU due to ' |
|
'CUDA OOM, the output is on CPU now.') |
|
|
|
return func(*cpu_args, **cpu_kwargs) |
|
else: |
|
|
|
return func(*args, **kwargs) |
|
|
|
return wrapped |
|
|
|
|
|
|
|
AvoidCUDAOOM = AvoidOOM() |
|
|