Spaces:
Running
Running
import gradio as gr | |
import os | |
import torch | |
# Import eSpeak TTS pipeline | |
from tts_cli import ( | |
build_model as build_model_espeak, | |
generate_long_form_tts as generate_long_form_tts_espeak, | |
) | |
# Import OpenPhonemizer TTS pipeline | |
from tts_cli_op import ( | |
build_model as build_model_open, | |
generate_long_form_tts as generate_long_form_tts_open, | |
) | |
from pretrained_models import Kokoro | |
# --------------------------------------------------------------------- | |
# Path to models and voicepacks | |
# --------------------------------------------------------------------- | |
MODELS_DIR = "pretrained_models/Kokoro" | |
VOICES_DIR = "pretrained_models/Kokoro/voices" | |
# --------------------------------------------------------------------- | |
# List the models (.pth) and voices (.pt) | |
# --------------------------------------------------------------------- | |
def get_models(): | |
return sorted([f for f in os.listdir(MODELS_DIR) if f.endswith(".pth")]) | |
def get_voices(): | |
return sorted([f for f in os.listdir(VOICES_DIR) if f.endswith(".pt")]) | |
# --------------------------------------------------------------------- | |
# We'll map engine selection -> (build_model_func, generate_func) | |
# --------------------------------------------------------------------- | |
ENGINES = { | |
"espeak": (build_model_espeak, generate_long_form_tts_espeak), | |
"openphonemizer": (build_model_open, generate_long_form_tts_open), | |
} | |
# --------------------------------------------------------------------- | |
# The main inference function called by Gradio | |
# --------------------------------------------------------------------- | |
def tts_inference(text, engine, model_file, voice_file, speed=1.0): | |
""" | |
text: Input string | |
engine: "espeak" or "openphonemizer" | |
model_file: Selected .pth from the models folder | |
voice_file: Selected .pt from the voices folder | |
speed: Speech speed | |
""" | |
# 1) Map engine to the correct build_model + generate_long_form_tts | |
build_fn, gen_fn = ENGINES[engine] | |
# 2) Prepare paths | |
model_path = os.path.join(MODELS_DIR, model_file) | |
voice_path = os.path.join(VOICES_DIR, voice_file) | |
# 3) Decide device | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# 4) Load model | |
model = build_fn(model_path, device=device) | |
# Set submodules eval | |
for k, subm in model.items(): | |
if hasattr(subm, "eval"): | |
subm.eval() | |
# 5) Load voicepack | |
voicepack = torch.load(voice_path, map_location=device) | |
if hasattr(voicepack, "eval"): | |
voicepack.eval() | |
# 6) Generate TTS | |
audio, phonemes = gen_fn(model, text, voicepack, speed=speed) | |
sr = 22050 # or your actual sample rate | |
return (sr, audio) # Gradio expects (sample_rate, np_array) | |
# --------------------------------------------------------------------- | |
# Build Gradio App | |
# --------------------------------------------------------------------- | |
def create_gradio_app(): | |
model_list = get_models() | |
voice_list = get_voices() | |
css = """ | |
h4 { | |
text-align: center; | |
display:block; | |
} | |
h2 { | |
text-align: center; | |
display:block; | |
} | |
""" | |
with gr.Blocks(theme=gr.themes.Ocean(), css=css) as demo: | |
gr.Markdown("## Kokoro TTS Demo: Choose phonemizer, model, and voice") | |
# Row 1: Text input | |
text_input = gr.Textbox( | |
label="Input Text", | |
value="Hello, world! Testing both eSpeak and OpenPhonemizer. Can you believe that we live in 2025 and have access to advanced AI?", | |
lines=3, | |
) | |
# Row 2: Engine selection | |
engine_dropdown = gr.Dropdown( | |
choices=["espeak", "openphonemizer"], | |
value="openphonemizer", | |
label="Phonemizer", | |
) | |
# Row 3: Model dropdown | |
model_dropdown = gr.Dropdown( | |
choices=model_list, | |
value=model_list[0] if model_list else None, | |
label="Model (.pth)", | |
) | |
# Row 4: Voice dropdown | |
voice_dropdown = gr.Dropdown( | |
choices=voice_list, | |
value=voice_list[0] if voice_list else None, | |
label="Voice (.pt)", | |
) | |
# Row 5: Speed slider | |
speed_slider = gr.Slider( | |
minimum=0.5, maximum=2.0, value=1.0, step=0.1, label="Speech Speed" | |
) | |
# Generate button + audio output | |
generate_btn = gr.Button("Generate") | |
tts_output = gr.Audio(label="TTS Output") | |
# Connect the button to our inference function | |
generate_btn.click( | |
fn=tts_inference, | |
inputs=[ | |
text_input, | |
engine_dropdown, | |
model_dropdown, | |
voice_dropdown, | |
speed_slider, | |
], | |
outputs=tts_output, | |
) | |
gr.Markdown( | |
"#### Kokoro TTS Demo based on [Kokoro-82M](https://huggingface.co./hexgrad/Kokoro-82M)" | |
) | |
return demo | |
# --------------------------------------------------------------------- | |
# Main | |
# --------------------------------------------------------------------- | |
if __name__ == "__main__": | |
app = create_gradio_app() | |
app.launch() | |