File size: 2,307 Bytes
45b4aa7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from torch.nn.parallel._functions import Scatter as OrigScatter
from ._functions import Scatter
from .data_container import DataContainer
def scatter(inputs, target_gpus, dim=0):
"""Scatter inputs to target gpus.
The only difference from original :func:`scatter` is to add support for
:type:`~mmcv.parallel.DataContainer`.
"""
def scatter_map(obj):
if isinstance(obj, torch.Tensor):
if target_gpus != [-1]:
return OrigScatter.apply(target_gpus, None, dim, obj)
else:
# for CPU inference we use self-implemented scatter
return Scatter.forward(target_gpus, obj)
if isinstance(obj, DataContainer):
if obj.cpu_only:
return obj.data
else:
return Scatter.forward(target_gpus, obj.data)
if isinstance(obj, tuple) and len(obj) > 0:
return list(zip(*map(scatter_map, obj)))
if isinstance(obj, list) and len(obj) > 0:
out = list(map(list, zip(*map(scatter_map, obj))))
return out
if isinstance(obj, dict) and len(obj) > 0:
out = list(map(type(obj), zip(*map(scatter_map, obj.items()))))
return out
return [obj for targets in target_gpus]
# After scatter_map is called, a scatter_map cell will exist. This cell
# has a reference to the actual function scatter_map, which has references
# to a closure that has a reference to the scatter_map cell (because the
# fn is recursive). To avoid this reference cycle, we set the function to
# None, clearing the cell
try:
return scatter_map(inputs)
finally:
scatter_map = None
def scatter_kwargs(inputs, kwargs, target_gpus, dim=0):
"""Scatter with support for kwargs dictionary."""
inputs = scatter(inputs, target_gpus, dim) if inputs else []
kwargs = scatter(kwargs, target_gpus, dim) if kwargs else []
if len(inputs) < len(kwargs):
inputs.extend([() for _ in range(len(kwargs) - len(inputs))])
elif len(kwargs) < len(inputs):
kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))])
inputs = tuple(inputs)
kwargs = tuple(kwargs)
return inputs, kwargs
|