MrSimple07 commited on
Commit
1913b79
·
verified ·
1 Parent(s): b258b64

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +94 -0
app.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import gradio as gr
4
+ from huggingface_hub import hf_hub_download
5
+ import langid
6
+ from openvoice.api import BaseSpeakerTTS, ToneColorConverter
7
+ import openvoice.se_extractor as se_extractor
8
+
9
+ # Constants
10
+ CKPT_BASE_PATH = "checkpoints"
11
+ EN_SUFFIX = f"{CKPT_BASE_PATH}/base_speakers/EN"
12
+ CONVERTER_SUFFIX = f"{CKPT_BASE_PATH}/converter"
13
+ OUTPUT_DIR = "outputs/"
14
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
15
+
16
+ # Download necessary files
17
+ def download_from_hf_hub(filename, local_dir="./"):
18
+ os.makedirs(local_dir, exist_ok=True)
19
+ hf_hub_download(repo_id="myshell-ai/OpenVoice", filename=filename, local_dir=local_dir)
20
+
21
+ for file in [f"{CONVERTER_SUFFIX}/checkpoint.pth", f"{CONVERTER_SUFFIX}/config.json",
22
+ f"{EN_SUFFIX}/checkpoint.pth", f"{EN_SUFFIX}/config.json",
23
+ f"{EN_SUFFIX}/en_default_se.pth", f"{EN_SUFFIX}/en_style_se.pth"]:
24
+ download_from_hf_hub(file)
25
+
26
+ # Initialize models
27
+ pt_device = "cpu"
28
+ en_base_speaker_tts = BaseSpeakerTTS(f"{EN_SUFFIX}/config.json", device=pt_device)
29
+ en_base_speaker_tts.load_ckpt(f"{EN_SUFFIX}/checkpoint.pth")
30
+
31
+ tone_color_converter = ToneColorConverter(f"{CONVERTER_SUFFIX}/config.json", device=pt_device)
32
+ tone_color_converter.load_ckpt(f"{CONVERTER_SUFFIX}/checkpoint.pth")
33
+
34
+ en_source_default_se = torch.load(f"{EN_SUFFIX}/en_default_se.pth")
35
+ en_source_style_se = torch.load(f"{EN_SUFFIX}/en_style_se.pth")
36
+
37
+ # Main prediction function
38
+ def predict(prompt, style, audio_file_pth, tau):
39
+ if len(prompt) < 2 or len(prompt) > 200:
40
+ return "Text should be between 2 and 200 characters.", None
41
+
42
+ try:
43
+ target_se, _ = se_extractor.get_se(audio_file_pth, tone_color_converter, target_dir=OUTPUT_DIR, vad=True)
44
+ except Exception as e:
45
+ return f"Error getting target tone color: {str(e)}", None
46
+
47
+ src_path = f"{OUTPUT_DIR}/tmp.wav"
48
+ en_base_speaker_tts.tts(prompt, src_path, speaker=style, language="English")
49
+
50
+ save_path = f"{OUTPUT_DIR}/output.wav"
51
+ tone_color_converter.convert(
52
+ audio_src_path=src_path,
53
+ src_se=en_source_style_se if style != "default" else en_source_default_se,
54
+ tgt_se=target_se,
55
+ output_path=save_path,
56
+ tau=tau
57
+ )
58
+
59
+ return "Voice cloning completed successfully.", save_path
60
+
61
+ # Gradio interface
62
+ def create_demo():
63
+ with gr.Blocks() as demo:
64
+ gr.Markdown("# OpenVoice: Instant Voice Cloning with fine-tuning")
65
+
66
+ with gr.Row():
67
+ input_text = gr.Textbox(label="Text to speak", placeholder="Enter text here (2-200 characters)")
68
+ style = gr.Dropdown(
69
+ label="Style",
70
+ choices=["default", "whispering", "cheerful", "terrified", "angry", "sad", "friendly"],
71
+ value="default"
72
+ )
73
+
74
+ with gr.Row():
75
+ reference_audio = gr.Audio(label="Reference Audio", type="filepath")
76
+ tau_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.7, label="Tau (Voice similarity)", info="Higher values make the output more similar to the reference voice")
77
+
78
+ submit_button = gr.Button("Generate Voice")
79
+
80
+ output_text = gr.Textbox(label="Status")
81
+ output_audio = gr.Audio(label="Generated Audio")
82
+
83
+ submit_button.click(
84
+ predict,
85
+ inputs=[input_text, style, reference_audio, tau_slider],
86
+ outputs=[output_text, output_audio]
87
+ )
88
+
89
+ return demo
90
+
91
+ # Launch the demo
92
+ if __name__ == "__main__":
93
+ demo = create_demo()
94
+ demo.launch()