import chainer from chainer.iterators import MultiprocessIterator from chainer.iterators import SerialIterator from chainer.iterators import ShuffleOrderSampler from chainer.training.extension import Extension import numpy as np class ShufflingEnabler(Extension): """An extension enabling shuffling on an Iterator""" def __init__(self, iterators): """Inits the ShufflingEnabler :param list[Iterator] iterators: The iterators to enable shuffling on """ self.set = False self.iterators = iterators def __call__(self, trainer): """Calls the enabler on the given iterator :param trainer: The iterator """ if not self.set: for iterator in self.iterators: iterator.start_shuffle() self.set = True class ToggleableShufflingSerialIterator(SerialIterator): """A SerialIterator having its shuffling property activated during training""" def __init__(self, dataset, batch_size, repeat=True, shuffle=True): """Init the Iterator :param torch.nn.Tensor dataset: The dataset to take batches from :param int batch_size: The batch size :param bool repeat: Whether to repeat data (allow multiple epochs) :param bool shuffle: Whether to shuffle the batches """ super(ToggleableShufflingSerialIterator, self).__init__( dataset, batch_size, repeat, shuffle ) def start_shuffle(self): """Starts shuffling (or reshuffles) the batches""" self._shuffle = True if int(chainer._version.__version__[0]) <= 4: self._order = np.random.permutation(len(self.dataset)) else: self.order_sampler = ShuffleOrderSampler() self._order = self.order_sampler(np.arange(len(self.dataset)), 0) class ToggleableShufflingMultiprocessIterator(MultiprocessIterator): """A MultiprocessIterator having its shuffling property activated during training""" def __init__( self, dataset, batch_size, repeat=True, shuffle=True, n_processes=None, n_prefetch=1, shared_mem=None, maxtasksperchild=20, ): """Init the iterator :param torch.nn.Tensor dataset: The dataset to take batches from :param int batch_size: The batch size :param bool repeat: Whether to repeat batches or not (enables multiple epochs) :param bool shuffle: Whether to shuffle the order of the batches :param int n_processes: How many processes to use :param int n_prefetch: The number of prefetch to use :param int shared_mem: How many memory to share between processes :param int maxtasksperchild: Maximum number of tasks per child """ super(ToggleableShufflingMultiprocessIterator, self).__init__( dataset=dataset, batch_size=batch_size, repeat=repeat, shuffle=shuffle, n_processes=n_processes, n_prefetch=n_prefetch, shared_mem=shared_mem, maxtasksperchild=maxtasksperchild, ) def start_shuffle(self): """Starts shuffling (or reshuffles) the batches""" self.shuffle = True if int(chainer._version.__version__[0]) <= 4: self._order = np.random.permutation(len(self.dataset)) else: self.order_sampler = ShuffleOrderSampler() self._order = self.order_sampler(np.arange(len(self.dataset)), 0) self._set_prefetch_state()