Spaces:
Build error
Build error
import os | |
import gradio as gr | |
import torch | |
from TTS.api import TTS | |
import spaces # assumed custom module providing GPU decorators | |
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer | |
from threading import Thread | |
import logging | |
from typing import Tuple, List, Dict, Generator | |
import time | |
# NEW: Import whisper for speech-to-text. | |
import whisper | |
# =========================== | |
# Global Environment Settings | |
# =========================== | |
os.environ["COQUI_TOS_AGREED"] = "1" | |
# Global device override (will be updated from UI later) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Load the Whisper model (this may take a moment at startup) | |
whisper_model = whisper.load_model("base") | |
# Global dictionary for storing saved voice clones. | |
voice_bank: Dict[str, str] = {} | |
# --------------------------- | |
# Simple Response Cache | |
# --------------------------- | |
response_cache: Dict[str, str] = {} | |
# =========================== | |
# Voice Cloning Setup | |
# =========================== | |
tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2").to(device) | |
def clone(text, audio): | |
""" | |
Generate a voice-cloned audio file given text and a reference audio file. | |
Returns the path to the output audio file. | |
""" | |
try: | |
tts.tts_to_file(text=text, speaker_wav=audio, language="en", file_path="./output.wav") | |
return "./output.wav" | |
except Exception as e: | |
logging.error(f"TTS cloning failed: {e}") | |
return None | |
def save_voice(voice_name: str, voice_audio: str) -> None: | |
""" | |
Save a cloned voice under the given name. | |
""" | |
global voice_bank | |
if voice_name and voice_audio: | |
voice_bank[voice_name] = voice_audio | |
def get_voice_options() -> List[str]: | |
""" | |
Returns a list of saved voice names. | |
""" | |
return list(voice_bank.keys()) | |
def refresh_voice_list() -> gr.update: | |
""" | |
Returns an update with the latest voice list. | |
""" | |
options = get_voice_options() | |
new_val = options[0] if options else "" | |
return gr.update(choices=options, value=new_val) | |
# =========================== | |
# Deep Agent Chat Setup | |
# =========================== | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
MODEL_ID = "unsloth/DeepSeek-R1-Distill-Qwen-7B-unsloth-bnb-4bit" | |
models: Dict[str, AutoModelForCausalLM] = {} | |
tokenizers: Dict[str, AutoTokenizer] = {} | |
bnb_config_4bit = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_quant_type="nf4", | |
bnb_4bit_compute_dtype=torch.bfloat16, | |
) | |
def get_model_and_tokenizer() -> Tuple[AutoModelForCausalLM, AutoTokenizer]: | |
# Warm-up: if the model isn’t loaded, load it now. | |
if "7B" not in models: | |
logging.info(f"Loading 7B model: {MODEL_ID} on demand") | |
try: | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_ID, | |
quantization_config=bnb_config_4bit, | |
torch_dtype=torch.bfloat16, | |
device_map='auto', | |
trust_remote_code=True, | |
) | |
model.eval() | |
models["7B"] = model | |
tokenizers["7B"] = tokenizer | |
logging.info("Loaded 7B model on demand.") | |
except Exception as e: | |
logging.error(f"Failed to load model and tokenizer: {e}") | |
raise e | |
return models["7B"], tokenizers["7B"] | |
# --------------------------- | |
# Prompt Templates | |
# --------------------------- | |
default_prompts = { | |
"coding": { | |
"brainstorm": ( | |
"**Round 1: Brainstorm & Analysis**\n" | |
"Please analyze the following coding challenge or question. Consider the overall problem, " | |
"potential edge cases, and any assumptions you might need to make. Explain your reasoning as you think aloud.\n\n" | |
"**User Request:**\n{user_prompt}\n" | |
), | |
"round2": ( | |
"**Round 2: Detailed Reasoning & Strategy**\n" | |
"Based on your initial analysis, please break down the problem into logical steps. " | |
"Outline a plan or strategy that could be used to solve the challenge, highlighting key algorithms, structures, or design considerations.\n\n" | |
"**Initial Analysis:**\n{brainstorm_response}\n\n" | |
"**User Request:**\n{user_prompt}\n" | |
), | |
"synthesis": ( | |
"**Round 3: Synthesis & Implementation**\n" | |
"Taking into account the steps outlined previously, synthesize a coherent solution. " | |
"Provide a detailed explanation of how the code addresses the problem while encouraging best practices and clear logic.\n\n" | |
"**Detailed Strategy:**\n{round2_response}\n" | |
), | |
"rationale": ( | |
"**Round 4: Reflection & Final Output**\n" | |
"Review your solution and provide a final, well-rounded response that summarizes your reasoning and the implementation strategy. " | |
"Explain any key decisions made during the process and how they contribute to an effective solution.\n\n" | |
"**Final Draft:**\n{final_response}\n" | |
) | |
}, | |
"math": { | |
"brainstorm": ( | |
"**Round 1: Problem Analysis & Exploration**\n" | |
"Carefully analyze the mathematical problem provided. Describe the underlying concepts and any assumptions you are making. " | |
"Detail your initial reasoning and potential methods to tackle the problem.\n\n" | |
"**Problem:**\n{user_prompt}\n" | |
), | |
"round2": ( | |
"**Round 2: Detailed Reasoning & Methodology**\n" | |
"Based on your initial exploration, break down the problem into sequential steps or methodologies. " | |
"Explain the reasoning behind each step and how they connect to solve the problem.\n\n" | |
"**Initial Analysis:**\n{brainstorm_response}\n\n" | |
"**Problem:**\n{user_prompt}\n" | |
), | |
"synthesis": ( | |
"**Round 3: Synthesis & Step-by-Step Solution**\n" | |
"Integrate your previous reasoning into a structured solution. Clearly explain each step of your calculation or proof, " | |
"ensuring that your logical progression is easy to follow.\n\n" | |
"**Detailed Methodology:**\n{round2_response}\n" | |
), | |
"rationale": ( | |
"**Round 4: Reflection & Final Explanation**\n" | |
"Present your final solution along with a detailed explanation of the reasoning behind each step. " | |
"Discuss any assumptions and insights that helped you arrive at the final answer.\n\n" | |
"**Final Solution:**\n{final_response}\n" | |
) | |
}, | |
"writing": { | |
"brainstorm": ( | |
"**Round 1: Creative Exploration & Conceptualization**\n" | |
"Read the following writing prompt and explore its themes, tone, and potential narrative directions. " | |
"Outline your initial thoughts and reasoning behind various creative choices.\n\n" | |
"**Writing Prompt:**\n{user_prompt}\n" | |
), | |
"round2": ( | |
"**Round 2: Detailed Outline & Narrative Structure**\n" | |
"Based on your brainstorming, create a detailed outline that organizes the narrative or essay. " | |
"Explain the reasoning behind your structure, the flow of ideas, and how you plan to incorporate creative elements.\n\n" | |
"**Initial Brainstorming:**\n{brainstorm_response}\n\n" | |
"**Writing Prompt:**\n{user_prompt}\n" | |
), | |
"synthesis": ( | |
"**Round 3: Draft Synthesis & Refinement**\n" | |
"Integrate your outline and creative ideas into a coherent draft. Provide a well-rounded narrative that is both engaging and logically structured. " | |
"Explain your thought process as you refine the narrative.\n\n" | |
"**Outline & Strategy:**\n{round2_response}\n" | |
), | |
"rationale": ( | |
"**Round 4: Reflection & Final Editing**\n" | |
"Review your draft and provide a final version that reflects thoughtful editing and creative reasoning. " | |
"Explain the choices made in refining the text, from structure to stylistic decisions.\n\n" | |
"**Final Draft:**\n{final_response}\n" | |
) | |
} | |
} | |
# The prompt state now contains both default and custom modes. | |
initial_prompt_state = { | |
"default": default_prompts, | |
"custom": {} # custom modes will be added here as {mode_name: [round_prompt1, round_prompt2, ...]} | |
} | |
def detect_domain(user_prompt: str) -> str: | |
prompt_lower = user_prompt.lower() | |
math_keywords = ["solve", "integral", "derivative", "equation", "proof", "calculate", "sum", "product"] | |
writing_keywords = ["write", "story", "essay", "novel", "poem", "article", "narrative", "creative"] | |
coding_keywords = ["code", "program", "debug", "compile", "algorithm", "function"] | |
if any(kw in prompt_lower for kw in math_keywords): | |
logging.info("Domain detected as: math") | |
return "math" | |
elif any(kw in prompt_lower for kw in writing_keywords): | |
logging.info("Domain detected as: writing") | |
return "writing" | |
elif any(kw in prompt_lower for kw in coding_keywords): | |
logging.info("Domain detected as: coding") | |
return "coding" | |
else: | |
logging.info("No specific domain detected; defaulting to coding") | |
return "coding" | |
class MemoryManager: | |
def __init__(self) -> None: | |
self.shared_memory: List[str] = [] | |
def store(self, item: str) -> None: | |
self.shared_memory.append(item) | |
logging.info(f"[Memory Stored]: {item[:50]}...") | |
def retrieve(self, query: str, top_k: int = 3) -> List[str]: | |
query_lower = query.lower() | |
relevant = [item for item in self.shared_memory if query_lower in item.lower()] | |
if not relevant: | |
logging.info("[Memory Retrieval]: No relevant memories found.") | |
else: | |
logging.info(f"[Memory Retrieval]: Found {len(relevant)} relevant memories.") | |
return relevant[-top_k:] | |
global_memory_manager = MemoryManager() | |
def generate_response(model, tokenizer, prompt: str, max_tokens: int, temperature: float, top_p: float, | |
repetition_penalty: float = 1.0, num_beams: int = 1) -> str: | |
# Check cache first | |
cache_key = f"{prompt}-{max_tokens}-{temperature}-{top_p}-{repetition_penalty}-{num_beams}" | |
if cache_key in response_cache: | |
logging.info("Returning cached response.") | |
return response_cache[cache_key] | |
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device) | |
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) | |
kwargs = dict( | |
input_ids=input_ids, | |
streamer=streamer, | |
max_new_tokens=max_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
do_sample=True, | |
repetition_penalty=repetition_penalty, | |
num_beams=num_beams, | |
) | |
thread = Thread(target=model.generate, kwargs=kwargs) | |
with torch.no_grad(): | |
thread.start() | |
response = "" | |
try: | |
for text in streamer: | |
response += text | |
except Exception as e: | |
logging.error(f"Error during generation: {e}") | |
raise e | |
thread.join() | |
# Cache the response | |
response_cache[cache_key] = response | |
return response | |
class MultiRoundAgent: | |
def __init__(self, model, tokenizer, prompt_templates, memory_manager: MemoryManager): | |
""" | |
prompt_templates can be a dict (for default modes) or a list (for custom modes) | |
""" | |
self.model = model | |
self.tokenizer = tokenizer | |
self.prompt_templates = prompt_templates | |
self.memory_manager = memory_manager | |
def run_pipeline(self, user_prompt: str, params: Dict, show_raw: bool = False) -> Generator[str, None, None]: | |
if isinstance(self.prompt_templates, dict): | |
# Default fixed 4-round pipeline | |
logging.info("--- Round 1 ---") | |
prompt_r1 = self.prompt_templates["brainstorm"].format(user_prompt=user_prompt) | |
r1 = generate_response(self.model, self.tokenizer, prompt_r1, params.get("max_new_tokens"), params.get("temp"), | |
params.get("top_p"), params.get("repetition_penalty"), params.get("num_beams")) | |
self.memory_manager.store(f"Round 1 Response: {r1}") | |
logging.info("--- Round 2 ---") | |
prompt_r2 = self.prompt_templates["round2"].format(brainstorm_response=r1, user_prompt=user_prompt) | |
r2 = generate_response(self.model, self.tokenizer, prompt_r2, params.get("max_new_tokens") + 100, | |
params.get("temp"), params.get("top_p"), params.get("repetition_penalty"), params.get("num_beams")) | |
self.memory_manager.store(f"Round 2 Response: {r2}") | |
logging.info("--- Round 3 ---") | |
prompt_r3 = self.prompt_templates["synthesis"].format(round2_response=r2) | |
input_ids_r3 = self.tokenizer.encode(prompt_r3, return_tensors="pt").to(self.model.device) | |
streamer_r3 = TextIteratorStreamer(self.tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) | |
kwargs_r3 = dict( | |
input_ids=input_ids_r3, | |
streamer=streamer_r3, | |
max_new_tokens=params.get("max_new_tokens") // 2, | |
temperature=params.get("temp"), | |
top_p=params.get("top_p"), | |
repetition_penalty=params.get("repetition_penalty"), | |
num_beams=params.get("num_beams") | |
) | |
thread_r3 = Thread(target=self.model.generate, kwargs=kwargs_r3) | |
with torch.no_grad(): | |
thread_r3.start() | |
r3 = "" | |
try: | |
for text in streamer_r3: | |
r3 += text | |
yield r3 # Progressive updates | |
except Exception as e: | |
logging.error(f"Error during Round 3 streaming: {e}") | |
raise e | |
thread_r3.join() | |
self.memory_manager.store(f"Final Synthesis Response: {r3}") | |
logging.info("--- Round 4 ---") | |
prompt_r4 = self.prompt_templates["rationale"].format(final_response=r3) | |
r4 = generate_response(self.model, self.tokenizer, prompt_r4, 300, params.get("temp"), | |
params.get("top_p"), params.get("repetition_penalty"), params.get("num_beams")) | |
self.memory_manager.store(f"Round 4 Response: {r4}") | |
final_output = (f"{r4}\n\n[Raw Outputs]\nRound 1:\n{r1}\n\nRound 2:\n{r2}\n\nRound 3:\n{r3}\n\nRound 4:\n{r4}\n") if show_raw else r4 | |
yield final_output | |
elif isinstance(self.prompt_templates, list): | |
# Custom mode: iterate over rounds. | |
prev_response = "" | |
full_output = "" | |
total_rounds = len(self.prompt_templates) | |
for idx, round_template in enumerate(self.prompt_templates): | |
round_num = idx + 1 | |
logging.info(f"--- Custom Mode: Round {round_num} of {total_rounds} ---") | |
if idx == 0: | |
prompt = round_template.format(user_prompt=user_prompt) | |
else: | |
prompt = round_template.format(user_prompt=user_prompt, prev_response=prev_response) | |
response = generate_response(self.model, self.tokenizer, prompt, params.get("max_new_tokens"), | |
params.get("temp"), params.get("top_p"), params.get("repetition_penalty"), params.get("num_beams")) | |
self.memory_manager.store(f"Custom Mode Round {round_num} Response: {response}") | |
full_output += f"\n--- Round {round_num} ---\n{response}" | |
prev_response = response | |
yield full_output | |
else: | |
yield "Invalid prompt template format." | |
def swarm_agent_iterative(user_prompt: str, temp: float, top_p: float, max_new_tokens: int, memory_top_k: int, | |
prompt_templates, domain: str, show_raw: bool, repetition_penalty: float, num_beams: int) -> Generator[str, None, None]: | |
model, tokenizer = get_model_and_tokenizer() | |
agent = MultiRoundAgent(model, tokenizer, prompt_templates, global_memory_manager) | |
params = { | |
"temp": temp, | |
"top_p": top_p, | |
"max_new_tokens": max_new_tokens, | |
"repetition_penalty": repetition_penalty, | |
"num_beams": num_beams | |
} | |
return agent.run_pipeline(user_prompt, params, show_raw) | |
def handle_explanation_request(user_prompt: str, history: List) -> str: | |
retrieved = global_memory_manager.retrieve("Round 4 Response:", top_k=3) | |
explanation_prompt = "Below are previous final outputs and related context from our conversation:\n" | |
if retrieved: | |
for item in retrieved: | |
explanation_prompt += f"- {item}\n" | |
else: | |
explanation_prompt += "No stored final output found.\n" | |
explanation_prompt += "\nRecent related exchanges:\n" | |
for chat in history: | |
if ("explain" in chat[0].lower()) or (chat[1] and "explain" in chat[1].lower()): | |
explanation_prompt += f"User: {chat[0]}\nAssistant: {chat[1]}\n" | |
explanation_prompt += "\nBased on the above context, please provide a detailed explanation of the creative choices." | |
model, tokenizer = get_model_and_tokenizer() | |
explanation = generate_response(model, tokenizer, explanation_prompt, 300, 0.7, 0.9) | |
return explanation | |
def format_history(history: List) -> List[Dict[str, str]]: | |
messages = [] | |
for item in history: | |
if isinstance(item, (list, tuple)) and len(item) == 2: | |
user_msg, assistant_msg = item | |
if user_msg == "__final_agent_response__": | |
continue | |
messages.append({"role": "user", "content": user_msg}) | |
if assistant_msg: | |
messages.append({"role": "assistant", "content": assistant_msg}) | |
elif isinstance(item, dict): | |
messages.append(item) | |
return messages | |
def gradio_interface(message: str, history: List, param_state: Dict, prompt_state: Dict, mode: str) -> Generator[List[Dict[str, str]], None, None]: | |
if "explain" in message.lower(): | |
explanation = handle_explanation_request(message, history) | |
history = history + [[message, explanation]] | |
yield format_history(history) | |
return | |
try: | |
temp = float(param_state.get("temperature", 0.5)) | |
top_p = float(param_state.get("top_p", 0.9)) | |
max_new_tokens = int(param_state.get("max_new_tokens", 300)) | |
repetition_penalty = float(param_state.get("repetition_penalty", 1.0)) | |
num_beams = int(param_state.get("num_beams", 1)) | |
memory_top_k = int(param_state.get("memory_top_k", 2)) | |
show_raw = bool(param_state.get("show_raw_output", False)) | |
except Exception as e: | |
logging.error(f"Parameter conversion error: {e}") | |
temp, top_p, max_new_tokens, repetition_penalty, num_beams, memory_top_k, show_raw = 0.5, 0.9, 300, 1.0, 1, 2, False | |
if mode in prompt_state.get("default", {}): | |
prompt_templates = prompt_state["default"][mode] | |
elif mode in prompt_state.get("custom", {}): | |
prompt_templates = prompt_state["custom"][mode] | |
else: | |
detected = detect_domain(message) | |
prompt_templates = prompt_state["default"].get(detected, prompt_state["default"]["coding"]) | |
mode = detected | |
history = history + [[message, ""]] | |
# Show a loading status | |
yield format_history(history) | |
for partial_response in swarm_agent_iterative( | |
user_prompt=message, | |
temp=temp, | |
top_p=top_p, | |
max_new_tokens=max_new_tokens, | |
memory_top_k=memory_top_k, | |
prompt_templates=prompt_templates, | |
domain=mode, | |
show_raw=show_raw, | |
repetition_penalty=repetition_penalty, | |
num_beams=num_beams | |
): | |
history[-1][1] = partial_response | |
yield format_history(history) | |
yield format_history(history) | |
def generate_agent_audio(latest_text: str, voice_reference: str) -> str: | |
""" | |
Generate an audio response using the cloned voice. | |
If the provided voice_reference is a key in the voice bank, its stored file path is used. | |
""" | |
if latest_text: | |
if voice_reference in voice_bank: | |
audio_path = clone(latest_text, voice_bank[voice_reference]) | |
else: | |
audio_path = clone(latest_text, voice_reference) | |
return audio_path | |
return None | |
# NEW: Speech-to-Text Function using Whisper. | |
def transcribe_audio(audio_file: str) -> str: | |
""" | |
Transcribe the provided audio file to text using the Whisper model. | |
""" | |
try: | |
result = whisper_model.transcribe(audio_file) | |
transcription = result.get("text", "").strip() | |
logging.info(f"Transcription result: {transcription}") | |
return transcription | |
except Exception as e: | |
logging.error(f"Transcription error: {e}") | |
return "Transcription failed." | |
# --------------------------- | |
# Warm-Up Model Function | |
# --------------------------- | |
def warmup_model(): | |
try: | |
get_model_and_tokenizer() | |
logging.info("Model warm-up complete.") | |
except Exception as e: | |
logging.error(f"Model warm-up failed: {e}") | |
warmup_model() | |
# =========================== | |
# Custom Gradio Theme | |
# =========================== | |
theme = gr.themes.Soft( | |
primary_hue="pink", | |
secondary_hue="pink", | |
neutral_hue="purple", | |
font=['IBM Plex Sans', 'ui-sans-serif', 'system-ui', 'sans-serif'], | |
).set( | |
background_fill_primary='white', | |
shadow_drop='rgba(0,0,0,0.05) 0px 1px 2px 0px', | |
shadow_drop_lg='0 1px 3px 0 rgb(0 0 0 / 0.1), 0 1px 2px -1px rgb(0 0 0 / 0.1)', | |
shadow_spread='3px', | |
block_background_fill='*background_fill_primary', | |
block_border_width='1px', | |
block_border_width_dark='1px', | |
block_label_background_fill='*background_fill_primary', | |
block_label_background_fill_dark='*background_fill_secondary', | |
block_label_text_color='*neutral_500', | |
block_label_text_color_dark='*neutral_200', | |
block_label_margin='0', | |
block_label_padding='*spacing_sm *spacing_lg', | |
block_label_radius='calc(*radius_sm - 1px) 0 calc(*radius_sm - 1px) 0', | |
block_label_text_size='*text_sm', | |
block_label_text_weight='400', | |
block_title_background_fill='none', | |
block_title_background_fill_dark='none', | |
block_title_text_color='*neutral_500', | |
block_title_text_color_dark='*neutral_200', | |
block_title_padding='0', | |
block_title_radius='none', | |
block_title_text_weight='400', | |
panel_border_width='0', | |
panel_border_width_dark='0', | |
checkbox_background_color_selected='*color_accent', | |
checkbox_background_color_selected_dark='*color_accent', | |
checkbox_border_color='*neutral_300', | |
checkbox_border_color_dark='*neutral_700', | |
checkbox_border_color_focus='*color_accent', | |
checkbox_border_color_focus_dark='*color_accent', | |
checkbox_border_color_selected='*color_accent', | |
checkbox_border_color_selected_dark='*color_accent', | |
checkbox_border_width='*input_border_width', | |
checkbox_shadow='*input_shadow', | |
checkbox_label_background_fill_selected='*checkbox_label_background_fill', | |
checkbox_label_background_fill_selected_dark='*checkbox_label_background_fill', | |
checkbox_label_shadow='none', | |
checkbox_label_text_color_selected='*checkbox_label_text_color', | |
input_background_fill='*neutral_100', | |
input_border_color='*border_color_primary', | |
input_shadow='none', | |
input_shadow_dark='none', | |
input_shadow_focus='*input_shadow', | |
input_shadow_focus_dark='*input_shadow', | |
slider_color='*color_accent', | |
slider_color_dark='*color_accent', | |
button_primary_background_fill_hover='*primary_600', | |
button_primary_background_fill_hover_dark='*primary_700', | |
button_primary_shadow='none', | |
button_primary_shadow_hover='*button_primary_shadow', | |
button_primary_shadow_active='*button_primary_shadow', | |
button_primary_shadow_dark='none', | |
button_secondary_background_fill='*neutral_200', | |
button_secondary_background_fill_hover='*neutral_300', | |
button_secondary_background_fill_hover_dark='*neutral_700', | |
button_secondary_text_color='black', | |
button_secondary_shadow='*button_primary_shadow', | |
button_secondary_shadow_hover='*button_secondary_shadow', | |
button_secondary_shadow_active='*button_secondary_shadow', | |
button_secondary_shadow_dark='*button_primary_shadow' | |
) | |
# =========================== | |
# Combined Gradio Interface | |
# =========================== | |
with gr.Blocks(theme=theme, title="Combined Voice Clone & Agent Chat") as demo: | |
# Shared states for project settings, prompt configuration, and voice selection. | |
param_state = gr.State({ | |
"temperature": 0.5, | |
"top_p": 0.9, | |
"max_new_tokens": 300, | |
"memory_top_k": 2, | |
"show_raw_output": False, | |
"repetition_penalty": 1.0, | |
"num_beams": 1, | |
"use_cpu": False # Toggle for device override | |
}) | |
prompt_state = gr.State(initial_prompt_state) | |
selected_voice = gr.State(value="") # holds the currently selected voice | |
# A status display to show device info. | |
device_status = gr.Markdown(f"**Running on:** {device.upper()}") | |
with gr.Tabs(): | |
# ----- Tab 1: Voice Setup ----- | |
with gr.Tab("Voice Setup"): | |
gr.Markdown("<h2 style='text-align: center; padding-top: 10px;'>Voice Setup</h2>") | |
with gr.Column(variant="panel"): | |
gr.Markdown("<p style='text-align: center;'>Clone a voice and save it with a custom name. Test TTS using your cloned voices.</p>") | |
with gr.Row(): | |
text_input = gr.Textbox(label='Text to Clone', placeholder="Enter the text to speak...", elem_classes="full-width") | |
with gr.Row(): | |
audio_input = gr.Audio(label='Voice Reference Audio', type='filepath') | |
with gr.Row(): | |
clone_btn = gr.Button("Clone Voice") | |
with gr.Row(): | |
output_audio = gr.Audio(label='Cloned Voice Output', type='filepath') | |
clone_btn.click(fn=clone, inputs=[text_input, audio_input], outputs=output_audio) | |
with gr.Row(): | |
voice_name_input = gr.Textbox(label="Voice Name", placeholder="Enter a name for this voice clone") | |
with gr.Row(): | |
save_voice_btn = gr.Button("Save Voice") | |
save_voice_btn.click(fn=save_voice, inputs=[voice_name_input, output_audio], outputs=[]) | |
with gr.Row(): | |
refresh_voice_btn_setup = gr.Button("Refresh Voice List") | |
voice_dropdown_setup = gr.Dropdown(choices=get_voice_options(), label="Select Saved Voice", interactive=True) | |
set_voice_btn = gr.Button("Set Selected Voice") | |
refresh_voice_btn_setup.click(fn=refresh_voice_list, outputs=voice_dropdown_setup) | |
set_voice_btn.click(fn=lambda x: x, inputs=[voice_dropdown_setup], outputs=selected_voice) | |
gr.Markdown("<p style='text-align: center;'>(The selected voice will be used for TTS responses in Chat.)</p>") | |
gr.Markdown("<hr>") | |
gr.Markdown("<h3 style='text-align: center;'>TTS Test</h3>") | |
with gr.Row(): | |
tts_test_input = gr.Textbox(label="Test Text", placeholder="Enter text to test TTS...", elem_classes="full-width") | |
with gr.Row(): | |
tts_test_btn = gr.Button("Test TTS") | |
tts_test_output = gr.Audio(label="TTS Output", type="filepath") | |
tts_test_btn.click(fn=lambda txt, override, sel: generate_agent_audio(txt, override if override else sel), | |
inputs=[tts_test_input, audio_input, selected_voice], | |
outputs=tts_test_output) | |
# ----- Tab 2: Chat ----- | |
with gr.Tab("Chat"): | |
gr.Markdown(""" | |
<div style="text-align: center; padding: 10px;"> | |
<h1>DeepSeek Agent Swarm Chat</h1> | |
<p>Multi-round agent with prompt chaining. Ask me anything!</p> | |
</div> | |
""") | |
with gr.Column(): | |
with gr.Row(): | |
mode_selector = gr.Radio(choices=["coding", "math", "writing"], value="coding", label="Select Mode") | |
with gr.Row(): | |
chat_input = gr.Textbox(label="Your Message", placeholder="Type your message here...", lines=2, elem_id="msg_input") | |
with gr.Row(): | |
chat_audio_input = gr.Audio(label="Or record/upload your message", type="filepath") | |
transcribe_btn = gr.Button("Transcribe Audio") | |
transcribe_btn.click(fn=transcribe_audio, inputs=chat_audio_input, outputs=chat_input) | |
with gr.Row(): | |
send_btn = gr.Button("Send", variant="primary") | |
export_btn = gr.Button("Generate Chat Transcript") | |
chatbot = gr.Chatbot(height=450, label="Agent Swarm Output", type="messages") | |
with gr.Row(): | |
use_tts_checkbox = gr.Checkbox(label="Generate Audio Response using TTS", value=False) | |
chat_voice_dropdown = gr.Dropdown(choices=get_voice_options(), label="Select Voice for TTS", interactive=True) | |
refresh_voice_btn_chat = gr.Button("Refresh Voice List") | |
refresh_voice_btn_chat.click(fn=refresh_voice_list, outputs=chat_voice_dropdown) | |
agent_audio = gr.Audio(label="Agent Audio Response", type="filepath") | |
def chat_wrapper(message, history, param_state, prompt_state, mode): | |
final_history = [] | |
history.append(["", "**Generating response...**"]) | |
for h in gradio_interface(message, history, param_state, prompt_state, mode): | |
final_history = h | |
return final_history | |
send_btn.click(fn=chat_wrapper, | |
inputs=[chat_input, chatbot, param_state, prompt_state, mode_selector], | |
outputs=[chatbot]) | |
def conditional_tts(latest_text, use_tts, selected_voice_val): | |
if use_tts: | |
return generate_agent_audio(latest_text, selected_voice_val) | |
return None | |
def get_latest_text(chat_history): | |
for msg in reversed(chat_history): | |
if msg.get("role") == "assistant" and msg.get("content"): | |
return msg["content"] | |
return "" | |
latest_text_state = gr.State(value="") | |
gen_audio_btn = gr.Button("Generate Audio from Agent Response") | |
gen_audio_btn.click(fn=lambda chat: get_latest_text(chat), | |
inputs=[chatbot], | |
outputs=latest_text_state) | |
gen_audio_btn.click(fn=conditional_tts, | |
inputs=[latest_text_state, use_tts_checkbox, chat_voice_dropdown], | |
outputs=agent_audio) | |
def export_transcript(history): | |
transcript = "" | |
for item in history: | |
if isinstance(item, list) and len(item) == 2: | |
transcript += f"User: {item[0]}\nAssistant: {item[1]}\n\n" | |
return transcript | |
export_btn.click(fn=export_transcript, inputs=[chatbot], outputs=chatbot) | |
# ----- Tab 3: Project Settings ----- | |
with gr.Tab("Project Settings"): | |
gr.Markdown("<h2 style='text-align: center;'>Project Settings</h2>") | |
with gr.Tabs(): | |
with gr.Tab("Generation Parameters"): | |
gr.Markdown("<h3>Generation Parameters</h3>") | |
with gr.Row(): | |
temp_slider = gr.Slider(minimum=0.1, maximum=1.0, step=0.1, value=0.5, label="Temperature") | |
top_p_slider = gr.Slider(minimum=0.01, maximum=1.0, step=0.05, value=0.9, label="Top P") | |
with gr.Row(): | |
max_tokens_num = gr.Number(value=300, label="Max New Tokens", precision=0) | |
memory_topk_slider = gr.Slider(minimum=1, maximum=5, step=1, value=2, label="Memory Retrieval Top K") | |
with gr.Row(): | |
rep_penalty_slider = gr.Slider(minimum=0.5, maximum=2.0, step=0.1, value=1.0, label="Repetition Penalty") | |
num_beams_slider = gr.Slider(minimum=1, maximum=5, step=1, value=1, label="Number of Beams") | |
with gr.Row(): | |
show_raw_checkbox = gr.Checkbox(value=False, label="Show Raw Output") | |
use_cpu_checkbox = gr.Checkbox(value=False, label="Force Use CPU") | |
save_params_btn = gr.Button("Save Generation Parameters") | |
def save_params(t, p, m, k, rp, nb, s, use_cpu): | |
global device | |
if use_cpu: | |
device = "cpu" | |
else: | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
return { | |
"temperature": t, | |
"top_p": p, | |
"max_new_tokens": m, | |
"memory_top_k": k, | |
"repetition_penalty": rp, | |
"num_beams": nb, | |
"show_raw_output": s, | |
"use_cpu": use_cpu | |
} | |
save_params_btn.click( | |
save_params, | |
inputs=[temp_slider, top_p_slider, max_tokens_num, memory_topk_slider, rep_penalty_slider, num_beams_slider, show_raw_checkbox, use_cpu_checkbox], | |
outputs=param_state, | |
) | |
save_params_btn.click(fn=lambda params: f"**Running on:** {device.upper()}", inputs=param_state, outputs=device_status) | |
gr.Markdown("Note: Repetition penalty and number of beams affect generation diversity and quality.") | |
with gr.Tab("Prompt Config (Default Modes)"): | |
gr.Markdown("<h3>Prompt Configurations for Default Modes</h3>") | |
with gr.Tabs(): | |
with gr.Tab("Coding"): | |
prompt_brainstorm_box_code = gr.Textbox( | |
value=default_prompts["coding"]["brainstorm"], | |
label="Brainstorm Prompt (Coding)", | |
lines=8, | |
) | |
prompt_round2_box_code = gr.Textbox( | |
value=default_prompts["coding"]["round2"], | |
label="Round 2 Prompt (Coding)", | |
lines=8, | |
) | |
prompt_synthesis_box_code = gr.Textbox( | |
value=default_prompts["coding"]["synthesis"], | |
label="Synthesis Prompt (Coding)", | |
lines=8, | |
) | |
prompt_rationale_box_code = gr.Textbox( | |
value=default_prompts["coding"]["rationale"], | |
label="Rationale Prompt (Coding)", | |
lines=8, | |
) | |
with gr.Tab("Math"): | |
prompt_brainstorm_box_math = gr.Textbox( | |
value=default_prompts["math"]["brainstorm"], | |
label="Brainstorm Prompt (Math)", | |
lines=8, | |
) | |
prompt_round2_box_math = gr.Textbox( | |
value=default_prompts["math"]["round2"], | |
label="Round 2 Prompt (Math)", | |
lines=8, | |
) | |
prompt_synthesis_box_math = gr.Textbox( | |
value=default_prompts["math"]["synthesis"], | |
label="Synthesis Prompt (Math)", | |
lines=8, | |
) | |
prompt_rationale_box_math = gr.Textbox( | |
value=default_prompts["math"]["rationale"], | |
label="Rationale Prompt (Math)", | |
lines=8, | |
) | |
with gr.Tab("Writing"): | |
prompt_brainstorm_box_writing = gr.Textbox( | |
value=default_prompts["writing"]["brainstorm"], | |
label="Brainstorm Prompt (Writing)", | |
lines=8, | |
) | |
prompt_round2_box_writing = gr.Textbox( | |
value=default_prompts["writing"]["round2"], | |
label="Round 2 Prompt (Writing)", | |
lines=8, | |
) | |
prompt_synthesis_box_writing = gr.Textbox( | |
value=default_prompts["writing"]["synthesis"], | |
label="Synthesis Prompt (Writing)", | |
lines=8, | |
) | |
prompt_rationale_box_writing = gr.Textbox( | |
value=default_prompts["writing"]["rationale"], | |
label="Rationale Prompt (Writing)", | |
lines=8, | |
) | |
save_prompts_btn = gr.Button("Save Default Prompt Configurations") | |
def save_default_prompts(code_brain, code_r2, code_syn, code_rat, math_brain, math_r2, math_syn, math_rat, writing_brain, writing_r2, writing_syn, writing_rat): | |
return { | |
"default": { | |
"coding": { | |
"brainstorm": code_brain, | |
"round2": code_r2, | |
"synthesis": code_syn, | |
"rationale": code_rat, | |
}, | |
"math": { | |
"brainstorm": math_brain, | |
"round2": math_r2, | |
"synthesis": math_syn, | |
"rationale": math_rat, | |
}, | |
"writing": { | |
"brainstorm": writing_brain, | |
"round2": writing_r2, | |
"synthesis": writing_syn, | |
"rationale": writing_rat, | |
} | |
}, | |
"custom": prompt_state.value.get("custom", {}) | |
} | |
save_prompts_btn.click( | |
save_default_prompts, | |
inputs=[prompt_brainstorm_box_code, prompt_round2_box_code, prompt_synthesis_box_code, prompt_rationale_box_code, | |
prompt_brainstorm_box_math, prompt_round2_box_math, prompt_synthesis_box_math, prompt_rationale_box_math, | |
prompt_brainstorm_box_writing, prompt_round2_box_writing, prompt_synthesis_box_writing, prompt_rationale_box_writing], | |
outputs=prompt_state, | |
) | |
with gr.Tab("Custom Modes"): | |
gr.Markdown("<h3>Create / Edit Custom Modes</h3>") | |
gr.Markdown( | |
"Define a custom mode by providing a unique mode name, selecting the number of rounds (up to 10), " | |
"and editing the prompt for each round. In custom mode prompts, you can use the placeholders `{user_prompt}` " | |
"(for the first round) and `{prev_response}` (for subsequent rounds)." | |
) | |
with gr.Row(): | |
custom_mode_name = gr.Textbox(label="Custom Mode Name", placeholder="Enter a unique mode name") | |
custom_round_count = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Number of Rounds") | |
custom_round1 = gr.Textbox(label="Round 1 Prompt", lines=4, placeholder="e.g., Use {user_prompt} here") | |
custom_round2 = gr.Textbox(label="Round 2 Prompt", lines=4, placeholder="e.g., Use {user_prompt} and {prev_response}") | |
custom_round3 = gr.Textbox(label="Round 3 Prompt", lines=4, placeholder="e.g., Use {user_prompt} and {prev_response}") | |
custom_round4 = gr.Textbox(label="Round 4 Prompt", lines=4, placeholder="Optional") | |
custom_round5 = gr.Textbox(label="Round 5 Prompt", lines=4, placeholder="Optional") | |
custom_round6 = gr.Textbox(label="Round 6 Prompt", lines=4, placeholder="Optional") | |
custom_round7 = gr.Textbox(label="Round 7 Prompt", lines=4, placeholder="Optional") | |
custom_round8 = gr.Textbox(label="Round 8 Prompt", lines=4, placeholder="Optional") | |
custom_round9 = gr.Textbox(label="Round 9 Prompt", lines=4, placeholder="Optional") | |
custom_round10 = gr.Textbox(label="Round 10 Prompt", lines=4, placeholder="Optional") | |
def save_custom_mode(name, round_count, r1, r2, r3, r4, r5, r6, r7, r8, r9, r10, current_prompt_state): | |
if not name: | |
return gr.update(), current_prompt_state | |
rounds = [] | |
round_prompts = [r1, r2, r3, r4, r5, r6, r7, r8, r9, r10] | |
for i in range(round_count): | |
if round_prompts[i].strip(): | |
rounds.append(round_prompts[i]) | |
custom_modes = current_prompt_state.get("custom", {}) | |
custom_modes[name] = rounds | |
new_prompt_state = { | |
"default": current_prompt_state.get("default", {}), | |
"custom": custom_modes | |
} | |
return gr.update(value=""), new_prompt_state | |
save_custom_mode_btn = gr.Button("Save Custom Mode") | |
save_custom_mode_btn.click( | |
save_custom_mode, | |
inputs=[custom_mode_name, custom_round_count, custom_round1, custom_round2, custom_round3, custom_round4, | |
custom_round5, custom_round6, custom_round7, custom_round8, custom_round9, custom_round10, | |
prompt_state], | |
outputs=[custom_mode_name, prompt_state] | |
) | |
def update_mode_choices(current_prompt_state): | |
default_modes = list(current_prompt_state.get("default", {}).keys()) | |
custom_modes = list(current_prompt_state.get("custom", {}).keys()) | |
all_modes = default_modes + custom_modes | |
default_choice = default_modes[0] if default_modes else (custom_modes[0] if custom_modes else "") | |
return gr.update(choices=all_modes, value=default_choice) | |
refresh_mode_selector_btn = gr.Button("Refresh Mode List") | |
refresh_mode_selector_btn.click(fn=update_mode_choices, inputs=prompt_state, outputs=mode_selector) | |
gr.Markdown("<hr>") | |
gr.Markdown("<p style='text-align: center;'>These settings affect the entire project.</p>") | |
gr.Markdown("<hr><p style='text-align: center;'>Agent Chat using DeepSeek Agent Swarm</p>") | |
if __name__ == "__main__": | |
demo.launch(share=True) |