|
|
|
import torch |
|
from torch.nn.parallel._functions import _get_stream |
|
|
|
|
|
def scatter(input, devices, streams=None): |
|
"""Scatters tensor across multiple GPUs.""" |
|
if streams is None: |
|
streams = [None] * len(devices) |
|
|
|
if isinstance(input, list): |
|
chunk_size = (len(input) - 1) // len(devices) + 1 |
|
outputs = [ |
|
scatter(input[i], [devices[i // chunk_size]], |
|
[streams[i // chunk_size]]) for i in range(len(input)) |
|
] |
|
return outputs |
|
elif isinstance(input, torch.Tensor): |
|
output = input.contiguous() |
|
|
|
stream = streams[0] if output.numel() > 0 else None |
|
if devices != [-1]: |
|
with torch.cuda.device(devices[0]), torch.cuda.stream(stream): |
|
output = output.cuda(devices[0], non_blocking=True) |
|
else: |
|
|
|
|
|
output = output.unsqueeze(0) |
|
return output |
|
else: |
|
raise Exception(f'Unknown type {type(input)}.') |
|
|
|
|
|
def synchronize_stream(output, devices, streams): |
|
if isinstance(output, list): |
|
chunk_size = len(output) // len(devices) |
|
for i in range(len(devices)): |
|
for j in range(chunk_size): |
|
synchronize_stream(output[i * chunk_size + j], [devices[i]], |
|
[streams[i]]) |
|
elif isinstance(output, torch.Tensor): |
|
if output.numel() != 0: |
|
with torch.cuda.device(devices[0]): |
|
main_stream = torch.cuda.current_stream() |
|
main_stream.wait_stream(streams[0]) |
|
output.record_stream(main_stream) |
|
else: |
|
raise Exception(f'Unknown type {type(output)}.') |
|
|
|
|
|
def get_input_device(input): |
|
if isinstance(input, list): |
|
for item in input: |
|
input_device = get_input_device(item) |
|
if input_device != -1: |
|
return input_device |
|
return -1 |
|
elif isinstance(input, torch.Tensor): |
|
return input.get_device() if input.is_cuda else -1 |
|
else: |
|
raise Exception(f'Unknown type {type(input)}.') |
|
|
|
|
|
class Scatter: |
|
|
|
@staticmethod |
|
def forward(target_gpus, input): |
|
input_device = get_input_device(input) |
|
streams = None |
|
if input_device == -1 and target_gpus != [-1]: |
|
|
|
streams = [_get_stream(device) for device in target_gpus] |
|
|
|
outputs = scatter(input, target_gpus, streams) |
|
|
|
if streams is not None: |
|
synchronize_stream(outputs, target_gpus, streams) |
|
|
|
return tuple(outputs) |
|
|