import os import subprocess import numpy as np import multiprocessing as mp import math import torch import torch.distributed as dist from torch.utils.data.sampler import Sampler from torch.nn import Module class DistModule(Module): def __init__(self, module): super(DistModule, self).__init__() self.module = module broadcast_params(self.module) def forward(self, *inputs, **kwargs): return self.module(*inputs, **kwargs) def train(self, mode=True): super(DistModule, self).train(mode) self.module.train(mode) def average_gradients(model): """ average gradients """ for param in model.parameters(): if param.requires_grad: dist.all_reduce(param.grad.data) def broadcast_params(model): """ broadcast model parameters """ for p in model.state_dict().values(): dist.broadcast(p, 0) def dist_init(launcher, backend='nccl', **kwargs): if mp.get_start_method(allow_none=True) is None: mp.set_start_method('spawn') if launcher == 'pytorch': _init_dist_pytorch(backend, **kwargs) elif launcher == 'mpi': _init_dist_mpi(backend, **kwargs) elif launcher == 'slurm': _init_dist_slurm(backend, **kwargs) else: raise ValueError('Invalid launcher type: {}'.format(launcher)) def _init_dist_pytorch(backend, **kwargs): rank = int(os.environ['RANK']) num_gpus = torch.cuda.device_count() torch.cuda.set_device(rank % num_gpus) dist.init_process_group(backend=backend, **kwargs) def _init_dist_mpi(backend, **kwargs): raise NotImplementedError def _init_dist_slurm(backend, port=10086, **kwargs): proc_id = int(os.environ['SLURM_PROCID']) ntasks = int(os.environ['SLURM_NTASKS']) node_list = os.environ['SLURM_NODELIST'] num_gpus = torch.cuda.device_count() torch.cuda.set_device(proc_id % num_gpus) addr = subprocess.getoutput( 'scontrol show hostname {} | head -n1'.format(node_list)) os.environ['MASTER_PORT'] = str(port) os.environ['MASTER_ADDR'] = addr os.environ['WORLD_SIZE'] = str(ntasks) os.environ['RANK'] = str(proc_id) dist.init_process_group(backend=backend) def gather_tensors(input_array): world_size = dist.get_world_size() ## gather shapes first myshape = input_array.shape mycount = input_array.size shape_tensor = torch.Tensor(np.array(myshape)).cuda() all_shape = [torch.Tensor(np.array(myshape)).cuda() for i in range(world_size)] dist.all_gather(all_shape, shape_tensor) ## compute largest shapes all_shape = [x.cpu().numpy() for x in all_shape] all_count = [int(x.prod()) for x in all_shape] all_shape = [list(map(int, x)) for x in all_shape] max_count = max(all_count) ## padding tensors and gather them output_tensors = [torch.Tensor(max_count).cuda() for i in range(world_size)] padded_input_array = np.zeros(max_count) padded_input_array[:mycount] = input_array.reshape(-1) input_tensor = torch.Tensor(padded_input_array).cuda() dist.all_gather(output_tensors, input_tensor) ## unpadding gathered tensors padded_output = [x.cpu().numpy() for x in output_tensors] output = [x[:all_count[i]].reshape(all_shape[i]) for i,x in enumerate(padded_output)] return output def gather_tensors_batch(input_array, part_size=10): # gather rank = dist.get_rank() all_features = [] part_num = input_array.shape[0] // part_size + 1 if input_array.shape[0] % part_size != 0 else input_array.shape[0] // part_size for i in range(part_num): part_feat = input_array[i * part_size:min((i+1)*part_size, input_array.shape[0]),...] assert part_feat.shape[0] > 0, "rank: {}, length of part features should > 0".format(rank) print("rank: {}, gather part: {}/{}, length: {}".format(rank, i, part_num, len(part_feat))) gather_part_feat = gather_tensors(part_feat) all_features.append(gather_part_feat) print("rank: {}, gather done.".format(rank)) all_features = np.concatenate([np.concatenate([all_features[i][j] for i in range(part_num)], axis=0) for j in range(len(all_features[0]))], axis=0) return all_features def reduce_tensors(tensor): reduced_tensor = tensor.clone() dist.all_reduce(reduced_tensor) return reduced_tensor class DistributedSequentialSampler(Sampler): def __init__(self, dataset, world_size=None, rank=None): if world_size == None: world_size = dist.get_world_size() if rank == None: rank = dist.get_rank() self.dataset = dataset self.world_size = world_size self.rank = rank assert len(self.dataset) >= self.world_size, '{} vs {}'.format(len(self.dataset), self.world_size) sub_num = int(math.ceil(len(self.dataset) * 1.0 / self.world_size)) self.beg = sub_num * self.rank #self.end = min(self.beg+sub_num, len(self.dataset)) self.end = self.beg + sub_num self.padded_ind = list(range(len(self.dataset))) + list(range(sub_num * self.world_size - len(self.dataset))) def __iter__(self): indices = [self.padded_ind[i] for i in range(self.beg, self.end)] return iter(indices) def __len__(self): return self.end - self.beg class GivenIterationSampler(Sampler): def __init__(self, dataset, total_iter, batch_size, last_iter=-1): self.dataset = dataset self.total_iter = total_iter self.batch_size = batch_size self.last_iter = last_iter self.total_size = self.total_iter * self.batch_size self.indices = self.gen_new_list() self.call = 0 def __iter__(self): if self.call == 0: self.call = 1 return iter(self.indices[(self.last_iter + 1) * self.batch_size:]) else: raise RuntimeError("this sampler is not designed to be called more than once!!") def gen_new_list(self): # each process shuffle all list with same seed, and pick one piece according to rank np.random.seed(0) all_size = self.total_size indices = np.arange(len(self.dataset)) indices = indices[:all_size] num_repeat = (all_size-1) // indices.shape[0] + 1 indices = np.tile(indices, num_repeat) indices = indices[:all_size] np.random.shuffle(indices) assert len(indices) == self.total_size return indices def __len__(self): return self.total_size class DistributedGivenIterationSampler(Sampler): def __init__(self, dataset, total_iter, batch_size, world_size=None, rank=None, last_iter=-1): if world_size is None: world_size = dist.get_world_size() if rank is None: rank = dist.get_rank() assert rank < world_size self.dataset = dataset self.total_iter = total_iter self.batch_size = batch_size self.world_size = world_size self.rank = rank self.last_iter = last_iter self.total_size = self.total_iter*self.batch_size self.indices = self.gen_new_list() self.call = 0 def __iter__(self): if self.call == 0: self.call = 1 return iter(self.indices[(self.last_iter+1)*self.batch_size:]) else: raise RuntimeError("this sampler is not designed to be called more than once!!") def gen_new_list(self): # each process shuffle all list with same seed, and pick one piece according to rank np.random.seed(0) all_size = self.total_size * self.world_size indices = np.arange(len(self.dataset)) indices = indices[:all_size] num_repeat = (all_size-1) // indices.shape[0] + 1 indices = np.tile(indices, num_repeat) indices = indices[:all_size] np.random.shuffle(indices) beg = self.total_size * self.rank indices = indices[beg:beg+self.total_size] assert len(indices) == self.total_size return indices def __len__(self): # note here we do not take last iter into consideration, since __len__ # should only be used for displaying, the correct remaining size is # handled by dataloader #return self.total_size - (self.last_iter+1)*self.batch_size return self.total_size