AlexK-PL commited on
Commit
9a5d905
1 Parent(s): 18a1e06

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +241 -0
app.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tempfile
2
+ import subprocess
3
+ import time
4
+
5
+ from typing import Optional
6
+ from AinaTheme import AinaGradioTheme
7
+ import gradio as gr
8
+ import numpy as np
9
+ import torch
10
+ import os
11
+ from TTS.utils.synthesizer import Synthesizer
12
+
13
+ from dotenv import load_dotenv
14
+
15
+ torch.manual_seed(0)
16
+ np.random.seed(0)
17
+
18
+ # CleanUnet Dependencies
19
+
20
+ import json
21
+ from copy import deepcopy
22
+
23
+ import numpy as np
24
+ import torch
25
+
26
+ # from util import print_size, sampling
27
+
28
+ import torchaudio
29
+ import torchaudio.transforms as T
30
+
31
+ import random
32
+
33
+ random.seed(0)
34
+ torch.manual_seed(0)
35
+ np.random.seed(0)
36
+
37
+ SAMPLE_RATE = 8000
38
+
39
+ CONFIG = "configs/DNS-large-full.json"
40
+ # CHECKPOINT = "./exp/DNS-large-full/checkpoint/pretrained.pkl"
41
+
42
+ # Parse configs. Globals nicer in this case
43
+ with open(CONFIG) as f:
44
+ data = f.read()
45
+ config = json.loads(data)
46
+ gen_config = config["gen_config"]
47
+ global network_config
48
+ network_config = config["network_config"] # to define wavenet
49
+ global train_config
50
+ train_config = config["train_config"] # train config
51
+ global trainset_config
52
+ trainset_config = config["trainset_config"] # to read trainset configurations
53
+
54
+ # global use_denoise
55
+ # use_denoise = False
56
+
57
+ # setup local experiment path
58
+ exp_path = train_config["exp_path"]
59
+ print('exp_path:', exp_path)
60
+
61
+ # load data
62
+ loader_config = deepcopy(trainset_config)
63
+ loader_config["crop_length_sec"] = 0
64
+
65
+ #############################################################################################################
66
+
67
+ load_dotenv()
68
+
69
+ MAX_INPUT_TEXT_LEN = int(os.environ.get("MAX_INPUT_TEXT_LEN", default=500))
70
+
71
+ # Dynamically read model files, exclude 'speakers.pth'
72
+ model_files = [f for f in os.listdir(os.getcwd()) if f.endswith('.pth') and f != 'speakers.pth']
73
+ model_files.sort(key=lambda x: os.path.getmtime(os.path.join(os.getcwd(), x)), reverse=True)
74
+
75
+ speakers_path = "speakers.pth"
76
+ speakers_list = torch.load(speakers_path)
77
+ speakers_list = list(speakers_list.keys())
78
+ speakers_list = [speaker for speaker in speakers_list]
79
+
80
+ default_speaker_list = speakers_list #
81
+
82
+ # Filtered lists based on dataset
83
+ festcat_speakers = [s for s in speakers_list if len(s) == 3] #
84
+ google_speakers = [s for s in speakers_list if 3 < len(s) < 20] #
85
+ commonvoice_speakers = [s for s in speakers_list if len(s) > 20] #
86
+
87
+ DEFAULT_SPEAKER_ID = os.environ.get("DEFAULT_SPEAKER_ID", default="pau")
88
+ model_file = model_files[0] # change this!!
89
+
90
+ model_path = os.path.join(os.getcwd(), model_file)
91
+ config_path = "config.json"
92
+
93
+ vocoder_path = None
94
+ vocoder_config_path = None
95
+
96
+ synthesizer = Synthesizer(
97
+ model_path, config_path, speakers_path, None, vocoder_path, vocoder_config_path,
98
+ )
99
+
100
+
101
+ def get_phonetic_transcription(text: str):
102
+ try:
103
+ result = subprocess.run(
104
+ ['espeak-ng', '--ipa', '-v', 'ca', text],
105
+ stdout=subprocess.PIPE,
106
+ stderr=subprocess.PIPE,
107
+ text=True,
108
+ check=True
109
+ )
110
+ return result.stdout.strip()
111
+ except subprocess.CalledProcessError as e:
112
+ print(f"An error occurred: {e}")
113
+ return None
114
+
115
+
116
+ def tts_inference(text: str, speaker_idx: str = None, use_denoise: int = 0):
117
+ # synthesize
118
+ if synthesizer is None:
119
+ raise NameError("model not found")
120
+ t1 = time.time()
121
+ wavs = synthesizer.tts(text, speaker_idx)
122
+ print(type(wavs))
123
+ if use_denoise == 0:
124
+ wavs_den = torch.Tensor(wavs).unsqueeze(0) # one sample
125
+ # wavs_den = denoise(wavs_den).tolist()
126
+ else:
127
+ wavs_den = wavs
128
+
129
+ # return output
130
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp:
131
+ # wavs must be a list of integers
132
+ synthesizer.save_wav(wavs, fp)
133
+ t2 = time.time() - t1
134
+ print(round(t2, 2))
135
+ output_audio = fp.name
136
+
137
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp:
138
+ # wavs must be a list of integers
139
+ synthesizer.save_wav(wavs_den, fp)
140
+ output_audio_den = fp.name
141
+
142
+ return output_audio, output_audio_den
143
+
144
+
145
+ title = "🗣️ Catalan Multispeaker TTS Tester 🗣️"
146
+ description = """
147
+ 1️⃣ Enter the text to synthesize.
148
+ 2️⃣ Select a voice from the dropdown menu.
149
+ 3️⃣ Enjoy!
150
+ """
151
+
152
+
153
+ def submit_input(input_, speaker_id, use_dn):
154
+ output_audio = None
155
+ output_phonetic = None
156
+ if input_ is not None and len(input_) < MAX_INPUT_TEXT_LEN:
157
+ output_audio, output_audio_den = tts_inference(input_, speaker_id, use_dn)
158
+ output_phonetic = get_phonetic_transcription(input_)
159
+ else:
160
+ gr.Warning(f"Your text exceeds the {MAX_INPUT_TEXT_LEN}-character limit.")
161
+ return output_audio, output_audio_den, output_phonetic
162
+
163
+
164
+ def change_interactive(text):
165
+ input_state = text
166
+ if input_state.strip() != "":
167
+ return gr.update(interactive=True)
168
+ else:
169
+ return gr.update(interactive=False)
170
+
171
+
172
+ def clean():
173
+ return (
174
+ None,
175
+ None,
176
+ )
177
+
178
+
179
+ with gr.Blocks(**AinaGradioTheme().get_kwargs()) as app:
180
+ gr.Markdown(f"<h1 style='text-align: center; margin-bottom: 1rem'>{title}</h1>")
181
+ gr.Markdown(description)
182
+
183
+ with gr.Row(equal_height=False):
184
+
185
+ with gr.Column(variant='panel'):
186
+ input_ = gr.Textbox(
187
+ label="Text",
188
+ value="Introdueix el text a sintetitzar.",
189
+ lines=4
190
+ )
191
+
192
+ dataset = gr.Radio(["All", "Festcat", "Google TTS", "CommonVoice"], label="Speakers Dataset",
193
+ value="All")
194
+
195
+
196
+ def update_speaker_list(dataset):
197
+ print("Updating speaker list based on dataset:", dataset)
198
+ if dataset == "Festcat":
199
+ current_speakers = festcat_speakers
200
+ elif dataset == "Google TTS":
201
+ current_speakers = google_speakers
202
+ elif dataset == "CommonVoice":
203
+ current_speakers = commonvoice_speakers
204
+ else:
205
+ current_speakers = speakers_list
206
+
207
+ return gr.update(choices=current_speakers, value=current_speakers[0])
208
+
209
+
210
+ speaker_id = gr.Dropdown(label="Select a voice", choices=speakers_list, value=DEFAULT_SPEAKER_ID,
211
+ interactive=True)
212
+ dataset.change(fn=update_speaker_list, inputs=dataset, outputs=speaker_id)
213
+
214
+ # model = gr.Dropdown(label="Select a model", choices=model_files, value=DEFAULT_MODEL_FILE_NAME)
215
+ with gr.Row():
216
+ clear_btn = gr.ClearButton(value='Clean', components=[input_])
217
+ # clear_btn = gr.Button(
218
+ # "Clean",
219
+ # )
220
+ submit_btn = gr.Button(
221
+ "Submit",
222
+ variant="primary",
223
+ )
224
+ use_denoise = gr.Radio(choices=[("Yes", 0), ("No", 1)], value=0)
225
+ with gr.Column(variant='panel'):
226
+ output_audio = gr.Audio(label="Output", type="filepath", autoplay=True, show_share_button=False)
227
+ output_audio_den = gr.Audio(label="Output denoised", type="filepath", autoplay=False,
228
+ show_share_button=False)
229
+
230
+ output_phonetic = gr.Textbox(label="Phonetic Transcription", readonly=True)
231
+
232
+ for button in [submit_btn]: # clear_btn
233
+ input_.change(fn=change_interactive, inputs=[input_], outputs=button)
234
+
235
+ # clear_btn.click(fn=clean, inputs=[], outputs=[input_, output_audio, output_phonetic], queue=False)
236
+ submit_btn.click(fn=submit_input, inputs=[input_, speaker_id, use_denoise], outputs=[output_audio,
237
+ output_audio_den,
238
+ output_phonetic])
239
+
240
+ app.queue(concurrency_count=1, api_open=False)
241
+ app.launch(show_api=False, server_name="0.0.0.0", server_port=7860)