Hev832 commited on
Commit
c2301e9
·
verified ·
1 Parent(s): a2ee446

Create infer-web.py

Browse files
Files changed (1) hide show
  1. infer-web.py +137 -0
infer-web.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ import subprocess
4
+ from mega import Mega
5
+ import gradio as gr
6
+
7
+ def download_from_url(url, model):
8
+ try:
9
+ model = model.replace('.pth', '').replace('.index', '').replace('.zip', '')
10
+ url = url.replace('/blob/main/', '/resolve/main/').strip()
11
+
12
+ for directory in ["downloads", "unzips", "zip"]:
13
+ os.makedirs(directory, exist_ok=True)
14
+
15
+ if url.endswith('.pth'):
16
+ subprocess.run(["wget", url, "-O", f'assets/weights/{model}.pth'])
17
+ elif url.endswith('.index'):
18
+ os.makedirs(f'logs/{model}', exist_ok=True)
19
+ subprocess.run(["wget", url, "-O", f'logs/{model}/added_{model}.index'])
20
+ elif url.endswith('.zip'):
21
+ subprocess.run(["wget", url, "-O", f'downloads/{model}.zip'])
22
+ else:
23
+ if "drive.google.com" in url:
24
+ url = url.split('/')[0]
25
+ subprocess.run(["gdown", url, "--fuzzy", "-O", f'downloads/{model}'])
26
+ elif "mega.nz" in url:
27
+ Mega().download_url(url, 'downloads')
28
+ else:
29
+ subprocess.run(["wget", url, "-O", f'downloads/{model}'])
30
+
31
+ downloaded_file = next((f for f in os.listdir("downloads")), None)
32
+ if downloaded_file:
33
+ if downloaded_file.endswith(".zip"):
34
+ shutil.unpack_archive(f'downloads/{downloaded_file}', "unzips", 'zip')
35
+ for root, _, files in os.walk('unzips'):
36
+ for file in files:
37
+ file_path = os.path.join(root, file)
38
+ if file.endswith(".index"):
39
+ os.makedirs(f'logs/{model}', exist_ok=True)
40
+ shutil.copy2(file_path, f'logs/{model}')
41
+ elif file.endswith(".pth") and "G_" not in file and "D_" not in file:
42
+ shutil.copy(file_path, f'assets/weights/{model}.pth')
43
+ elif downloaded_file.endswith(".pth"):
44
+ shutil.copy(f'downloads/{downloaded_file}', f'assets/weights/{model}.pth')
45
+ elif downloaded_file.endswith(".index"):
46
+ os.makedirs(f'logs/{model}', exist_ok=True)
47
+ shutil.copy(f'downloads/{downloaded_file}', f'logs/{model}/added_{model}.index')
48
+ else:
49
+ return "Failed to download file"
50
+ return f"Successfully downloaded {model} voice models"
51
+ except Exception as e:
52
+ return f"Error: {str(e)}"
53
+ finally:
54
+ shutil.rmtree("downloads", ignore_errors=True)
55
+ shutil.rmtree("unzips", ignore_errors=True)
56
+ shutil.rmtree("zip", ignore_errors=True)
57
+
58
+ def listen_to_model(model_path, index_path, pitch, input_path, f0_method, save_as, index_rate, volume_normalization, consonant_protection):
59
+ if not os.path.exists(model_path):
60
+ return "Model path not found"
61
+ if not os.path.exists(index_path):
62
+ return f"{index_path} was not found"
63
+ if not os.path.exists(input_path):
64
+ return f"{input_path} was not found"
65
+
66
+ os.environ['index_root'] = os.path.dirname(index_path)
67
+ index_path = os.path.basename(index_path)
68
+ model_name = os.path.basename(model_path)
69
+ os.environ['weight_root'] = os.path.dirname(model_path)
70
+
71
+ try:
72
+ command = [
73
+ "python", "tools/infer_cli.py",
74
+ "--f0up_key", str(pitch),
75
+ "--input_path", input_path,
76
+ "--index_path", index_path,
77
+ "--f0method", f0_method,
78
+ "--opt_path", save_as,
79
+ "--model_name", model_name,
80
+ "--index_rate", str(index_rate),
81
+ "--device", "cuda:0",
82
+ "--is_half", "True",
83
+ "--filter_radius", "3",
84
+ "--resample_sr", "0",
85
+ "--rms_mix_rate", str(volume_normalization),
86
+ "--protect", str(consonant_protection)
87
+ ]
88
+ subprocess.run(command, check=True)
89
+ return save_as
90
+ except subprocess.CalledProcessError as e:
91
+ return f"Error: {str(e)}"
92
+
93
+ with gr.Blocks() as demo:
94
+ gr.Markdown("# RVC V2 Web UI")
95
+
96
+ with gr.Tabs():
97
+ with gr.TabItem("Download Model"):
98
+ gr.Markdown("### Download RVC Model")
99
+ url_input = gr.Textbox(label="Model URL", placeholder="Enter the model URL here")
100
+ model_input = gr.Textbox(label="Model Name", placeholder="Enter the model name here")
101
+ download_button = gr.Button("Download")
102
+ download_output = gr.Textbox(label="Download Status")
103
+ download_button.click(download_from_url, inputs=[url_input, model_input], outputs=download_output)
104
+
105
+ with gr.TabItem("Listen to Model"):
106
+ gr.Markdown("### Listen to Your Model")
107
+ model_path_input = gr.Textbox(label="Model Path", value="/content/RVC/assets/weights/Sonic.pth")
108
+ index_path_input = gr.Textbox(label="Index Path", value="/content/RVC/logs/Sonic/added_IVF905_Flat_nprobe_1_Sonic_v2.index")
109
+ input_path_input = gr.Textbox(label="Input Audio Path", value="/content/RVC/audios/astronauts.mp3")
110
+ save_as_input = gr.Textbox(label="Save Output As", value="/content/RVC/audios/cli_output.wav")
111
+ f0_method_input = gr.Radio(label="F0 Method", choices=["rmvpe", "pm", "harvest"], value="rmvpe")
112
+
113
+ with gr.Row():
114
+ pitch_input = gr.Slider(label="Pitch", minimum=-12, maximum=12, step=1, value=0)
115
+ index_rate_input = gr.Slider(label="Index Rate", minimum=0, maximum=1, step=0.01, value=0.5)
116
+ volume_normalization_input = gr.Slider(label="Volume Normalization", minimum=0, maximum=1, step=0.01, value=0)
117
+ consonant_protection_input = gr.Slider(label="Consonant Protection", minimum=0, maximum=1, step=0.01, value=0.5)
118
+
119
+ listen_button = gr.Button("Generate and Listen")
120
+ audio_output = gr.Audio(label="Output Audio")
121
+ listen_button.click(
122
+ listen_to_model,
123
+ inputs=[
124
+ model_path_input,
125
+ index_path_input,
126
+ pitch_input,
127
+ input_path_input,
128
+ f0_method_input,
129
+ save_as_input,
130
+ index_rate_input,
131
+ volume_normalization_input,
132
+ consonant_protection_input
133
+ ],
134
+ outputs=audio_output
135
+ )
136
+
137
+ demo.launch()