import math from typing import TypeVar, Optional, Iterator import torch from torch.utils.data import Sampler, Dataset import torch.distributed as dist import random import numpy as np import torch class DistributedSamplerChunkByNode(torch.utils.data.Sampler): def __init__( self, dataset, all_datasets, chunk_or_not, num_replicas: Optional[int] = None, rank: Optional[int] = None, shuffle: bool = True, seed: int = 0, drop_last: bool = False, node_rank=0, node_number=1, process_num_per_node=1, rank_within_local_node=0, ) -> None: if num_replicas is None: if not dist.is_available(): raise RuntimeError("Requires distributed package to be available") num_replicas = dist.get_world_size() if rank is None: if not dist.is_available(): raise RuntimeError("Requires distributed package to be available") rank = dist.get_rank() if rank >= num_replicas or rank < 0: raise ValueError( "Invalid rank {}, rank should be in the interval" " [0, {}]".format(rank, num_replicas - 1) ) self.dataset = dataset self.num_replicas = num_replicas self.rank = rank self.epoch = 0 self.node_number = node_number self.node_rank = node_rank self.chunk_or_not = chunk_or_not self.process_num_per_node = process_num_per_node self.rank_within_local_node = rank_within_local_node assert self.process_num_per_node * self.node_number == self.num_replicas # 1. divide the datasets into two parts normal_datasets = [] chunked_datasets = [] for dataset_i, chunk_i in zip(all_datasets, chunk_or_not): if chunk_i: chunked_datasets.append(dataset_i) else: normal_datasets.append(dataset_i) # 2. calculate dataset sizes: self.normal_dataset_size = sum( [len(i) for i in normal_datasets] ) # this part we follow the conventional distributed sampler # 3. Divide self.current_node_start_range = -1 self.current_node_end_range = -1 assert len(chunked_datasets) >= self.node_number chunk_size = len(chunked_datasets) // self.node_number current_example_num = self.normal_dataset_size for index in range(len(chunked_datasets)): if index == self.node_rank * chunk_size: self.current_node_start_range = current_example_num current_example_num += len(chunked_datasets[index]) if index == (self.node_rank + 1) * chunk_size - 1: self.current_node_end_range = current_example_num if self.current_node_end_range == -1: # boundary self.current_node_end_range = current_example_num self.drop_last = drop_last # If the dataset length is evenly divisible by # of replicas, then there # is no need to drop any data, since the dataset will be split equally. if self.drop_last and len(self.dataset) % self.num_replicas != 0: # type: ignore[arg-type] # Split to nearest available length that is evenly divisible. # This is to ensure each rank receives the same amount of data when # using this Sampler. self.num_samples = math.ceil( # `type:ignore` is required because Dataset cannot provide a default __len__ # see NOTE in pytorch/torch/utils/data/sampler.py (len(self.dataset) - self.num_replicas) / self.num_replicas # type: ignore[arg-type] ) else: self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) # type: ignore[arg-type] self.total_size = self.num_samples * self.num_replicas self.shuffle = shuffle self.seed = seed def __iter__(self): indices = self.generate_indices_within_range_with_rank( seed=self.seed, epoch=self.epoch, # NOTE: Distribute among all processes process_num=self.num_replicas, rank=self.rank, generate_length=-1, valid_indices=list(range(self.normal_dataset_size)), prefix="Normal ", ) addition_indices = self.generate_indices_within_range_with_rank( seed=self.seed, epoch=self.epoch, # NOTE : very important arguments, distribute among local nodes process_num=self.process_num_per_node, rank=self.rank_within_local_node, generate_length=self.num_samples - len(indices), valid_indices=list(range(self.current_node_start_range, self.current_node_end_range)), prefix="Distribute ", ) indices.extend(addition_indices) random.seed(self.seed + self.epoch + 10 * self.rank) # Set the seed to maximize randomness random.shuffle(indices) # Reshuffle assert len(indices) == self.num_samples return iter(indices) def generate_indices_within_range_with_rank( self, seed, epoch, process_num, generate_length, valid_indices, rank=-1, shuffle=True, prefix="" ): """ Use scenario : we want to sample 2500 examples from 10000 examples, while not sampling overlapping examples with other three process. Modified from DistributedSampler """ dataset_size = len(valid_indices) if shuffle: # deterministically shuffle based on epoch and seed g = torch.Generator() g.manual_seed(seed + epoch) indices = torch.randperm(dataset_size, generator=g).tolist() # type: ignore[arg-type] else: indices = list(range(dataset_size)) # type: ignore[arg-type] indices = [valid_indices[i] for i in indices] num_samples_normal = math.ceil((dataset_size - process_num) / process_num) # type: ignore[arg-type] # remove tail of data to make it evenly divisible. indices = indices[: num_samples_normal * process_num] print("\n") print( prefix, "Global Rank {} Local Rank {} generate_length {} valid_indices {} process_num {} indices_before_subsample {} {}".format( self.rank, rank, generate_length, len(valid_indices), process_num, len(indices), indices[:10] ), ) # subsample indices = indices[rank : num_samples_normal * process_num : process_num] print( prefix, "Global Rank {} Local Rank {} generate_length {} valid_indices {} process_num {} indices_after_subsample {} {}".format( self.rank, rank, generate_length, len(valid_indices), process_num, len(indices), indices[:10] ), ) print("\n") if generate_length != -1: if len(indices) > generate_length: indices = indices[:generate_length] else: indices.extend(np.random.choice(valid_indices, generate_length - len(indices)).tolist()) return indices def __len__(self) -> int: return self.num_samples def set_epoch(self, epoch: int) -> None: r""" Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas use a different random ordering for each epoch. Otherwise, the next iteration of this sampler will yield the same ordering. Args: epoch (int): Epoch number. """ self.epoch = epoch