WhisperSpeech / whisperspeech /wh_transcribe.py
tonic
Laion WhisperSpeech Demo
33d9042
raw
history blame
No virus
5.25 kB
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/2A. Whisper quantization dataset preparation.ipynb.
# %% auto 0
__all__ = []
# %% ../nbs/2A. Whisper quantization dataset preparation.ipynb 3
import os
import io
import time
import torch
import torchaudio
# %% ../nbs/2A. Whisper quantization dataset preparation.ipynb 4
from pathlib import Path
import json
from fastprogress import progress_bar, master_bar
import numpy as np
import random
import whisper
from torch import nn
import torch.nn.functional as F
from torch.utils.data.dataloader import DataLoader
from fastcore.script import *
from . import vad
import webdataset as wds
# %% ../nbs/2A. Whisper quantization dataset preparation.ipynb 9
# let's make it a bit more conservative
# with full 30 second chunks it sometimes misses a small part of the transcript
def random_cutter(dur):
if random.random() < 0.5:
return dur > 28 * (random.random()*0.95+0.05)
else:
return dur > 28
def chunk_merger(segments, should_cut=lambda x: x > 28):
if len(segments) == 0: return segments
curr_start = segments[0][0]
curr_end = 0
merged = []
for ts,te in segments:
if should_cut(te - curr_start) and curr_end - curr_start > 0:
merged.append((curr_start, curr_end))
curr_start = ts
curr_end = te
merged.append((curr_start, curr_end))
return merged
# %% ../nbs/2A. Whisper quantization dataset preparation.ipynb 18
def merge_in(*datasets):
"""Merge multiple datasets into the current one returning samples with the union of keys.
It requires (and validates) all datasets to have the same ordering of keys so you have
to use it before any sample shuffling. Shard shuffling is ok.
"""
def merge_loop(main_samples):
for samples in zip(*[main_samples]+[iter(x) for x in datasets]):
key = samples[0]['__key__']
news = {}
for s in samples:
assert s['__key__'] == key
news.update(s)
yield news
return merge_loop
# %% ../nbs/2A. Whisper quantization dataset preparation.ipynb 19
import copy
# %% ../nbs/2A. Whisper quantization dataset preparation.ipynb 20
# a workaround for https://github.com/webdataset/webdataset/issues/297
# should be possible to use ds.compose here
def wds_compose(ds, *args):
ds = copy.copy(ds)
ds.pipeline = copy.copy(ds.pipeline)
for f in args:
ds.append(f)
return ds
# %% ../nbs/2A. Whisper quantization dataset preparation.ipynb 24
def split_to_chunks(stream, pad_to_seconds=30, random_shift=False):
for s in stream:
audio, sr = s.get('flac', s.get('wav', (None, None)))
if audio is None:
print(f"warning: '{s['__key__']}' does not contain an audio file")
continue
imax = len(s['vad.npy']) - 1
for i,(ts,te) in enumerate(s['vad.npy']):
samples = audio[0,int(ts*sr):int(te*sr)]
if pad_to_seconds is not None:
padding = pad_to_seconds*sr-samples.shape[-1]
lpad = random.randint(0, padding) if random_shift else 0
samples = F.pad(samples, (lpad, padding-lpad))
yield {"__key__": s['__key__'] + f"_{i:03d}",
"__url__": s['__url__'],
"i": i, "imax": imax,
"tstart": ts, "tend": te, "total_seconds": audio.shape[-1]/sr,
"lpad": lpad, "rpad": padding-lpad,
"lpad_s": lpad/sr, "rpad_s": (padding-lpad)/sr,
"samples": samples, "sample_rate": sr}
# %% ../nbs/2A. Whisper quantization dataset preparation.ipynb 38
def flac_to_txt_name(input, model_size):
return input.rsplit("/", 1)[1].replace('flac', f'{model_size}-txt') + ".gz"
@call_parse
def process_shard(
input:str, # input shard URL/path
output:str=None, # output shard URL/path
bs:int=None, # batch size (16 uses around 11GB of VRAM)
n_samples:int=None, # limit the number of samples (useful for quick benchmarking)
whisper_model:str="base.en" # Whisper model size
):
if output is None: output = flac_to_txt_name(input, whisper_model)
if bs is None: bs = 16
if n_samples is None: n_samples = 'noinfer'
else: n_samples = n_samples // bs
ds = wds_compose(vad.load_dataset(input),
merge_in(wds.WebDataset(vad.flac_to_vad_name(input)).decode()),
wds.map_dict(**{"vad.npy":chunk_merger}),
split_to_chunks,
wds.to_tuple('__key__', 'samples'),
wds.batched(bs),
)
dl = DataLoader(ds, num_workers=2, batch_size=None)
whmodel = whisper.load_model(whisper_model)
decoding_options = whisper.DecodingOptions(language='en')
tmp = output+".tmp"
with wds.TarWriter(tmp) as sink:
for keys, samples in progress_bar(dl, total=n_samples):
with torch.no_grad():
embs = whmodel.encoder(whisper.log_mel_spectrogram(samples).cuda())
decs = whmodel.decode(embs, decoding_options)
for key, dec in zip(keys, decs):
sink.write({
"__key__": key,
"txt": dec.text,
})
os.rename(tmp, output)