AiOS / util /distribute_utils.py
ttxskk
update
d7e58f0
raw
history blame
6.75 kB
import mmcv
import os
import os.path as osp
import pickle
import shutil
import tempfile
import time
import torch
import torch.distributed as dist
from mmcv.runner import get_dist_info
import random
import numpy as np
import subprocess
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# torch.set_deterministic(True)
def time_synchronized():
torch.cuda.synchronize() if torch.cuda.is_available() else None
return time.time()
def setup_for_distributed(is_master):
"""This function disables printing when not in master process."""
import builtins as __builtin__
builtin_print = __builtin__.print
def print(*args, **kwargs):
force = kwargs.pop('force', False)
if is_master or force:
builtin_print(*args, **kwargs)
__builtin__.print = print
def init_distributed_mode(port=None, master_port=29500):
"""Initialize slurm distributed training environment.
If argument ``port`` is not specified, then the master port will be system
environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
environment variable, then a default port ``29500`` will be used.
Args:
backend (str): Backend of torch.distributed.
port (int, optional): Master port. Defaults to None.
"""
dist_backend = 'nccl'
proc_id = int(os.environ['SLURM_PROCID'])
ntasks = int(os.environ['SLURM_NTASKS'])
node_list = os.environ['SLURM_NODELIST']
num_gpus = torch.cuda.device_count()
torch.cuda.set_device(proc_id % num_gpus)
addr = subprocess.getoutput(
f'scontrol show hostname {node_list} | head -n1')
# specify master port
if port is not None:
os.environ['MASTER_PORT'] = str(port)
elif 'MASTER_PORT' in os.environ:
pass # use MASTER_PORT in the environment variable
else:
# 29500 is torch.distributed default port
os.environ['MASTER_PORT'] = str(master_port)
# use MASTER_ADDR in the environment variable if it already exists
if 'MASTER_ADDR' not in os.environ:
os.environ['MASTER_ADDR'] = addr
os.environ['WORLD_SIZE'] = str(ntasks)
os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
os.environ['RANK'] = str(proc_id)
dist.init_process_group(backend=dist_backend)
distributed = True
gpu_idx = proc_id % num_gpus
return distributed, gpu_idx
def is_dist_avail_and_initialized():
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True
def get_world_size():
if not is_dist_avail_and_initialized():
return 1
return dist.get_world_size()
def get_rank():
if not is_dist_avail_and_initialized():
return 0
return dist.get_rank()
def get_process_groups():
world_size = int(os.environ['WORLD_SIZE'])
ranks = list(range(world_size))
num_gpus = torch.cuda.device_count()
num_nodes = world_size // num_gpus
if world_size % num_gpus != 0:
raise NotImplementedError('Not implemented for node not fully used.')
groups = []
for node_idx in range(num_nodes):
groups.append(ranks[node_idx * num_gpus:(node_idx + 1) * num_gpus])
process_groups = [torch.distributed.new_group(group) for group in groups]
return process_groups
def get_group_idx():
num_gpus = torch.cuda.device_count()
proc_id = get_rank()
group_idx = proc_id // num_gpus
return group_idx
def is_main_process():
return get_rank() == 0
def cleanup():
dist.destroy_process_group()
def collect_results(result_part, size, tmpdir=None):
rank, world_size = get_dist_info()
# create a tmp dir if it is not specified
if tmpdir is None:
MAX_LEN = 512
# 32 is whitespace
dir_tensor = torch.full((MAX_LEN, ),
32,
dtype=torch.uint8,
device='cuda')
if rank == 0:
tmpdir = tempfile.mkdtemp()
tmpdir = torch.tensor(bytearray(tmpdir.encode()),
dtype=torch.uint8,
device='cuda')
dir_tensor[:len(tmpdir)] = tmpdir
dist.broadcast(dir_tensor, 0)
tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip()
else:
mmcv.mkdir_or_exist(tmpdir)
# dump the part result to the dir
mmcv.dump(result_part, osp.join(tmpdir, f'part_{rank}.pkl'))
dist.barrier()
# collect all parts
if rank != 0:
return None
else:
# load results of all parts from tmp dir
part_list = []
for i in range(world_size):
part_file = osp.join(tmpdir, f'part_{i}.pkl')
part_list.append(mmcv.load(part_file))
# sort the results
ordered_results = []
for res in zip(*part_list):
ordered_results.extend(list(res))
# the dataloader may pad some samples
ordered_results = ordered_results[:size]
# remove tmp dir
shutil.rmtree(tmpdir)
return ordered_results
def all_gather(data):
"""
Run all_gather on arbitrary picklable data (not necessarily tensors)
Args:
data:
Any picklable object
Returns:
data_list(list):
List of data gathered from each rank
"""
world_size = get_world_size()
if world_size == 1:
return [data]
# serialized to a Tensor
buffer = pickle.dumps(data)
storage = torch.ByteStorage.from_buffer(buffer)
tensor = torch.ByteTensor(storage).to('cuda')
# obtain Tensor size of each rank
local_size = torch.tensor([tensor.numel()], device='cuda')
size_list = [torch.tensor([0], device='cuda') for _ in range(world_size)]
dist.all_gather(size_list, local_size)
size_list = [int(size.item()) for size in size_list]
max_size = max(size_list)
# receiving Tensor from all ranks
# we pad the tensor because torch all_gather does not support
# gathering tensors of different shapes
tensor_list = []
for _ in size_list:
tensor_list.append(
torch.empty((max_size, ), dtype=torch.uint8, device='cuda'))
if local_size != max_size:
padding = torch.empty(size=(max_size - local_size, ),
dtype=torch.uint8,
device='cuda')
tensor = torch.cat((tensor, padding), dim=0)
dist.all_gather(tensor_list, tensor)
data_list = []
for size, tensor in zip(size_list, tensor_list):
buffer = tensor.cpu().numpy().tobytes()[:size]
data_list.append(pickle.loads(buffer))
return data_list