# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/4A. S2A dataset preparation.ipynb. # %% auto 0 __all__ = ['flac_to_s2a_name'] # %% ../nbs/4A. S2A dataset preparation.ipynb 2 import sys import os import itertools from pathlib import Path import numpy as np import torch import torchaudio import torch.nn.functional as F from torch.profiler import profile, record_function, ProfilerActivity from fastprogress import progress_bar from fastcore.script import * import whisper from . import vad, wh_transcribe, vq_stoks, extract_acoustic import webdataset as wds # %% ../nbs/4A. S2A dataset preparation.ipynb 4 def flac_to_s2a_name(input): if '-flac-' in input: return input.rsplit("/", 1)[1].replace('flac', 's2a') + ".gz" else: return input.rsplit("/", 1)[1].replace('raw', 's2a') + ".gz" # %% ../nbs/4A. S2A dataset preparation.ipynb 6 def resampler(newsr = 24000, key = 'samples_24k'): _last_sr = None tform = None def _resample(samples): for s in samples: sr = s['sample_rate'] if sr != newsr: if sr != _last_sr: tform = torchaudio.transforms.Resample(sr, newsr) s[key] = tform(s['samples']) else: s[key] = s['samples'] yield s return _resample # %% ../nbs/4A. S2A dataset preparation.ipynb 9 @call_parse def prepare_s2a( input:str, # FLAC webdataset file path (or - to read the names from stdin) proc_dataset_path:Path, # processed VAD files path output:str=None, # output file name vq_model:str="collabora/spear-tts-pytorch:whisper-vq-stoks.model", # the model path (use repo_id:filename to download it from hugginface) n_samples:int=None, # process a limited amount of samples batch_size:int=1, # process several segments at once fix_dots:bool=False, # fix dots in file names ): if ":" in vq_model: repo, fname = vq_model.split(":", 1) vq_model = vq_stoks.RQBottleneckTransformer.load_model(repo, fname).cuda() else: vq_model = vq_stoks.RQBottleneckTransformer.load_model(local_filename=vq_model).cuda() amodel = extract_acoustic.load_model() amodel.set_target_bandwidth(3) if input == "-": input = [f.strip() for f in sys.stdin.readlines()] assert output, "please provide the output shard name" else: if output is None: output = flac_to_s2a_name(input) input = [input] total = n_samples//batch_size if n_samples else 'noinfer' ds = wds.WebDataset(input, shardshuffle=True, rename_files=vad.fix_dots_in_names if fix_dots else None).compose( wds.decode(wds.torch_audio), wds.select(lambda x: 'wav' in x or 'flac' in x), vq_stoks.merge_in(vq_stoks.derived_dataset(proc_dataset_path, 'vad')), wds.map_dict(**{"vad.npy":wh_transcribe.chunk_merger}), lambda x: wh_transcribe.split_to_chunks(x), resampler(), resampler(16000, 'samples_16k'), wds.to_tuple('__key__', 'rpad_s', 'samples_16k', 'samples_24k'), wds.batched(64), ) dl = wds.WebLoader(ds, num_workers=4, batch_size=None).unbatched().shuffle(2000).batched(batch_size) speakers = set() tmp = output+".tmp" with wds.TarWriter(tmp) as sink: for keys, rpad_ss, samples, samples24k in progress_bar(dl, total=total): with record_function('to_cuda'): samples, samples24k = samples.cuda(), samples24k.unsqueeze(1).cuda() with record_function('encodec'): atoks = amodel.encode(samples24k)[0][0] with record_function('vq_stoks'): stoks = vq_model.encode_audio(samples) with record_function('from_cuda'): atoks, stoks = atoks.cpu().numpy().astype(np.int16), stoks.cpu().numpy().astype(np.int16) for key, rpad_s, _atoks, _stoks in zip(keys, rpad_ss, atoks, stoks): speakers.add(key.split('/')[1]) sink.write({ "__key__": key, "atoks.npy": _atoks[:,:int(-rpad_s * 75)], "stoks.npy": _stoks[:int(-rpad_s * 25)], }) with open(output+".speakers.txt", "w") as f: f.write("\n".join(speakers)) if not n_samples: os.rename(tmp, output)