# ruff: noqa: E402 import json import re import tempfile from importlib.resources import files from groq import Groq import os import click import gradio as gr import numpy as np import soundfile as sf import torchaudio from cached_path import cached_path from transformers import AutoModelForCausalLM, AutoTokenizer try: import spaces USING_SPACES = True except ImportError: USING_SPACES = False def gpu_decorator(func): if USING_SPACES: return spaces.GPU(func) else: return func from f5_tts.model import DiT, UNetT from f5_tts.infer.utils_infer import ( load_vocoder, load_model, preprocess_ref_audio_text, infer_process, remove_silence_for_generated_wav, save_spectrogram, ) DEFAULT_TTS_MODEL = "F5-TTS" tts_model_choice = DEFAULT_TTS_MODEL DEFAULT_TTS_MODEL_CFG = [ "hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors", "hf://SWivid/F5-TTS/F5TTS_Base/vocab.txt", json.dumps(dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)), ] # Load models vocoder = load_vocoder() def load_f5tts(ckpt_path=str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors"))): F5TTS_model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4) return load_model(DiT, F5TTS_model_cfg, ckpt_path) F5TTS_ema_model = load_f5tts() chat_model_state = None chat_tokenizer_state = None groq_token = os.getenv("Groq_TOKEN", None) client = Groq( api_key=groq_token, ) @gpu_decorator def generate_response(messages): """Generate response using Groq""" chat_completion = client.chat.completions.create( messages=[ { "role": "user", "content": messages, } ] if isinstance(messages, str) else messages, model="llama-3.3-70b-versatile", stream=False, ) return chat_completion.choices[0].message.content # this may need to be fixed @gpu_decorator def process_audio_input(audio_path, text, history, conv_state): if not audio_path and not text.strip(): return history, conv_state, "" if audio_path: text = preprocess_ref_audio_text(audio_path, text)[1] if not text.strip(): return history, conv_state, "" conv_state.append({"role": "user", "content": text}) history.append((text, None)) response = generate_response(conv_state) conv_state.append({"role": "assistant", "content": response}) history[-1] = (text, response) return history, conv_state, "" @gpu_decorator def infer( ref_audio_orig, ref_text, gen_text, model, remove_silence, cross_fade_duration=0.15, nfe_step=32, speed=1, show_info=gr.Info, ): if not ref_audio_orig: gr.Warning("Please provide reference audio.") return gr.update(), gr.update(), ref_text if not gen_text.strip(): gr.Warning("Please enter text to generate.") return gr.update(), gr.update(), ref_text ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=show_info) ema_model = F5TTS_ema_model # Use F5-TTS by default final_wave, final_sample_rate, combined_spectrogram = infer_process( ref_audio, ref_text, gen_text, ema_model, vocoder, cross_fade_duration=cross_fade_duration, nfe_step=nfe_step, speed=speed, show_info=show_info, progress=gr.Progress(), ) if remove_silence: with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f: sf.write(f.name, final_wave, final_sample_rate) remove_silence_for_generated_wav(f.name) final_wave, _ = torchaudio.load(f.name) final_wave = final_wave.squeeze().cpu().numpy() with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_spectrogram: spectrogram_path = tmp_spectrogram.name save_spectrogram(combined_spectrogram, spectrogram_path) return (final_sample_rate, final_wave), spectrogram_path, ref_text with gr.Blocks() as app_chat: gr.Markdown(""" # Voice Chat Have a conversation with an AI using your reference voice! 1. Upload a reference audio clip and optionally its transcript. 2. Load the chat model. 3. Record your message through your microphone. 4. The AI will respond using the reference voice. """) if not USING_SPACES: load_chat_model_btn = gr.Button("Load Chat Model", variant="primary") chat_interface_container = gr.Column(visible=False) @gpu_decorator def load_chat_model(): global chat_model_state, chat_tokenizer_state if chat_model_state is None: gr.Info("Loading chat model...") model_name = "deepseek-ai/Janus-Pro-7B" chat_model_state = AutoModelForCausalLM.from_pretrained( model_name, device_map="auto" ) chat_tokenizer_state = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) gr.Info("Chat model loaded.") return gr.update(visible=False), gr.update(visible=True) load_chat_model_btn.click(load_chat_model, outputs=[load_chat_model_btn, chat_interface_container]) else: chat_interface_container = gr.Column() model_name = "deepseek-ai/Janus-Pro-7B" chat_model_state = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto", trust_remote_code=True) chat_tokenizer_state = AutoTokenizer.from_pretrained(model_name) with chat_interface_container: with gr.Row(): with gr.Column(): ref_audio_chat = gr.Audio(label="Reference Audio", type="filepath") with gr.Column(): with gr.Accordion("Advanced Settings", open=False): remove_silence_chat = gr.Checkbox( label="Remove Silences", value=True, ) ref_text_chat = gr.Textbox( label="Reference Text", info="Optional: Leave blank to auto-transcribe", lines=2, ) system_prompt_chat = gr.Textbox( label="System Prompt", value="You are not an AI assistant, you are whoever the user says you are. You must stay in character. Keep your responses concise since they will be spoken out loud.", lines=2, ) chatbot_interface = gr.Chatbot(label="Conversation") with gr.Row(): with gr.Column(): audio_input_chat = gr.Microphone( label="Speak your message", type="filepath", ) audio_output_chat = gr.Audio(autoplay=True) with gr.Column(): text_input_chat = gr.Textbox( label="Type your message", lines=1, ) send_btn_chat = gr.Button("Send Message") clear_btn_chat = gr.Button("Clear Conversation") conversation_state = gr.State( value=[ { "role": "system", "content": "You are not an AI assistant, you are whoever the user says you are. You must stay in character. Keep your responses concise since they will be spoken out loud.", } ] ) @gpu_decorator def generate_audio_response(history, ref_audio, ref_text, remove_silence): if not history or not ref_audio: return None last_user_message, last_ai_response = history[-1] if not last_ai_response: return None audio_result, _, ref_text_out = infer( ref_audio, ref_text, last_ai_response, tts_model_choice, remove_silence, cross_fade_duration=0.15, speed=1.0, show_info=print, ) return audio_result, ref_text_out def clear_conversation(): return [], [{"role": "system", "content": "You are not an AI assistant, you are whoever the user says you are. You must stay in character. Keep your responses concise since they will be spoken out loud."}] def update_system_prompt(new_prompt): return [], [{"role": "system", "content": new_prompt}] audio_input_chat.stop_recording( process_audio_input, inputs=[audio_input_chat, text_input_chat, chatbot_interface, conversation_state], outputs=[chatbot_interface, conversation_state], ).then( generate_audio_response, inputs=[chatbot_interface, ref_audio_chat, ref_text_chat, remove_silence_chat], outputs=[audio_output_chat, ref_text_chat], ).then(lambda: None, None, audio_input_chat) text_input_chat.submit( process_audio_input, inputs=[audio_input_chat, text_input_chat, chatbot_interface, conversation_state], outputs=[chatbot_interface, conversation_state], ).then( generate_audio_response, inputs=[chatbot_interface, ref_audio_chat, ref_text_chat, remove_silence_chat], outputs=[audio_output_chat, ref_text_chat], ).then(lambda: None, None, text_input_chat) send_btn_chat.click( process_audio_input, inputs=[audio_input_chat, text_input_chat, chatbot_interface, conversation_state], outputs=[chatbot_interface, conversation_state], ).then( generate_audio_response, inputs=[chatbot_interface, ref_audio_chat, ref_text_chat, remove_silence_chat], outputs=[audio_output_chat, ref_text_chat], ).then(lambda: None, None, text_input_chat) clear_btn_chat.click(clear_conversation, outputs=[chatbot_interface, conversation_state]) system_prompt_chat.change(update_system_prompt, inputs=system_prompt_chat, outputs=[chatbot_interface, conversation_state]) app = app_chat @click.command() @click.option("--port", "-p", default=None, type=int, help="Port to run the app on") @click.option("--host", "-H", default=None, help="Host to run the app on") @click.option("--share", "-s", default=False, is_flag=True, help="Share the app via Gradio share link") @click.option("--api", "-a", default=True, is_flag=True, help="Allow API access") @click.option("--root_path", "-r", default=None, type=str, help='Root path for the application') def main(port, host, share, api, root_path): app.queue(api_open=api).launch( server_name=host, server_port=port, share=share, show_api=api, root_path=root_path ) if __name__ == "__main__": main()