Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) 2023 Amphion. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
# This source file is copied from https://github.com/facebookresearch/encodec | |
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
"""Torch distributed utilities.""" | |
import typing as tp | |
import torch | |
def rank(): | |
if torch.distributed.is_initialized(): | |
return torch.distributed.get_rank() | |
else: | |
return 0 | |
def world_size(): | |
if torch.distributed.is_initialized(): | |
return torch.distributed.get_world_size() | |
else: | |
return 1 | |
def is_distributed(): | |
return world_size() > 1 | |
def all_reduce(tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM): | |
if is_distributed(): | |
return torch.distributed.all_reduce(tensor, op) | |
def _is_complex_or_float(tensor): | |
return torch.is_floating_point(tensor) or torch.is_complex(tensor) | |
def _check_number_of_params(params: tp.List[torch.Tensor]): | |
# utility function to check that the number of params in all workers is the same, | |
# and thus avoid a deadlock with distributed all reduce. | |
if not is_distributed() or not params: | |
return | |
# print('params[0].device ', params[0].device) | |
tensor = torch.tensor([len(params)], device=params[0].device, dtype=torch.long) | |
all_reduce(tensor) | |
if tensor.item() != len(params) * world_size(): | |
# If not all the workers have the same number, for at least one of them, | |
# this inequality will be verified. | |
raise RuntimeError( | |
f"Mismatch in number of params: ours is {len(params)}, " | |
"at least one worker has a different one." | |
) | |
def broadcast_tensors(tensors: tp.Iterable[torch.Tensor], src: int = 0): | |
"""Broadcast the tensors from the given parameters to all workers. | |
This can be used to ensure that all workers have the same model to start with. | |
""" | |
if not is_distributed(): | |
return | |
tensors = [tensor for tensor in tensors if _is_complex_or_float(tensor)] | |
_check_number_of_params(tensors) | |
handles = [] | |
for tensor in tensors: | |
# src = int(rank()) # added code | |
handle = torch.distributed.broadcast(tensor.data, src=src, async_op=True) | |
handles.append(handle) | |
for handle in handles: | |
handle.wait() | |
def sync_buffer(buffers, average=True): | |
""" | |
Sync grad for buffers. If average is False, broadcast instead of averaging. | |
""" | |
if not is_distributed(): | |
return | |
handles = [] | |
for buffer in buffers: | |
if torch.is_floating_point(buffer.data): | |
if average: | |
handle = torch.distributed.all_reduce( | |
buffer.data, op=torch.distributed.ReduceOp.SUM, async_op=True | |
) | |
else: | |
handle = torch.distributed.broadcast(buffer.data, src=0, async_op=True) | |
handles.append((buffer, handle)) | |
for buffer, handle in handles: | |
handle.wait() | |
if average: | |
buffer.data /= world_size | |
def sync_grad(params): | |
""" | |
Simpler alternative to DistributedDataParallel, that doesn't rely | |
on any black magic. For simple models it can also be as fast. | |
Just call this on your model parameters after the call to backward! | |
""" | |
if not is_distributed(): | |
return | |
handles = [] | |
for p in params: | |
if p.grad is not None: | |
handle = torch.distributed.all_reduce( | |
p.grad.data, op=torch.distributed.ReduceOp.SUM, async_op=True | |
) | |
handles.append((p, handle)) | |
for p, handle in handles: | |
handle.wait() | |
p.grad.data /= world_size() | |
def average_metrics(metrics: tp.Dict[str, float], count=1.0): | |
"""Average a dictionary of metrics across all workers, using the optional | |
`count` as unormalized weight. | |
""" | |
if not is_distributed(): | |
return metrics | |
keys, values = zip(*metrics.items()) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
tensor = torch.tensor(list(values) + [1], device=device, dtype=torch.float32) | |
tensor *= count | |
all_reduce(tensor) | |
averaged = (tensor[:-1] / tensor[-1]).cpu().tolist() | |
return dict(zip(keys, averaged)) | |