Spaces:
Running
Running
# ruff: noqa: E402 | |
import json | |
import re | |
import tempfile | |
from importlib.resources import files | |
from groq import Groq | |
import os | |
import click | |
import gradio as gr | |
import numpy as np | |
import soundfile as sf | |
import torchaudio | |
from cached_path import cached_path | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
try: | |
import spaces | |
USING_SPACES = True | |
except ImportError: | |
USING_SPACES = False | |
def gpu_decorator(func): | |
if USING_SPACES: | |
return spaces.GPU(func) | |
else: | |
return func | |
from f5_tts.model import DiT, UNetT | |
from f5_tts.infer.utils_infer import ( | |
load_vocoder, | |
load_model, | |
preprocess_ref_audio_text, | |
infer_process, | |
remove_silence_for_generated_wav, | |
save_spectrogram, | |
) | |
DEFAULT_TTS_MODEL = "F5-TTS" | |
tts_model_choice = DEFAULT_TTS_MODEL | |
DEFAULT_TTS_MODEL_CFG = [ | |
"hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors", | |
"hf://SWivid/F5-TTS/F5TTS_Base/vocab.txt", | |
json.dumps(dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)), | |
] | |
# Load models | |
vocoder = load_vocoder() | |
def load_f5tts(ckpt_path=str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors"))): | |
F5TTS_model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4) | |
return load_model(DiT, F5TTS_model_cfg, ckpt_path) | |
F5TTS_ema_model = load_f5tts() | |
chat_model_state = None | |
chat_tokenizer_state = None | |
groq_token = os.getenv("Groq_TOKEN", None) | |
client = Groq( | |
api_key=groq_token, | |
) | |
def generate_response(messages): | |
"""Generate response using Groq""" | |
chat_completion = client.chat.completions.create( | |
messages=[ | |
{ | |
"role": "user", | |
"content": messages, | |
} | |
] if isinstance(messages, str) else messages, | |
model="llama-3.3-70b-versatile", | |
stream=False, | |
) | |
return chat_completion.choices[0].message.content # this may need to be fixed | |
def process_audio_input(audio_path, text, history, conv_state): | |
if not audio_path and not text.strip(): | |
return history, conv_state, "" | |
if audio_path: | |
text = preprocess_ref_audio_text(audio_path, text)[1] | |
if not text.strip(): | |
return history, conv_state, "" | |
conv_state.append({"role": "user", "content": text}) | |
history.append((text, None)) | |
response = generate_response(conv_state) | |
conv_state.append({"role": "assistant", "content": response}) | |
history[-1] = (text, response) | |
return history, conv_state, "" | |
def infer( | |
ref_audio_orig, | |
ref_text, | |
gen_text, | |
model, | |
remove_silence, | |
cross_fade_duration=0.15, | |
nfe_step=32, | |
speed=1, | |
show_info=gr.Info, | |
): | |
if not ref_audio_orig: | |
gr.Warning("Please provide reference audio.") | |
return gr.update(), gr.update(), ref_text | |
if not gen_text.strip(): | |
gr.Warning("Please enter text to generate.") | |
return gr.update(), gr.update(), ref_text | |
ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=show_info) | |
ema_model = F5TTS_ema_model # Use F5-TTS by default | |
final_wave, final_sample_rate, combined_spectrogram = infer_process( | |
ref_audio, | |
ref_text, | |
gen_text, | |
ema_model, | |
vocoder, | |
cross_fade_duration=cross_fade_duration, | |
nfe_step=nfe_step, | |
speed=speed, | |
show_info=show_info, | |
progress=gr.Progress(), | |
) | |
if remove_silence: | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f: | |
sf.write(f.name, final_wave, final_sample_rate) | |
remove_silence_for_generated_wav(f.name) | |
final_wave, _ = torchaudio.load(f.name) | |
final_wave = final_wave.squeeze().cpu().numpy() | |
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_spectrogram: | |
spectrogram_path = tmp_spectrogram.name | |
save_spectrogram(combined_spectrogram, spectrogram_path) | |
return (final_sample_rate, final_wave), spectrogram_path, ref_text | |
with gr.Blocks() as app_chat: | |
gr.Markdown(""" | |
# Voice Chat | |
Have a conversation with an AI using your reference voice! | |
1. Upload a reference audio clip and optionally its transcript. | |
2. Load the chat model. | |
3. Record your message through your microphone. | |
4. The AI will respond using the reference voice. | |
""") | |
if not USING_SPACES: | |
load_chat_model_btn = gr.Button("Load Chat Model", variant="primary") | |
chat_interface_container = gr.Column(visible=False) | |
def load_chat_model(): | |
global chat_model_state, chat_tokenizer_state | |
if chat_model_state is None: | |
gr.Info("Loading chat model...") | |
model_name = "deepseek-ai/Janus-Pro-7B" | |
chat_model_state = AutoModelForCausalLM.from_pretrained( | |
model_name, device_map="auto" | |
) | |
chat_tokenizer_state = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | |
gr.Info("Chat model loaded.") | |
return gr.update(visible=False), gr.update(visible=True) | |
load_chat_model_btn.click(load_chat_model, outputs=[load_chat_model_btn, chat_interface_container]) | |
else: | |
chat_interface_container = gr.Column() | |
model_name = "deepseek-ai/Janus-Pro-7B" | |
chat_model_state = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto", trust_remote_code=True) | |
chat_tokenizer_state = AutoTokenizer.from_pretrained(model_name) | |
with chat_interface_container: | |
with gr.Row(): | |
with gr.Column(): | |
ref_audio_chat = gr.Audio(label="Reference Audio", type="filepath") | |
with gr.Column(): | |
with gr.Accordion("Advanced Settings", open=False): | |
remove_silence_chat = gr.Checkbox( | |
label="Remove Silences", | |
value=True, | |
) | |
ref_text_chat = gr.Textbox( | |
label="Reference Text", | |
info="Optional: Leave blank to auto-transcribe", | |
lines=2, | |
) | |
system_prompt_chat = gr.Textbox( | |
label="System Prompt", | |
value="You are not an AI assistant, you are whoever the user says you are. You must stay in character. Keep your responses concise since they will be spoken out loud.", | |
lines=2, | |
) | |
chatbot_interface = gr.Chatbot(label="Conversation") | |
with gr.Row(): | |
with gr.Column(): | |
audio_input_chat = gr.Microphone( | |
label="Speak your message", | |
type="filepath", | |
) | |
audio_output_chat = gr.Audio(autoplay=True) | |
with gr.Column(): | |
text_input_chat = gr.Textbox( | |
label="Type your message", | |
lines=1, | |
) | |
send_btn_chat = gr.Button("Send Message") | |
clear_btn_chat = gr.Button("Clear Conversation") | |
conversation_state = gr.State( | |
value=[ | |
{ | |
"role": "system", | |
"content": "You are not an AI assistant, you are whoever the user says you are. You must stay in character. Keep your responses concise since they will be spoken out loud.", | |
} | |
] | |
) | |
def generate_audio_response(history, ref_audio, ref_text, remove_silence): | |
if not history or not ref_audio: | |
return None | |
last_user_message, last_ai_response = history[-1] | |
if not last_ai_response: | |
return None | |
audio_result, _, ref_text_out = infer( | |
ref_audio, | |
ref_text, | |
last_ai_response, | |
tts_model_choice, | |
remove_silence, | |
cross_fade_duration=0.15, | |
speed=1.0, | |
show_info=print, | |
) | |
return audio_result, ref_text_out | |
def clear_conversation(): | |
return [], [{"role": "system", "content": "You are not an AI assistant, you are whoever the user says you are. You must stay in character. Keep your responses concise since they will be spoken out loud."}] | |
def update_system_prompt(new_prompt): | |
return [], [{"role": "system", "content": new_prompt}] | |
audio_input_chat.stop_recording( | |
process_audio_input, | |
inputs=[audio_input_chat, text_input_chat, chatbot_interface, conversation_state], | |
outputs=[chatbot_interface, conversation_state], | |
).then( | |
generate_audio_response, | |
inputs=[chatbot_interface, ref_audio_chat, ref_text_chat, remove_silence_chat], | |
outputs=[audio_output_chat, ref_text_chat], | |
).then(lambda: None, None, audio_input_chat) | |
text_input_chat.submit( | |
process_audio_input, | |
inputs=[audio_input_chat, text_input_chat, chatbot_interface, conversation_state], | |
outputs=[chatbot_interface, conversation_state], | |
).then( | |
generate_audio_response, | |
inputs=[chatbot_interface, ref_audio_chat, ref_text_chat, remove_silence_chat], | |
outputs=[audio_output_chat, ref_text_chat], | |
).then(lambda: None, None, text_input_chat) | |
send_btn_chat.click( | |
process_audio_input, | |
inputs=[audio_input_chat, text_input_chat, chatbot_interface, conversation_state], | |
outputs=[chatbot_interface, conversation_state], | |
).then( | |
generate_audio_response, | |
inputs=[chatbot_interface, ref_audio_chat, ref_text_chat, remove_silence_chat], | |
outputs=[audio_output_chat, ref_text_chat], | |
).then(lambda: None, None, text_input_chat) | |
clear_btn_chat.click(clear_conversation, outputs=[chatbot_interface, conversation_state]) | |
system_prompt_chat.change(update_system_prompt, inputs=system_prompt_chat, outputs=[chatbot_interface, conversation_state]) | |
app = app_chat | |
def main(port, host, share, api, root_path): | |
app.queue(api_open=api).launch( | |
server_name=host, | |
server_port=port, | |
share=share, | |
show_api=api, | |
root_path=root_path | |
) | |
if __name__ == "__main__": | |
main() |