sheet-demo / app.py
unilight
fix sslmos
b7ad4c6
raw
history blame
6.76 kB
import os
import torch.nn.functional as F
import torchaudio
from loguru import logger
import gradio as gr
from huggingface_hub import hf_hub_download
import torch
import yaml
# from s3prl_vc.upstream.interface import get_upstream
# from s3prl.nn import Featurizer
# import s3prl_vc.models
# from s3prl_vc.utils import read_hdf5
# from s3prl_vc.vocoder import Vocoder
# ---------- Settings ----------
GPU_ID = '-1'
os.environ['CUDA_VISIBLE_DEVICES'] = GPU_ID
DEVICE = 'cuda' if GPU_ID != '-1' else 'cpu'
SERVER_PORT = 42208
SERVER_NAME = "0.0.0.0"
SSL_DIR = './keyble_ssl'
FS = 16000
resamplers = {}
MIN_REQUIRED_WAV_LENGTH = 1040
# EXAMPLE_DIR = './examples'
# en_examples = sorted(glob(os.path.join(EXAMPLE_DIR, "en", '*.wav')))
# jp_examples = sorted(glob(os.path.join(EXAMPLE_DIR, "jp", '*.wav')))
# zh_examples = sorted(glob(os.path.join(EXAMPLE_DIR, "zh", '*.wav')))
# TRGSPKS = ["TEF1", "TEF2", "TEM1", "TEM2"]
# ref_samples = {
# trgspk: sorted(glob(os.path.join("./ref_samples", trgspk, '*.wav')))
# for trgspk in TRGSPKS
# }
# ---------- Logging ----------
logger.add('app.log', mode='a')
logger.info('============================= App restarted =============================')
# ---------- Download models ----------
logger.info('============================= Download models ===========================')
model_paths = {
"SSL-MOS, all training sets": {
"ckpt": hf_hub_download(repo_id="unilight/sheet-models", filename="bvcc+nisqa+pstn+singmos+somos+tencent+tmhint-qi/sslmos+mdf/2337/checkpoint-86000steps.pkl"),
"config": hf_hub_download(repo_id="unilight/sheet-models", filename="bvcc+nisqa+pstn+singmos+somos+tencent+tmhint-qi/sslmos+mdf/2337/config.yml"),
}
}
# ---------- Model ----------
models = {}
for name, path_dict in model_paths.items():
logger.info(f'============================= Setting up model for {name} =============')
checkpoint_path = path_dict["ckpt"]
config_path = path_dict["config"]
with open(config_path) as f:
config = yaml.load(f, Loader=yaml.Loader)
if config["model_type"] == "SSLMOS":
from models.sslmos import SSLMOS
model = SSLMOS(
config["model_input"],
num_listeners=config.get("num_listeners", None),
num_domains=config.get("num_domains", None),
**config["model_params"],
).to(DEVICE)
model.load_state_dict(torch.load(checkpoint_path, map_location="cpu")["model"])
model = model.eval().to(DEVICE)
logger.info(f"Loaded model parameters from {checkpoint_path}.")
models[name] = model
def read_wav(wav_path):
# read waveform
waveform, sample_rate = torchaudio.load(
wav_path, channels_first=False
) # waveform: [T, 1]
# resample if needed
if sample_rate != FS:
resampler_key = f"{sample_rate}-{FS}"
if resampler_key not in resamplers:
resamplers[resampler_key] = torchaudio.transforms.Resample(
sample_rate, FS, dtype=waveform.dtype
)
waveform = resamplers[resampler_key](waveform)
waveform = waveform.squeeze(-1)
# always pad to a minumum length
if waveform.shape[0] < MIN_REQUIRED_WAV_LENGTH:
to_pad = (MIN_REQUIRED_WAV_LENGTH - waveform.shape[0]) // 2
waveform = F.pad(waveform, (to_pad, to_pad), "constant", 0)
return waveform, sample_rate
def predict(model_name, wav_file):
x, fs = read_wav(wav_file)
logger.info('wav file loaded')
# set up model input
model_input = x.unsqueeze(0).to(DEVICE)
model_lengths = model_input.new_tensor([model_input.size(1)]).long()
inputs = {
config["model_input"]: model_input,
config["model_input"] + "_lengths": model_lengths,
}
with torch.no_grad():
# model forward
if config["inference_mode"] == "mean_listener":
outputs = models[model_name].mean_listener_inference(inputs)
elif config["inference_mode"] == "mean_net":
outputs = models[model_name].mean_net_inference(inputs)
pred_mean_scores = outputs["scores"].cpu().detach().numpy()[0]
return pred_mean_scores
with gr.Blocks(title="S3PRL-VC: Any-to-one voice conversion demo on VCC2020") as demo:
gr.Markdown(
"""
# S3PRL-VC: Any-to-one voice conversion demo on VCC2020
### [[Paper (ICASSP2023)]](https://arxiv.org/abs/2110.06280) [[Paper(JSTSP)]](https://arxiv.org/abs/2207.04356) [[Code]](https://github.com/unilight/s3prl-vc)
**S3PRL-VC** is a voice conversion (VC) toolkit for benchmarking self-supervised speech representations (S3Rs). The term **any-to-one** means that the system can convert from any unseen speaker to a pre-defined speaker given in training.
In this demo, you can record your voice, and the model will convert your voice to one of the four pre-defined speakers. These four speakers come from the **voice conversion challenge (VCC) 2020**. You can listen to the samples to get a sense of what these speakers sound like.
The **RTF** of the system is around **1.5~2.5**, i.e. if you recorded a 5 second long audio, it will take 5 * (1.5~2.5) = 7.5~12.5 seconds to generate the output.
"""
)
with gr.Row():
with gr.Column():
gr.Markdown("## Record your speech here!")
input_wav = gr.Audio(label="Input speech", source='microphone', type='filepath')
gr.Markdown("## Select a model!")
model_name = gr.Radio(label="Model", choices=list(model_paths.keys()))
evaluate_btn = gr.Button(value="Evaluate!")
# gr.Markdown("### You can use these examples if using a microphone is too troublesome!")
# gr.Markdown("I recorded the samples using my Macbook Pro, so there might be some noises.")
# gr.Examples(
# examples=en_examples,
# inputs=input_wav,
# label="English examples"
# )
# gr.Examples(
# examples=jp_examples,
# inputs=input_wav,
# label="Japanese examples"
# )
# gr.Examples(
# examples=zh_examples,
# inputs=input_wav,
# label="Mandarin examples"
# )
with gr.Column():
gr.Markdown("## The predicted scores is here:")
output_score = gr.Textbox(label="Prediction", interactive=False)
evaluate_btn.click(predict, [model_name, input_wav], output_score)
if __name__ == '__main__':
try:
demo.launch(debug=True,
enable_queue=True,
)
except KeyboardInterrupt as e:
print(e)
finally:
demo.close()