Spaces:
Paused
Paused
import gradio as gr | |
from fish_speech import LM | |
import re | |
from rustymimi import Tokenizer | |
from huggingface_hub import snapshot_download, hf_hub_download | |
import numpy as np | |
import spaces | |
# Voice mapping dictionary: | |
# US voices | |
# heart (default) -> <|speaker:0|> | |
# bella -> <|speaker:1|> | |
# nova -> <|speaker:2|> | |
# sky -> <|speaker:3|> | |
# sarah -> <|speaker:4|> | |
# michael -> <|speaker:5|> | |
# fenrir -> <|speaker:6|> | |
# liam -> <|speaker:7|> | |
# British voices | |
# emma -> <|speaker:8|> | |
# isabella -> <|speaker:9|> | |
# fable -> <|speaker:10|> | |
voice_mapping = { | |
"Heart (US)": "<|speaker:0|>", | |
"Bella (US)": "<|speaker:1|>", | |
"Nova (US)": "<|speaker:2|>", | |
"Sky (US)": "<|speaker:3|>", | |
"Sarah (US)": "<|speaker:4|>", | |
"Michael (US)": "<|speaker:5|>", | |
"Fenrir (US)": "<|speaker:6|>", | |
"Liam (US)": "<|speaker:7|>", | |
"Emma (UK)": "<|speaker:8|>", | |
"Isabella (UK)": "<|speaker:9|>", | |
"Fable (UK)": "<|speaker:10|>", | |
} | |
# Initialize models | |
print("Downloading and initializing models...") | |
def get_mimi_path(): | |
"""Get Mimi tokenizer weights from Hugging Face.""" | |
repo_id = "kyutai/moshiko-mlx-bf16" | |
filename = "tokenizer-e351c8d8-checkpoint125.safetensors" | |
return hf_hub_download(repo_id, filename) | |
dir = snapshot_download("jkeisling/smoltts_v0") | |
mimi_path = get_mimi_path() | |
# lm = LM(dir, dtype="bf16", device="cuda", version="dual_ar") | |
codec = Tokenizer(mimi_path) | |
# Naively split text into sentences | |
def split_sentences(text): | |
sentences = re.split(r"(?<=[?.!])\s+", text) | |
return [s.strip() for s in sentences if s.strip()] | |
def synthesize_speech(text, temperature, top_p, voice): | |
"""Generate speech from text using Fish Speech, processing each sentence separately.""" | |
lm = LM(dir, dtype="bf16", device="cuda", version="dual_ar") | |
sysprompt = voice_mapping.get(voice, "<|speaker:0|>") | |
sentences = split_sentences(text) | |
pcm_list = [] | |
for sentence in sentences: | |
# Generate audio for each sentence individually | |
generated = lm([sentence], temp=temperature, top_p=top_p, sysprompt=sysprompt) | |
pcm = codec.decode(generated) | |
pcm_list.append(pcm.flatten()) | |
# Concatenate all PCM arrays into one | |
final_pcm = np.concatenate(pcm_list) | |
return (24_000, final_pcm) | |
# Create the Gradio interface | |
with gr.Blocks( | |
theme=gr.themes.Default( | |
font=[gr.themes.GoogleFont("IBM Plex Sans"), "Arial", "sans-serif"], | |
font_mono=gr.themes.GoogleFont("IBM Plex Mono"), | |
primary_hue=gr.themes.colors.blue, | |
secondary_hue=gr.themes.colors.slate, | |
) | |
) as demo: | |
with gr.Row(): | |
gr.Markdown(""" | |
# SmolTTS v0 | |
SmolTTS v0 is an autoregressive 150M parameter character-level text-to-speech model pretrained with an [RQTransformer backbone](https://arxiv.org/abs/2203.01941) and paired with a pretrained [Mimi codec](https://arxiv.org/abs/2410.00037) vocoder. Designed for US and UK English, it was trained entirely on synthetic speech data generated using [Kokoro TTS](https://huggingface.co./hexgrad/Kokoro-82M). SmolTTS is Apache 2.0 licensed - enjoy! | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
input_text = gr.Textbox( | |
label="Input Text", placeholder="Enter text to synthesize...", lines=3 | |
) | |
voice_dropdown = gr.Dropdown( | |
label="Voice", | |
choices=list(voice_mapping.keys()), | |
value="heart (US)", | |
info="Select a voice (sysprompt mapping)", | |
) | |
with gr.Row(): | |
temperature = gr.Slider( | |
minimum=0.0, maximum=1.0, value=0.1, step=0.1, label="Temperature" | |
) | |
top_p = gr.Slider( | |
minimum=0.0, maximum=1.0, value=0.85, step=0.01, label="Top P" | |
) | |
with gr.Column(): | |
audio_output = gr.Audio(label="Generated Speech", type="numpy") | |
generate_btn = gr.Button("Generate Speech", variant="primary") | |
generate_btn.click( | |
fn=synthesize_speech, | |
inputs=[input_text, temperature, top_p, voice_dropdown], | |
outputs=[audio_output], | |
) | |
if __name__ == "__main__": | |
demo.launch(server_name="0.0.0.0", share=False) | |