RedSparkie commited on
Commit
fb4364b
verified
1 Parent(s): e408d5d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -34
app.py CHANGED
@@ -1,4 +1,3 @@
1
- import spaces
2
  import gradio as gr
3
  import torch
4
  from TTS.api import TTS
@@ -12,6 +11,9 @@ from TTS.tts.models.xtts import Xtts
12
  # Aceptar los t茅rminos de COQUI
13
  os.environ["COQUI_TOS_AGREED"] = "1"
14
 
 
 
 
15
  # Definir el dispositivo como CPU
16
  device = "cpu"
17
 
@@ -20,28 +22,19 @@ model_path = hf_hub_download(repo_id="RedSparkie/danielmula", filename="model.pt
20
  config_path = hf_hub_download(repo_id="RedSparkie/danielmula", filename="config.json")
21
  vocab_path = hf_hub_download(repo_id="RedSparkie/danielmula", filename="vocab.json")
22
 
23
- # Funci贸n para limpiar la cach茅 de GPU (no necesaria para CPU, pero la mantengo por si en el futuro usas GPU)
24
- def clear_gpu_cache():
25
- if torch.cuda.is_available():
26
- torch.cuda.empty_cache()
27
-
28
  # Cargar el modelo XTTS
29
  XTTS_MODEL = None
30
  def load_model(xtts_checkpoint, xtts_config, xtts_vocab):
31
  global XTTS_MODEL
32
- clear_gpu_cache()
33
- if not xtts_checkpoint or not xtts_config or not xtts_vocab:
34
- return "You need to run the previous steps or manually set the `XTTS checkpoint path`, `XTTS config path`, and `XTTS vocab path` fields !!"
35
  config = XttsConfig()
36
  config.load_json(xtts_config)
 
 
37
  XTTS_MODEL = Xtts.init_from_config(config)
38
  print("Loading XTTS model!")
39
- XTTS_MODEL.load_checkpoint(config, checkpoint_path=xtts_checkpoint, vocab_path=xtts_vocab, use_deepspeed=False)
40
 
41
- # No mover a GPU ya que usamos CPU
42
- # if torch.cuda.is_available():
43
- # XTTS_MODEL.cuda()
44
-
45
  print("Model Loaded!")
46
 
47
  # Funci贸n para ejecutar TTS
@@ -49,24 +42,28 @@ def run_tts(lang, tts_text, speaker_audio_file):
49
  if XTTS_MODEL is None or not speaker_audio_file:
50
  return "You need to run the previous step to load the model !!", None, None
51
 
52
- gpt_cond_latent, speaker_embedding = XTTS_MODEL.get_conditioning_latents(
53
- audio_path=speaker_audio_file,
54
- gpt_cond_len=XTTS_MODEL.config.gpt_cond_len,
55
- max_ref_length=XTTS_MODEL.config.max_ref_len,
56
- sound_norm_refs=XTTS_MODEL.config.sound_norm_refs
57
- )
58
- out = XTTS_MODEL.inference(
59
- text=tts_text,
60
- language=lang,
61
- gpt_cond_latent=gpt_cond_latent,
62
- speaker_embedding=speaker_embedding,
63
- temperature=XTTS_MODEL.config.temperature,
64
- length_penalty=XTTS_MODEL.config.length_penalty,
65
- repetition_penalty=XTTS_MODEL.config.repetition_penalty,
66
- top_k=XTTS_MODEL.config.top_k,
67
- top_p=XTTS_MODEL.config.top_p,
68
- )
69
-
 
 
 
 
70
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp:
71
  out["wav"] = torch.tensor(out["wav"]).unsqueeze(0)
72
  out_path = fp.name
@@ -76,7 +73,6 @@ def run_tts(lang, tts_text, speaker_audio_file):
76
  return out_path, speaker_audio_file
77
 
78
  # Definir la funci贸n para Gradio
79
- @spaces.GPU(enable_queue=True)
80
  def generate(text, audio):
81
  load_model(model_path, config_path, vocab_path)
82
  out_path, speaker_audio_file = run_tts(lang='es', tts_text=text, speaker_audio_file=audio)
@@ -85,7 +81,7 @@ def generate(text, audio):
85
  # Configurar la interfaz de Gradio
86
  demo = gr.Interface(
87
  fn=generate,
88
- inputs=[gr.Textbox(label='Frase a generar'), gr.Audio(type='filepath', label='Voz de referencia')],
89
  outputs=gr.Audio(type='filepath')
90
  )
91
 
 
 
1
  import gradio as gr
2
  import torch
3
  from TTS.api import TTS
 
11
  # Aceptar los t茅rminos de COQUI
12
  os.environ["COQUI_TOS_AGREED"] = "1"
13
 
14
+ # Establecer precisi贸n reducida para acelerar en CPU
15
+ torch.set_default_dtype(torch.float16)
16
+
17
  # Definir el dispositivo como CPU
18
  device = "cpu"
19
 
 
22
  config_path = hf_hub_download(repo_id="RedSparkie/danielmula", filename="config.json")
23
  vocab_path = hf_hub_download(repo_id="RedSparkie/danielmula", filename="vocab.json")
24
 
 
 
 
 
 
25
  # Cargar el modelo XTTS
26
  XTTS_MODEL = None
27
  def load_model(xtts_checkpoint, xtts_config, xtts_vocab):
28
  global XTTS_MODEL
 
 
 
29
  config = XttsConfig()
30
  config.load_json(xtts_config)
31
+
32
+ # Inicializar el modelo
33
  XTTS_MODEL = Xtts.init_from_config(config)
34
  print("Loading XTTS model!")
 
35
 
36
+ # Cargar el checkpoint del modelo con `weights_only=True` para evitar advertencias de seguridad
37
+ XTTS_MODEL.load_checkpoint(config, checkpoint_path=xtts_checkpoint, vocab_path=xtts_vocab, use_deepspeed=False, weights_only=True)
 
 
38
  print("Model Loaded!")
39
 
40
  # Funci贸n para ejecutar TTS
 
42
  if XTTS_MODEL is None or not speaker_audio_file:
43
  return "You need to run the previous step to load the model !!", None, None
44
 
45
+ # Usar inference_mode para mejorar el rendimiento
46
+ with torch.inference_mode():
47
+ gpt_cond_latent, speaker_embedding = XTTS_MODEL.get_conditioning_latents(
48
+ audio_path=speaker_audio_file,
49
+ gpt_cond_len=XTTS_MODEL.config.gpt_cond_len,
50
+ max_ref_length=XTTS_MODEL.config.max_ref_len,
51
+ sound_norm_refs=XTTS_MODEL.config.sound_norm_refs
52
+ )
53
+
54
+ out = XTTS_MODEL.inference(
55
+ text=tts_text,
56
+ language=lang,
57
+ gpt_cond_latent=gpt_cond_latent,
58
+ speaker_embedding=speaker_embedding,
59
+ temperature=XTTS_MODEL.config.temperature,
60
+ length_penalty=XTTS_MODEL.config.length_penalty,
61
+ repetition_penalty=XTTS_MODEL.config.repetition_penalty,
62
+ top_k=XTTS_MODEL.config.top_k,
63
+ top_p=XTTS_MODEL.config.top_p,
64
+ )
65
+
66
+ # Guardar el audio generado en un archivo temporal
67
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp:
68
  out["wav"] = torch.tensor(out["wav"]).unsqueeze(0)
69
  out_path = fp.name
 
73
  return out_path, speaker_audio_file
74
 
75
  # Definir la funci贸n para Gradio
 
76
  def generate(text, audio):
77
  load_model(model_path, config_path, vocab_path)
78
  out_path, speaker_audio_file = run_tts(lang='es', tts_text=text, speaker_audio_file=audio)
 
81
  # Configurar la interfaz de Gradio
82
  demo = gr.Interface(
83
  fn=generate,
84
+ inputs=[gr.Textbox(label='Texto:'), gr.Audio(type='filepath', label='Voz de referencia')],
85
  outputs=gr.Audio(type='filepath')
86
  )
87