thanhtl commited on
Commit
19812c5
·
verified ·
1 Parent(s): 7a5796b

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +297 -0
  2. audio.wav +0 -0
  3. requirements.txt +15 -0
app.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This demo is adopted from https://github.com/coqui-ai/TTS/blob/dev/TTS/demos/xtts_ft_demo/xtts_demo.py
2
+ # With some modifications to fit the viXTTS model
3
+ import argparse
4
+ import hashlib
5
+ import logging
6
+ import os
7
+ import string
8
+ import subprocess
9
+ import sys
10
+ import tempfile
11
+ from datetime import datetime
12
+
13
+ import gradio as gr
14
+ import torch
15
+ import torchaudio
16
+ from huggingface_hub import hf_hub_download, snapshot_download
17
+ from underthesea import sent_tokenize
18
+ from unidecode import unidecode
19
+
20
+ from TTS.tts.configs.xtts_config import XttsConfig
21
+ from TTS.tts.models.xtts import Xtts
22
+
23
+ XTTS_MODEL = None
24
+ SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
25
+ MODEL_DIR = os.path.join(SCRIPT_DIR, "model")
26
+ OUTPUT_DIR = os.path.join(SCRIPT_DIR, "output")
27
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
28
+
29
+
30
+ def clear_gpu_cache():
31
+ if torch.cuda.is_available():
32
+ torch.cuda.empty_cache()
33
+
34
+
35
+ def load_model(checkpoint_dir="model/", repo_id="capleaf/viXTTS", use_deepspeed=False):
36
+ global XTTS_MODEL
37
+ clear_gpu_cache()
38
+ os.makedirs(checkpoint_dir, exist_ok=True)
39
+
40
+ required_files = ["model.pth", "config.json", "vocab.json", "speakers_xtts.pth"]
41
+ files_in_dir = os.listdir(checkpoint_dir)
42
+ if not all(file in files_in_dir for file in required_files):
43
+ print(f"Missing model files! Downloading from {repo_id}...")
44
+ snapshot_download(
45
+ repo_id=repo_id,
46
+ repo_type="model",
47
+ local_dir=checkpoint_dir,
48
+ )
49
+ hf_hub_download(
50
+ repo_id="coqui/XTTS-v2",
51
+ filename="speakers_xtts.pth",
52
+ local_dir=checkpoint_dir,
53
+ )
54
+ print( f"Model download finished...")
55
+
56
+ xtts_config = os.path.join(checkpoint_dir, "config.json")
57
+ config = XttsConfig()
58
+ config.load_json(xtts_config)
59
+
60
+
61
+ XTTS_MODEL = Xtts.init_from_config(config)
62
+ print( "Loading model...")
63
+ XTTS_MODEL.load_checkpoint(
64
+ config, checkpoint_dir=checkpoint_dir, use_deepspeed=False
65
+ )
66
+ if torch.cuda.is_available():
67
+ XTTS_MODEL.cuda()
68
+ else:
69
+ print("use cpu")
70
+ XTTS_MODEL.cpu()
71
+
72
+ print("Model Loaded!")
73
+
74
+ return XTTS_MODEL
75
+ def generate_hash(data):
76
+ hash_object = hashlib.md5()
77
+ hash_object.update(data)
78
+ return hash_object.hexdigest()
79
+
80
+
81
+ def get_file_name(text, max_char=50):
82
+ filename = text[:max_char]
83
+ filename = filename.lower()
84
+ filename = filename.replace(" ", "_")
85
+ filename = filename.translate(
86
+ str.maketrans("", "", string.punctuation.replace("_", ""))
87
+ )
88
+ filename = unidecode(filename)
89
+ current_datetime = datetime.now().strftime("%m%d%H%M%S")
90
+ filename = f"{current_datetime}_{filename}"
91
+ return filename
92
+
93
+
94
+
95
+
96
+ def normalize_vietnamese_text(text):
97
+ text = (
98
+ text
99
+ .replace("..", ".")
100
+ .replace("!.", "!")
101
+ .replace("?.", "?")
102
+ .replace(" .", ".")
103
+ .replace(" ,", ",")
104
+ .replace('"', "")
105
+ .replace("'", "")
106
+ .replace("AI", "Ây Ai")
107
+ .replace("A.I", "Ây Ai")
108
+ )
109
+ return text
110
+
111
+
112
+ def calculate_keep_len(text, lang):
113
+ """Simple hack for short sentences"""
114
+ if lang in ["ja", "zh-cn"]:
115
+ return -1
116
+
117
+ word_count = len(text.split())
118
+ num_punct = text.count(".") + text.count("!") + text.count("?") + text.count(",")
119
+
120
+ if word_count < 5:
121
+ return 15000 * word_count + 2000 * num_punct
122
+ elif word_count < 10:
123
+ return 13000 * word_count + 2000 * num_punct
124
+ return -1
125
+
126
+
127
+ def run_tts(lang, tts_text, speaker_audio_file, normalize_text):
128
+ global XTTS_MODEL
129
+
130
+ if XTTS_MODEL is None:
131
+ return "You need to run the previous step to load the model !!", None, None
132
+
133
+ if not speaker_audio_file:
134
+ return "You need to provide reference audio!!!", None, None
135
+
136
+ print("Computing conditioning latents...")
137
+ gpt_cond_latent, speaker_embedding = XTTS_MODEL.get_conditioning_latents(
138
+ audio_path=speaker_audio_file,
139
+ gpt_cond_len=XTTS_MODEL.config.gpt_cond_len,
140
+ max_ref_length=XTTS_MODEL.config.max_ref_len,
141
+ sound_norm_refs=XTTS_MODEL.config.sound_norm_refs,
142
+ )
143
+
144
+ if normalize_text and lang == "vi":
145
+ tts_text = normalize_vietnamese_text(tts_text)
146
+
147
+ # Split text by sentence
148
+ if lang in ["ja", "zh-cn"]:
149
+ sentences = tts_text.split("。")
150
+ else:
151
+ sentences = sent_tokenize(tts_text)
152
+
153
+ from pprint import pprint
154
+
155
+ pprint(sentences)
156
+
157
+ wav_chunks = []
158
+ for sentence in sentences:
159
+ if sentence.strip() == "":
160
+ continue
161
+ wav_chunk = XTTS_MODEL.inference(
162
+ text=sentence,
163
+ language=lang,
164
+ gpt_cond_latent=gpt_cond_latent,
165
+ speaker_embedding=speaker_embedding,
166
+ # The following values are carefully chosen for viXTTS
167
+ temperature=0.3,
168
+ length_penalty=1.0,
169
+ repetition_penalty=10.0,
170
+ top_k=30,
171
+ top_p=0.85,
172
+ enable_text_splitting=True,
173
+ )
174
+
175
+ keep_len = calculate_keep_len(sentence, lang)
176
+ wav_chunk["wav"] = wav_chunk["wav"][:keep_len]
177
+
178
+ wav_chunks.append(torch.tensor(wav_chunk["wav"]))
179
+
180
+ out_wav = torch.cat(wav_chunks, dim=0).unsqueeze(0)
181
+ gr_audio_id = os.path.basename(os.path.dirname(speaker_audio_file))
182
+ out_path = os.path.join(OUTPUT_DIR, f"{get_file_name(tts_text)}_{gr_audio_id}.wav")
183
+ print("Saving output to ", out_path)
184
+ torchaudio.save(out_path, out_wav, 24000)
185
+
186
+ return "Speech generated !", out_path
187
+
188
+
189
+ # Define a logger to redirect
190
+ class Logger:
191
+ def __init__(self, filename="log.out"):
192
+ self.log_file = filename
193
+ self.terminal = sys.stdout
194
+ self.log = open(self.log_file, "w")
195
+
196
+ def write(self, message):
197
+ self.terminal.write(message)
198
+ self.log.write(message)
199
+
200
+ def flush(self):
201
+ self.terminal.flush()
202
+ self.log.flush()
203
+
204
+ def isatty(self):
205
+ return False
206
+
207
+
208
+ # Redirect stdout and stderr to a file
209
+ sys.stdout = Logger()
210
+ sys.stderr = sys.stdout
211
+
212
+
213
+ logging.basicConfig(
214
+ level=logging.ERROR,
215
+ format="%(asctime)s [%(levelname)s] %(message)s",
216
+ handlers=[logging.StreamHandler(sys.stdout)],
217
+ )
218
+
219
+
220
+ def read_logs():
221
+ sys.stdout.flush()
222
+ with open(sys.stdout.log_file, "r") as f:
223
+ return f.read()
224
+
225
+
226
+ if __name__ == "__main__":
227
+
228
+ REFERENCE_AUDIO = os.path.join(SCRIPT_DIR, "audio.wav")
229
+
230
+ print("start loading model")
231
+ XTTS_MODEL = load_model()
232
+
233
+ with gr.Blocks() as demo:
234
+ intro = """
235
+ # Fake giọng Demo
236
+ Customize from HuggingFace: [viXTTS](https://huggingface.co/capleaf/viXTTS)
237
+ """
238
+ gr.Markdown(intro)
239
+ with gr.Row():
240
+ with gr.Column() as col2:
241
+ speaker_reference_audio = gr.Audio(
242
+ label="Giọng đọc mẫu:",
243
+ value=REFERENCE_AUDIO,
244
+ type="filepath",
245
+ )
246
+
247
+ tts_language = gr.Dropdown(
248
+ label="Language",
249
+ value="vi",
250
+ choices=[
251
+ "vi",
252
+ "en",
253
+ "es",
254
+ "fr",
255
+ "de",
256
+ "it",
257
+ "pt",
258
+ "pl",
259
+ "tr",
260
+ "ru",
261
+ "nl",
262
+ "cs",
263
+ "ar",
264
+ "zh",
265
+ "hu",
266
+ "ko",
267
+ "ja",
268
+ ],
269
+ )
270
+
271
+ normalize_text = gr.Checkbox(
272
+ label="Normalize Input Text",
273
+ value=True,
274
+ )
275
+
276
+ tts_text = gr.Textbox(
277
+ label="Input Text.",
278
+ value="Xin chào, tôi là một công cụ chuyển đổi văn bản thành giọng nói tiếng Việt được phát triển bởi nhóm Nón lá.",
279
+ )
280
+ tts_btn = gr.Button(value="Inference", variant="primary")
281
+
282
+ with gr.Column() as col3:
283
+ progress_gen = gr.Label(label="Progress:")
284
+ tts_output_audio = gr.Audio(label="Kết quả.")
285
+
286
+ tts_btn.click(
287
+ fn=run_tts,
288
+ inputs=[
289
+ tts_language,
290
+ tts_text,
291
+ speaker_reference_audio,
292
+ normalize_text,
293
+ ],
294
+ outputs=[progress_gen, tts_output_audio],
295
+ )
296
+
297
+ demo.launch()
audio.wav ADDED
Binary file (459 kB). View file
 
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TTS @ git+https://github.com/thinhlpg/TTS.git@ff217b3f27b294de194cc59c5119d1e08b06413c
2
+
3
+ gradio
4
+ deepfilternet==0.5.6
5
+ vinorm==2.0.7
6
+ underthesea==6.8.0
7
+ deepspeed
8
+ cutlet
9
+ unidic
10
+
11
+ huggingface-hub~=0.27.0
12
+
13
+ torch~=2.2.2
14
+ torchaudio~=2.2.2
15
+ Unidecode~=1.3.8