Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from typing import Iterator | |
import torch | |
from mmengine.dataset import DefaultSampler | |
from mmpretrain.registry import DATA_SAMPLERS | |
class SequentialSampler(DefaultSampler): | |
"""Sequential sampler which supports different subsample policy. | |
Args: | |
dataset (Sized): The dataset. | |
round_up (bool): Whether to add extra samples to make the number of | |
samples evenly divisible by the world size. Defaults to True. | |
subsample_type (str): The method to subsample data on different rank. | |
Supported type: | |
- ``'default'``: Original torch behavior. Sample the examples one | |
by one for each GPU in terms. For instance, 8 examples on 2 GPUs, | |
GPU0: [0,2,4,8], GPU1: [1,3,5,7] | |
- ``'sequential'``: Subsample all examples to n chunk sequntially. | |
For instance, 8 examples on 2 GPUs, | |
GPU0: [0,1,2,3], GPU1: [4,5,6,7] | |
""" | |
def __init__(self, subsample_type: str = 'default', **kwargs) -> None: | |
super().__init__(shuffle=False, **kwargs) | |
if subsample_type not in ['default', 'sequential']: | |
raise ValueError(f'Unsupported subsample typer "{subsample_type}",' | |
' please choose from ["default", "sequential"]') | |
self.subsample_type = subsample_type | |
def __iter__(self) -> Iterator[int]: | |
"""Iterate the indices.""" | |
indices = torch.arange(len(self.dataset)).tolist() | |
# add extra samples to make it evenly divisible | |
if self.round_up: | |
indices = ( | |
indices * | |
int(self.total_size / len(indices) + 1))[:self.total_size] | |
# subsample | |
if self.subsample_type == 'default': | |
indices = indices[self.rank:self.total_size:self.world_size] | |
elif self.subsample_type == 'sequential': | |
num_samples_per_rank = self.total_size // self.world_size | |
indices = indices[self.rank * | |
num_samples_per_rank:(self.rank + 1) * | |
num_samples_per_rank] | |
return iter(indices) | |