|
import os |
|
import subprocess |
|
import numpy as np |
|
import multiprocessing as mp |
|
import math |
|
|
|
import torch |
|
import torch.distributed as dist |
|
from torch.utils.data.sampler import Sampler |
|
from torch.nn import Module |
|
|
|
class DistModule(Module): |
|
def __init__(self, module): |
|
super(DistModule, self).__init__() |
|
self.module = module |
|
broadcast_params(self.module) |
|
def forward(self, *inputs, **kwargs): |
|
return self.module(*inputs, **kwargs) |
|
def train(self, mode=True): |
|
super(DistModule, self).train(mode) |
|
self.module.train(mode) |
|
|
|
def average_gradients(model): |
|
""" average gradients """ |
|
for param in model.parameters(): |
|
if param.requires_grad: |
|
dist.all_reduce(param.grad.data) |
|
|
|
def broadcast_params(model): |
|
""" broadcast model parameters """ |
|
for p in model.state_dict().values(): |
|
dist.broadcast(p, 0) |
|
|
|
def dist_init(launcher, backend='nccl', **kwargs): |
|
if mp.get_start_method(allow_none=True) is None: |
|
mp.set_start_method('spawn') |
|
if launcher == 'pytorch': |
|
_init_dist_pytorch(backend, **kwargs) |
|
elif launcher == 'mpi': |
|
_init_dist_mpi(backend, **kwargs) |
|
elif launcher == 'slurm': |
|
_init_dist_slurm(backend, **kwargs) |
|
else: |
|
raise ValueError('Invalid launcher type: {}'.format(launcher)) |
|
|
|
def _init_dist_pytorch(backend, **kwargs): |
|
rank = int(os.environ['RANK']) |
|
num_gpus = torch.cuda.device_count() |
|
torch.cuda.set_device(rank % num_gpus) |
|
dist.init_process_group(backend=backend, **kwargs) |
|
|
|
def _init_dist_mpi(backend, **kwargs): |
|
raise NotImplementedError |
|
|
|
def _init_dist_slurm(backend, port=10086, **kwargs): |
|
proc_id = int(os.environ['SLURM_PROCID']) |
|
ntasks = int(os.environ['SLURM_NTASKS']) |
|
node_list = os.environ['SLURM_NODELIST'] |
|
num_gpus = torch.cuda.device_count() |
|
torch.cuda.set_device(proc_id % num_gpus) |
|
addr = subprocess.getoutput( |
|
'scontrol show hostname {} | head -n1'.format(node_list)) |
|
os.environ['MASTER_PORT'] = str(port) |
|
os.environ['MASTER_ADDR'] = addr |
|
os.environ['WORLD_SIZE'] = str(ntasks) |
|
os.environ['RANK'] = str(proc_id) |
|
dist.init_process_group(backend=backend) |
|
|
|
def gather_tensors(input_array): |
|
world_size = dist.get_world_size() |
|
|
|
myshape = input_array.shape |
|
mycount = input_array.size |
|
shape_tensor = torch.Tensor(np.array(myshape)).cuda() |
|
all_shape = [torch.Tensor(np.array(myshape)).cuda() for i in range(world_size)] |
|
dist.all_gather(all_shape, shape_tensor) |
|
|
|
all_shape = [x.cpu().numpy() for x in all_shape] |
|
all_count = [int(x.prod()) for x in all_shape] |
|
all_shape = [list(map(int, x)) for x in all_shape] |
|
max_count = max(all_count) |
|
|
|
output_tensors = [torch.Tensor(max_count).cuda() for i in range(world_size)] |
|
padded_input_array = np.zeros(max_count) |
|
padded_input_array[:mycount] = input_array.reshape(-1) |
|
input_tensor = torch.Tensor(padded_input_array).cuda() |
|
dist.all_gather(output_tensors, input_tensor) |
|
|
|
padded_output = [x.cpu().numpy() for x in output_tensors] |
|
output = [x[:all_count[i]].reshape(all_shape[i]) for i,x in enumerate(padded_output)] |
|
return output |
|
|
|
def gather_tensors_batch(input_array, part_size=10): |
|
|
|
rank = dist.get_rank() |
|
all_features = [] |
|
part_num = input_array.shape[0] // part_size + 1 if input_array.shape[0] % part_size != 0 else input_array.shape[0] // part_size |
|
for i in range(part_num): |
|
part_feat = input_array[i * part_size:min((i+1)*part_size, input_array.shape[0]),...] |
|
assert part_feat.shape[0] > 0, "rank: {}, length of part features should > 0".format(rank) |
|
print("rank: {}, gather part: {}/{}, length: {}".format(rank, i, part_num, len(part_feat))) |
|
gather_part_feat = gather_tensors(part_feat) |
|
all_features.append(gather_part_feat) |
|
print("rank: {}, gather done.".format(rank)) |
|
all_features = np.concatenate([np.concatenate([all_features[i][j] for i in range(part_num)], axis=0) for j in range(len(all_features[0]))], axis=0) |
|
return all_features |
|
|
|
def reduce_tensors(tensor): |
|
reduced_tensor = tensor.clone() |
|
dist.all_reduce(reduced_tensor) |
|
return reduced_tensor |
|
|
|
class DistributedSequentialSampler(Sampler): |
|
def __init__(self, dataset, world_size=None, rank=None): |
|
if world_size == None: |
|
world_size = dist.get_world_size() |
|
if rank == None: |
|
rank = dist.get_rank() |
|
self.dataset = dataset |
|
self.world_size = world_size |
|
self.rank = rank |
|
assert len(self.dataset) >= self.world_size, '{} vs {}'.format(len(self.dataset), self.world_size) |
|
sub_num = int(math.ceil(len(self.dataset) * 1.0 / self.world_size)) |
|
self.beg = sub_num * self.rank |
|
|
|
self.end = self.beg + sub_num |
|
self.padded_ind = list(range(len(self.dataset))) + list(range(sub_num * self.world_size - len(self.dataset))) |
|
|
|
def __iter__(self): |
|
indices = [self.padded_ind[i] for i in range(self.beg, self.end)] |
|
return iter(indices) |
|
|
|
def __len__(self): |
|
return self.end - self.beg |
|
|
|
class GivenIterationSampler(Sampler): |
|
def __init__(self, dataset, total_iter, batch_size, last_iter=-1): |
|
self.dataset = dataset |
|
self.total_iter = total_iter |
|
self.batch_size = batch_size |
|
self.last_iter = last_iter |
|
|
|
self.total_size = self.total_iter * self.batch_size |
|
self.indices = self.gen_new_list() |
|
self.call = 0 |
|
|
|
def __iter__(self): |
|
if self.call == 0: |
|
self.call = 1 |
|
return iter(self.indices[(self.last_iter + 1) * self.batch_size:]) |
|
else: |
|
raise RuntimeError("this sampler is not designed to be called more than once!!") |
|
|
|
def gen_new_list(self): |
|
|
|
|
|
np.random.seed(0) |
|
|
|
all_size = self.total_size |
|
indices = np.arange(len(self.dataset)) |
|
indices = indices[:all_size] |
|
num_repeat = (all_size-1) // indices.shape[0] + 1 |
|
indices = np.tile(indices, num_repeat) |
|
indices = indices[:all_size] |
|
|
|
np.random.shuffle(indices) |
|
|
|
assert len(indices) == self.total_size |
|
|
|
return indices |
|
|
|
def __len__(self): |
|
return self.total_size |
|
|
|
|
|
class DistributedGivenIterationSampler(Sampler): |
|
def __init__(self, dataset, total_iter, batch_size, world_size=None, rank=None, last_iter=-1): |
|
if world_size is None: |
|
world_size = dist.get_world_size() |
|
if rank is None: |
|
rank = dist.get_rank() |
|
assert rank < world_size |
|
self.dataset = dataset |
|
self.total_iter = total_iter |
|
self.batch_size = batch_size |
|
self.world_size = world_size |
|
self.rank = rank |
|
self.last_iter = last_iter |
|
|
|
self.total_size = self.total_iter*self.batch_size |
|
|
|
self.indices = self.gen_new_list() |
|
self.call = 0 |
|
|
|
def __iter__(self): |
|
if self.call == 0: |
|
self.call = 1 |
|
return iter(self.indices[(self.last_iter+1)*self.batch_size:]) |
|
else: |
|
raise RuntimeError("this sampler is not designed to be called more than once!!") |
|
|
|
def gen_new_list(self): |
|
|
|
|
|
np.random.seed(0) |
|
|
|
all_size = self.total_size * self.world_size |
|
indices = np.arange(len(self.dataset)) |
|
indices = indices[:all_size] |
|
num_repeat = (all_size-1) // indices.shape[0] + 1 |
|
indices = np.tile(indices, num_repeat) |
|
indices = indices[:all_size] |
|
|
|
np.random.shuffle(indices) |
|
beg = self.total_size * self.rank |
|
indices = indices[beg:beg+self.total_size] |
|
|
|
assert len(indices) == self.total_size |
|
|
|
return indices |
|
|
|
def __len__(self): |
|
|
|
|
|
|
|
|
|
return self.total_size |
|
|
|
|
|
|