File size: 4,282 Bytes
1945c48
1785140
 
 
b865d16
 
a627d55
b865d16
 
 
a627d55
1785140
1945c48
fb4364b
 
 
a627d55
3cc5048
1785140
a627d55
 
 
 
b865d16
3408722
 
 
 
 
 
 
 
 
 
 
 
 
 
a627d55
b865d16
 
 
 
 
fb4364b
 
b865d16
a627d55
 
3408722
0dcc709
b865d16
 
a627d55
33b51a6
b865d16
 
 
3408722
 
 
 
 
 
 
 
fb4364b
 
 
3408722
fb4364b
 
 
 
3408722
 
 
fb4364b
 
 
 
 
 
 
 
 
 
 
 
 
 
b865d16
 
 
 
a627d55
a7161ed
 
b865d16
a627d55
b865d16
a7161ed
 
 
1945c48
a627d55
1945c48
b865d16
0dcc709
6f17cca
1945c48
 
3408722
0dcc709
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import gradio as gr
import torch
from TTS.api import TTS
import os
import tempfile
import torchaudio
from huggingface_hub import hf_hub_download
from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.xtts import Xtts

# Aceptar los t茅rminos de COQUI
os.environ["COQUI_TOS_AGREED"] = "1"

# Establecer precisi贸n reducida para acelerar en CPU
torch.set_default_dtype(torch.float16)

# Definir el dispositivo como CPU
device = "cpu"

# Descargar archivos desde HuggingFace
model_path = hf_hub_download(repo_id="RedSparkie/danielmula", filename="model.pth")
config_path = hf_hub_download(repo_id="RedSparkie/danielmula", filename="config.json")
vocab_path = hf_hub_download(repo_id="RedSparkie/danielmula", filename="vocab.json")

# Funci贸n para resamplear el audio a 24000 Hz y convertirlo a 16 bits
def preprocess_audio(audio_path, target_sr=24000):
    waveform, original_sr = torchaudio.load(audio_path)
    
    # Resamplear si la frecuencia de muestreo es diferente
    if original_sr != target_sr:
        resampler = torchaudio.transforms.Resample(orig_freq=original_sr, new_freq=target_sr)
        waveform = resampler(waveform)
    
    # Convertir a 16 bits
    waveform = waveform * (2**15)  # Escalar para 16 bits
    waveform = waveform.to(torch.int16)  # Convertir a formato de 16 bits
    return waveform, target_sr

# Cargar el modelo XTTS
XTTS_MODEL = None
def load_model(xtts_checkpoint, xtts_config, xtts_vocab):
    global XTTS_MODEL
    config = XttsConfig()
    config.load_json(xtts_config)
    
    # Inicializar el modelo
    XTTS_MODEL = Xtts.init_from_config(config)
    print("Loading XTTS model!")
    
    # Cargar el checkpoint del modelo
    XTTS_MODEL.load_checkpoint(config, checkpoint_path=xtts_checkpoint, vocab_path=xtts_vocab, use_deepspeed=False)
    print("Model Loaded!")

# Funci贸n para ejecutar TTS
def run_tts(lang, tts_text, speaker_audio_file):
    if XTTS_MODEL is None or not speaker_audio_file:
        return "You need to run the previous step to load the model !!", None, None

    # Preprocesar el audio (resampleo a 24000 Hz y conversi贸n a 16 bits)
    waveform, sr = preprocess_audio(speaker_audio_file)

    # Guardar el audio procesado temporalmente para usarlo con el modelo
    with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp:
        torchaudio.save(fp.name, waveform, sr)
        processed_audio_path = fp.name

    # Usar inference_mode para mejorar el rendimiento
    with torch.inference_mode():
        gpt_cond_latent, speaker_embedding = XTTS_MODEL.get_conditioning_latents(
            audio_path=processed_audio_path, 
            gpt_cond_len=XTTS_MODEL.config.gpt_cond_len, 
            max_ref_length=XTTS_MODEL.config.max_ref_len, 
            sound_norm_refs=XTTS_MODEL.config.sound_norm_refs
        )

        if gpt_cond_latent is None or speaker_embedding is None:
            return "Failed to process the audio file.", None, None
        
        out = XTTS_MODEL.inference(
            text=tts_text,
            language=lang,
            gpt_cond_latent=gpt_cond_latent,
            speaker_embedding=speaker_embedding,
            temperature=XTTS_MODEL.config.temperature,
            length_penalty=XTTS_MODEL.config.length_penalty,
            repetition_penalty=XTTS_MODEL.config.repetition_penalty,
            top_k=XTTS_MODEL.config.top_k,
            top_p=XTTS_MODEL.config.top_p,
        )
    
    # Guardar el audio generado en un archivo temporal
    with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp:
        out["wav"] = torch.tensor(out["wav"]).unsqueeze(0)
        out_path = fp.name
        torchaudio.save(out_path, out["wav"], 24000)
    print("Speech generated!")

    return out_path, speaker_audio_file

# Definir la funci贸n para Gradio
def generate(text, audio):
    load_model(model_path, config_path, vocab_path)
    out_path, speaker_audio_file = run_tts(lang='es', tts_text=text, speaker_audio_file=audio)
    return out_path

# Configurar la interfaz de Gradio
demo = gr.Interface(
    fn=generate, 
    inputs=[gr.Textbox(label='Frase a generar'), gr.Audio(type='filepath', label='Voz de referencia')], 
    outputs=gr.Audio(type='filepath')
)

# Lanzar la interfaz con un enlace p煤blico
demo.launch(share=True)