Voice-Chat / app.py
NeoPy's picture
Update app.py
fe5270d verified
raw
history blame
11.1 kB
# ruff: noqa: E402
import json
import re
import tempfile
from importlib.resources import files
from groq import Groq
import os
import click
import gradio as gr
import numpy as np
import soundfile as sf
import torchaudio
from cached_path import cached_path
from transformers import AutoModelForCausalLM, AutoTokenizer
try:
import spaces
USING_SPACES = True
except ImportError:
USING_SPACES = False
def gpu_decorator(func):
if USING_SPACES:
return spaces.GPU(func)
else:
return func
from f5_tts.model import DiT, UNetT
from f5_tts.infer.utils_infer import (
load_vocoder,
load_model,
preprocess_ref_audio_text,
infer_process,
remove_silence_for_generated_wav,
save_spectrogram,
)
DEFAULT_TTS_MODEL = "F5-TTS"
tts_model_choice = DEFAULT_TTS_MODEL
DEFAULT_TTS_MODEL_CFG = [
"hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors",
"hf://SWivid/F5-TTS/F5TTS_Base/vocab.txt",
json.dumps(dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)),
]
# Load models
vocoder = load_vocoder()
def load_f5tts(ckpt_path=str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors"))):
F5TTS_model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
return load_model(DiT, F5TTS_model_cfg, ckpt_path)
F5TTS_ema_model = load_f5tts()
chat_model_state = None
chat_tokenizer_state = None
groq_token = os.getenv("Groq_TOKEN", None)
client = Groq(
api_key=groq_token,
)
@gpu_decorator
def generate_response(messages):
"""Generate response using Groq"""
chat_completion = client.chat.completions.create(
messages=[
{
"role": "user",
"content": messages,
}
] if isinstance(messages, str) else messages,
model="llama-3.3-70b-versatile",
stream=False,
)
return chat_completion.choices[0].message.content # this may need to be fixed
@gpu_decorator
def process_audio_input(audio_path, text, history, conv_state):
if not audio_path and not text.strip():
return history, conv_state, ""
if audio_path:
text = preprocess_ref_audio_text(audio_path, text)[1]
if not text.strip():
return history, conv_state, ""
conv_state.append({"role": "user", "content": text})
history.append((text, None))
response = generate_response(conv_state)
conv_state.append({"role": "assistant", "content": response})
history[-1] = (text, response)
return history, conv_state, ""
@gpu_decorator
def infer(
ref_audio_orig,
ref_text,
gen_text,
model,
remove_silence,
cross_fade_duration=0.15,
nfe_step=32,
speed=1,
show_info=gr.Info,
):
if not ref_audio_orig:
gr.Warning("Please provide reference audio.")
return gr.update(), gr.update(), ref_text
if not gen_text.strip():
gr.Warning("Please enter text to generate.")
return gr.update(), gr.update(), ref_text
ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=show_info)
ema_model = F5TTS_ema_model # Use F5-TTS by default
final_wave, final_sample_rate, combined_spectrogram = infer_process(
ref_audio,
ref_text,
gen_text,
ema_model,
vocoder,
cross_fade_duration=cross_fade_duration,
nfe_step=nfe_step,
speed=speed,
show_info=show_info,
progress=gr.Progress(),
)
if remove_silence:
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
sf.write(f.name, final_wave, final_sample_rate)
remove_silence_for_generated_wav(f.name)
final_wave, _ = torchaudio.load(f.name)
final_wave = final_wave.squeeze().cpu().numpy()
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_spectrogram:
spectrogram_path = tmp_spectrogram.name
save_spectrogram(combined_spectrogram, spectrogram_path)
return (final_sample_rate, final_wave), spectrogram_path, ref_text
with gr.Blocks() as app_chat:
gr.Markdown("""
# Voice Chat
Have a conversation with an AI using your reference voice!
1. Upload a reference audio clip and optionally its transcript.
2. Load the chat model.
3. Record your message through your microphone.
4. The AI will respond using the reference voice.
""")
if not USING_SPACES:
load_chat_model_btn = gr.Button("Load Chat Model", variant="primary")
chat_interface_container = gr.Column(visible=False)
@gpu_decorator
def load_chat_model():
global chat_model_state, chat_tokenizer_state
if chat_model_state is None:
gr.Info("Loading chat model...")
model_name = "deepseek-ai/Janus-Pro-7B"
chat_model_state = AutoModelForCausalLM.from_pretrained(
model_name, device_map="auto"
)
chat_tokenizer_state = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
gr.Info("Chat model loaded.")
return gr.update(visible=False), gr.update(visible=True)
load_chat_model_btn.click(load_chat_model, outputs=[load_chat_model_btn, chat_interface_container])
else:
chat_interface_container = gr.Column()
model_name = "deepseek-ai/Janus-Pro-7B"
chat_model_state = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto", trust_remote_code=True)
chat_tokenizer_state = AutoTokenizer.from_pretrained(model_name)
with chat_interface_container:
with gr.Row():
with gr.Column():
ref_audio_chat = gr.Audio(label="Reference Audio", type="filepath")
with gr.Column():
with gr.Accordion("Advanced Settings", open=False):
remove_silence_chat = gr.Checkbox(
label="Remove Silences",
value=True,
)
ref_text_chat = gr.Textbox(
label="Reference Text",
info="Optional: Leave blank to auto-transcribe",
lines=2,
)
system_prompt_chat = gr.Textbox(
label="System Prompt",
value="You are not an AI assistant, you are whoever the user says you are. You must stay in character. Keep your responses concise since they will be spoken out loud.",
lines=2,
)
chatbot_interface = gr.Chatbot(label="Conversation")
with gr.Row():
with gr.Column():
audio_input_chat = gr.Microphone(
label="Speak your message",
type="filepath",
)
audio_output_chat = gr.Audio(autoplay=True)
with gr.Column():
text_input_chat = gr.Textbox(
label="Type your message",
lines=1,
)
send_btn_chat = gr.Button("Send Message")
clear_btn_chat = gr.Button("Clear Conversation")
conversation_state = gr.State(
value=[
{
"role": "system",
"content": "You are not an AI assistant, you are whoever the user says you are. You must stay in character. Keep your responses concise since they will be spoken out loud.",
}
]
)
@gpu_decorator
def generate_audio_response(history, ref_audio, ref_text, remove_silence):
if not history or not ref_audio:
return None
last_user_message, last_ai_response = history[-1]
if not last_ai_response:
return None
audio_result, _, ref_text_out = infer(
ref_audio,
ref_text,
last_ai_response,
tts_model_choice,
remove_silence,
cross_fade_duration=0.15,
speed=1.0,
show_info=print,
)
return audio_result, ref_text_out
def clear_conversation():
return [], [{"role": "system", "content": "You are not an AI assistant, you are whoever the user says you are. You must stay in character. Keep your responses concise since they will be spoken out loud."}]
def update_system_prompt(new_prompt):
return [], [{"role": "system", "content": new_prompt}]
audio_input_chat.stop_recording(
process_audio_input,
inputs=[audio_input_chat, text_input_chat, chatbot_interface, conversation_state],
outputs=[chatbot_interface, conversation_state],
).then(
generate_audio_response,
inputs=[chatbot_interface, ref_audio_chat, ref_text_chat, remove_silence_chat],
outputs=[audio_output_chat, ref_text_chat],
).then(lambda: None, None, audio_input_chat)
text_input_chat.submit(
process_audio_input,
inputs=[audio_input_chat, text_input_chat, chatbot_interface, conversation_state],
outputs=[chatbot_interface, conversation_state],
).then(
generate_audio_response,
inputs=[chatbot_interface, ref_audio_chat, ref_text_chat, remove_silence_chat],
outputs=[audio_output_chat, ref_text_chat],
).then(lambda: None, None, text_input_chat)
send_btn_chat.click(
process_audio_input,
inputs=[audio_input_chat, text_input_chat, chatbot_interface, conversation_state],
outputs=[chatbot_interface, conversation_state],
).then(
generate_audio_response,
inputs=[chatbot_interface, ref_audio_chat, ref_text_chat, remove_silence_chat],
outputs=[audio_output_chat, ref_text_chat],
).then(lambda: None, None, text_input_chat)
clear_btn_chat.click(clear_conversation, outputs=[chatbot_interface, conversation_state])
system_prompt_chat.change(update_system_prompt, inputs=system_prompt_chat, outputs=[chatbot_interface, conversation_state])
app = app_chat
@click.command()
@click.option("--port", "-p", default=None, type=int, help="Port to run the app on")
@click.option("--host", "-H", default=None, help="Host to run the app on")
@click.option("--share", "-s", default=False, is_flag=True, help="Share the app via Gradio share link")
@click.option("--api", "-a", default=True, is_flag=True, help="Allow API access")
@click.option("--root_path", "-r", default=None, type=str, help='Root path for the application')
def main(port, host, share, api, root_path):
app.queue(api_open=api).launch(
server_name=host,
server_port=port,
share=share,
show_api=api,
root_path=root_path
)
if __name__ == "__main__":
main()