import gradio as gr import os import torch # Import eSpeak TTS pipeline from tts_cli import ( build_model as build_model_espeak, generate_long_form_tts as generate_long_form_tts_espeak, ) # Import OpenPhonemizer TTS pipeline from tts_cli_op import ( build_model as build_model_open, generate_long_form_tts as generate_long_form_tts_open, ) from pretrained_models import Kokoro # --------------------------------------------------------------------- # Path to models and voicepacks # --------------------------------------------------------------------- MODELS_DIR = "pretrained_models/Kokoro" VOICES_DIR = "pretrained_models/Kokoro/voices" # --------------------------------------------------------------------- # List the models (.pth) and voices (.pt) # --------------------------------------------------------------------- def get_models(): return sorted([f for f in os.listdir(MODELS_DIR) if f.endswith(".pth")]) def get_voices(): return sorted([f for f in os.listdir(VOICES_DIR) if f.endswith(".pt")]) # --------------------------------------------------------------------- # We'll map engine selection -> (build_model_func, generate_func) # --------------------------------------------------------------------- ENGINES = { "espeak": (build_model_espeak, generate_long_form_tts_espeak), "openphonemizer": (build_model_open, generate_long_form_tts_open), } # --------------------------------------------------------------------- # The main inference function called by Gradio # --------------------------------------------------------------------- def tts_inference(text, engine, model_file, voice_file, speed=1.0): """ text: Input string engine: "espeak" or "openphonemizer" model_file: Selected .pth from the models folder voice_file: Selected .pt from the voices folder speed: Speech speed """ # 1) Map engine to the correct build_model + generate_long_form_tts build_fn, gen_fn = ENGINES[engine] # 2) Prepare paths model_path = os.path.join(MODELS_DIR, model_file) voice_path = os.path.join(VOICES_DIR, voice_file) # 3) Decide device device = "cuda" if torch.cuda.is_available() else "cpu" # 4) Load model model = build_fn(model_path, device=device) # Set submodules eval for k, subm in model.items(): if hasattr(subm, "eval"): subm.eval() # 5) Load voicepack voicepack = torch.load(voice_path, map_location=device) if hasattr(voicepack, "eval"): voicepack.eval() # 6) Generate TTS audio, phonemes = gen_fn(model, text, voicepack, speed=speed) sr = 22050 # or your actual sample rate return (sr, audio) # Gradio expects (sample_rate, np_array) # --------------------------------------------------------------------- # Build Gradio App # --------------------------------------------------------------------- def create_gradio_app(): model_list = get_models() voice_list = get_voices() css = """ h4 { text-align: center; display:block; } h2 { text-align: center; display:block; } """ with gr.Blocks(theme=gr.themes.Ocean(), css=css) as demo: gr.Markdown("## Kokoro TTS Demo: Choose engine, model, and voice") # Row 1: Text input text_input = gr.Textbox( label="Input Text", value="Hello, world! Testing both eSpeak and OpenPhonemizer. Can you believe that we live in 2024 and have access to advanced AI?", lines=3, ) # Row 2: Engine selection engine_dropdown = gr.Dropdown( choices=["espeak", "openphonemizer"], value="openphonemizer", label="Phonemizer", ) # Row 3: Model dropdown model_dropdown = gr.Dropdown( choices=model_list, value=model_list[0] if model_list else None, label="Model (.pth)", ) # Row 4: Voice dropdown voice_dropdown = gr.Dropdown( choices=voice_list, value=voice_list[0] if voice_list else None, label="Voice (.pt)", ) # Row 5: Speed slider speed_slider = gr.Slider( minimum=0.5, maximum=2.0, value=1.0, step=0.1, label="Speech Speed" ) # Generate button + audio output generate_btn = gr.Button("Generate") tts_output = gr.Audio(label="TTS Output") # Connect the button to our inference function generate_btn.click( fn=tts_inference, inputs=[ text_input, engine_dropdown, model_dropdown, voice_dropdown, speed_slider, ], outputs=tts_output, ) gr.Markdown( "#### Kokoro TTS Demo based on [Kokoro-82M](https://huggingface.co./hexgrad/Kokoro-82M)" ) return demo # --------------------------------------------------------------------- # Main # --------------------------------------------------------------------- if __name__ == "__main__": app = create_gradio_app() app.launch()