File size: 2,482 Bytes
33d9042
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/1B. Voice activity detection.ipynb.

# %% auto 0
__all__ = []

# %% ../nbs/1B. Voice activity detection.ipynb 3
import os
import torch
import torchaudio

from pathlib import Path
from fastprogress import progress_bar
from fastcore.script import call_parse

import whisperx
import random
import numpy as np
import webdataset as wds

# %% ../nbs/1B. Voice activity detection.ipynb 5
# some of the original file names have a dot in their name
# webdataset does not like it so let's patch it
def fix_dots_in_names(name):
    name, ext = name.rsplit('.', 1)
    return ".".join((name.replace('.', '_'), ext))

def load_dataset(url, decode=True, rename_files=None):
    ds = wds.WebDataset(url, rename_files=rename_files)
    if not decode: return ds
    return ds.decode(wds.torch_audio)

# %% ../nbs/1B. Voice activity detection.ipynb 7
def extract_segments(vad_result, max_duration):
    binarize = whisperx.vad.Binarize(max_duration=max_duration)
    segments = binarize(vad_result)
    return [(x.start, x.end) for x in segments.get_timeline()]

def segment_audio(vad_model, audio, sr=16000):
    vad_result = vad_model({"waveform": audio, "sample_rate": sr})
    return extract_segments(vad_result, 30)

# %% ../nbs/1B. Voice activity detection.ipynb 13
def flac_to_vad_name(input):
    if '-flac-' in input:
        return input.rsplit("/", 1)[1].replace('flac', 'vad') + ".gz"
    else:
        return input.rsplit("/", 1)[1].replace('raw', 'vad') + ".gz"

@call_parse
def process_shard(
    input:str,           # input shard URL/path
    output:str=None,     # output shard URL/path
    fix_dots:bool=False, # fix dots in LibriLight filenames
):
    if output is None: output = flac_to_vad_name(input)
    
    ds = torch.utils.data.DataLoader(load_dataset(input, rename_files=fix_dots_in_names if fix_dots else None), num_workers=2, batch_size=None)
    vad_model = whisperx.vad.load_vad_model('cuda')
    
    tmp = output+".tmp"
    with wds.TarWriter(tmp) as sink:
        for s in progress_bar(ds, total='noinfer'):
            audio, sr = s.get('flac', s.get('wav', (None, None)))
            if audio is None:
                print(f"warning: '{s['__key__']}' does not contain an audio file")
                continue
            sink.write({
                "__key__": s['__key__'],
                "vad.npy": np.array(segment_audio(vad_model, audio, sr=sr), dtype=np.float16)
            })
    os.rename(tmp, output)