Spaces:
Running
Running
#!/usr/bin/env python3 | |
import gradio as gr | |
import numpy as np | |
import torch | |
import json | |
import io | |
import soundfile as sf | |
from PIL import Image | |
import matplotlib | |
import joblib | |
from sklearn.decomposition import PCA | |
from collections import OrderedDict | |
import nltk | |
matplotlib.use("Agg") # Use non-interactive backend | |
import matplotlib.pyplot as plt | |
# ------------------------------------------------------------------- | |
# IMPORT OR DEFINE YOUR TEXT-TO-SPEECH FUNCTIONS | |
# (Adjust these imports to match your local TTS code) | |
# ------------------------------------------------------------------- | |
from text2speech import tts_randomized, parse_speed, tts_with_style_vector | |
# Constants and Paths | |
VOICES_JSON_PATH = "voices.json" | |
PCA_MODEL_PATH = "pca_model.pkl" | |
ANNOTATED_FEATURES_PATH = "annotated_features.npy" | |
VECTOR_DIMENSION = 256 | |
ANNOTATED_FEATURES_NAMES = ["Gender", "Tone", "Quality", "Enunciation", "Pace", "Style"] | |
ANNOTATED_FEATURES_INFO = [ | |
"Male | Female", | |
"High | Low", | |
"Noisy | Clean", | |
"Clear | Unclear", | |
"Rapid | Slow", | |
"Colloquial | Formal", | |
] | |
# Download necessary NLTK data | |
nltk.download("punkt_tab") | |
############################################################################## | |
# DEVICE CONFIGURATION | |
############################################################################## | |
# Detect if CUDA is available and set the device accordingly | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
print(f"Using device: {device}") | |
############################################################################## | |
# LOAD PCA MODEL AND ANNOTATED FEATURES | |
############################################################################## | |
try: | |
pca = joblib.load(PCA_MODEL_PATH) | |
print("PCA model loaded successfully.") | |
except FileNotFoundError: | |
print(f"Error: PCA model file '{PCA_MODEL_PATH}' not found.") | |
pca = None | |
try: | |
annotated_features = np.load(ANNOTATED_FEATURES_PATH) | |
print("Annotated features loaded successfully.") | |
except FileNotFoundError: | |
print(f"Error: Annotated features file '{ANNOTATED_FEATURES_PATH}' not found.") | |
annotated_features = None | |
############################################################################## | |
# UTILITY FUNCTIONS | |
############################################################################## | |
def load_voices_json(): | |
"""Load the voices.json file.""" | |
try: | |
with open(VOICES_JSON_PATH, "r") as f: | |
return json.load(f, object_pairs_hook=OrderedDict) | |
except FileNotFoundError: | |
print(f"Warning: {VOICES_JSON_PATH} not found. Creating a new one.") | |
return OrderedDict() | |
except json.JSONDecodeError: | |
print(f"Warning: {VOICES_JSON_PATH} is not valid JSON.") | |
return OrderedDict() | |
def save_voices_json(data, path=VOICES_JSON_PATH): | |
"""Save to voices.json.""" | |
with open(path, "w") as f: | |
json.dump(data, f, indent=2) | |
print(f"Voices saved to '{path}'.") | |
def update_sliders(voice_name): | |
""" | |
Update slider values based on the selected predefined voice using reverse PCA. | |
Returns a list of PCA component values to set the sliders. | |
""" | |
if not voice_name: | |
# Return default slider values (e.g., zeros) if no voice is selected | |
return [0.0] * len(ANNOTATED_FEATURES_NAMES) | |
voices_data = load_voices_json() | |
if voice_name not in voices_data: | |
print(f"Voice '{voice_name}' not found in {VOICES_JSON_PATH}.") | |
return [0.0] * len(ANNOTATED_FEATURES_NAMES) | |
style_vector = np.array(voices_data[voice_name], dtype=np.float32).reshape(1, -1) | |
if pca is None: | |
print("PCA model is not loaded.") | |
return [0.0] * len(ANNOTATED_FEATURES_NAMES) | |
try: | |
# Transform the style vector into PCA component values | |
pca_components = pca.transform(style_vector)[0] | |
return pca_components.tolist() | |
except Exception as e: | |
print(f"Error transforming style vector to PCA components: {e}") | |
return [0.0] * len(ANNOTATED_FEATURES_NAMES) | |
def generate_audio_with_voice(text, voice_key, speed_val): | |
""" | |
Generate audio using the style vector of the selected predefined voice. | |
Returns (audio_tuple, style_vector) or (None, error_message). | |
""" | |
try: | |
# Load voices data | |
voices_data = load_voices_json() | |
if voice_key not in voices_data: | |
msg = f"Voice '{voice_key}' not found in {VOICES_JSON_PATH}." | |
print(msg) | |
return None, msg | |
style_vector = np.array(voices_data[voice_key], dtype=np.float32).reshape(1, -1) | |
print(f"Selected Voice: {voice_key}") | |
print(f"Style Vector (First 6): {style_vector[0][:6]}") | |
# Convert to torch tensor and move to device | |
style_vec_torch = torch.from_numpy(style_vector).float().to(device) | |
# Generate audio | |
audio_np = tts_with_style_vector( | |
text, | |
style_vec=style_vec_torch, | |
speed=speed_val, | |
alpha=0.3, | |
beta=0.7, | |
diffusion_steps=7, | |
embedding_scale=1.0, | |
) | |
if audio_np is None: | |
msg = "Audio generation failed." | |
print(msg) | |
return None, msg | |
sr = 24000 | |
audio_tuple = (sr, audio_np) | |
return audio_tuple, style_vector.tolist() | |
except Exception as e: | |
print(f"Error in generate_audio_with_voice: {e}") | |
return None, "An error occurred during audio generation." | |
def build_modified_vector(voice_key, top6_values): | |
"""Reconstruct a style vector by applying inverse PCA on the given 6 slider values.""" | |
voices_data = load_voices_json() | |
if voice_key not in voices_data: | |
print(f"Voice '{voice_key}' not found in {VOICES_JSON_PATH}.") | |
return None | |
arr = np.array(voices_data[voice_key], dtype=np.float32).squeeze() | |
if arr.ndim != 1 or arr.shape[0] != VECTOR_DIMENSION: | |
print(f"Voice '{voice_key}' has invalid shape {arr.shape}. Expected (256,).") | |
return None | |
try: | |
pca_components = np.array(top6_values).reshape(1, -1) | |
reconstructed_vec = pca.inverse_transform(pca_components)[0] | |
return reconstructed_vec | |
except Exception as e: | |
print(f"Error reconstructing style vector: {e}") | |
return None | |
def generate_custom_audio(text, voice_key, randomize, speed_val, *slider_values): | |
""" | |
Generate audio with either a random style vector or a reconstructed vector | |
from the 6 PCA sliders. Returns (audio_tuple, style_vector) or (None, None). | |
""" | |
try: | |
if randomize: | |
# Generate randomized style vector | |
audio_np, random_style_vec = tts_randomized(text, speed=speed_val) | |
if random_style_vec is None: | |
print("Failed to generate randomized style vector.") | |
return None, None | |
final_vec = ( | |
random_style_vec.cpu().numpy().flatten() | |
if isinstance(random_style_vec, torch.Tensor) | |
else np.array(random_style_vec).flatten() | |
) | |
print("Randomized Style Vector (First 6):", final_vec[:6]) | |
else: | |
# Reconstruct vector from PCA sliders | |
reconstructed_vec = build_modified_vector(voice_key, slider_values) | |
if reconstructed_vec is None: | |
print("No reconstructed vector. Skipping audio generation.") | |
return None, None | |
style_vec_torch = ( | |
torch.from_numpy(reconstructed_vec).float().unsqueeze(0).to(device) | |
) | |
audio_np = tts_with_style_vector( | |
text, | |
style_vec=style_vec_torch, | |
speed=speed_val, | |
alpha=0.3, | |
beta=0.7, | |
diffusion_steps=7, | |
embedding_scale=1.0, | |
) | |
final_vec = reconstructed_vec | |
print("Reconstructed Style Vector (First 6):", final_vec[:6]) | |
if audio_np is None: | |
print("Audio generation failed.") | |
return None, None | |
sr = 24000 | |
audio_tuple = (sr, audio_np) | |
return audio_tuple, final_vec.tolist() | |
except Exception as e: | |
print(f"Error generating audio and style: {e}") | |
return None, None | |
def save_style_to_json(style_data, style_name): | |
""" | |
Saves the provided style_data (list of floats) into voices.json under style_name. | |
Returns a status message. | |
""" | |
if not style_name.strip(): | |
return "Please enter a new style name before saving." | |
voices_data = load_voices_json() | |
if style_name in voices_data: | |
return ( | |
f"Style name '{style_name}' already exists. Please choose a different name." | |
) | |
if len(style_data) != VECTOR_DIMENSION: | |
return f"Style vector length mismatch. Expected {VECTOR_DIMENSION}, got {len(style_data)}." | |
voices_data[style_name] = style_data | |
save_voices_json(voices_data) | |
return f"Saved style as '{style_name}' in {VOICES_JSON_PATH}." | |
def rearrange_voices(new_order): | |
""" | |
Rearrange the voices in voices.json based on the comma-separated `new_order`. | |
Returns (status_msg, updated_list_of_voices). | |
""" | |
voices_data = load_voices_json() | |
new_order_list = [name.strip() for name in new_order.split(",")] | |
if not all(name in voices_data for name in new_order_list): | |
return "Error: New order contains invalid voice names.", list( | |
voices_data.keys() | |
) | |
ordered_data = OrderedDict() | |
for name in new_order_list: | |
ordered_data[name] = voices_data[name] | |
save_voices_json(ordered_data) | |
print(f"Voices rearranged: {list(ordered_data.keys())}") | |
return "Voices rearranged successfully.", list(ordered_data.keys()) | |
def delete_voice(selected): | |
"""Delete voices from the voices.json. Returns (status_msg, updated_list_of_voices).""" | |
if not selected: | |
return "No voices selected for deletion.", list(load_voices_json().keys()) | |
voices_data = load_voices_json() | |
for voice_name in selected: | |
if voice_name in voices_data: | |
del voices_data[voice_name] | |
print(f"Voice '{voice_name}' deleted.") | |
save_voices_json(voices_data) | |
return "Deleted selected voices successfully.", list(voices_data.keys()) | |
def upload_new_voices(uploaded_file): | |
"""Upload new voices from a JSON file. Returns (status_msg, updated_list_of_voices).""" | |
if uploaded_file is None: | |
return "No file uploaded.", list(load_voices_json().keys()) | |
try: | |
uploaded_data = json.load(uploaded_file) | |
if not isinstance(uploaded_data, dict): | |
return ( | |
"Invalid JSON format. Expected a dictionary of voices.", | |
list(load_voices_json().keys()), | |
) | |
voices_data = load_voices_json() | |
voices_data.update(uploaded_data) | |
save_voices_json(voices_data) | |
print(f"Voices uploaded: {list(uploaded_data.keys())}") | |
return "Voices uploaded successfully.", list(voices_data.keys()) | |
except json.JSONDecodeError: | |
return "Uploaded file is not valid JSON.", list(load_voices_json().keys()) | |
# ------------------------------------------------------------------- | |
# GRADIO INTERFACE | |
# ------------------------------------------------------------------- | |
def create_combined_interface(): | |
# We'll initially load the voices to get a default set for the dropdown | |
voices_data = load_voices_json() | |
voice_choices = list(voices_data.keys()) | |
default_voice = voice_choices[0] if voice_choices else None | |
css = """ | |
h4 { | |
text-align: center; | |
display:block; | |
} | |
""" | |
with gr.Blocks(theme=gr.themes.Ocean(), css=css) as demo: | |
gr.Markdown("# StyleTTS2 Studio - Build custom voices") | |
# ------------------------------------------------------- | |
# 1) Text-to-Speech Tab | |
# ------------------------------------------------------- | |
with gr.Tab("Text-to-Speech"): | |
gr.Markdown("### Generate Speech with Predefined Voices") | |
with gr.Column(): | |
text_input = gr.Textbox( | |
label="Text to Synthesize", | |
value="How much wood could a woodchuck chuck if a woodchuck could chuck wood?", | |
lines=3, | |
) | |
voice_dropdown = gr.Dropdown( | |
choices=voice_choices, | |
label="Select Base Voice", | |
value=default_voice, | |
interactive=True, | |
) | |
speed_slider = gr.Slider( | |
minimum=50, | |
maximum=200, | |
step=1, | |
label="Speed (%)", | |
value=120, | |
) | |
generate_btn = gr.Button("Generate Audio") | |
status_tts = gr.Textbox(label="Status", visible=False) | |
audio_output = gr.Audio(label="Synthesized Audio") | |
# Generate TTS callback | |
def on_generate_tts(text, voice, speed): | |
if not voice: | |
return None, "No voice selected." | |
speed_val = speed / 100 # Convert percentage to multiplier | |
audio_result, msg = generate_audio_with_voice(text, voice, speed_val) | |
if audio_result is None: | |
return None, msg | |
return audio_result, "Audio generated successfully." | |
generate_btn.click( | |
fn=on_generate_tts, | |
inputs=[text_input, voice_dropdown, speed_slider], | |
outputs=[audio_output, status_tts], | |
) | |
# ------------------------------------------------------- | |
# 2) Voice Studio Tab | |
# ------------------------------------------------------- | |
with gr.Tab("Voice Studio"): | |
gr.Markdown("### Customize and Create New Voices") | |
with gr.Column(): | |
text_input_studio = gr.Textbox( | |
label="Text to Synthesize", | |
value="Use the sliders to customize a voice!", | |
lines=3, | |
) | |
voice_dropdown_studio = gr.Dropdown( | |
choices=voice_choices, | |
label="Select Base Voice", | |
value=default_voice, | |
) | |
speed_slider_studio = gr.Slider( | |
minimum=50, | |
maximum=200, | |
step=1, | |
label="Speed (%)", | |
value=120, | |
) | |
# Sliders for PCA components (6 sliders) | |
pca_sliders = [ | |
gr.Slider( | |
minimum=-2.0, | |
maximum=2.0, | |
value=0.0, | |
step=0.1, | |
label=feature, | |
) | |
for feature in ANNOTATED_FEATURES_NAMES | |
] | |
generate_btn_studio = gr.Button("Generate Customized Audio") | |
audio_output_studio = gr.Audio(label="Customized Synthesized Audio") | |
new_style_name = gr.Textbox(label="New Style Name", value="") | |
save_btn_studio = gr.Button("Save Customized Voice") | |
status_text = gr.Textbox(label="Status", visible=True) | |
# State to hold the last style vector | |
style_vector_state_studio = gr.State() | |
# Generate customized audio callback | |
def on_generate_studio(text, voice, speed, *pca_values): | |
if not voice: | |
return None, "No voice selected.", None | |
speed_val = speed / 100 | |
audio_tuple, style_vector = generate_custom_audio( | |
text, voice, False, speed_val, *pca_values | |
) | |
if audio_tuple is None: | |
return None, "Failed to generate audio.", None | |
return audio_tuple, "Audio generated successfully.", style_vector | |
generate_btn_studio.click( | |
fn=on_generate_studio, | |
inputs=[text_input_studio, voice_dropdown_studio, speed_slider_studio] | |
+ pca_sliders, | |
outputs=[audio_output_studio, status_text, style_vector_state_studio], | |
) | |
# Save customized voice callback | |
def on_save_style_studio(style_vector, style_name): | |
"""Save the new style, then update the dropdown choices.""" | |
if not style_vector or not style_name: | |
return ( | |
gr.update(value="Please enter a name for the new voice!"), | |
gr.update(), | |
gr.update(), | |
) | |
# Save the style | |
result = save_style_to_json(style_vector, style_name) | |
# Reload the voices to get the new list | |
new_choices = list(load_voices_json().keys()) | |
# Return dictionary updates to existing components | |
return ( | |
gr.update(value=result), | |
gr.update(choices=new_choices), | |
gr.update(choices=new_choices), | |
) | |
save_btn_studio.click( | |
fn=on_save_style_studio, | |
inputs=[style_vector_state_studio, new_style_name], | |
# We update: status_text, voice_dropdown, voice_dropdown_studio | |
outputs=[status_text, voice_dropdown, voice_dropdown_studio], | |
) | |
# Update sliders callback | |
voice_dropdown_studio.change( | |
fn=update_sliders, | |
inputs=voice_dropdown_studio, | |
outputs=pca_sliders, | |
) | |
# ------------------------------------------------------- | |
# Optionally: Reload voices on page load | |
# ------------------------------------------------------- | |
def on_page_load(): | |
new_choices = list(load_voices_json().keys()) | |
return { | |
voice_dropdown: gr.update(choices=new_choices), | |
voice_dropdown_studio: gr.update(choices=new_choices), | |
} | |
# This automatically refreshes dropdowns every time the user loads/refreshes the page | |
demo.load( | |
on_page_load, inputs=None, outputs=[voice_dropdown, voice_dropdown_studio] | |
) | |
gr.Markdown( | |
"#### Based on [StyleTTS2](https://github.com/yl4579/StyleTTS2) and [artificial StyleTTS2](https://huggingface.co./dkounadis/artificial-styletts2/tree/main)" | |
) | |
return demo | |
if __name__ == "__main__": | |
try: | |
interface = create_combined_interface() | |
interface.launch(share=False) # or share=True if you want a public share link | |
except Exception as e: | |
print(f"An error occurred while launching the interface: {e}") | |