RedSparkie's picture
Update app.py
3408722 verified
raw
history blame
4.28 kB
import gradio as gr
import torch
from TTS.api import TTS
import os
import tempfile
import torchaudio
from huggingface_hub import hf_hub_download
from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.xtts import Xtts
# Aceptar los t茅rminos de COQUI
os.environ["COQUI_TOS_AGREED"] = "1"
# Establecer precisi贸n reducida para acelerar en CPU
torch.set_default_dtype(torch.float16)
# Definir el dispositivo como CPU
device = "cpu"
# Descargar archivos desde HuggingFace
model_path = hf_hub_download(repo_id="RedSparkie/danielmula", filename="model.pth")
config_path = hf_hub_download(repo_id="RedSparkie/danielmula", filename="config.json")
vocab_path = hf_hub_download(repo_id="RedSparkie/danielmula", filename="vocab.json")
# Funci贸n para resamplear el audio a 24000 Hz y convertirlo a 16 bits
def preprocess_audio(audio_path, target_sr=24000):
waveform, original_sr = torchaudio.load(audio_path)
# Resamplear si la frecuencia de muestreo es diferente
if original_sr != target_sr:
resampler = torchaudio.transforms.Resample(orig_freq=original_sr, new_freq=target_sr)
waveform = resampler(waveform)
# Convertir a 16 bits
waveform = waveform * (2**15) # Escalar para 16 bits
waveform = waveform.to(torch.int16) # Convertir a formato de 16 bits
return waveform, target_sr
# Cargar el modelo XTTS
XTTS_MODEL = None
def load_model(xtts_checkpoint, xtts_config, xtts_vocab):
global XTTS_MODEL
config = XttsConfig()
config.load_json(xtts_config)
# Inicializar el modelo
XTTS_MODEL = Xtts.init_from_config(config)
print("Loading XTTS model!")
# Cargar el checkpoint del modelo
XTTS_MODEL.load_checkpoint(config, checkpoint_path=xtts_checkpoint, vocab_path=xtts_vocab, use_deepspeed=False)
print("Model Loaded!")
# Funci贸n para ejecutar TTS
def run_tts(lang, tts_text, speaker_audio_file):
if XTTS_MODEL is None or not speaker_audio_file:
return "You need to run the previous step to load the model !!", None, None
# Preprocesar el audio (resampleo a 24000 Hz y conversi贸n a 16 bits)
waveform, sr = preprocess_audio(speaker_audio_file)
# Guardar el audio procesado temporalmente para usarlo con el modelo
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp:
torchaudio.save(fp.name, waveform, sr)
processed_audio_path = fp.name
# Usar inference_mode para mejorar el rendimiento
with torch.inference_mode():
gpt_cond_latent, speaker_embedding = XTTS_MODEL.get_conditioning_latents(
audio_path=processed_audio_path,
gpt_cond_len=XTTS_MODEL.config.gpt_cond_len,
max_ref_length=XTTS_MODEL.config.max_ref_len,
sound_norm_refs=XTTS_MODEL.config.sound_norm_refs
)
if gpt_cond_latent is None or speaker_embedding is None:
return "Failed to process the audio file.", None, None
out = XTTS_MODEL.inference(
text=tts_text,
language=lang,
gpt_cond_latent=gpt_cond_latent,
speaker_embedding=speaker_embedding,
temperature=XTTS_MODEL.config.temperature,
length_penalty=XTTS_MODEL.config.length_penalty,
repetition_penalty=XTTS_MODEL.config.repetition_penalty,
top_k=XTTS_MODEL.config.top_k,
top_p=XTTS_MODEL.config.top_p,
)
# Guardar el audio generado en un archivo temporal
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp:
out["wav"] = torch.tensor(out["wav"]).unsqueeze(0)
out_path = fp.name
torchaudio.save(out_path, out["wav"], 24000)
print("Speech generated!")
return out_path, speaker_audio_file
# Definir la funci贸n para Gradio
def generate(text, audio):
load_model(model_path, config_path, vocab_path)
out_path, speaker_audio_file = run_tts(lang='es', tts_text=text, speaker_audio_file=audio)
return out_path
# Configurar la interfaz de Gradio
demo = gr.Interface(
fn=generate,
inputs=[gr.Textbox(label='Frase a generar'), gr.Audio(type='filepath', label='Voz de referencia')],
outputs=gr.Audio(type='filepath')
)
# Lanzar la interfaz con un enlace p煤blico
demo.launch(share=True)