AiOS / util /distribute_utils.py
ttxskk
update
d7e58f0
import mmcv
import os
import os.path as osp
import pickle
import shutil
import tempfile
import time
import torch
import torch.distributed as dist
from mmcv.runner import get_dist_info
import random
import numpy as np
import subprocess
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# torch.set_deterministic(True)
def time_synchronized():
torch.cuda.synchronize() if torch.cuda.is_available() else None
return time.time()
def setup_for_distributed(is_master):
"""This function disables printing when not in master process."""
import builtins as __builtin__
builtin_print = __builtin__.print
def print(*args, **kwargs):
force = kwargs.pop('force', False)
if is_master or force:
builtin_print(*args, **kwargs)
__builtin__.print = print
def init_distributed_mode(port=None, master_port=29500):
"""Initialize slurm distributed training environment.
If argument ``port`` is not specified, then the master port will be system
environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
environment variable, then a default port ``29500`` will be used.
Args:
backend (str): Backend of torch.distributed.
port (int, optional): Master port. Defaults to None.
"""
dist_backend = 'nccl'
proc_id = int(os.environ['SLURM_PROCID'])
ntasks = int(os.environ['SLURM_NTASKS'])
node_list = os.environ['SLURM_NODELIST']
num_gpus = torch.cuda.device_count()
torch.cuda.set_device(proc_id % num_gpus)
addr = subprocess.getoutput(
f'scontrol show hostname {node_list} | head -n1')
# specify master port
if port is not None:
os.environ['MASTER_PORT'] = str(port)
elif 'MASTER_PORT' in os.environ:
pass # use MASTER_PORT in the environment variable
else:
# 29500 is torch.distributed default port
os.environ['MASTER_PORT'] = str(master_port)
# use MASTER_ADDR in the environment variable if it already exists
if 'MASTER_ADDR' not in os.environ:
os.environ['MASTER_ADDR'] = addr
os.environ['WORLD_SIZE'] = str(ntasks)
os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
os.environ['RANK'] = str(proc_id)
dist.init_process_group(backend=dist_backend)
distributed = True
gpu_idx = proc_id % num_gpus
return distributed, gpu_idx
def is_dist_avail_and_initialized():
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True
def get_world_size():
if not is_dist_avail_and_initialized():
return 1
return dist.get_world_size()
def get_rank():
if not is_dist_avail_and_initialized():
return 0
return dist.get_rank()
def get_process_groups():
world_size = int(os.environ['WORLD_SIZE'])
ranks = list(range(world_size))
num_gpus = torch.cuda.device_count()
num_nodes = world_size // num_gpus
if world_size % num_gpus != 0:
raise NotImplementedError('Not implemented for node not fully used.')
groups = []
for node_idx in range(num_nodes):
groups.append(ranks[node_idx * num_gpus:(node_idx + 1) * num_gpus])
process_groups = [torch.distributed.new_group(group) for group in groups]
return process_groups
def get_group_idx():
num_gpus = torch.cuda.device_count()
proc_id = get_rank()
group_idx = proc_id // num_gpus
return group_idx
def is_main_process():
return get_rank() == 0
def cleanup():
dist.destroy_process_group()
def collect_results(result_part, size, tmpdir=None):
rank, world_size = get_dist_info()
# create a tmp dir if it is not specified
if tmpdir is None:
MAX_LEN = 512
# 32 is whitespace
dir_tensor = torch.full((MAX_LEN, ),
32,
dtype=torch.uint8,
device='cuda')
if rank == 0:
tmpdir = tempfile.mkdtemp()
tmpdir = torch.tensor(bytearray(tmpdir.encode()),
dtype=torch.uint8,
device='cuda')
dir_tensor[:len(tmpdir)] = tmpdir
dist.broadcast(dir_tensor, 0)
tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip()
else:
mmcv.mkdir_or_exist(tmpdir)
# dump the part result to the dir
mmcv.dump(result_part, osp.join(tmpdir, f'part_{rank}.pkl'))
dist.barrier()
# collect all parts
if rank != 0:
return None
else:
# load results of all parts from tmp dir
part_list = []
for i in range(world_size):
part_file = osp.join(tmpdir, f'part_{i}.pkl')
part_list.append(mmcv.load(part_file))
# sort the results
ordered_results = []
for res in zip(*part_list):
ordered_results.extend(list(res))
# the dataloader may pad some samples
ordered_results = ordered_results[:size]
# remove tmp dir
shutil.rmtree(tmpdir)
return ordered_results
def all_gather(data):
"""
Run all_gather on arbitrary picklable data (not necessarily tensors)
Args:
data:
Any picklable object
Returns:
data_list(list):
List of data gathered from each rank
"""
world_size = get_world_size()
if world_size == 1:
return [data]
# serialized to a Tensor
buffer = pickle.dumps(data)
storage = torch.ByteStorage.from_buffer(buffer)
tensor = torch.ByteTensor(storage).to('cuda')
# obtain Tensor size of each rank
local_size = torch.tensor([tensor.numel()], device='cuda')
size_list = [torch.tensor([0], device='cuda') for _ in range(world_size)]
dist.all_gather(size_list, local_size)
size_list = [int(size.item()) for size in size_list]
max_size = max(size_list)
# receiving Tensor from all ranks
# we pad the tensor because torch all_gather does not support
# gathering tensors of different shapes
tensor_list = []
for _ in size_list:
tensor_list.append(
torch.empty((max_size, ), dtype=torch.uint8, device='cuda'))
if local_size != max_size:
padding = torch.empty(size=(max_size - local_size, ),
dtype=torch.uint8,
device='cuda')
tensor = torch.cat((tensor, padding), dim=0)
dist.all_gather(tensor_list, tensor)
data_list = []
for size, tensor in zip(size_list, tensor_list):
buffer = tensor.cpu().numpy().tobytes()[:size]
data_list.append(pickle.loads(buffer))
return data_list