tonic
Laion WhisperSpeech Demo
33d9042
raw
history blame
No virus
6.14 kB
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/D. Common dataset utilities.ipynb.
# %% auto 0
__all__ = ['shard_glob', 'join_datasets', 'resampler', 'derived_name', 'derived_dataset', 'merge_in', 'AtomicTarWriter',
'readlines']
# %% ../nbs/D. Common dataset utilities.ipynb 1
import os
import torch
import torchaudio
from pathlib import Path
import webdataset as wds
from contextlib import contextmanager
import torch.nn.functional as F
# %% ../nbs/D. Common dataset utilities.ipynb 2
def shard_glob(input):
if '{' in input:
return wds.shardlists.expand_urls(input)
if isinstance(input, (Path, str)):
path = Path(input)
if path.is_dir():
glob = '*.tar.gz'
else:
glob = path.name
path = path.parent
input = Path(path).glob(glob)
else:
raise ArgumentError("input should be either a list or a path with an optional glob specifier")
return [str(x) for x in input]
# %% ../nbs/D. Common dataset utilities.ipynb 3
class join_datasets(torch.utils.data.IterableDataset):
def __init__(self, datasets):
self.datasets = datasets
def __iter__(self):
probs = torch.tensor([getattr(ds, 'weight', 1) for ds in self.datasets], dtype=torch.float)
its = [iter(ds) for ds in self.datasets]
while True:
try:
yield next(its[torch.multinomial(probs, 1)])
except StopIteration:
return
def __len__(self):
return sum([ds.total_samples for ds in self.datasets])
# %% ../nbs/D. Common dataset utilities.ipynb 5
def resampler(newsr = 24000, key = 'samples_24k'):
_last_sr = None
tform = None
def _resample(samples):
for s in samples:
sr = s['sample_rate']
if sr != newsr:
if sr != _last_sr: tform = torchaudio.transforms.Resample(sr, newsr)
s[key] = tform(s['samples'])
else:
s[key] = s['samples']
yield s
return _resample
# %% ../nbs/D. Common dataset utilities.ipynb 6
def derived_name(input, kind, base="audio", suffix=".gz", dir=None):
dir = Path(dir) if dir else Path(input).parent
return str(dir/(Path(input).name.replace(f"-{base}-", f"-{kind}-") + suffix))
# %% ../nbs/D. Common dataset utilities.ipynb 7
def derived_dataset(kind, base='audio', suffix=".gz", decoders=[], dir=None):
def deriver(url):
url = str(derived_name(url, kind, base=base, suffix=suffix, dir=dir))
return wds.WebDataset(
wds.SimpleShardList([url])
).decode(*decoders)
return deriver
# %% ../nbs/D. Common dataset utilities.ipynb 8
def merge_in(dataset_fun):
"""Merge a dataset into the current one returning samples with the union of keys. Pass in a function
that takes a URL of a sample and returns a dataset for it (called everytime the URL changes).
It requires (and validates) that both datasets have the same ordering of keys so you have
to use it before any sample shuffling. Shard shuffling is ok.
"""
def merge_loop(main_samples):
#print("new merge loop:", dataset_fun)
merged_samples = None
cur_url = None
i = None
for s in main_samples:
url = s['__url__']
if url != cur_url:
# this will open a new file when we get the first sample with a new __url__
merged_samples = iter(dataset_fun(url))
cur_url = url
try:
merge_s = next(merged_samples)
except StopIteration:
# if the original shard got repeated we won't observe a __url__ change
# in this case restart the dataset from the beginning
merged_samples = iter(dataset_fun(url))
merge_s = next(merged_samples)
assert merge_s['__key__'] == s['__key__'], f"sample keys don't match: {merge_s['__key__']}, {s['__key__']} in file {s['__url__']}"
news = {}
news.update(merge_s)
news.update(s)
yield news
return merge_loop
# %% ../nbs/D. Common dataset utilities.ipynb 9
def split_to_chunks(stream, ikey='vad.npy', metakeys=[], pad_to_seconds=30, random_shift=False):
for s in stream:
audio, sr = s['audio']
imax = len(s[ikey]) - 1
for i,(ts,te) in enumerate(s[ikey]):
samples = audio[0,int(ts*sr):int(te*sr)]
if pad_to_seconds is not None:
padding = pad_to_seconds*sr-samples.shape[-1]
lpad = random.randint(0, padding) if random_shift else 0
samples = F.pad(samples, (lpad, padding-lpad))
subs = {"__key__": s['__key__'] + f"_{i:03d}",
"src_key": s['__key__'],
"__url__": s['__url__'],
"i": i, "imax": imax,
"tstart": ts, "tend": te, "total_seconds": audio.shape[-1]/sr,
"lpad": lpad, "rpad": padding-lpad,
"lpad_s": lpad/sr, "rpad_s": (padding-lpad)/sr,
"samples": samples, "sample_rate": sr}
for k in metakeys:
subs[k] = s[k][i]
yield subs
# %% ../nbs/D. Common dataset utilities.ipynb 10
def vad_dataset(shards, ikey='vad.npy', kind='vad'):
return wds.WebDataset(shards).compose(
wds.decode(wds.torch_audio),
merge_in(derived_dataset(kind)),
wds.select(lambda x: 'wav' in x or 'flac' in x or 'mp3' in x or 'ogg' in x), # skip samples without audio
wds.rename(audio="flac;mp3;wav;ogg"),
lambda x: split_to_chunks(x, ikey=ikey),
)
# %% ../nbs/D. Common dataset utilities.ipynb 11
@contextmanager
def AtomicTarWriter(name, throwaway=False):
tmp = name+".tmp"
with wds.TarWriter(tmp, compress=name.endswith('gz')) as sink:
yield sink
if not throwaway:
os.rename(tmp, name)
# %% ../nbs/D. Common dataset utilities.ipynb 12
def readlines(fname):
with open(fname) as file:
return [line.rstrip() for line in file]