File size: 3,568 Bytes
c52280c
 
 
 
 
16f5d6e
9ca2c45
a611372
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c52280c
 
 
 
 
 
 
a611372
 
 
 
 
c52280c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2bbd427
 
c52280c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import torch
import torchaudio
import spaces
from typing import List
import soundfile as sf
import gradio as gr
import tempfile
import subprocess

def convert_to_16kHz_mono(input_file, output_file):
    """
    Converts an audio file to 16KHz sample rate and single channel (mono) using ffmpeg.

    Parameters:
    input_file (str): Path to the input audio file.
    output_file (str): Path to the output WAV file.
    """
    try:
        # Run the ffmpeg command
        subprocess.run(['ffmpeg', '-y', '-i', input_file, '-ar', '16000', '-ac', '1', output_file], check=True)
        print(f"Conversion complete: {output_file}")
        return output_file
    except subprocess.CalledProcessError as e:
        print(f"An error occurred during conversion: {e}")

def create_temp_wav_file():
    # Create a temporary file using NamedTemporaryFile
    temp_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
    
    # Get the path of the temporary file
    temp_file_path = temp_file.name
    
    return temp_file_path

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
knn_vc = torch.hub.load('bshall/knn-vc', 'knn_vc', prematched=True, trust_repo=True, pretrained=True, device=device)


def convert_voice(src_wav_path:str, ref_wav_paths, top_k:int):

    tmp_src_wav_path = create_temp_wav_file()
    tmp_ref_wav_path = create_temp_wav_file()
    src_wav_path = convert_to_16kHz_mono(src_wav_path, tmp_src_wav_path)
    ref_wav_paths = convert_to_16kHz_mono(ref_wav_paths, tmp_ref_wav_path)

    query_seq = knn_vc.get_features(src_wav_path)
    matching_set = knn_vc.get_matching_set([ref_wav_paths])
    out_wav = knn_vc.match(query_seq, matching_set, topk=int(top_k))

    with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as converted_file:
            sf.write(converted_file.name, out_wav, 16000, "PCM_24")
    
    return converted_file.name


title =  """
<div style="text-align: center; max-width: 700px; margin: 0 auto;">
    <div
        style="display: inline-flex; align-items: center; gap: 0.8rem; font-size: 1.75rem;"
    > <h1 style="font-weight: 900; margin-bottom: 7px; line-height: normal;">
        KNN Voice Conversion
    </h1> </div>
</div>
"""     

description = """
Voice Conversion With Just k-Nearest Neighbors. The source and reference utterance(s) are encoded into self-supervised features using WavLM.
Each source feature is assigned to the mean of the k closest features from the reference.
The resulting feature sequence is then vocoded with HiFi-GAN to arrive at the converted waveform output.
"""

article = """
If the model contributes to your research please cite the following work: 

Baas, M., van Niekerk, B., & Kamper, H. (2023). Voice conversion with just nearest neighbors. arXiv preprint arXiv:2305.18975.

demo contributed by [@wetdog](https://github.com/wetdog)
"""
demo = gr.Blocks()
with demo:
    gr.Markdown(title)
    gr.Markdown(description)
    gr.Interface(
    fn=convert_voice,
    inputs=[
    gr.Audio(type='filepath'),
    gr.Audio(type='filepath'), 
    #gr.File(file_count="multiple", type="filepath", label="Reference Audio Files"),
    gr.Slider(
            3,
            10,
            value=4,
            step=1,
            label="Top-k",
            info=f"These default settings provide pretty good results, but feel free to modify the kNN topk",
        )],
    outputs=[gr.Audio(type='filepath')],
    allow_flagging=False,)
    gr.Markdown(article)

demo.queue(max_size=10)
demo.launch(show_api=False, server_name="0.0.0.0", server_port=7860)