|
import chainer |
|
from chainer.iterators import MultiprocessIterator |
|
from chainer.iterators import SerialIterator |
|
from chainer.iterators import ShuffleOrderSampler |
|
from chainer.training.extension import Extension |
|
|
|
import numpy as np |
|
|
|
|
|
class ShufflingEnabler(Extension): |
|
"""An extension enabling shuffling on an Iterator""" |
|
|
|
def __init__(self, iterators): |
|
"""Inits the ShufflingEnabler |
|
|
|
:param list[Iterator] iterators: The iterators to enable shuffling on |
|
""" |
|
self.set = False |
|
self.iterators = iterators |
|
|
|
def __call__(self, trainer): |
|
"""Calls the enabler on the given iterator |
|
|
|
:param trainer: The iterator |
|
""" |
|
if not self.set: |
|
for iterator in self.iterators: |
|
iterator.start_shuffle() |
|
self.set = True |
|
|
|
|
|
class ToggleableShufflingSerialIterator(SerialIterator): |
|
"""A SerialIterator having its shuffling property activated during training""" |
|
|
|
def __init__(self, dataset, batch_size, repeat=True, shuffle=True): |
|
"""Init the Iterator |
|
|
|
:param torch.nn.Tensor dataset: The dataset to take batches from |
|
:param int batch_size: The batch size |
|
:param bool repeat: Whether to repeat data (allow multiple epochs) |
|
:param bool shuffle: Whether to shuffle the batches |
|
""" |
|
super(ToggleableShufflingSerialIterator, self).__init__( |
|
dataset, batch_size, repeat, shuffle |
|
) |
|
|
|
def start_shuffle(self): |
|
"""Starts shuffling (or reshuffles) the batches""" |
|
self._shuffle = True |
|
if int(chainer._version.__version__[0]) <= 4: |
|
self._order = np.random.permutation(len(self.dataset)) |
|
else: |
|
self.order_sampler = ShuffleOrderSampler() |
|
self._order = self.order_sampler(np.arange(len(self.dataset)), 0) |
|
|
|
|
|
class ToggleableShufflingMultiprocessIterator(MultiprocessIterator): |
|
"""A MultiprocessIterator having its shuffling property activated during training""" |
|
|
|
def __init__( |
|
self, |
|
dataset, |
|
batch_size, |
|
repeat=True, |
|
shuffle=True, |
|
n_processes=None, |
|
n_prefetch=1, |
|
shared_mem=None, |
|
maxtasksperchild=20, |
|
): |
|
"""Init the iterator |
|
|
|
:param torch.nn.Tensor dataset: The dataset to take batches from |
|
:param int batch_size: The batch size |
|
:param bool repeat: Whether to repeat batches or not (enables multiple epochs) |
|
:param bool shuffle: Whether to shuffle the order of the batches |
|
:param int n_processes: How many processes to use |
|
:param int n_prefetch: The number of prefetch to use |
|
:param int shared_mem: How many memory to share between processes |
|
:param int maxtasksperchild: Maximum number of tasks per child |
|
""" |
|
super(ToggleableShufflingMultiprocessIterator, self).__init__( |
|
dataset=dataset, |
|
batch_size=batch_size, |
|
repeat=repeat, |
|
shuffle=shuffle, |
|
n_processes=n_processes, |
|
n_prefetch=n_prefetch, |
|
shared_mem=shared_mem, |
|
maxtasksperchild=maxtasksperchild, |
|
) |
|
|
|
def start_shuffle(self): |
|
"""Starts shuffling (or reshuffles) the batches""" |
|
self.shuffle = True |
|
if int(chainer._version.__version__[0]) <= 4: |
|
self._order = np.random.permutation(len(self.dataset)) |
|
else: |
|
self.order_sampler = ShuffleOrderSampler() |
|
self._order = self.order_sampler(np.arange(len(self.dataset)), 0) |
|
self._set_prefetch_state() |
|
|