wetdog's picture
Fix tempfile import
9ca2c45
raw
history blame
No virus
2.34 kB
import torch
import torchaudio
import spaces
from typing import List
import soundfile as sf
import gradio as gr
import tempfile
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
knn_vc = torch.hub.load('bshall/knn-vc', 'knn_vc', prematched=True, trust_repo=True, pretrained=True, device=device)
def convert_voice(src_wav_path:str, ref_wav_paths, top_k:int):
query_seq = knn_vc.get_features(src_wav_path)
matching_set = knn_vc.get_matching_set([ref_wav_paths])
out_wav = knn_vc.match(query_seq, matching_set, topk=int(top_k))
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as converted_file:
sf.write(converted_file.name, out_wav, 16000, "PCM_24")
return converted_file.name
title = """
<div style="text-align: center; max-width: 700px; margin: 0 auto;">
<div
style="display: inline-flex; align-items: center; gap: 0.8rem; font-size: 1.75rem;"
> <h1 style="font-weight: 900; margin-bottom: 7px; line-height: normal;">
KNN Voice Conversion
</h1> </div>
</div>
"""
description = """
Voice Conversion With Just k-Nearest Neighbors. The source and reference utterance(s) are encoded into self-supervised features using WavLM.
Each source feature is assigned to the mean of the k closest features from the reference.
The resulting feature sequence is then vocoded with HiFi-GAN to arrive at the converted waveform output.
"""
article = """
If the model contributes to your research please cite the following work:
Baas, M., van Niekerk, B., & Kamper, H. (2023). Voice conversion with just nearest neighbors. arXiv preprint arXiv:2305.18975.
demo contributed by [@wetdog](https://github.com/wetdog)
"""
demo = gr.Blocks()
with demo:
gr.Markdown(title)
gr.Markdown(description)
gr.Interface(
fn=convert_voice,
inputs=[
gr.Audio(type='filepath'),
gr.Audio(type='filepath'),
gr.Slider(
3,
10,
value=4,
step=1,
label="Top-k",
info=f"These default settings provide pretty good results, but feel free to modify the kNN topk",
)],
outputs=[gr.Audio(type='filepath')],
allow_flagging=False,)
gr.Markdown(article)
demo.queue(max_size=10)
demo.launch(show_api=False, server_name="0.0.0.0", server_port=7860)