File size: 6,141 Bytes
33d9042
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/D. Common dataset utilities.ipynb.

# %% auto 0
__all__ = ['shard_glob', 'join_datasets', 'resampler', 'derived_name', 'derived_dataset', 'merge_in', 'AtomicTarWriter',
           'readlines']

# %% ../nbs/D. Common dataset utilities.ipynb 1
import os
import torch
import torchaudio
from pathlib import Path
import webdataset as wds
from contextlib import contextmanager

import torch.nn.functional as F

# %% ../nbs/D. Common dataset utilities.ipynb 2
def shard_glob(input):
    if '{' in input:
        return wds.shardlists.expand_urls(input)
    if isinstance(input, (Path, str)):
        path = Path(input)
        if path.is_dir():
            glob = '*.tar.gz'
        else:
            glob = path.name
            path = path.parent
        input = Path(path).glob(glob)
    else:
        raise ArgumentError("input should be either a list or a path with an optional glob specifier")
    return [str(x) for x in input]

# %% ../nbs/D. Common dataset utilities.ipynb 3
class join_datasets(torch.utils.data.IterableDataset):
    def __init__(self, datasets):
        self.datasets = datasets
        
    def __iter__(self):
        probs = torch.tensor([getattr(ds, 'weight', 1) for ds in self.datasets], dtype=torch.float)
        its = [iter(ds) for ds in self.datasets]
        while True:
            try:
                yield next(its[torch.multinomial(probs, 1)])
            except StopIteration:
                return    
    
    def __len__(self):
        return sum([ds.total_samples for ds in self.datasets])

# %% ../nbs/D. Common dataset utilities.ipynb 5
def resampler(newsr = 24000, key = 'samples_24k'):
    _last_sr = None
    tform = None
    
    def _resample(samples):
        for s in samples:
            sr = s['sample_rate']
            if sr != newsr:
                if sr != _last_sr: tform = torchaudio.transforms.Resample(sr, newsr)
                s[key] = tform(s['samples'])
            else:
                s[key] = s['samples']
            yield s
    
    return _resample

# %% ../nbs/D. Common dataset utilities.ipynb 6
def derived_name(input, kind, base="audio", suffix=".gz", dir=None):
    dir = Path(dir) if dir else Path(input).parent
    return str(dir/(Path(input).name.replace(f"-{base}-", f"-{kind}-") + suffix))

# %% ../nbs/D. Common dataset utilities.ipynb 7
def derived_dataset(kind, base='audio', suffix=".gz", decoders=[], dir=None):
    def deriver(url):
        url = str(derived_name(url, kind, base=base, suffix=suffix, dir=dir))
        return wds.WebDataset(
            wds.SimpleShardList([url])
        ).decode(*decoders)
    return deriver

# %% ../nbs/D. Common dataset utilities.ipynb 8
def merge_in(dataset_fun):
    """Merge a dataset into the current one returning samples with the union of keys. Pass in a function
    that takes a URL of a sample and returns a dataset for it (called everytime the URL changes).
    
    It requires (and validates) that both datasets have the same ordering of keys so you have
    to use it before any sample shuffling. Shard shuffling is ok.
    """
    def merge_loop(main_samples):
        #print("new merge loop:", dataset_fun)
        merged_samples = None
        cur_url = None
        i = None
        for s in main_samples:
            url = s['__url__']
            if url != cur_url:
                # this will open a new file when we get the first sample with a new __url__
                merged_samples = iter(dataset_fun(url))
                cur_url = url
            try:
                merge_s = next(merged_samples)
            except StopIteration:
                # if the original shard got repeated we won't observe a __url__ change
                # in this case restart the dataset from the beginning
                merged_samples = iter(dataset_fun(url))
                merge_s = next(merged_samples)
            assert merge_s['__key__'] == s['__key__'], f"sample keys don't match: {merge_s['__key__']}, {s['__key__']} in file {s['__url__']}"
            news = {}
            news.update(merge_s)
            news.update(s)
            yield news
    return merge_loop

# %% ../nbs/D. Common dataset utilities.ipynb 9
def split_to_chunks(stream, ikey='vad.npy', metakeys=[], pad_to_seconds=30, random_shift=False):
    for s in stream:
        audio, sr = s['audio']
        imax = len(s[ikey]) - 1
        for i,(ts,te) in enumerate(s[ikey]):
            samples = audio[0,int(ts*sr):int(te*sr)]
            if pad_to_seconds is not None:
                padding = pad_to_seconds*sr-samples.shape[-1]
                lpad = random.randint(0, padding) if random_shift else 0
                samples = F.pad(samples, (lpad, padding-lpad))
            subs = {"__key__": s['__key__'] + f"_{i:03d}",
                    "src_key": s['__key__'],
                    "__url__": s['__url__'],
                    "i": i, "imax": imax,
                    "tstart": ts, "tend": te, "total_seconds": audio.shape[-1]/sr,
                    "lpad": lpad, "rpad": padding-lpad,
                    "lpad_s": lpad/sr, "rpad_s": (padding-lpad)/sr,
                    "samples": samples, "sample_rate": sr}
            for k in metakeys:
                subs[k] = s[k][i]
            yield subs

# %% ../nbs/D. Common dataset utilities.ipynb 10
def vad_dataset(shards, ikey='vad.npy', kind='vad'):
    return wds.WebDataset(shards).compose(
        wds.decode(wds.torch_audio),
        merge_in(derived_dataset(kind)),
        wds.select(lambda x: 'wav' in x or 'flac' in x or 'mp3' in x or 'ogg' in x), # skip samples without audio
        wds.rename(audio="flac;mp3;wav;ogg"),
        lambda x: split_to_chunks(x, ikey=ikey),
    )

# %% ../nbs/D. Common dataset utilities.ipynb 11
@contextmanager
def AtomicTarWriter(name, throwaway=False):
    tmp = name+".tmp"
    with wds.TarWriter(tmp, compress=name.endswith('gz')) as sink:
        yield sink
    if not throwaway:
        os.rename(tmp, name)

# %% ../nbs/D. Common dataset utilities.ipynb 12
def readlines(fname):
    with open(fname) as file:
        return [line.rstrip() for line in file]