Spaces:
Build error
Build error
File size: 44,902 Bytes
2353098 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 |
import os
import gradio as gr
import torch
from TTS.api import TTS
import spaces # assumed custom module providing GPU decorators
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer
from threading import Thread
import logging
from typing import Tuple, List, Dict, Generator
import time
# NEW: Import whisper for speech-to-text.
import whisper
# ===========================
# Global Environment Settings
# ===========================
os.environ["COQUI_TOS_AGREED"] = "1"
# Global device override (will be updated from UI later)
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load the Whisper model (this may take a moment at startup)
whisper_model = whisper.load_model("base")
# Global dictionary for storing saved voice clones.
voice_bank: Dict[str, str] = {}
# ---------------------------
# Simple Response Cache
# ---------------------------
response_cache: Dict[str, str] = {}
# ===========================
# Voice Cloning Setup
# ===========================
tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2").to(device)
@spaces.GPU(enable_queue=True)
def clone(text, audio):
"""
Generate a voice-cloned audio file given text and a reference audio file.
Returns the path to the output audio file.
"""
try:
tts.tts_to_file(text=text, speaker_wav=audio, language="en", file_path="./output.wav")
return "./output.wav"
except Exception as e:
logging.error(f"TTS cloning failed: {e}")
return None
def save_voice(voice_name: str, voice_audio: str) -> None:
"""
Save a cloned voice under the given name.
"""
global voice_bank
if voice_name and voice_audio:
voice_bank[voice_name] = voice_audio
def get_voice_options() -> List[str]:
"""
Returns a list of saved voice names.
"""
return list(voice_bank.keys())
def refresh_voice_list() -> gr.update:
"""
Returns an update with the latest voice list.
"""
options = get_voice_options()
new_val = options[0] if options else ""
return gr.update(choices=options, value=new_val)
# ===========================
# Deep Agent Chat Setup
# ===========================
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
MODEL_ID = "unsloth/DeepSeek-R1-Distill-Qwen-7B-unsloth-bnb-4bit"
models: Dict[str, AutoModelForCausalLM] = {}
tokenizers: Dict[str, AutoTokenizer] = {}
bnb_config_4bit = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
def get_model_and_tokenizer() -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
# Warm-up: if the model isn’t loaded, load it now.
if "7B" not in models:
logging.info(f"Loading 7B model: {MODEL_ID} on demand")
try:
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
quantization_config=bnb_config_4bit,
torch_dtype=torch.bfloat16,
device_map='auto',
trust_remote_code=True,
)
model.eval()
models["7B"] = model
tokenizers["7B"] = tokenizer
logging.info("Loaded 7B model on demand.")
except Exception as e:
logging.error(f"Failed to load model and tokenizer: {e}")
raise e
return models["7B"], tokenizers["7B"]
# ---------------------------
# Prompt Templates
# ---------------------------
default_prompts = {
"coding": {
"brainstorm": (
"**Round 1: Brainstorm & Analysis**\n"
"Please analyze the following coding challenge or question. Consider the overall problem, "
"potential edge cases, and any assumptions you might need to make. Explain your reasoning as you think aloud.\n\n"
"**User Request:**\n{user_prompt}\n"
),
"round2": (
"**Round 2: Detailed Reasoning & Strategy**\n"
"Based on your initial analysis, please break down the problem into logical steps. "
"Outline a plan or strategy that could be used to solve the challenge, highlighting key algorithms, structures, or design considerations.\n\n"
"**Initial Analysis:**\n{brainstorm_response}\n\n"
"**User Request:**\n{user_prompt}\n"
),
"synthesis": (
"**Round 3: Synthesis & Implementation**\n"
"Taking into account the steps outlined previously, synthesize a coherent solution. "
"Provide a detailed explanation of how the code addresses the problem while encouraging best practices and clear logic.\n\n"
"**Detailed Strategy:**\n{round2_response}\n"
),
"rationale": (
"**Round 4: Reflection & Final Output**\n"
"Review your solution and provide a final, well-rounded response that summarizes your reasoning and the implementation strategy. "
"Explain any key decisions made during the process and how they contribute to an effective solution.\n\n"
"**Final Draft:**\n{final_response}\n"
)
},
"math": {
"brainstorm": (
"**Round 1: Problem Analysis & Exploration**\n"
"Carefully analyze the mathematical problem provided. Describe the underlying concepts and any assumptions you are making. "
"Detail your initial reasoning and potential methods to tackle the problem.\n\n"
"**Problem:**\n{user_prompt}\n"
),
"round2": (
"**Round 2: Detailed Reasoning & Methodology**\n"
"Based on your initial exploration, break down the problem into sequential steps or methodologies. "
"Explain the reasoning behind each step and how they connect to solve the problem.\n\n"
"**Initial Analysis:**\n{brainstorm_response}\n\n"
"**Problem:**\n{user_prompt}\n"
),
"synthesis": (
"**Round 3: Synthesis & Step-by-Step Solution**\n"
"Integrate your previous reasoning into a structured solution. Clearly explain each step of your calculation or proof, "
"ensuring that your logical progression is easy to follow.\n\n"
"**Detailed Methodology:**\n{round2_response}\n"
),
"rationale": (
"**Round 4: Reflection & Final Explanation**\n"
"Present your final solution along with a detailed explanation of the reasoning behind each step. "
"Discuss any assumptions and insights that helped you arrive at the final answer.\n\n"
"**Final Solution:**\n{final_response}\n"
)
},
"writing": {
"brainstorm": (
"**Round 1: Creative Exploration & Conceptualization**\n"
"Read the following writing prompt and explore its themes, tone, and potential narrative directions. "
"Outline your initial thoughts and reasoning behind various creative choices.\n\n"
"**Writing Prompt:**\n{user_prompt}\n"
),
"round2": (
"**Round 2: Detailed Outline & Narrative Structure**\n"
"Based on your brainstorming, create a detailed outline that organizes the narrative or essay. "
"Explain the reasoning behind your structure, the flow of ideas, and how you plan to incorporate creative elements.\n\n"
"**Initial Brainstorming:**\n{brainstorm_response}\n\n"
"**Writing Prompt:**\n{user_prompt}\n"
),
"synthesis": (
"**Round 3: Draft Synthesis & Refinement**\n"
"Integrate your outline and creative ideas into a coherent draft. Provide a well-rounded narrative that is both engaging and logically structured. "
"Explain your thought process as you refine the narrative.\n\n"
"**Outline & Strategy:**\n{round2_response}\n"
),
"rationale": (
"**Round 4: Reflection & Final Editing**\n"
"Review your draft and provide a final version that reflects thoughtful editing and creative reasoning. "
"Explain the choices made in refining the text, from structure to stylistic decisions.\n\n"
"**Final Draft:**\n{final_response}\n"
)
}
}
# The prompt state now contains both default and custom modes.
initial_prompt_state = {
"default": default_prompts,
"custom": {} # custom modes will be added here as {mode_name: [round_prompt1, round_prompt2, ...]}
}
def detect_domain(user_prompt: str) -> str:
prompt_lower = user_prompt.lower()
math_keywords = ["solve", "integral", "derivative", "equation", "proof", "calculate", "sum", "product"]
writing_keywords = ["write", "story", "essay", "novel", "poem", "article", "narrative", "creative"]
coding_keywords = ["code", "program", "debug", "compile", "algorithm", "function"]
if any(kw in prompt_lower for kw in math_keywords):
logging.info("Domain detected as: math")
return "math"
elif any(kw in prompt_lower for kw in writing_keywords):
logging.info("Domain detected as: writing")
return "writing"
elif any(kw in prompt_lower for kw in coding_keywords):
logging.info("Domain detected as: coding")
return "coding"
else:
logging.info("No specific domain detected; defaulting to coding")
return "coding"
class MemoryManager:
def __init__(self) -> None:
self.shared_memory: List[str] = []
def store(self, item: str) -> None:
self.shared_memory.append(item)
logging.info(f"[Memory Stored]: {item[:50]}...")
def retrieve(self, query: str, top_k: int = 3) -> List[str]:
query_lower = query.lower()
relevant = [item for item in self.shared_memory if query_lower in item.lower()]
if not relevant:
logging.info("[Memory Retrieval]: No relevant memories found.")
else:
logging.info(f"[Memory Retrieval]: Found {len(relevant)} relevant memories.")
return relevant[-top_k:]
global_memory_manager = MemoryManager()
def generate_response(model, tokenizer, prompt: str, max_tokens: int, temperature: float, top_p: float,
repetition_penalty: float = 1.0, num_beams: int = 1) -> str:
# Check cache first
cache_key = f"{prompt}-{max_tokens}-{temperature}-{top_p}-{repetition_penalty}-{num_beams}"
if cache_key in response_cache:
logging.info("Returning cached response.")
return response_cache[cache_key]
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
kwargs = dict(
input_ids=input_ids,
streamer=streamer,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True,
repetition_penalty=repetition_penalty,
num_beams=num_beams,
)
thread = Thread(target=model.generate, kwargs=kwargs)
with torch.no_grad():
thread.start()
response = ""
try:
for text in streamer:
response += text
except Exception as e:
logging.error(f"Error during generation: {e}")
raise e
thread.join()
# Cache the response
response_cache[cache_key] = response
return response
class MultiRoundAgent:
def __init__(self, model, tokenizer, prompt_templates, memory_manager: MemoryManager):
"""
prompt_templates can be a dict (for default modes) or a list (for custom modes)
"""
self.model = model
self.tokenizer = tokenizer
self.prompt_templates = prompt_templates
self.memory_manager = memory_manager
def run_pipeline(self, user_prompt: str, params: Dict, show_raw: bool = False) -> Generator[str, None, None]:
if isinstance(self.prompt_templates, dict):
# Default fixed 4-round pipeline
logging.info("--- Round 1 ---")
prompt_r1 = self.prompt_templates["brainstorm"].format(user_prompt=user_prompt)
r1 = generate_response(self.model, self.tokenizer, prompt_r1, params.get("max_new_tokens"), params.get("temp"),
params.get("top_p"), params.get("repetition_penalty"), params.get("num_beams"))
self.memory_manager.store(f"Round 1 Response: {r1}")
logging.info("--- Round 2 ---")
prompt_r2 = self.prompt_templates["round2"].format(brainstorm_response=r1, user_prompt=user_prompt)
r2 = generate_response(self.model, self.tokenizer, prompt_r2, params.get("max_new_tokens") + 100,
params.get("temp"), params.get("top_p"), params.get("repetition_penalty"), params.get("num_beams"))
self.memory_manager.store(f"Round 2 Response: {r2}")
logging.info("--- Round 3 ---")
prompt_r3 = self.prompt_templates["synthesis"].format(round2_response=r2)
input_ids_r3 = self.tokenizer.encode(prompt_r3, return_tensors="pt").to(self.model.device)
streamer_r3 = TextIteratorStreamer(self.tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
kwargs_r3 = dict(
input_ids=input_ids_r3,
streamer=streamer_r3,
max_new_tokens=params.get("max_new_tokens") // 2,
temperature=params.get("temp"),
top_p=params.get("top_p"),
repetition_penalty=params.get("repetition_penalty"),
num_beams=params.get("num_beams")
)
thread_r3 = Thread(target=self.model.generate, kwargs=kwargs_r3)
with torch.no_grad():
thread_r3.start()
r3 = ""
try:
for text in streamer_r3:
r3 += text
yield r3 # Progressive updates
except Exception as e:
logging.error(f"Error during Round 3 streaming: {e}")
raise e
thread_r3.join()
self.memory_manager.store(f"Final Synthesis Response: {r3}")
logging.info("--- Round 4 ---")
prompt_r4 = self.prompt_templates["rationale"].format(final_response=r3)
r4 = generate_response(self.model, self.tokenizer, prompt_r4, 300, params.get("temp"),
params.get("top_p"), params.get("repetition_penalty"), params.get("num_beams"))
self.memory_manager.store(f"Round 4 Response: {r4}")
final_output = (f"{r4}\n\n[Raw Outputs]\nRound 1:\n{r1}\n\nRound 2:\n{r2}\n\nRound 3:\n{r3}\n\nRound 4:\n{r4}\n") if show_raw else r4
yield final_output
elif isinstance(self.prompt_templates, list):
# Custom mode: iterate over rounds.
prev_response = ""
full_output = ""
total_rounds = len(self.prompt_templates)
for idx, round_template in enumerate(self.prompt_templates):
round_num = idx + 1
logging.info(f"--- Custom Mode: Round {round_num} of {total_rounds} ---")
if idx == 0:
prompt = round_template.format(user_prompt=user_prompt)
else:
prompt = round_template.format(user_prompt=user_prompt, prev_response=prev_response)
response = generate_response(self.model, self.tokenizer, prompt, params.get("max_new_tokens"),
params.get("temp"), params.get("top_p"), params.get("repetition_penalty"), params.get("num_beams"))
self.memory_manager.store(f"Custom Mode Round {round_num} Response: {response}")
full_output += f"\n--- Round {round_num} ---\n{response}"
prev_response = response
yield full_output
else:
yield "Invalid prompt template format."
@spaces.GPU(duration=180)
def swarm_agent_iterative(user_prompt: str, temp: float, top_p: float, max_new_tokens: int, memory_top_k: int,
prompt_templates, domain: str, show_raw: bool, repetition_penalty: float, num_beams: int) -> Generator[str, None, None]:
model, tokenizer = get_model_and_tokenizer()
agent = MultiRoundAgent(model, tokenizer, prompt_templates, global_memory_manager)
params = {
"temp": temp,
"top_p": top_p,
"max_new_tokens": max_new_tokens,
"repetition_penalty": repetition_penalty,
"num_beams": num_beams
}
return agent.run_pipeline(user_prompt, params, show_raw)
def handle_explanation_request(user_prompt: str, history: List) -> str:
retrieved = global_memory_manager.retrieve("Round 4 Response:", top_k=3)
explanation_prompt = "Below are previous final outputs and related context from our conversation:\n"
if retrieved:
for item in retrieved:
explanation_prompt += f"- {item}\n"
else:
explanation_prompt += "No stored final output found.\n"
explanation_prompt += "\nRecent related exchanges:\n"
for chat in history:
if ("explain" in chat[0].lower()) or (chat[1] and "explain" in chat[1].lower()):
explanation_prompt += f"User: {chat[0]}\nAssistant: {chat[1]}\n"
explanation_prompt += "\nBased on the above context, please provide a detailed explanation of the creative choices."
model, tokenizer = get_model_and_tokenizer()
explanation = generate_response(model, tokenizer, explanation_prompt, 300, 0.7, 0.9)
return explanation
def format_history(history: List) -> List[Dict[str, str]]:
messages = []
for item in history:
if isinstance(item, (list, tuple)) and len(item) == 2:
user_msg, assistant_msg = item
if user_msg == "__final_agent_response__":
continue
messages.append({"role": "user", "content": user_msg})
if assistant_msg:
messages.append({"role": "assistant", "content": assistant_msg})
elif isinstance(item, dict):
messages.append(item)
return messages
def gradio_interface(message: str, history: List, param_state: Dict, prompt_state: Dict, mode: str) -> Generator[List[Dict[str, str]], None, None]:
if "explain" in message.lower():
explanation = handle_explanation_request(message, history)
history = history + [[message, explanation]]
yield format_history(history)
return
try:
temp = float(param_state.get("temperature", 0.5))
top_p = float(param_state.get("top_p", 0.9))
max_new_tokens = int(param_state.get("max_new_tokens", 300))
repetition_penalty = float(param_state.get("repetition_penalty", 1.0))
num_beams = int(param_state.get("num_beams", 1))
memory_top_k = int(param_state.get("memory_top_k", 2))
show_raw = bool(param_state.get("show_raw_output", False))
except Exception as e:
logging.error(f"Parameter conversion error: {e}")
temp, top_p, max_new_tokens, repetition_penalty, num_beams, memory_top_k, show_raw = 0.5, 0.9, 300, 1.0, 1, 2, False
if mode in prompt_state.get("default", {}):
prompt_templates = prompt_state["default"][mode]
elif mode in prompt_state.get("custom", {}):
prompt_templates = prompt_state["custom"][mode]
else:
detected = detect_domain(message)
prompt_templates = prompt_state["default"].get(detected, prompt_state["default"]["coding"])
mode = detected
history = history + [[message, ""]]
# Show a loading status
yield format_history(history)
for partial_response in swarm_agent_iterative(
user_prompt=message,
temp=temp,
top_p=top_p,
max_new_tokens=max_new_tokens,
memory_top_k=memory_top_k,
prompt_templates=prompt_templates,
domain=mode,
show_raw=show_raw,
repetition_penalty=repetition_penalty,
num_beams=num_beams
):
history[-1][1] = partial_response
yield format_history(history)
yield format_history(history)
def generate_agent_audio(latest_text: str, voice_reference: str) -> str:
"""
Generate an audio response using the cloned voice.
If the provided voice_reference is a key in the voice bank, its stored file path is used.
"""
if latest_text:
if voice_reference in voice_bank:
audio_path = clone(latest_text, voice_bank[voice_reference])
else:
audio_path = clone(latest_text, voice_reference)
return audio_path
return None
# NEW: Speech-to-Text Function using Whisper.
def transcribe_audio(audio_file: str) -> str:
"""
Transcribe the provided audio file to text using the Whisper model.
"""
try:
result = whisper_model.transcribe(audio_file)
transcription = result.get("text", "").strip()
logging.info(f"Transcription result: {transcription}")
return transcription
except Exception as e:
logging.error(f"Transcription error: {e}")
return "Transcription failed."
# ---------------------------
# Warm-Up Model Function
# ---------------------------
def warmup_model():
try:
get_model_and_tokenizer()
logging.info("Model warm-up complete.")
except Exception as e:
logging.error(f"Model warm-up failed: {e}")
warmup_model()
# ===========================
# Custom Gradio Theme
# ===========================
theme = gr.themes.Soft(
primary_hue="pink",
secondary_hue="pink",
neutral_hue="purple",
font=['IBM Plex Sans', 'ui-sans-serif', 'system-ui', 'sans-serif'],
).set(
background_fill_primary='white',
shadow_drop='rgba(0,0,0,0.05) 0px 1px 2px 0px',
shadow_drop_lg='0 1px 3px 0 rgb(0 0 0 / 0.1), 0 1px 2px -1px rgb(0 0 0 / 0.1)',
shadow_spread='3px',
block_background_fill='*background_fill_primary',
block_border_width='1px',
block_border_width_dark='1px',
block_label_background_fill='*background_fill_primary',
block_label_background_fill_dark='*background_fill_secondary',
block_label_text_color='*neutral_500',
block_label_text_color_dark='*neutral_200',
block_label_margin='0',
block_label_padding='*spacing_sm *spacing_lg',
block_label_radius='calc(*radius_sm - 1px) 0 calc(*radius_sm - 1px) 0',
block_label_text_size='*text_sm',
block_label_text_weight='400',
block_title_background_fill='none',
block_title_background_fill_dark='none',
block_title_text_color='*neutral_500',
block_title_text_color_dark='*neutral_200',
block_title_padding='0',
block_title_radius='none',
block_title_text_weight='400',
panel_border_width='0',
panel_border_width_dark='0',
checkbox_background_color_selected='*color_accent',
checkbox_background_color_selected_dark='*color_accent',
checkbox_border_color='*neutral_300',
checkbox_border_color_dark='*neutral_700',
checkbox_border_color_focus='*color_accent',
checkbox_border_color_focus_dark='*color_accent',
checkbox_border_color_selected='*color_accent',
checkbox_border_color_selected_dark='*color_accent',
checkbox_border_width='*input_border_width',
checkbox_shadow='*input_shadow',
checkbox_label_background_fill_selected='*checkbox_label_background_fill',
checkbox_label_background_fill_selected_dark='*checkbox_label_background_fill',
checkbox_label_shadow='none',
checkbox_label_text_color_selected='*checkbox_label_text_color',
input_background_fill='*neutral_100',
input_border_color='*border_color_primary',
input_shadow='none',
input_shadow_dark='none',
input_shadow_focus='*input_shadow',
input_shadow_focus_dark='*input_shadow',
slider_color='*color_accent',
slider_color_dark='*color_accent',
button_primary_background_fill_hover='*primary_600',
button_primary_background_fill_hover_dark='*primary_700',
button_primary_shadow='none',
button_primary_shadow_hover='*button_primary_shadow',
button_primary_shadow_active='*button_primary_shadow',
button_primary_shadow_dark='none',
button_secondary_background_fill='*neutral_200',
button_secondary_background_fill_hover='*neutral_300',
button_secondary_background_fill_hover_dark='*neutral_700',
button_secondary_text_color='black',
button_secondary_shadow='*button_primary_shadow',
button_secondary_shadow_hover='*button_secondary_shadow',
button_secondary_shadow_active='*button_secondary_shadow',
button_secondary_shadow_dark='*button_primary_shadow'
)
# ===========================
# Combined Gradio Interface
# ===========================
with gr.Blocks(theme=theme, title="Combined Voice Clone & Agent Chat") as demo:
# Shared states for project settings, prompt configuration, and voice selection.
param_state = gr.State({
"temperature": 0.5,
"top_p": 0.9,
"max_new_tokens": 300,
"memory_top_k": 2,
"show_raw_output": False,
"repetition_penalty": 1.0,
"num_beams": 1,
"use_cpu": False # Toggle for device override
})
prompt_state = gr.State(initial_prompt_state)
selected_voice = gr.State(value="") # holds the currently selected voice
# A status display to show device info.
device_status = gr.Markdown(f"**Running on:** {device.upper()}")
with gr.Tabs():
# ----- Tab 1: Voice Setup -----
with gr.Tab("Voice Setup"):
gr.Markdown("<h2 style='text-align: center; padding-top: 10px;'>Voice Setup</h2>")
with gr.Column(variant="panel"):
gr.Markdown("<p style='text-align: center;'>Clone a voice and save it with a custom name. Test TTS using your cloned voices.</p>")
with gr.Row():
text_input = gr.Textbox(label='Text to Clone', placeholder="Enter the text to speak...", elem_classes="full-width")
with gr.Row():
audio_input = gr.Audio(label='Voice Reference Audio', type='filepath')
with gr.Row():
clone_btn = gr.Button("Clone Voice")
with gr.Row():
output_audio = gr.Audio(label='Cloned Voice Output', type='filepath')
clone_btn.click(fn=clone, inputs=[text_input, audio_input], outputs=output_audio)
with gr.Row():
voice_name_input = gr.Textbox(label="Voice Name", placeholder="Enter a name for this voice clone")
with gr.Row():
save_voice_btn = gr.Button("Save Voice")
save_voice_btn.click(fn=save_voice, inputs=[voice_name_input, output_audio], outputs=[])
with gr.Row():
refresh_voice_btn_setup = gr.Button("Refresh Voice List")
voice_dropdown_setup = gr.Dropdown(choices=get_voice_options(), label="Select Saved Voice", interactive=True)
set_voice_btn = gr.Button("Set Selected Voice")
refresh_voice_btn_setup.click(fn=refresh_voice_list, outputs=voice_dropdown_setup)
set_voice_btn.click(fn=lambda x: x, inputs=[voice_dropdown_setup], outputs=selected_voice)
gr.Markdown("<p style='text-align: center;'>(The selected voice will be used for TTS responses in Chat.)</p>")
gr.Markdown("<hr>")
gr.Markdown("<h3 style='text-align: center;'>TTS Test</h3>")
with gr.Row():
tts_test_input = gr.Textbox(label="Test Text", placeholder="Enter text to test TTS...", elem_classes="full-width")
with gr.Row():
tts_test_btn = gr.Button("Test TTS")
tts_test_output = gr.Audio(label="TTS Output", type="filepath")
tts_test_btn.click(fn=lambda txt, override, sel: generate_agent_audio(txt, override if override else sel),
inputs=[tts_test_input, audio_input, selected_voice],
outputs=tts_test_output)
# ----- Tab 2: Chat -----
with gr.Tab("Chat"):
gr.Markdown("""
<div style="text-align: center; padding: 10px;">
<h1>DeepSeek Agent Swarm Chat</h1>
<p>Multi-round agent with prompt chaining. Ask me anything!</p>
</div>
""")
with gr.Column():
with gr.Row():
mode_selector = gr.Radio(choices=["coding", "math", "writing"], value="coding", label="Select Mode")
with gr.Row():
chat_input = gr.Textbox(label="Your Message", placeholder="Type your message here...", lines=2, elem_id="msg_input")
with gr.Row():
chat_audio_input = gr.Audio(label="Or record/upload your message", type="filepath")
transcribe_btn = gr.Button("Transcribe Audio")
transcribe_btn.click(fn=transcribe_audio, inputs=chat_audio_input, outputs=chat_input)
with gr.Row():
send_btn = gr.Button("Send", variant="primary")
export_btn = gr.Button("Generate Chat Transcript")
chatbot = gr.Chatbot(height=450, label="Agent Swarm Output", type="messages")
with gr.Row():
use_tts_checkbox = gr.Checkbox(label="Generate Audio Response using TTS", value=False)
chat_voice_dropdown = gr.Dropdown(choices=get_voice_options(), label="Select Voice for TTS", interactive=True)
refresh_voice_btn_chat = gr.Button("Refresh Voice List")
refresh_voice_btn_chat.click(fn=refresh_voice_list, outputs=chat_voice_dropdown)
agent_audio = gr.Audio(label="Agent Audio Response", type="filepath")
def chat_wrapper(message, history, param_state, prompt_state, mode):
final_history = []
history.append(["", "**Generating response...**"])
for h in gradio_interface(message, history, param_state, prompt_state, mode):
final_history = h
return final_history
send_btn.click(fn=chat_wrapper,
inputs=[chat_input, chatbot, param_state, prompt_state, mode_selector],
outputs=[chatbot])
def conditional_tts(latest_text, use_tts, selected_voice_val):
if use_tts:
return generate_agent_audio(latest_text, selected_voice_val)
return None
def get_latest_text(chat_history):
for msg in reversed(chat_history):
if msg.get("role") == "assistant" and msg.get("content"):
return msg["content"]
return ""
latest_text_state = gr.State(value="")
gen_audio_btn = gr.Button("Generate Audio from Agent Response")
gen_audio_btn.click(fn=lambda chat: get_latest_text(chat),
inputs=[chatbot],
outputs=latest_text_state)
gen_audio_btn.click(fn=conditional_tts,
inputs=[latest_text_state, use_tts_checkbox, chat_voice_dropdown],
outputs=agent_audio)
def export_transcript(history):
transcript = ""
for item in history:
if isinstance(item, list) and len(item) == 2:
transcript += f"User: {item[0]}\nAssistant: {item[1]}\n\n"
return transcript
export_btn.click(fn=export_transcript, inputs=[chatbot], outputs=chatbot)
# ----- Tab 3: Project Settings -----
with gr.Tab("Project Settings"):
gr.Markdown("<h2 style='text-align: center;'>Project Settings</h2>")
with gr.Tabs():
with gr.Tab("Generation Parameters"):
gr.Markdown("<h3>Generation Parameters</h3>")
with gr.Row():
temp_slider = gr.Slider(minimum=0.1, maximum=1.0, step=0.1, value=0.5, label="Temperature")
top_p_slider = gr.Slider(minimum=0.01, maximum=1.0, step=0.05, value=0.9, label="Top P")
with gr.Row():
max_tokens_num = gr.Number(value=300, label="Max New Tokens", precision=0)
memory_topk_slider = gr.Slider(minimum=1, maximum=5, step=1, value=2, label="Memory Retrieval Top K")
with gr.Row():
rep_penalty_slider = gr.Slider(minimum=0.5, maximum=2.0, step=0.1, value=1.0, label="Repetition Penalty")
num_beams_slider = gr.Slider(minimum=1, maximum=5, step=1, value=1, label="Number of Beams")
with gr.Row():
show_raw_checkbox = gr.Checkbox(value=False, label="Show Raw Output")
use_cpu_checkbox = gr.Checkbox(value=False, label="Force Use CPU")
save_params_btn = gr.Button("Save Generation Parameters")
def save_params(t, p, m, k, rp, nb, s, use_cpu):
global device
if use_cpu:
device = "cpu"
else:
device = "cuda" if torch.cuda.is_available() else "cpu"
return {
"temperature": t,
"top_p": p,
"max_new_tokens": m,
"memory_top_k": k,
"repetition_penalty": rp,
"num_beams": nb,
"show_raw_output": s,
"use_cpu": use_cpu
}
save_params_btn.click(
save_params,
inputs=[temp_slider, top_p_slider, max_tokens_num, memory_topk_slider, rep_penalty_slider, num_beams_slider, show_raw_checkbox, use_cpu_checkbox],
outputs=param_state,
)
save_params_btn.click(fn=lambda params: f"**Running on:** {device.upper()}", inputs=param_state, outputs=device_status)
gr.Markdown("Note: Repetition penalty and number of beams affect generation diversity and quality.")
with gr.Tab("Prompt Config (Default Modes)"):
gr.Markdown("<h3>Prompt Configurations for Default Modes</h3>")
with gr.Tabs():
with gr.Tab("Coding"):
prompt_brainstorm_box_code = gr.Textbox(
value=default_prompts["coding"]["brainstorm"],
label="Brainstorm Prompt (Coding)",
lines=8,
)
prompt_round2_box_code = gr.Textbox(
value=default_prompts["coding"]["round2"],
label="Round 2 Prompt (Coding)",
lines=8,
)
prompt_synthesis_box_code = gr.Textbox(
value=default_prompts["coding"]["synthesis"],
label="Synthesis Prompt (Coding)",
lines=8,
)
prompt_rationale_box_code = gr.Textbox(
value=default_prompts["coding"]["rationale"],
label="Rationale Prompt (Coding)",
lines=8,
)
with gr.Tab("Math"):
prompt_brainstorm_box_math = gr.Textbox(
value=default_prompts["math"]["brainstorm"],
label="Brainstorm Prompt (Math)",
lines=8,
)
prompt_round2_box_math = gr.Textbox(
value=default_prompts["math"]["round2"],
label="Round 2 Prompt (Math)",
lines=8,
)
prompt_synthesis_box_math = gr.Textbox(
value=default_prompts["math"]["synthesis"],
label="Synthesis Prompt (Math)",
lines=8,
)
prompt_rationale_box_math = gr.Textbox(
value=default_prompts["math"]["rationale"],
label="Rationale Prompt (Math)",
lines=8,
)
with gr.Tab("Writing"):
prompt_brainstorm_box_writing = gr.Textbox(
value=default_prompts["writing"]["brainstorm"],
label="Brainstorm Prompt (Writing)",
lines=8,
)
prompt_round2_box_writing = gr.Textbox(
value=default_prompts["writing"]["round2"],
label="Round 2 Prompt (Writing)",
lines=8,
)
prompt_synthesis_box_writing = gr.Textbox(
value=default_prompts["writing"]["synthesis"],
label="Synthesis Prompt (Writing)",
lines=8,
)
prompt_rationale_box_writing = gr.Textbox(
value=default_prompts["writing"]["rationale"],
label="Rationale Prompt (Writing)",
lines=8,
)
save_prompts_btn = gr.Button("Save Default Prompt Configurations")
def save_default_prompts(code_brain, code_r2, code_syn, code_rat, math_brain, math_r2, math_syn, math_rat, writing_brain, writing_r2, writing_syn, writing_rat):
return {
"default": {
"coding": {
"brainstorm": code_brain,
"round2": code_r2,
"synthesis": code_syn,
"rationale": code_rat,
},
"math": {
"brainstorm": math_brain,
"round2": math_r2,
"synthesis": math_syn,
"rationale": math_rat,
},
"writing": {
"brainstorm": writing_brain,
"round2": writing_r2,
"synthesis": writing_syn,
"rationale": writing_rat,
}
},
"custom": prompt_state.value.get("custom", {})
}
save_prompts_btn.click(
save_default_prompts,
inputs=[prompt_brainstorm_box_code, prompt_round2_box_code, prompt_synthesis_box_code, prompt_rationale_box_code,
prompt_brainstorm_box_math, prompt_round2_box_math, prompt_synthesis_box_math, prompt_rationale_box_math,
prompt_brainstorm_box_writing, prompt_round2_box_writing, prompt_synthesis_box_writing, prompt_rationale_box_writing],
outputs=prompt_state,
)
with gr.Tab("Custom Modes"):
gr.Markdown("<h3>Create / Edit Custom Modes</h3>")
gr.Markdown(
"Define a custom mode by providing a unique mode name, selecting the number of rounds (up to 10), "
"and editing the prompt for each round. In custom mode prompts, you can use the placeholders `{user_prompt}` "
"(for the first round) and `{prev_response}` (for subsequent rounds)."
)
with gr.Row():
custom_mode_name = gr.Textbox(label="Custom Mode Name", placeholder="Enter a unique mode name")
custom_round_count = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Number of Rounds")
custom_round1 = gr.Textbox(label="Round 1 Prompt", lines=4, placeholder="e.g., Use {user_prompt} here")
custom_round2 = gr.Textbox(label="Round 2 Prompt", lines=4, placeholder="e.g., Use {user_prompt} and {prev_response}")
custom_round3 = gr.Textbox(label="Round 3 Prompt", lines=4, placeholder="e.g., Use {user_prompt} and {prev_response}")
custom_round4 = gr.Textbox(label="Round 4 Prompt", lines=4, placeholder="Optional")
custom_round5 = gr.Textbox(label="Round 5 Prompt", lines=4, placeholder="Optional")
custom_round6 = gr.Textbox(label="Round 6 Prompt", lines=4, placeholder="Optional")
custom_round7 = gr.Textbox(label="Round 7 Prompt", lines=4, placeholder="Optional")
custom_round8 = gr.Textbox(label="Round 8 Prompt", lines=4, placeholder="Optional")
custom_round9 = gr.Textbox(label="Round 9 Prompt", lines=4, placeholder="Optional")
custom_round10 = gr.Textbox(label="Round 10 Prompt", lines=4, placeholder="Optional")
def save_custom_mode(name, round_count, r1, r2, r3, r4, r5, r6, r7, r8, r9, r10, current_prompt_state):
if not name:
return gr.update(), current_prompt_state
rounds = []
round_prompts = [r1, r2, r3, r4, r5, r6, r7, r8, r9, r10]
for i in range(round_count):
if round_prompts[i].strip():
rounds.append(round_prompts[i])
custom_modes = current_prompt_state.get("custom", {})
custom_modes[name] = rounds
new_prompt_state = {
"default": current_prompt_state.get("default", {}),
"custom": custom_modes
}
return gr.update(value=""), new_prompt_state
save_custom_mode_btn = gr.Button("Save Custom Mode")
save_custom_mode_btn.click(
save_custom_mode,
inputs=[custom_mode_name, custom_round_count, custom_round1, custom_round2, custom_round3, custom_round4,
custom_round5, custom_round6, custom_round7, custom_round8, custom_round9, custom_round10,
prompt_state],
outputs=[custom_mode_name, prompt_state]
)
def update_mode_choices(current_prompt_state):
default_modes = list(current_prompt_state.get("default", {}).keys())
custom_modes = list(current_prompt_state.get("custom", {}).keys())
all_modes = default_modes + custom_modes
default_choice = default_modes[0] if default_modes else (custom_modes[0] if custom_modes else "")
return gr.update(choices=all_modes, value=default_choice)
refresh_mode_selector_btn = gr.Button("Refresh Mode List")
refresh_mode_selector_btn.click(fn=update_mode_choices, inputs=prompt_state, outputs=mode_selector)
gr.Markdown("<hr>")
gr.Markdown("<p style='text-align: center;'>These settings affect the entire project.</p>")
gr.Markdown("<hr><p style='text-align: center;'>Agent Chat using DeepSeek Agent Swarm</p>")
if __name__ == "__main__":
demo.launch(share=True) |