import os import torch.nn.functional as F import torchaudio from loguru import logger import gradio as gr from huggingface_hub import hf_hub_download import torch import yaml # ---------- Settings ---------- GPU_ID = '-1' os.environ['CUDA_VISIBLE_DEVICES'] = GPU_ID DEVICE = 'cuda' if GPU_ID != '-1' else 'cpu' SERVER_PORT = 42208 SERVER_NAME = "0.0.0.0" SSL_DIR = './keyble_ssl' FS = 16000 resamplers = {} MIN_REQUIRED_WAV_LENGTH = 1040 # EXAMPLE_DIR = './examples' # en_examples = sorted(glob(os.path.join(EXAMPLE_DIR, "en", '*.wav'))) # jp_examples = sorted(glob(os.path.join(EXAMPLE_DIR, "jp", '*.wav'))) # zh_examples = sorted(glob(os.path.join(EXAMPLE_DIR, "zh", '*.wav'))) # ---------- Logging ---------- logger.add('app.log', mode='a') logger.info('============================= App restarted =============================') # ---------- Download models ---------- logger.info('============================= Download models ===========================') model_paths = { "SSL-MOS, all training sets": { "ckpt": hf_hub_download(repo_id="unilight/sheet-models", filename="bvcc+nisqa+pstn+singmos+somos+tencent+tmhint-qi/sslmos+mdf/2337/checkpoint-86000steps.pkl"), "config": hf_hub_download(repo_id="unilight/sheet-models", filename="bvcc+nisqa+pstn+singmos+somos+tencent+tmhint-qi/sslmos+mdf/2337/config.yml"), } } # ---------- Model ---------- models = {} for name, path_dict in model_paths.items(): logger.info(f'============================= Setting up model for {name} =============') checkpoint_path = path_dict["ckpt"] config_path = path_dict["config"] with open(config_path) as f: config = yaml.load(f, Loader=yaml.Loader) if config["model_type"] == "SSLMOS": from models.sslmos import SSLMOS model = SSLMOS( config["model_input"], num_listeners=config.get("num_listeners", None), num_domains=config.get("num_domains", None), **config["model_params"], ).to(DEVICE) model.load_state_dict(torch.load(checkpoint_path, map_location="cpu")["model"]) model = model.eval().to(DEVICE) logger.info(f"Loaded model parameters from {checkpoint_path}.") models[name] = model def read_wav(wav_path): # read waveform waveform, sample_rate = torchaudio.load( wav_path, channels_first=False ) # waveform: [T, 1] # resample if needed if sample_rate != FS: resampler_key = f"{sample_rate}-{FS}" if resampler_key not in resamplers: resamplers[resampler_key] = torchaudio.transforms.Resample( sample_rate, FS, dtype=waveform.dtype ) waveform = resamplers[resampler_key](waveform) waveform = waveform.squeeze(-1) # always pad to a minumum length if waveform.shape[0] < MIN_REQUIRED_WAV_LENGTH: to_pad = (MIN_REQUIRED_WAV_LENGTH - waveform.shape[0]) // 2 waveform = F.pad(waveform, (to_pad, to_pad), "constant", 0) return waveform, sample_rate def predict(model_name, wav_file): x, fs = read_wav(wav_file) logger.info('wav file loaded') # set up model input model_input = x.unsqueeze(0).to(DEVICE) model_lengths = model_input.new_tensor([model_input.size(1)]).long() inputs = { config["model_input"]: model_input, config["model_input"] + "_lengths": model_lengths, } with torch.no_grad(): # model forward if config["inference_mode"] == "mean_listener": outputs = models[model_name].mean_listener_inference(inputs) elif config["inference_mode"] == "mean_net": outputs = models[model_name].mean_net_inference(inputs) pred_mean_scores = outputs["scores"].cpu().detach().numpy()[0] return pred_mean_scores with gr.Blocks(title="S3PRL-VC: Any-to-one voice conversion demo on VCC2020") as demo: gr.Markdown( """ # Demo for SHEET: Speech Human Evaluation Estimation Toolkit ### [Paper (To be uploaded)] [[Code]](https://github.com/unilight/sheet) **SHEET** is a subjective speech quality assessment (SSQA) toolkit designed to conduct SSQA research. It was specifically designed to interactive with MOS-Bench, a collective of datasets to benchmark SSQA models. In this demo, you can record your own voice or upload speech files to assess the quality. """ ) with gr.Row(): with gr.Column(): gr.Markdown("## Record your speech here!") input_wav = gr.Audio(label="Input speech", type='filepath') gr.Markdown("## Select a model!") model_name = gr.Radio(label="Model", choices=list(model_paths.keys())) evaluate_btn = gr.Button(value="Evaluate!") # gr.Markdown("### You can use these examples if using a microphone is too troublesome!") # gr.Markdown("I recorded the samples using my Macbook Pro, so there might be some noises.") # gr.Examples( # examples=en_examples, # inputs=input_wav, # label="English examples" # ) # gr.Examples( # examples=jp_examples, # inputs=input_wav, # label="Japanese examples" # ) # gr.Examples( # examples=zh_examples, # inputs=input_wav, # label="Mandarin examples" # ) with gr.Column(): gr.Markdown("## The predicted scores is here:") output_score = gr.Textbox(label="Prediction", interactive=False) evaluate_btn.click(predict, [model_name, input_wav], output_score) if __name__ == '__main__': try: demo.launch(debug=True) except KeyboardInterrupt as e: print(e) finally: demo.close()