PolyAI-pheme / data /sampler.py
taras-sereda's picture
minimal set of files to run inference; pheme-small checkpoint
96ee597
raw
history blame
No virus
4.21 kB
"""Original sampling logic of MQTTS.
Copyright PolyAI Limited.
"""
import math
import random
import numpy as np
from torch.utils import data
def StandardSampler(dataset, shuffle, distributed=False,
world_size=None, rank=None):
if distributed:
return data.distributed.DistributedSampler(
dataset, shuffle=shuffle, num_replicas=world_size, rank=rank)
if shuffle:
return data.RandomSampler(dataset)
return data.SequentialSampler(dataset)
def RandomBucketSampler(
nbuckets, length, batch_size, drop_last, distributed=False,
world_size=None, rank=None):
if distributed:
return DistributedRandomBucketSampler(
nbuckets, length, batch_size, drop_last, world_size, rank)
return SingleRandomBucketSampler(nbuckets, length, batch_size, drop_last)
class SingleRandomBucketSampler(data.Sampler):
def __init__(self, nbuckets, length, batch_size, drop_last):
self.length = length
self.batch_size = batch_size
self.drop_last = drop_last
indices = np.argsort([-x for x in length])
split = len(indices) // nbuckets
self.indices = []
for i in range(nbuckets):
self.indices.append(indices[i*split:(i+1)*split])
if nbuckets * split < len(length):
self.indices.append(indices[nbuckets*split:])
def __iter__(self):
random.shuffle(self.indices)
for x in self.indices:
random.shuffle(x)
idxs = [i for x in self.indices for i in x]
batches, batch, sum_len, max_len = [], [], 0, 0
for idx in idxs:
batch.append(idx)
sum_len += self.length[idx]
max_len = max(self.length[idx], max_len)
if max_len * len(batch) > self.batch_size:
batches.append(batch[:-1])
batch, sum_len, max_len = [batch[-1]], self.length[idx], self.length[idx] # noqa
if len(batch) > 0 and not self.drop_last:
batches.append(batch)
random.shuffle(batches)
return iter(batches)
class DistributedRandomBucketSampler(data.Sampler):
def __init__(self, nbuckets, length, batch_size,
drop_last, num_replicas, rank, seed=1234):
if rank >= num_replicas or rank < 0:
raise ValueError(
"Invalid rank {}, rank should be in the interval"
" [0, {}]".format(rank, num_replicas - 1))
indices = np.argsort(length)
split = len(indices) // nbuckets
self.length = length
self.batch_size = batch_size
self.drop_last = drop_last
self.indices = []
for i in range(nbuckets):
self.indices.append(indices[i*split:(i+1)*split])
if nbuckets * split < len(length):
self.indices.append(indices[nbuckets*split:])
self.num_replicas = num_replicas
self.rank = rank
self.epoch = 0
self.seed = seed
def __iter__(self):
# Deterministic shuffling
random.Random(self.epoch + self.seed).shuffle(self.indices)
for i, x in enumerate(self.indices):
seed = self.epoch + self.seed + i * 5
random.Random(seed).shuffle(x)
indices = [i for x in self.indices for i in x]
# Batching
batches, batch, sum_len, max_len = [], [], 0, 0
for idx in indices:
batch.append(idx)
sum_len += self.length[idx]
max_len = max(self.length[idx], max_len)
if max_len * len(batch) > self.batch_size:
batches.append(batch[:-1])
batch, sum_len, max_len = [batch[-1]], self.length[idx], self.length[idx] # noqa
# Subsample
num_samples = math.ceil(
(len(batches) - self.num_replicas) / self.num_replicas)
total_size = num_samples * self.num_replicas
batches = batches[:total_size]
batches = batches[self.rank*num_samples: (self.rank+1)*num_samples]
assert len(batches) == num_samples
# Stochastic suffling
random.shuffle(batches)
return iter(batches)
def set_epoch(self, epoch):
self.epoch = epoch