File size: 4,225 Bytes
16f2323
cff6a85
 
 
 
 
 
16f2323
cff6a85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16f2323
cff6a85
 
 
 
 
 
 
 
 
 
 
 
 
9b1745e
cff6a85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9b1745e
cff6a85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import gradio as gr
from fish_speech import LM
import re
from rustymimi import Tokenizer
from huggingface_hub import snapshot_download, hf_hub_download
import numpy as np
import spaces

# Voice mapping dictionary:
# US voices
#   heart (default) -> <|speaker:0|>
#   bella -> <|speaker:1|>
#   nova -> <|speaker:2|>
#   sky -> <|speaker:3|>
#   sarah -> <|speaker:4|>
#   michael -> <|speaker:5|>
#   fenrir -> <|speaker:6|>
#   liam -> <|speaker:7|>
# British voices
#   emma -> <|speaker:8|>
#   isabella -> <|speaker:9|>
#   fable -> <|speaker:10|>
voice_mapping = {
    "Heart (US)": "<|speaker:0|>",
    "Bella (US)": "<|speaker:1|>",
    "Nova (US)": "<|speaker:2|>",
    "Sky (US)": "<|speaker:3|>",
    "Sarah (US)": "<|speaker:4|>",
    "Michael (US)": "<|speaker:5|>",
    "Fenrir (US)": "<|speaker:6|>",
    "Liam (US)": "<|speaker:7|>",
    "Emma (UK)": "<|speaker:8|>",
    "Isabella (UK)": "<|speaker:9|>",
    "Fable (UK)": "<|speaker:10|>",
}

# Initialize models
print("Downloading and initializing models...")


def get_mimi_path():
    """Get Mimi tokenizer weights from Hugging Face."""
    repo_id = "kyutai/moshiko-mlx-bf16"
    filename = "tokenizer-e351c8d8-checkpoint125.safetensors"
    return hf_hub_download(repo_id, filename)


dir = snapshot_download("jkeisling/smoltts_v0")
mimi_path = get_mimi_path()
lm = LM(dir, dtype="bf16", device="cuda", version="dual_ar")
codec = Tokenizer(mimi_path)


# Naively split text into sentences
def split_sentences(text):
    sentences = re.split(r"(?<=[?.!])\s+", text)
    return [s.strip() for s in sentences if s.strip()]


@spaces.GPU
def synthesize_speech(text, temperature, top_p, voice):
    """Generate speech from text using Fish Speech, processing each sentence separately."""
    sysprompt = voice_mapping.get(voice, "<|speaker:0|>")
    sentences = split_sentences(text)
    pcm_list = []

    for sentence in sentences:
        # Generate audio for each sentence individually
        generated = lm([sentence], temp=temperature, top_p=top_p, sysprompt=sysprompt)
        pcm = codec.decode(generated)
        pcm_list.append(pcm.flatten())

    # Concatenate all PCM arrays into one
    final_pcm = np.concatenate(pcm_list)
    return (24_000, final_pcm)


# Create the Gradio interface
with gr.Blocks(
    theme=gr.themes.Default(
        font=[gr.themes.GoogleFont("IBM Plex Sans"), "Arial", "sans-serif"],
        font_mono=gr.themes.GoogleFont("IBM Plex Mono"),
        primary_hue=gr.themes.colors.blue,
        secondary_hue=gr.themes.colors.slate,
    )
) as demo:
    with gr.Row():
        gr.Markdown("""
            # SmolTTS v0

            SmolTTS v0 is an autoregressive 150M parameter character-level text-to-speech model pretrained with an [RQTransformer backbone](https://arxiv.org/abs/2203.01941) and paired with a pretrained [Mimi codec](https://arxiv.org/abs/2410.00037) vocoder. Designed for US and UK English, it was trained entirely on synthetic speech data generated using [Kokoro TTS](https://huggingface.co./hexgrad/Kokoro-82M). SmolTTS is Apache 2.0 licensed - enjoy!
            """)

    with gr.Row():
        with gr.Column():
            input_text = gr.Textbox(
                label="Input Text", placeholder="Enter text to synthesize...", lines=3
            )
            voice_dropdown = gr.Dropdown(
                label="Voice",
                choices=list(voice_mapping.keys()),
                value="Heart (US)",
                info="Select a voice (sysprompt mapping)",
            )
            with gr.Row():
                temperature = gr.Slider(
                    minimum=0.0, maximum=1.0, value=0.1, step=0.1, label="Temperature"
                )
                top_p = gr.Slider(
                    minimum=0.0, maximum=1.0, value=0.85, step=0.01, label="Top P"
                )
        with gr.Column():
            audio_output = gr.Audio(label="Generated Speech", type="numpy")

    generate_btn = gr.Button("Generate Speech", variant="primary")
    generate_btn.click(
        fn=synthesize_speech,
        inputs=[input_text, temperature, top_p, voice_dropdown],
        outputs=[audio_output],
    )

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", share=False)