import os from glob import glob import torch.nn.functional as F import torchaudio from loguru import logger import soundfile as sf import librosa import gradio as gr from huggingface_hub import hf_hub_download import time import torch import yaml # from s3prl_vc.upstream.interface import get_upstream # from s3prl.nn import Featurizer # import s3prl_vc.models # from s3prl_vc.utils import read_hdf5 # from s3prl_vc.vocoder import Vocoder # ---------- Settings ---------- GPU_ID = '-1' os.environ['CUDA_VISIBLE_DEVICES'] = GPU_ID DEVICE = 'cuda' if GPU_ID != '-1' else 'cpu' SERVER_PORT = 42208 SERVER_NAME = "0.0.0.0" SSL_DIR = './keyble_ssl' FS = 16000 resamplers = {} MIN_REQUIRED_WAV_LENGTH = 1040 # EXAMPLE_DIR = './examples' # en_examples = sorted(glob(os.path.join(EXAMPLE_DIR, "en", '*.wav'))) # jp_examples = sorted(glob(os.path.join(EXAMPLE_DIR, "jp", '*.wav'))) # zh_examples = sorted(glob(os.path.join(EXAMPLE_DIR, "zh", '*.wav'))) # TRGSPKS = ["TEF1", "TEF2", "TEM1", "TEM2"] # ref_samples = { # trgspk: sorted(glob(os.path.join("./ref_samples", trgspk, '*.wav'))) # for trgspk in TRGSPKS # } # ---------- Logging ---------- logger.add('app.log', mode='a') logger.info('============================= App restarted =============================') # ---------- Download models ---------- logger.info('============================= Download models ===========================') model_paths = { "SSL-MOS, all training sets": { "ckpt": hf_hub_download(repo_id="unilight/sheet-models", filename="bvcc+nisqa+pstn+singmos+somos+tencent+tmhint-qi/sslmos+mdf/2337/checkpoint-86000steps.pkl"), "config": hf_hub_download(repo_id="unilight/sheet-models", filename="bvcc+nisqa+pstn+singmos+somos+tencent+tmhint-qi/sslmos+mdf/2337/config.yml"), } } # ---------- Model ---------- models = {} for name, path_dict in model_paths.items(): logger.info(f'============================= Setting up model for {name} =============') checkpoint_path = path_dict["ckpt"] config_path = path_dict["config"] with open(config_path) as f: config = yaml.load(f, Loader=yaml.Loader) if config["model_type"] == "SSLMOS": from models.sslmos import SSLMOS model = SSLMOS( config["model_input"], num_listeners=config.get("num_listeners", None), num_domains=config.get("num_domains", None), **config["model_params"], ).to(DEVICE) model.load_state_dict(torch.load(checkpoint_path, map_location="cpu")["model"]) model = model.eval().to(DEVICE) logger.info(f"Loaded model parameters from {checkpoint_path}.") models[name] = model def read_wav(wav_path): # read waveform waveform, sample_rate = torchaudio.load( wav_path, channels_first=False ) # waveform: [T, 1] # resample if needed if sample_rate != FS: resampler_key = f"{sample_rate}-{FS}" if resampler_key not in resamplers: resamplers[resampler_key] = torchaudio.transforms.Resample( sample_rate, FS, dtype=waveform.dtype ) waveform = resamplers[resampler_key](waveform) waveform = waveform.squeeze(-1) # always pad to a minumum length if waveform.shape[0] < MIN_REQUIRED_WAV_LENGTH: to_pad = (MIN_REQUIRED_WAV_LENGTH - waveform.shape[0]) // 2 waveform = F.pad(waveform, (to_pad, to_pad), "constant", 0) return waveform, sample_rate def predict(model_name, wav_file): x, fs = read_wav(wav_file) logger.info('wav file loaded') # set up model input model_input = x.unsqueeze(0).to(DEVICE) model_lengths = model_input.new_tensor([model_input.size(1)]).long() inputs = { config["model_input"]: model_input, config["model_input"] + "_lengths": model_lengths, } with torch.no_grad(): # model forward if config["inference_mode"] == "mean_listener": outputs = models[model_name].mean_listener_inference(inputs) elif config["inference_mode"] == "mean_net": outputs = models[model_name].mean_net_inference(inputs) pred_mean_scores = outputs["scores"].cpu().detach().numpy()[0] return pred_mean_scores with gr.Blocks(title="S3PRL-VC: Any-to-one voice conversion demo on VCC2020") as demo: gr.Markdown( """ # S3PRL-VC: Any-to-one voice conversion demo on VCC2020 ### [[Paper (ICASSP2023)]](https://arxiv.org/abs/2110.06280) [[Paper(JSTSP)]](https://arxiv.org/abs/2207.04356) [[Code]](https://github.com/unilight/s3prl-vc) **S3PRL-VC** is a voice conversion (VC) toolkit for benchmarking self-supervised speech representations (S3Rs). The term **any-to-one** means that the system can convert from any unseen speaker to a pre-defined speaker given in training. In this demo, you can record your voice, and the model will convert your voice to one of the four pre-defined speakers. These four speakers come from the **voice conversion challenge (VCC) 2020**. You can listen to the samples to get a sense of what these speakers sound like. The **RTF** of the system is around **1.5~2.5**, i.e. if you recorded a 5 second long audio, it will take 5 * (1.5~2.5) = 7.5~12.5 seconds to generate the output. """ ) with gr.Row(): with gr.Column(): gr.Markdown("## Record your speech here!") input_wav = gr.Audio(label="Input speech", source='microphone', type='filepath') gr.Markdown("## Select a model!") model_name = gr.Radio(label="Model", choices=list(model_paths.keys())) evaluate_btn = gr.Button(value="Evaluate!") # gr.Markdown("### You can use these examples if using a microphone is too troublesome!") # gr.Markdown("I recorded the samples using my Macbook Pro, so there might be some noises.") # gr.Examples( # examples=en_examples, # inputs=input_wav, # label="English examples" # ) # gr.Examples( # examples=jp_examples, # inputs=input_wav, # label="Japanese examples" # ) # gr.Examples( # examples=zh_examples, # inputs=input_wav, # label="Mandarin examples" # ) with gr.Column(): gr.Markdown("## The predicted scores is here:") output_score = gr.Textbox(label="Prediction", interactive=False) evaluate_btn.click(predict, [model_name, input_wav], output_score) if __name__ == '__main__': try: demo.launch(debug=True, enable_queue=True, ) except KeyboardInterrupt as e: print(e) finally: demo.close()