wuhp's picture
Create app.py
2353098 verified
import os
import gradio as gr
import torch
from TTS.api import TTS
import spaces # assumed custom module providing GPU decorators
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer
from threading import Thread
import logging
from typing import Tuple, List, Dict, Generator
import time
# NEW: Import whisper for speech-to-text.
import whisper
# ===========================
# Global Environment Settings
# ===========================
os.environ["COQUI_TOS_AGREED"] = "1"
# Global device override (will be updated from UI later)
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load the Whisper model (this may take a moment at startup)
whisper_model = whisper.load_model("base")
# Global dictionary for storing saved voice clones.
voice_bank: Dict[str, str] = {}
# ---------------------------
# Simple Response Cache
# ---------------------------
response_cache: Dict[str, str] = {}
# ===========================
# Voice Cloning Setup
# ===========================
tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2").to(device)
@spaces.GPU(enable_queue=True)
def clone(text, audio):
"""
Generate a voice-cloned audio file given text and a reference audio file.
Returns the path to the output audio file.
"""
try:
tts.tts_to_file(text=text, speaker_wav=audio, language="en", file_path="./output.wav")
return "./output.wav"
except Exception as e:
logging.error(f"TTS cloning failed: {e}")
return None
def save_voice(voice_name: str, voice_audio: str) -> None:
"""
Save a cloned voice under the given name.
"""
global voice_bank
if voice_name and voice_audio:
voice_bank[voice_name] = voice_audio
def get_voice_options() -> List[str]:
"""
Returns a list of saved voice names.
"""
return list(voice_bank.keys())
def refresh_voice_list() -> gr.update:
"""
Returns an update with the latest voice list.
"""
options = get_voice_options()
new_val = options[0] if options else ""
return gr.update(choices=options, value=new_val)
# ===========================
# Deep Agent Chat Setup
# ===========================
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
MODEL_ID = "unsloth/DeepSeek-R1-Distill-Qwen-7B-unsloth-bnb-4bit"
models: Dict[str, AutoModelForCausalLM] = {}
tokenizers: Dict[str, AutoTokenizer] = {}
bnb_config_4bit = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
def get_model_and_tokenizer() -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
# Warm-up: if the model isn’t loaded, load it now.
if "7B" not in models:
logging.info(f"Loading 7B model: {MODEL_ID} on demand")
try:
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
quantization_config=bnb_config_4bit,
torch_dtype=torch.bfloat16,
device_map='auto',
trust_remote_code=True,
)
model.eval()
models["7B"] = model
tokenizers["7B"] = tokenizer
logging.info("Loaded 7B model on demand.")
except Exception as e:
logging.error(f"Failed to load model and tokenizer: {e}")
raise e
return models["7B"], tokenizers["7B"]
# ---------------------------
# Prompt Templates
# ---------------------------
default_prompts = {
"coding": {
"brainstorm": (
"**Round 1: Brainstorm & Analysis**\n"
"Please analyze the following coding challenge or question. Consider the overall problem, "
"potential edge cases, and any assumptions you might need to make. Explain your reasoning as you think aloud.\n\n"
"**User Request:**\n{user_prompt}\n"
),
"round2": (
"**Round 2: Detailed Reasoning & Strategy**\n"
"Based on your initial analysis, please break down the problem into logical steps. "
"Outline a plan or strategy that could be used to solve the challenge, highlighting key algorithms, structures, or design considerations.\n\n"
"**Initial Analysis:**\n{brainstorm_response}\n\n"
"**User Request:**\n{user_prompt}\n"
),
"synthesis": (
"**Round 3: Synthesis & Implementation**\n"
"Taking into account the steps outlined previously, synthesize a coherent solution. "
"Provide a detailed explanation of how the code addresses the problem while encouraging best practices and clear logic.\n\n"
"**Detailed Strategy:**\n{round2_response}\n"
),
"rationale": (
"**Round 4: Reflection & Final Output**\n"
"Review your solution and provide a final, well-rounded response that summarizes your reasoning and the implementation strategy. "
"Explain any key decisions made during the process and how they contribute to an effective solution.\n\n"
"**Final Draft:**\n{final_response}\n"
)
},
"math": {
"brainstorm": (
"**Round 1: Problem Analysis & Exploration**\n"
"Carefully analyze the mathematical problem provided. Describe the underlying concepts and any assumptions you are making. "
"Detail your initial reasoning and potential methods to tackle the problem.\n\n"
"**Problem:**\n{user_prompt}\n"
),
"round2": (
"**Round 2: Detailed Reasoning & Methodology**\n"
"Based on your initial exploration, break down the problem into sequential steps or methodologies. "
"Explain the reasoning behind each step and how they connect to solve the problem.\n\n"
"**Initial Analysis:**\n{brainstorm_response}\n\n"
"**Problem:**\n{user_prompt}\n"
),
"synthesis": (
"**Round 3: Synthesis & Step-by-Step Solution**\n"
"Integrate your previous reasoning into a structured solution. Clearly explain each step of your calculation or proof, "
"ensuring that your logical progression is easy to follow.\n\n"
"**Detailed Methodology:**\n{round2_response}\n"
),
"rationale": (
"**Round 4: Reflection & Final Explanation**\n"
"Present your final solution along with a detailed explanation of the reasoning behind each step. "
"Discuss any assumptions and insights that helped you arrive at the final answer.\n\n"
"**Final Solution:**\n{final_response}\n"
)
},
"writing": {
"brainstorm": (
"**Round 1: Creative Exploration & Conceptualization**\n"
"Read the following writing prompt and explore its themes, tone, and potential narrative directions. "
"Outline your initial thoughts and reasoning behind various creative choices.\n\n"
"**Writing Prompt:**\n{user_prompt}\n"
),
"round2": (
"**Round 2: Detailed Outline & Narrative Structure**\n"
"Based on your brainstorming, create a detailed outline that organizes the narrative or essay. "
"Explain the reasoning behind your structure, the flow of ideas, and how you plan to incorporate creative elements.\n\n"
"**Initial Brainstorming:**\n{brainstorm_response}\n\n"
"**Writing Prompt:**\n{user_prompt}\n"
),
"synthesis": (
"**Round 3: Draft Synthesis & Refinement**\n"
"Integrate your outline and creative ideas into a coherent draft. Provide a well-rounded narrative that is both engaging and logically structured. "
"Explain your thought process as you refine the narrative.\n\n"
"**Outline & Strategy:**\n{round2_response}\n"
),
"rationale": (
"**Round 4: Reflection & Final Editing**\n"
"Review your draft and provide a final version that reflects thoughtful editing and creative reasoning. "
"Explain the choices made in refining the text, from structure to stylistic decisions.\n\n"
"**Final Draft:**\n{final_response}\n"
)
}
}
# The prompt state now contains both default and custom modes.
initial_prompt_state = {
"default": default_prompts,
"custom": {} # custom modes will be added here as {mode_name: [round_prompt1, round_prompt2, ...]}
}
def detect_domain(user_prompt: str) -> str:
prompt_lower = user_prompt.lower()
math_keywords = ["solve", "integral", "derivative", "equation", "proof", "calculate", "sum", "product"]
writing_keywords = ["write", "story", "essay", "novel", "poem", "article", "narrative", "creative"]
coding_keywords = ["code", "program", "debug", "compile", "algorithm", "function"]
if any(kw in prompt_lower for kw in math_keywords):
logging.info("Domain detected as: math")
return "math"
elif any(kw in prompt_lower for kw in writing_keywords):
logging.info("Domain detected as: writing")
return "writing"
elif any(kw in prompt_lower for kw in coding_keywords):
logging.info("Domain detected as: coding")
return "coding"
else:
logging.info("No specific domain detected; defaulting to coding")
return "coding"
class MemoryManager:
def __init__(self) -> None:
self.shared_memory: List[str] = []
def store(self, item: str) -> None:
self.shared_memory.append(item)
logging.info(f"[Memory Stored]: {item[:50]}...")
def retrieve(self, query: str, top_k: int = 3) -> List[str]:
query_lower = query.lower()
relevant = [item for item in self.shared_memory if query_lower in item.lower()]
if not relevant:
logging.info("[Memory Retrieval]: No relevant memories found.")
else:
logging.info(f"[Memory Retrieval]: Found {len(relevant)} relevant memories.")
return relevant[-top_k:]
global_memory_manager = MemoryManager()
def generate_response(model, tokenizer, prompt: str, max_tokens: int, temperature: float, top_p: float,
repetition_penalty: float = 1.0, num_beams: int = 1) -> str:
# Check cache first
cache_key = f"{prompt}-{max_tokens}-{temperature}-{top_p}-{repetition_penalty}-{num_beams}"
if cache_key in response_cache:
logging.info("Returning cached response.")
return response_cache[cache_key]
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
kwargs = dict(
input_ids=input_ids,
streamer=streamer,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True,
repetition_penalty=repetition_penalty,
num_beams=num_beams,
)
thread = Thread(target=model.generate, kwargs=kwargs)
with torch.no_grad():
thread.start()
response = ""
try:
for text in streamer:
response += text
except Exception as e:
logging.error(f"Error during generation: {e}")
raise e
thread.join()
# Cache the response
response_cache[cache_key] = response
return response
class MultiRoundAgent:
def __init__(self, model, tokenizer, prompt_templates, memory_manager: MemoryManager):
"""
prompt_templates can be a dict (for default modes) or a list (for custom modes)
"""
self.model = model
self.tokenizer = tokenizer
self.prompt_templates = prompt_templates
self.memory_manager = memory_manager
def run_pipeline(self, user_prompt: str, params: Dict, show_raw: bool = False) -> Generator[str, None, None]:
if isinstance(self.prompt_templates, dict):
# Default fixed 4-round pipeline
logging.info("--- Round 1 ---")
prompt_r1 = self.prompt_templates["brainstorm"].format(user_prompt=user_prompt)
r1 = generate_response(self.model, self.tokenizer, prompt_r1, params.get("max_new_tokens"), params.get("temp"),
params.get("top_p"), params.get("repetition_penalty"), params.get("num_beams"))
self.memory_manager.store(f"Round 1 Response: {r1}")
logging.info("--- Round 2 ---")
prompt_r2 = self.prompt_templates["round2"].format(brainstorm_response=r1, user_prompt=user_prompt)
r2 = generate_response(self.model, self.tokenizer, prompt_r2, params.get("max_new_tokens") + 100,
params.get("temp"), params.get("top_p"), params.get("repetition_penalty"), params.get("num_beams"))
self.memory_manager.store(f"Round 2 Response: {r2}")
logging.info("--- Round 3 ---")
prompt_r3 = self.prompt_templates["synthesis"].format(round2_response=r2)
input_ids_r3 = self.tokenizer.encode(prompt_r3, return_tensors="pt").to(self.model.device)
streamer_r3 = TextIteratorStreamer(self.tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
kwargs_r3 = dict(
input_ids=input_ids_r3,
streamer=streamer_r3,
max_new_tokens=params.get("max_new_tokens") // 2,
temperature=params.get("temp"),
top_p=params.get("top_p"),
repetition_penalty=params.get("repetition_penalty"),
num_beams=params.get("num_beams")
)
thread_r3 = Thread(target=self.model.generate, kwargs=kwargs_r3)
with torch.no_grad():
thread_r3.start()
r3 = ""
try:
for text in streamer_r3:
r3 += text
yield r3 # Progressive updates
except Exception as e:
logging.error(f"Error during Round 3 streaming: {e}")
raise e
thread_r3.join()
self.memory_manager.store(f"Final Synthesis Response: {r3}")
logging.info("--- Round 4 ---")
prompt_r4 = self.prompt_templates["rationale"].format(final_response=r3)
r4 = generate_response(self.model, self.tokenizer, prompt_r4, 300, params.get("temp"),
params.get("top_p"), params.get("repetition_penalty"), params.get("num_beams"))
self.memory_manager.store(f"Round 4 Response: {r4}")
final_output = (f"{r4}\n\n[Raw Outputs]\nRound 1:\n{r1}\n\nRound 2:\n{r2}\n\nRound 3:\n{r3}\n\nRound 4:\n{r4}\n") if show_raw else r4
yield final_output
elif isinstance(self.prompt_templates, list):
# Custom mode: iterate over rounds.
prev_response = ""
full_output = ""
total_rounds = len(self.prompt_templates)
for idx, round_template in enumerate(self.prompt_templates):
round_num = idx + 1
logging.info(f"--- Custom Mode: Round {round_num} of {total_rounds} ---")
if idx == 0:
prompt = round_template.format(user_prompt=user_prompt)
else:
prompt = round_template.format(user_prompt=user_prompt, prev_response=prev_response)
response = generate_response(self.model, self.tokenizer, prompt, params.get("max_new_tokens"),
params.get("temp"), params.get("top_p"), params.get("repetition_penalty"), params.get("num_beams"))
self.memory_manager.store(f"Custom Mode Round {round_num} Response: {response}")
full_output += f"\n--- Round {round_num} ---\n{response}"
prev_response = response
yield full_output
else:
yield "Invalid prompt template format."
@spaces.GPU(duration=180)
def swarm_agent_iterative(user_prompt: str, temp: float, top_p: float, max_new_tokens: int, memory_top_k: int,
prompt_templates, domain: str, show_raw: bool, repetition_penalty: float, num_beams: int) -> Generator[str, None, None]:
model, tokenizer = get_model_and_tokenizer()
agent = MultiRoundAgent(model, tokenizer, prompt_templates, global_memory_manager)
params = {
"temp": temp,
"top_p": top_p,
"max_new_tokens": max_new_tokens,
"repetition_penalty": repetition_penalty,
"num_beams": num_beams
}
return agent.run_pipeline(user_prompt, params, show_raw)
def handle_explanation_request(user_prompt: str, history: List) -> str:
retrieved = global_memory_manager.retrieve("Round 4 Response:", top_k=3)
explanation_prompt = "Below are previous final outputs and related context from our conversation:\n"
if retrieved:
for item in retrieved:
explanation_prompt += f"- {item}\n"
else:
explanation_prompt += "No stored final output found.\n"
explanation_prompt += "\nRecent related exchanges:\n"
for chat in history:
if ("explain" in chat[0].lower()) or (chat[1] and "explain" in chat[1].lower()):
explanation_prompt += f"User: {chat[0]}\nAssistant: {chat[1]}\n"
explanation_prompt += "\nBased on the above context, please provide a detailed explanation of the creative choices."
model, tokenizer = get_model_and_tokenizer()
explanation = generate_response(model, tokenizer, explanation_prompt, 300, 0.7, 0.9)
return explanation
def format_history(history: List) -> List[Dict[str, str]]:
messages = []
for item in history:
if isinstance(item, (list, tuple)) and len(item) == 2:
user_msg, assistant_msg = item
if user_msg == "__final_agent_response__":
continue
messages.append({"role": "user", "content": user_msg})
if assistant_msg:
messages.append({"role": "assistant", "content": assistant_msg})
elif isinstance(item, dict):
messages.append(item)
return messages
def gradio_interface(message: str, history: List, param_state: Dict, prompt_state: Dict, mode: str) -> Generator[List[Dict[str, str]], None, None]:
if "explain" in message.lower():
explanation = handle_explanation_request(message, history)
history = history + [[message, explanation]]
yield format_history(history)
return
try:
temp = float(param_state.get("temperature", 0.5))
top_p = float(param_state.get("top_p", 0.9))
max_new_tokens = int(param_state.get("max_new_tokens", 300))
repetition_penalty = float(param_state.get("repetition_penalty", 1.0))
num_beams = int(param_state.get("num_beams", 1))
memory_top_k = int(param_state.get("memory_top_k", 2))
show_raw = bool(param_state.get("show_raw_output", False))
except Exception as e:
logging.error(f"Parameter conversion error: {e}")
temp, top_p, max_new_tokens, repetition_penalty, num_beams, memory_top_k, show_raw = 0.5, 0.9, 300, 1.0, 1, 2, False
if mode in prompt_state.get("default", {}):
prompt_templates = prompt_state["default"][mode]
elif mode in prompt_state.get("custom", {}):
prompt_templates = prompt_state["custom"][mode]
else:
detected = detect_domain(message)
prompt_templates = prompt_state["default"].get(detected, prompt_state["default"]["coding"])
mode = detected
history = history + [[message, ""]]
# Show a loading status
yield format_history(history)
for partial_response in swarm_agent_iterative(
user_prompt=message,
temp=temp,
top_p=top_p,
max_new_tokens=max_new_tokens,
memory_top_k=memory_top_k,
prompt_templates=prompt_templates,
domain=mode,
show_raw=show_raw,
repetition_penalty=repetition_penalty,
num_beams=num_beams
):
history[-1][1] = partial_response
yield format_history(history)
yield format_history(history)
def generate_agent_audio(latest_text: str, voice_reference: str) -> str:
"""
Generate an audio response using the cloned voice.
If the provided voice_reference is a key in the voice bank, its stored file path is used.
"""
if latest_text:
if voice_reference in voice_bank:
audio_path = clone(latest_text, voice_bank[voice_reference])
else:
audio_path = clone(latest_text, voice_reference)
return audio_path
return None
# NEW: Speech-to-Text Function using Whisper.
def transcribe_audio(audio_file: str) -> str:
"""
Transcribe the provided audio file to text using the Whisper model.
"""
try:
result = whisper_model.transcribe(audio_file)
transcription = result.get("text", "").strip()
logging.info(f"Transcription result: {transcription}")
return transcription
except Exception as e:
logging.error(f"Transcription error: {e}")
return "Transcription failed."
# ---------------------------
# Warm-Up Model Function
# ---------------------------
def warmup_model():
try:
get_model_and_tokenizer()
logging.info("Model warm-up complete.")
except Exception as e:
logging.error(f"Model warm-up failed: {e}")
warmup_model()
# ===========================
# Custom Gradio Theme
# ===========================
theme = gr.themes.Soft(
primary_hue="pink",
secondary_hue="pink",
neutral_hue="purple",
font=['IBM Plex Sans', 'ui-sans-serif', 'system-ui', 'sans-serif'],
).set(
background_fill_primary='white',
shadow_drop='rgba(0,0,0,0.05) 0px 1px 2px 0px',
shadow_drop_lg='0 1px 3px 0 rgb(0 0 0 / 0.1), 0 1px 2px -1px rgb(0 0 0 / 0.1)',
shadow_spread='3px',
block_background_fill='*background_fill_primary',
block_border_width='1px',
block_border_width_dark='1px',
block_label_background_fill='*background_fill_primary',
block_label_background_fill_dark='*background_fill_secondary',
block_label_text_color='*neutral_500',
block_label_text_color_dark='*neutral_200',
block_label_margin='0',
block_label_padding='*spacing_sm *spacing_lg',
block_label_radius='calc(*radius_sm - 1px) 0 calc(*radius_sm - 1px) 0',
block_label_text_size='*text_sm',
block_label_text_weight='400',
block_title_background_fill='none',
block_title_background_fill_dark='none',
block_title_text_color='*neutral_500',
block_title_text_color_dark='*neutral_200',
block_title_padding='0',
block_title_radius='none',
block_title_text_weight='400',
panel_border_width='0',
panel_border_width_dark='0',
checkbox_background_color_selected='*color_accent',
checkbox_background_color_selected_dark='*color_accent',
checkbox_border_color='*neutral_300',
checkbox_border_color_dark='*neutral_700',
checkbox_border_color_focus='*color_accent',
checkbox_border_color_focus_dark='*color_accent',
checkbox_border_color_selected='*color_accent',
checkbox_border_color_selected_dark='*color_accent',
checkbox_border_width='*input_border_width',
checkbox_shadow='*input_shadow',
checkbox_label_background_fill_selected='*checkbox_label_background_fill',
checkbox_label_background_fill_selected_dark='*checkbox_label_background_fill',
checkbox_label_shadow='none',
checkbox_label_text_color_selected='*checkbox_label_text_color',
input_background_fill='*neutral_100',
input_border_color='*border_color_primary',
input_shadow='none',
input_shadow_dark='none',
input_shadow_focus='*input_shadow',
input_shadow_focus_dark='*input_shadow',
slider_color='*color_accent',
slider_color_dark='*color_accent',
button_primary_background_fill_hover='*primary_600',
button_primary_background_fill_hover_dark='*primary_700',
button_primary_shadow='none',
button_primary_shadow_hover='*button_primary_shadow',
button_primary_shadow_active='*button_primary_shadow',
button_primary_shadow_dark='none',
button_secondary_background_fill='*neutral_200',
button_secondary_background_fill_hover='*neutral_300',
button_secondary_background_fill_hover_dark='*neutral_700',
button_secondary_text_color='black',
button_secondary_shadow='*button_primary_shadow',
button_secondary_shadow_hover='*button_secondary_shadow',
button_secondary_shadow_active='*button_secondary_shadow',
button_secondary_shadow_dark='*button_primary_shadow'
)
# ===========================
# Combined Gradio Interface
# ===========================
with gr.Blocks(theme=theme, title="Combined Voice Clone & Agent Chat") as demo:
# Shared states for project settings, prompt configuration, and voice selection.
param_state = gr.State({
"temperature": 0.5,
"top_p": 0.9,
"max_new_tokens": 300,
"memory_top_k": 2,
"show_raw_output": False,
"repetition_penalty": 1.0,
"num_beams": 1,
"use_cpu": False # Toggle for device override
})
prompt_state = gr.State(initial_prompt_state)
selected_voice = gr.State(value="") # holds the currently selected voice
# A status display to show device info.
device_status = gr.Markdown(f"**Running on:** {device.upper()}")
with gr.Tabs():
# ----- Tab 1: Voice Setup -----
with gr.Tab("Voice Setup"):
gr.Markdown("<h2 style='text-align: center; padding-top: 10px;'>Voice Setup</h2>")
with gr.Column(variant="panel"):
gr.Markdown("<p style='text-align: center;'>Clone a voice and save it with a custom name. Test TTS using your cloned voices.</p>")
with gr.Row():
text_input = gr.Textbox(label='Text to Clone', placeholder="Enter the text to speak...", elem_classes="full-width")
with gr.Row():
audio_input = gr.Audio(label='Voice Reference Audio', type='filepath')
with gr.Row():
clone_btn = gr.Button("Clone Voice")
with gr.Row():
output_audio = gr.Audio(label='Cloned Voice Output', type='filepath')
clone_btn.click(fn=clone, inputs=[text_input, audio_input], outputs=output_audio)
with gr.Row():
voice_name_input = gr.Textbox(label="Voice Name", placeholder="Enter a name for this voice clone")
with gr.Row():
save_voice_btn = gr.Button("Save Voice")
save_voice_btn.click(fn=save_voice, inputs=[voice_name_input, output_audio], outputs=[])
with gr.Row():
refresh_voice_btn_setup = gr.Button("Refresh Voice List")
voice_dropdown_setup = gr.Dropdown(choices=get_voice_options(), label="Select Saved Voice", interactive=True)
set_voice_btn = gr.Button("Set Selected Voice")
refresh_voice_btn_setup.click(fn=refresh_voice_list, outputs=voice_dropdown_setup)
set_voice_btn.click(fn=lambda x: x, inputs=[voice_dropdown_setup], outputs=selected_voice)
gr.Markdown("<p style='text-align: center;'>(The selected voice will be used for TTS responses in Chat.)</p>")
gr.Markdown("<hr>")
gr.Markdown("<h3 style='text-align: center;'>TTS Test</h3>")
with gr.Row():
tts_test_input = gr.Textbox(label="Test Text", placeholder="Enter text to test TTS...", elem_classes="full-width")
with gr.Row():
tts_test_btn = gr.Button("Test TTS")
tts_test_output = gr.Audio(label="TTS Output", type="filepath")
tts_test_btn.click(fn=lambda txt, override, sel: generate_agent_audio(txt, override if override else sel),
inputs=[tts_test_input, audio_input, selected_voice],
outputs=tts_test_output)
# ----- Tab 2: Chat -----
with gr.Tab("Chat"):
gr.Markdown("""
<div style="text-align: center; padding: 10px;">
<h1>DeepSeek Agent Swarm Chat</h1>
<p>Multi-round agent with prompt chaining. Ask me anything!</p>
</div>
""")
with gr.Column():
with gr.Row():
mode_selector = gr.Radio(choices=["coding", "math", "writing"], value="coding", label="Select Mode")
with gr.Row():
chat_input = gr.Textbox(label="Your Message", placeholder="Type your message here...", lines=2, elem_id="msg_input")
with gr.Row():
chat_audio_input = gr.Audio(label="Or record/upload your message", type="filepath")
transcribe_btn = gr.Button("Transcribe Audio")
transcribe_btn.click(fn=transcribe_audio, inputs=chat_audio_input, outputs=chat_input)
with gr.Row():
send_btn = gr.Button("Send", variant="primary")
export_btn = gr.Button("Generate Chat Transcript")
chatbot = gr.Chatbot(height=450, label="Agent Swarm Output", type="messages")
with gr.Row():
use_tts_checkbox = gr.Checkbox(label="Generate Audio Response using TTS", value=False)
chat_voice_dropdown = gr.Dropdown(choices=get_voice_options(), label="Select Voice for TTS", interactive=True)
refresh_voice_btn_chat = gr.Button("Refresh Voice List")
refresh_voice_btn_chat.click(fn=refresh_voice_list, outputs=chat_voice_dropdown)
agent_audio = gr.Audio(label="Agent Audio Response", type="filepath")
def chat_wrapper(message, history, param_state, prompt_state, mode):
final_history = []
history.append(["", "**Generating response...**"])
for h in gradio_interface(message, history, param_state, prompt_state, mode):
final_history = h
return final_history
send_btn.click(fn=chat_wrapper,
inputs=[chat_input, chatbot, param_state, prompt_state, mode_selector],
outputs=[chatbot])
def conditional_tts(latest_text, use_tts, selected_voice_val):
if use_tts:
return generate_agent_audio(latest_text, selected_voice_val)
return None
def get_latest_text(chat_history):
for msg in reversed(chat_history):
if msg.get("role") == "assistant" and msg.get("content"):
return msg["content"]
return ""
latest_text_state = gr.State(value="")
gen_audio_btn = gr.Button("Generate Audio from Agent Response")
gen_audio_btn.click(fn=lambda chat: get_latest_text(chat),
inputs=[chatbot],
outputs=latest_text_state)
gen_audio_btn.click(fn=conditional_tts,
inputs=[latest_text_state, use_tts_checkbox, chat_voice_dropdown],
outputs=agent_audio)
def export_transcript(history):
transcript = ""
for item in history:
if isinstance(item, list) and len(item) == 2:
transcript += f"User: {item[0]}\nAssistant: {item[1]}\n\n"
return transcript
export_btn.click(fn=export_transcript, inputs=[chatbot], outputs=chatbot)
# ----- Tab 3: Project Settings -----
with gr.Tab("Project Settings"):
gr.Markdown("<h2 style='text-align: center;'>Project Settings</h2>")
with gr.Tabs():
with gr.Tab("Generation Parameters"):
gr.Markdown("<h3>Generation Parameters</h3>")
with gr.Row():
temp_slider = gr.Slider(minimum=0.1, maximum=1.0, step=0.1, value=0.5, label="Temperature")
top_p_slider = gr.Slider(minimum=0.01, maximum=1.0, step=0.05, value=0.9, label="Top P")
with gr.Row():
max_tokens_num = gr.Number(value=300, label="Max New Tokens", precision=0)
memory_topk_slider = gr.Slider(minimum=1, maximum=5, step=1, value=2, label="Memory Retrieval Top K")
with gr.Row():
rep_penalty_slider = gr.Slider(minimum=0.5, maximum=2.0, step=0.1, value=1.0, label="Repetition Penalty")
num_beams_slider = gr.Slider(minimum=1, maximum=5, step=1, value=1, label="Number of Beams")
with gr.Row():
show_raw_checkbox = gr.Checkbox(value=False, label="Show Raw Output")
use_cpu_checkbox = gr.Checkbox(value=False, label="Force Use CPU")
save_params_btn = gr.Button("Save Generation Parameters")
def save_params(t, p, m, k, rp, nb, s, use_cpu):
global device
if use_cpu:
device = "cpu"
else:
device = "cuda" if torch.cuda.is_available() else "cpu"
return {
"temperature": t,
"top_p": p,
"max_new_tokens": m,
"memory_top_k": k,
"repetition_penalty": rp,
"num_beams": nb,
"show_raw_output": s,
"use_cpu": use_cpu
}
save_params_btn.click(
save_params,
inputs=[temp_slider, top_p_slider, max_tokens_num, memory_topk_slider, rep_penalty_slider, num_beams_slider, show_raw_checkbox, use_cpu_checkbox],
outputs=param_state,
)
save_params_btn.click(fn=lambda params: f"**Running on:** {device.upper()}", inputs=param_state, outputs=device_status)
gr.Markdown("Note: Repetition penalty and number of beams affect generation diversity and quality.")
with gr.Tab("Prompt Config (Default Modes)"):
gr.Markdown("<h3>Prompt Configurations for Default Modes</h3>")
with gr.Tabs():
with gr.Tab("Coding"):
prompt_brainstorm_box_code = gr.Textbox(
value=default_prompts["coding"]["brainstorm"],
label="Brainstorm Prompt (Coding)",
lines=8,
)
prompt_round2_box_code = gr.Textbox(
value=default_prompts["coding"]["round2"],
label="Round 2 Prompt (Coding)",
lines=8,
)
prompt_synthesis_box_code = gr.Textbox(
value=default_prompts["coding"]["synthesis"],
label="Synthesis Prompt (Coding)",
lines=8,
)
prompt_rationale_box_code = gr.Textbox(
value=default_prompts["coding"]["rationale"],
label="Rationale Prompt (Coding)",
lines=8,
)
with gr.Tab("Math"):
prompt_brainstorm_box_math = gr.Textbox(
value=default_prompts["math"]["brainstorm"],
label="Brainstorm Prompt (Math)",
lines=8,
)
prompt_round2_box_math = gr.Textbox(
value=default_prompts["math"]["round2"],
label="Round 2 Prompt (Math)",
lines=8,
)
prompt_synthesis_box_math = gr.Textbox(
value=default_prompts["math"]["synthesis"],
label="Synthesis Prompt (Math)",
lines=8,
)
prompt_rationale_box_math = gr.Textbox(
value=default_prompts["math"]["rationale"],
label="Rationale Prompt (Math)",
lines=8,
)
with gr.Tab("Writing"):
prompt_brainstorm_box_writing = gr.Textbox(
value=default_prompts["writing"]["brainstorm"],
label="Brainstorm Prompt (Writing)",
lines=8,
)
prompt_round2_box_writing = gr.Textbox(
value=default_prompts["writing"]["round2"],
label="Round 2 Prompt (Writing)",
lines=8,
)
prompt_synthesis_box_writing = gr.Textbox(
value=default_prompts["writing"]["synthesis"],
label="Synthesis Prompt (Writing)",
lines=8,
)
prompt_rationale_box_writing = gr.Textbox(
value=default_prompts["writing"]["rationale"],
label="Rationale Prompt (Writing)",
lines=8,
)
save_prompts_btn = gr.Button("Save Default Prompt Configurations")
def save_default_prompts(code_brain, code_r2, code_syn, code_rat, math_brain, math_r2, math_syn, math_rat, writing_brain, writing_r2, writing_syn, writing_rat):
return {
"default": {
"coding": {
"brainstorm": code_brain,
"round2": code_r2,
"synthesis": code_syn,
"rationale": code_rat,
},
"math": {
"brainstorm": math_brain,
"round2": math_r2,
"synthesis": math_syn,
"rationale": math_rat,
},
"writing": {
"brainstorm": writing_brain,
"round2": writing_r2,
"synthesis": writing_syn,
"rationale": writing_rat,
}
},
"custom": prompt_state.value.get("custom", {})
}
save_prompts_btn.click(
save_default_prompts,
inputs=[prompt_brainstorm_box_code, prompt_round2_box_code, prompt_synthesis_box_code, prompt_rationale_box_code,
prompt_brainstorm_box_math, prompt_round2_box_math, prompt_synthesis_box_math, prompt_rationale_box_math,
prompt_brainstorm_box_writing, prompt_round2_box_writing, prompt_synthesis_box_writing, prompt_rationale_box_writing],
outputs=prompt_state,
)
with gr.Tab("Custom Modes"):
gr.Markdown("<h3>Create / Edit Custom Modes</h3>")
gr.Markdown(
"Define a custom mode by providing a unique mode name, selecting the number of rounds (up to 10), "
"and editing the prompt for each round. In custom mode prompts, you can use the placeholders `{user_prompt}` "
"(for the first round) and `{prev_response}` (for subsequent rounds)."
)
with gr.Row():
custom_mode_name = gr.Textbox(label="Custom Mode Name", placeholder="Enter a unique mode name")
custom_round_count = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Number of Rounds")
custom_round1 = gr.Textbox(label="Round 1 Prompt", lines=4, placeholder="e.g., Use {user_prompt} here")
custom_round2 = gr.Textbox(label="Round 2 Prompt", lines=4, placeholder="e.g., Use {user_prompt} and {prev_response}")
custom_round3 = gr.Textbox(label="Round 3 Prompt", lines=4, placeholder="e.g., Use {user_prompt} and {prev_response}")
custom_round4 = gr.Textbox(label="Round 4 Prompt", lines=4, placeholder="Optional")
custom_round5 = gr.Textbox(label="Round 5 Prompt", lines=4, placeholder="Optional")
custom_round6 = gr.Textbox(label="Round 6 Prompt", lines=4, placeholder="Optional")
custom_round7 = gr.Textbox(label="Round 7 Prompt", lines=4, placeholder="Optional")
custom_round8 = gr.Textbox(label="Round 8 Prompt", lines=4, placeholder="Optional")
custom_round9 = gr.Textbox(label="Round 9 Prompt", lines=4, placeholder="Optional")
custom_round10 = gr.Textbox(label="Round 10 Prompt", lines=4, placeholder="Optional")
def save_custom_mode(name, round_count, r1, r2, r3, r4, r5, r6, r7, r8, r9, r10, current_prompt_state):
if not name:
return gr.update(), current_prompt_state
rounds = []
round_prompts = [r1, r2, r3, r4, r5, r6, r7, r8, r9, r10]
for i in range(round_count):
if round_prompts[i].strip():
rounds.append(round_prompts[i])
custom_modes = current_prompt_state.get("custom", {})
custom_modes[name] = rounds
new_prompt_state = {
"default": current_prompt_state.get("default", {}),
"custom": custom_modes
}
return gr.update(value=""), new_prompt_state
save_custom_mode_btn = gr.Button("Save Custom Mode")
save_custom_mode_btn.click(
save_custom_mode,
inputs=[custom_mode_name, custom_round_count, custom_round1, custom_round2, custom_round3, custom_round4,
custom_round5, custom_round6, custom_round7, custom_round8, custom_round9, custom_round10,
prompt_state],
outputs=[custom_mode_name, prompt_state]
)
def update_mode_choices(current_prompt_state):
default_modes = list(current_prompt_state.get("default", {}).keys())
custom_modes = list(current_prompt_state.get("custom", {}).keys())
all_modes = default_modes + custom_modes
default_choice = default_modes[0] if default_modes else (custom_modes[0] if custom_modes else "")
return gr.update(choices=all_modes, value=default_choice)
refresh_mode_selector_btn = gr.Button("Refresh Mode List")
refresh_mode_selector_btn.click(fn=update_mode_choices, inputs=prompt_state, outputs=mode_selector)
gr.Markdown("<hr>")
gr.Markdown("<p style='text-align: center;'>These settings affect the entire project.</p>")
gr.Markdown("<hr><p style='text-align: center;'>Agent Chat using DeepSeek Agent Swarm</p>")
if __name__ == "__main__":
demo.launch(share=True)