Files changed (2) hide show
  1. app.py +45 -28
  2. requirements.txt +2 -1
app.py CHANGED
@@ -19,23 +19,48 @@ from models.tts.maskgct.g2p.g2p_generation import g2p, chn_eng_g2p
19
 
20
  from transformers import SeamlessM4TFeatureExtractor
21
 
22
- # import whisperx
 
23
 
24
  processor = SeamlessM4TFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0")
25
 
26
  device = torch.device("cuda" if torch.cuda.is_available() else "CPU")
27
 
28
- # whisper_model = whisperx.load_model("small", "cuda", compute_type="int8")
29
 
30
- # @torch.no_grad()
31
- # def get_prompt_text(speech_16k):
32
- # asr_result = whisper_model.transcribe(speech_16k)
33
- # print("asr_result:", asr_result)
34
- # language = asr_result["language"]
35
- # #text = asr_result["text"] # whisper asr result
36
- # text = asr_result["segments"][0]["text"]
37
- # print("prompt text:", text)
38
- # return text, language
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
 
41
  def g2p_(text, language):
@@ -279,10 +304,7 @@ def load_models():
279
  @torch.no_grad()
280
  def maskgct_inference(
281
  prompt_speech_path,
282
- prompt_text,
283
  target_text,
284
- language="en",
285
- target_language="en",
286
  target_len=None,
287
  n_timesteps=25,
288
  cfg=2.5,
@@ -295,14 +317,18 @@ def maskgct_inference(
295
  speech_16k = librosa.load(prompt_speech_path, sr=16000)[0]
296
  speech = librosa.load(prompt_speech_path, sr=24000)[0]
297
 
298
- # if prompt_text is None:
299
- # prompt_text, language = get_prompt_text(prompt_speech_path)
300
-
 
 
 
 
301
  combine_semantic_code, _ = text2semantic(
302
  device,
303
  speech_16k,
304
- prompt_text,
305
- language,
306
  target_text,
307
  target_language,
308
  target_len,
@@ -326,21 +352,15 @@ def maskgct_inference(
326
  @spaces.GPU
327
  def inference(
328
  prompt_wav,
329
- prompt_text,
330
  target_text,
331
  target_len,
332
  n_timesteps,
333
- language,
334
- target_language,
335
  ):
336
  save_path = "./output/output.wav"
337
  os.makedirs("./output", exist_ok=True)
338
  recovered_audio = maskgct_inference(
339
  prompt_wav,
340
- prompt_text,
341
  target_text,
342
- language,
343
- target_language,
344
  target_len=target_len,
345
  n_timesteps=int(n_timesteps),
346
  device=device,
@@ -369,7 +389,6 @@ iface = gr.Interface(
369
  fn=inference,
370
  inputs=[
371
  gr.Audio(label="Upload Prompt Wav", type="filepath"),
372
- gr.Textbox(label="Prompt Text"),
373
  gr.Textbox(label="Target Text"),
374
  gr.Number(
375
  label="Target Duration (in seconds), if the target duration is less than 0, the system will estimate a duration.", value=-1
@@ -377,8 +396,6 @@ iface = gr.Interface(
377
  gr.Slider(
378
  label="Number of Timesteps", minimum=15, maximum=100, value=25, step=1
379
  ),
380
- gr.Dropdown(label="Language", choices=language_list, value="en"),
381
- gr.Dropdown(label="Target Language", choices=language_list, value="en"),
382
  ],
383
  outputs=gr.Audio(label="Generated Audio"),
384
  title="MaskGCT TTS Demo",
 
19
 
20
  from transformers import SeamlessM4TFeatureExtractor
21
 
22
+ import whisper
23
+ import langid
24
 
25
  processor = SeamlessM4TFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0")
26
 
27
  device = torch.device("cuda" if torch.cuda.is_available() else "CPU")
28
 
29
+ whisper_model = whisper.load_model("turbo")
30
 
31
+ def detect_speech_language(speech_file):
32
+ # load audio and pad/trim it to fit 30 seconds
33
+ audio = whisper.load_audio(speech_file)
34
+ audio = whisper.pad_or_trim(audio)
35
+
36
+ # make log-Mel spectrogram and move to the same device as the model
37
+ mel = whisper.log_mel_spectrogram(audio, n_mels=128).to(whisper_model.device)
38
+
39
+ # detect the spoken language
40
+ _, probs = whisper_model.detect_language(mel)
41
+ return max(probs, key=probs.get)
42
+
43
+
44
+ def detect_text_language(text):
45
+ return langid.classify(text)[0]
46
+
47
+ @torch.no_grad()
48
+ def get_prompt_text(speech_16k, language):
49
+ full_prompt_text = ""
50
+ shot_prompt_text = ""
51
+ short_prompt_end_ts = 0.0
52
+
53
+ asr_result = whisper_model.transcribe(speech_16k, language=language)
54
+ full_prompt_text = asr_result["text"] # whisper asr result
55
+ #text = asr_result["segments"][0]["text"] # whisperx asr result
56
+ shot_prompt_text = ""
57
+ short_prompt_end_ts = 0.0
58
+ for segment in asr_result["segments"]:
59
+ shot_prompt_text = shot_prompt_text + segment['text']
60
+ short_prompt_end_ts = segment['end']
61
+ if short_prompt_end_ts >= 4:
62
+ break
63
+ return full_prompt_text, shot_prompt_text, short_prompt_end_ts
64
 
65
 
66
  def g2p_(text, language):
 
304
  @torch.no_grad()
305
  def maskgct_inference(
306
  prompt_speech_path,
 
307
  target_text,
 
 
308
  target_len=None,
309
  n_timesteps=25,
310
  cfg=2.5,
 
317
  speech_16k = librosa.load(prompt_speech_path, sr=16000)[0]
318
  speech = librosa.load(prompt_speech_path, sr=24000)[0]
319
 
320
+ prompt_language = detect_speech_language(prompt_speech_path)
321
+ full_prompt_text, short_prompt_text, shot_prompt_end_ts = get_prompt_text(prompt_speech_path,
322
+ prompt_language)
323
+ # use the first 4+ seconds wav as the prompt in case the prompt wav is too long
324
+ speech = speech[0: int(shot_prompt_end_ts * 24000)]
325
+ speech_16k = speech_16k[0: int(shot_prompt_end_ts*16000)]
326
+ target_language = detect_text_language(target_text)
327
  combine_semantic_code, _ = text2semantic(
328
  device,
329
  speech_16k,
330
+ short_prompt_text,
331
+ prompt_language,
332
  target_text,
333
  target_language,
334
  target_len,
 
352
  @spaces.GPU
353
  def inference(
354
  prompt_wav,
 
355
  target_text,
356
  target_len,
357
  n_timesteps,
 
 
358
  ):
359
  save_path = "./output/output.wav"
360
  os.makedirs("./output", exist_ok=True)
361
  recovered_audio = maskgct_inference(
362
  prompt_wav,
 
363
  target_text,
 
 
364
  target_len=target_len,
365
  n_timesteps=int(n_timesteps),
366
  device=device,
 
389
  fn=inference,
390
  inputs=[
391
  gr.Audio(label="Upload Prompt Wav", type="filepath"),
 
392
  gr.Textbox(label="Target Text"),
393
  gr.Number(
394
  label="Target Duration (in seconds), if the target duration is less than 0, the system will estimate a duration.", value=-1
 
396
  gr.Slider(
397
  label="Number of Timesteps", minimum=15, maximum=100, value=25, step=1
398
  ),
 
 
399
  ],
400
  outputs=gr.Audio(label="Generated Audio"),
401
  title="MaskGCT TTS Demo",
requirements.txt CHANGED
@@ -30,4 +30,5 @@ LangSegment
30
  onnxruntime
31
  pyopenjtalk
32
  pykakasi
33
- openai-whisper
 
 
30
  onnxruntime
31
  pyopenjtalk
32
  pykakasi
33
+ openai-whisper
34
+ langid