File size: 44,902 Bytes
2353098
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
import os
import gradio as gr
import torch
from TTS.api import TTS
import spaces  # assumed custom module providing GPU decorators
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer
from threading import Thread
import logging
from typing import Tuple, List, Dict, Generator
import time

# NEW: Import whisper for speech-to-text.
import whisper

# ===========================
# Global Environment Settings
# ===========================
os.environ["COQUI_TOS_AGREED"] = "1"
# Global device override (will be updated from UI later)
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load the Whisper model (this may take a moment at startup)
whisper_model = whisper.load_model("base")

# Global dictionary for storing saved voice clones.
voice_bank: Dict[str, str] = {}

# ---------------------------
# Simple Response Cache
# ---------------------------
response_cache: Dict[str, str] = {}

# ===========================
# Voice Cloning Setup
# ===========================
tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2").to(device)

@spaces.GPU(enable_queue=True)
def clone(text, audio):
    """
    Generate a voice-cloned audio file given text and a reference audio file.
    Returns the path to the output audio file.
    """
    try:
        tts.tts_to_file(text=text, speaker_wav=audio, language="en", file_path="./output.wav")
        return "./output.wav"
    except Exception as e:
        logging.error(f"TTS cloning failed: {e}")
        return None

def save_voice(voice_name: str, voice_audio: str) -> None:
    """
    Save a cloned voice under the given name.
    """
    global voice_bank
    if voice_name and voice_audio:
        voice_bank[voice_name] = voice_audio

def get_voice_options() -> List[str]:
    """
    Returns a list of saved voice names.
    """
    return list(voice_bank.keys())

def refresh_voice_list() -> gr.update:
    """
    Returns an update with the latest voice list.
    """
    options = get_voice_options()
    new_val = options[0] if options else ""
    return gr.update(choices=options, value=new_val)

# ===========================
# Deep Agent Chat Setup
# ===========================
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

MODEL_ID = "unsloth/DeepSeek-R1-Distill-Qwen-7B-unsloth-bnb-4bit"
models: Dict[str, AutoModelForCausalLM] = {}
tokenizers: Dict[str, AutoTokenizer] = {}

bnb_config_4bit = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)

def get_model_and_tokenizer() -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
    # Warm-up: if the model isn’t loaded, load it now.
    if "7B" not in models:
        logging.info(f"Loading 7B model: {MODEL_ID} on demand")
        try:
            tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
            model = AutoModelForCausalLM.from_pretrained(
                MODEL_ID,
                quantization_config=bnb_config_4bit,
                torch_dtype=torch.bfloat16,
                device_map='auto',
                trust_remote_code=True,
            )
            model.eval()
            models["7B"] = model
            tokenizers["7B"] = tokenizer
            logging.info("Loaded 7B model on demand.")
        except Exception as e:
            logging.error(f"Failed to load model and tokenizer: {e}")
            raise e
    return models["7B"], tokenizers["7B"]

# ---------------------------
# Prompt Templates
# ---------------------------
default_prompts = {
    "coding": {
        "brainstorm": (
            "**Round 1: Brainstorm & Analysis**\n"
            "Please analyze the following coding challenge or question. Consider the overall problem, "
            "potential edge cases, and any assumptions you might need to make. Explain your reasoning as you think aloud.\n\n"
            "**User Request:**\n{user_prompt}\n"
        ),
        "round2": (
            "**Round 2: Detailed Reasoning & Strategy**\n"
            "Based on your initial analysis, please break down the problem into logical steps. "
            "Outline a plan or strategy that could be used to solve the challenge, highlighting key algorithms, structures, or design considerations.\n\n"
            "**Initial Analysis:**\n{brainstorm_response}\n\n"
            "**User Request:**\n{user_prompt}\n"
        ),
        "synthesis": (
            "**Round 3: Synthesis & Implementation**\n"
            "Taking into account the steps outlined previously, synthesize a coherent solution. "
            "Provide a detailed explanation of how the code addresses the problem while encouraging best practices and clear logic.\n\n"
            "**Detailed Strategy:**\n{round2_response}\n"
        ),
        "rationale": (
            "**Round 4: Reflection & Final Output**\n"
            "Review your solution and provide a final, well-rounded response that summarizes your reasoning and the implementation strategy. "
            "Explain any key decisions made during the process and how they contribute to an effective solution.\n\n"
            "**Final Draft:**\n{final_response}\n"
        )
    },
    "math": {
        "brainstorm": (
            "**Round 1: Problem Analysis & Exploration**\n"
            "Carefully analyze the mathematical problem provided. Describe the underlying concepts and any assumptions you are making. "
            "Detail your initial reasoning and potential methods to tackle the problem.\n\n"
            "**Problem:**\n{user_prompt}\n"
        ),
        "round2": (
            "**Round 2: Detailed Reasoning & Methodology**\n"
            "Based on your initial exploration, break down the problem into sequential steps or methodologies. "
            "Explain the reasoning behind each step and how they connect to solve the problem.\n\n"
            "**Initial Analysis:**\n{brainstorm_response}\n\n"
            "**Problem:**\n{user_prompt}\n"
        ),
        "synthesis": (
            "**Round 3: Synthesis & Step-by-Step Solution**\n"
            "Integrate your previous reasoning into a structured solution. Clearly explain each step of your calculation or proof, "
            "ensuring that your logical progression is easy to follow.\n\n"
            "**Detailed Methodology:**\n{round2_response}\n"
        ),
        "rationale": (
            "**Round 4: Reflection & Final Explanation**\n"
            "Present your final solution along with a detailed explanation of the reasoning behind each step. "
            "Discuss any assumptions and insights that helped you arrive at the final answer.\n\n"
            "**Final Solution:**\n{final_response}\n"
        )
    },
    "writing": {
        "brainstorm": (
            "**Round 1: Creative Exploration & Conceptualization**\n"
            "Read the following writing prompt and explore its themes, tone, and potential narrative directions. "
            "Outline your initial thoughts and reasoning behind various creative choices.\n\n"
            "**Writing Prompt:**\n{user_prompt}\n"
        ),
        "round2": (
            "**Round 2: Detailed Outline & Narrative Structure**\n"
            "Based on your brainstorming, create a detailed outline that organizes the narrative or essay. "
            "Explain the reasoning behind your structure, the flow of ideas, and how you plan to incorporate creative elements.\n\n"
            "**Initial Brainstorming:**\n{brainstorm_response}\n\n"
            "**Writing Prompt:**\n{user_prompt}\n"
        ),
        "synthesis": (
            "**Round 3: Draft Synthesis & Refinement**\n"
            "Integrate your outline and creative ideas into a coherent draft. Provide a well-rounded narrative that is both engaging and logically structured. "
            "Explain your thought process as you refine the narrative.\n\n"
            "**Outline & Strategy:**\n{round2_response}\n"
        ),
        "rationale": (
            "**Round 4: Reflection & Final Editing**\n"
            "Review your draft and provide a final version that reflects thoughtful editing and creative reasoning. "
            "Explain the choices made in refining the text, from structure to stylistic decisions.\n\n"
            "**Final Draft:**\n{final_response}\n"
        )
    }
}

# The prompt state now contains both default and custom modes.
initial_prompt_state = {
    "default": default_prompts,
    "custom": {}  # custom modes will be added here as {mode_name: [round_prompt1, round_prompt2, ...]}
}

def detect_domain(user_prompt: str) -> str:
    prompt_lower = user_prompt.lower()
    math_keywords = ["solve", "integral", "derivative", "equation", "proof", "calculate", "sum", "product"]
    writing_keywords = ["write", "story", "essay", "novel", "poem", "article", "narrative", "creative"]
    coding_keywords = ["code", "program", "debug", "compile", "algorithm", "function"]

    if any(kw in prompt_lower for kw in math_keywords):
        logging.info("Domain detected as: math")
        return "math"
    elif any(kw in prompt_lower for kw in writing_keywords):
        logging.info("Domain detected as: writing")
        return "writing"
    elif any(kw in prompt_lower for kw in coding_keywords):
        logging.info("Domain detected as: coding")
        return "coding"
    else:
        logging.info("No specific domain detected; defaulting to coding")
        return "coding"

class MemoryManager:
    def __init__(self) -> None:
        self.shared_memory: List[str] = []

    def store(self, item: str) -> None:
        self.shared_memory.append(item)
        logging.info(f"[Memory Stored]: {item[:50]}...")

    def retrieve(self, query: str, top_k: int = 3) -> List[str]:
        query_lower = query.lower()
        relevant = [item for item in self.shared_memory if query_lower in item.lower()]
        if not relevant:
            logging.info("[Memory Retrieval]: No relevant memories found.")
        else:
            logging.info(f"[Memory Retrieval]: Found {len(relevant)} relevant memories.")
        return relevant[-top_k:]

global_memory_manager = MemoryManager()

def generate_response(model, tokenizer, prompt: str, max_tokens: int, temperature: float, top_p: float,
                      repetition_penalty: float = 1.0, num_beams: int = 1) -> str:
    # Check cache first
    cache_key = f"{prompt}-{max_tokens}-{temperature}-{top_p}-{repetition_penalty}-{num_beams}"
    if cache_key in response_cache:
        logging.info("Returning cached response.")
        return response_cache[cache_key]

    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
    streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
    kwargs = dict(
        input_ids=input_ids,
        streamer=streamer,
        max_new_tokens=max_tokens,
        temperature=temperature,
        top_p=top_p,
        do_sample=True,
        repetition_penalty=repetition_penalty,
        num_beams=num_beams,
    )
    thread = Thread(target=model.generate, kwargs=kwargs)
    with torch.no_grad():
        thread.start()
    response = ""
    try:
        for text in streamer:
            response += text
    except Exception as e:
        logging.error(f"Error during generation: {e}")
        raise e
    thread.join()
    # Cache the response
    response_cache[cache_key] = response
    return response

class MultiRoundAgent:
    def __init__(self, model, tokenizer, prompt_templates, memory_manager: MemoryManager):
        """
        prompt_templates can be a dict (for default modes) or a list (for custom modes)
        """
        self.model = model
        self.tokenizer = tokenizer
        self.prompt_templates = prompt_templates
        self.memory_manager = memory_manager

    def run_pipeline(self, user_prompt: str, params: Dict, show_raw: bool = False) -> Generator[str, None, None]:
        if isinstance(self.prompt_templates, dict):
            # Default fixed 4-round pipeline
            logging.info("--- Round 1 ---")
            prompt_r1 = self.prompt_templates["brainstorm"].format(user_prompt=user_prompt)
            r1 = generate_response(self.model, self.tokenizer, prompt_r1, params.get("max_new_tokens"), params.get("temp"),
                                   params.get("top_p"), params.get("repetition_penalty"), params.get("num_beams"))
            self.memory_manager.store(f"Round 1 Response: {r1}")

            logging.info("--- Round 2 ---")
            prompt_r2 = self.prompt_templates["round2"].format(brainstorm_response=r1, user_prompt=user_prompt)
            r2 = generate_response(self.model, self.tokenizer, prompt_r2, params.get("max_new_tokens") + 100,
                                   params.get("temp"), params.get("top_p"), params.get("repetition_penalty"), params.get("num_beams"))
            self.memory_manager.store(f"Round 2 Response: {r2}")

            logging.info("--- Round 3 ---")
            prompt_r3 = self.prompt_templates["synthesis"].format(round2_response=r2)
            input_ids_r3 = self.tokenizer.encode(prompt_r3, return_tensors="pt").to(self.model.device)
            streamer_r3 = TextIteratorStreamer(self.tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
            kwargs_r3 = dict(
                input_ids=input_ids_r3,
                streamer=streamer_r3,
                max_new_tokens=params.get("max_new_tokens") // 2,
                temperature=params.get("temp"),
                top_p=params.get("top_p"),
                repetition_penalty=params.get("repetition_penalty"),
                num_beams=params.get("num_beams")
            )
            thread_r3 = Thread(target=self.model.generate, kwargs=kwargs_r3)
            with torch.no_grad():
                thread_r3.start()
            r3 = ""
            try:
                for text in streamer_r3:
                    r3 += text
                    yield r3  # Progressive updates
            except Exception as e:
                logging.error(f"Error during Round 3 streaming: {e}")
                raise e
            thread_r3.join()
            self.memory_manager.store(f"Final Synthesis Response: {r3}")

            logging.info("--- Round 4 ---")
            prompt_r4 = self.prompt_templates["rationale"].format(final_response=r3)
            r4 = generate_response(self.model, self.tokenizer, prompt_r4, 300, params.get("temp"),
                                   params.get("top_p"), params.get("repetition_penalty"), params.get("num_beams"))
            self.memory_manager.store(f"Round 4 Response: {r4}")

            final_output = (f"{r4}\n\n[Raw Outputs]\nRound 1:\n{r1}\n\nRound 2:\n{r2}\n\nRound 3:\n{r3}\n\nRound 4:\n{r4}\n") if show_raw else r4
            yield final_output

        elif isinstance(self.prompt_templates, list):
            # Custom mode: iterate over rounds.
            prev_response = ""
            full_output = ""
            total_rounds = len(self.prompt_templates)
            for idx, round_template in enumerate(self.prompt_templates):
                round_num = idx + 1
                logging.info(f"--- Custom Mode: Round {round_num} of {total_rounds} ---")
                if idx == 0:
                    prompt = round_template.format(user_prompt=user_prompt)
                else:
                    prompt = round_template.format(user_prompt=user_prompt, prev_response=prev_response)
                response = generate_response(self.model, self.tokenizer, prompt, params.get("max_new_tokens"),
                                             params.get("temp"), params.get("top_p"), params.get("repetition_penalty"), params.get("num_beams"))
                self.memory_manager.store(f"Custom Mode Round {round_num} Response: {response}")
                full_output += f"\n--- Round {round_num} ---\n{response}"
                prev_response = response
                yield full_output
        else:
            yield "Invalid prompt template format."

@spaces.GPU(duration=180)
def swarm_agent_iterative(user_prompt: str, temp: float, top_p: float, max_new_tokens: int, memory_top_k: int,
                          prompt_templates, domain: str, show_raw: bool, repetition_penalty: float, num_beams: int) -> Generator[str, None, None]:
    model, tokenizer = get_model_and_tokenizer()
    agent = MultiRoundAgent(model, tokenizer, prompt_templates, global_memory_manager)
    params = {
        "temp": temp,
        "top_p": top_p,
        "max_new_tokens": max_new_tokens,
        "repetition_penalty": repetition_penalty,
        "num_beams": num_beams
    }
    return agent.run_pipeline(user_prompt, params, show_raw)

def handle_explanation_request(user_prompt: str, history: List) -> str:
    retrieved = global_memory_manager.retrieve("Round 4 Response:", top_k=3)
    explanation_prompt = "Below are previous final outputs and related context from our conversation:\n"
    if retrieved:
        for item in retrieved:
            explanation_prompt += f"- {item}\n"
    else:
        explanation_prompt += "No stored final output found.\n"
    explanation_prompt += "\nRecent related exchanges:\n"
    for chat in history:
        if ("explain" in chat[0].lower()) or (chat[1] and "explain" in chat[1].lower()):
            explanation_prompt += f"User: {chat[0]}\nAssistant: {chat[1]}\n"
    explanation_prompt += "\nBased on the above context, please provide a detailed explanation of the creative choices."
    model, tokenizer = get_model_and_tokenizer()
    explanation = generate_response(model, tokenizer, explanation_prompt, 300, 0.7, 0.9)
    return explanation

def format_history(history: List) -> List[Dict[str, str]]:
    messages = []
    for item in history:
        if isinstance(item, (list, tuple)) and len(item) == 2:
            user_msg, assistant_msg = item
            if user_msg == "__final_agent_response__":
                continue
            messages.append({"role": "user", "content": user_msg})
            if assistant_msg:
                messages.append({"role": "assistant", "content": assistant_msg})
        elif isinstance(item, dict):
            messages.append(item)
    return messages

def gradio_interface(message: str, history: List, param_state: Dict, prompt_state: Dict, mode: str) -> Generator[List[Dict[str, str]], None, None]:
    if "explain" in message.lower():
        explanation = handle_explanation_request(message, history)
        history = history + [[message, explanation]]
        yield format_history(history)
        return

    try:
        temp = float(param_state.get("temperature", 0.5))
        top_p = float(param_state.get("top_p", 0.9))
        max_new_tokens = int(param_state.get("max_new_tokens", 300))
        repetition_penalty = float(param_state.get("repetition_penalty", 1.0))
        num_beams = int(param_state.get("num_beams", 1))
        memory_top_k = int(param_state.get("memory_top_k", 2))
        show_raw = bool(param_state.get("show_raw_output", False))
    except Exception as e:
        logging.error(f"Parameter conversion error: {e}")
        temp, top_p, max_new_tokens, repetition_penalty, num_beams, memory_top_k, show_raw = 0.5, 0.9, 300, 1.0, 1, 2, False

    if mode in prompt_state.get("default", {}):
        prompt_templates = prompt_state["default"][mode]
    elif mode in prompt_state.get("custom", {}):
        prompt_templates = prompt_state["custom"][mode]
    else:
        detected = detect_domain(message)
        prompt_templates = prompt_state["default"].get(detected, prompt_state["default"]["coding"])
        mode = detected

    history = history + [[message, ""]]
    # Show a loading status
    yield format_history(history)
    for partial_response in swarm_agent_iterative(
        user_prompt=message,
        temp=temp,
        top_p=top_p,
        max_new_tokens=max_new_tokens,
        memory_top_k=memory_top_k,
        prompt_templates=prompt_templates,
        domain=mode,
        show_raw=show_raw,
        repetition_penalty=repetition_penalty,
        num_beams=num_beams
    ):
        history[-1][1] = partial_response
        yield format_history(history)
    yield format_history(history)

def generate_agent_audio(latest_text: str, voice_reference: str) -> str:
    """
    Generate an audio response using the cloned voice.
    If the provided voice_reference is a key in the voice bank, its stored file path is used.
    """
    if latest_text:
        if voice_reference in voice_bank:
            audio_path = clone(latest_text, voice_bank[voice_reference])
        else:
            audio_path = clone(latest_text, voice_reference)
        return audio_path
    return None

# NEW: Speech-to-Text Function using Whisper.
def transcribe_audio(audio_file: str) -> str:
    """
    Transcribe the provided audio file to text using the Whisper model.
    """
    try:
        result = whisper_model.transcribe(audio_file)
        transcription = result.get("text", "").strip()
        logging.info(f"Transcription result: {transcription}")
        return transcription
    except Exception as e:
        logging.error(f"Transcription error: {e}")
        return "Transcription failed."

# ---------------------------
# Warm-Up Model Function
# ---------------------------
def warmup_model():
    try:
        get_model_and_tokenizer()
        logging.info("Model warm-up complete.")
    except Exception as e:
        logging.error(f"Model warm-up failed: {e}")

warmup_model()

# ===========================
# Custom Gradio Theme
# ===========================
theme = gr.themes.Soft(
    primary_hue="pink",
    secondary_hue="pink",
    neutral_hue="purple",
    font=['IBM Plex Sans', 'ui-sans-serif', 'system-ui', 'sans-serif'],
).set(
    background_fill_primary='white',
    shadow_drop='rgba(0,0,0,0.05) 0px 1px 2px 0px',
    shadow_drop_lg='0 1px 3px 0 rgb(0 0 0 / 0.1), 0 1px 2px -1px rgb(0 0 0 / 0.1)',
    shadow_spread='3px',
    block_background_fill='*background_fill_primary',
    block_border_width='1px',
    block_border_width_dark='1px',
    block_label_background_fill='*background_fill_primary',
    block_label_background_fill_dark='*background_fill_secondary',
    block_label_text_color='*neutral_500',
    block_label_text_color_dark='*neutral_200',
    block_label_margin='0',
    block_label_padding='*spacing_sm *spacing_lg',
    block_label_radius='calc(*radius_sm - 1px) 0 calc(*radius_sm - 1px) 0',
    block_label_text_size='*text_sm',
    block_label_text_weight='400',
    block_title_background_fill='none',
    block_title_background_fill_dark='none',
    block_title_text_color='*neutral_500',
    block_title_text_color_dark='*neutral_200',
    block_title_padding='0',
    block_title_radius='none',
    block_title_text_weight='400',
    panel_border_width='0',
    panel_border_width_dark='0',
    checkbox_background_color_selected='*color_accent',
    checkbox_background_color_selected_dark='*color_accent',
    checkbox_border_color='*neutral_300',
    checkbox_border_color_dark='*neutral_700',
    checkbox_border_color_focus='*color_accent',
    checkbox_border_color_focus_dark='*color_accent',
    checkbox_border_color_selected='*color_accent',
    checkbox_border_color_selected_dark='*color_accent',
    checkbox_border_width='*input_border_width',
    checkbox_shadow='*input_shadow',
    checkbox_label_background_fill_selected='*checkbox_label_background_fill',
    checkbox_label_background_fill_selected_dark='*checkbox_label_background_fill',
    checkbox_label_shadow='none',
    checkbox_label_text_color_selected='*checkbox_label_text_color',
    input_background_fill='*neutral_100',
    input_border_color='*border_color_primary',
    input_shadow='none',
    input_shadow_dark='none',
    input_shadow_focus='*input_shadow',
    input_shadow_focus_dark='*input_shadow',
    slider_color='*color_accent',
    slider_color_dark='*color_accent',
    button_primary_background_fill_hover='*primary_600',
    button_primary_background_fill_hover_dark='*primary_700',
    button_primary_shadow='none',
    button_primary_shadow_hover='*button_primary_shadow',
    button_primary_shadow_active='*button_primary_shadow',
    button_primary_shadow_dark='none',
    button_secondary_background_fill='*neutral_200',
    button_secondary_background_fill_hover='*neutral_300',
    button_secondary_background_fill_hover_dark='*neutral_700',
    button_secondary_text_color='black',
    button_secondary_shadow='*button_primary_shadow',
    button_secondary_shadow_hover='*button_secondary_shadow',
    button_secondary_shadow_active='*button_secondary_shadow',
    button_secondary_shadow_dark='*button_primary_shadow'
)

# ===========================
# Combined Gradio Interface
# ===========================
with gr.Blocks(theme=theme, title="Combined Voice Clone & Agent Chat") as demo:
    # Shared states for project settings, prompt configuration, and voice selection.
    param_state = gr.State({
        "temperature": 0.5,
        "top_p": 0.9,
        "max_new_tokens": 300,
        "memory_top_k": 2,
        "show_raw_output": False,
        "repetition_penalty": 1.0,
        "num_beams": 1,
        "use_cpu": False  # Toggle for device override
    })
    prompt_state = gr.State(initial_prompt_state)
    selected_voice = gr.State(value="")  # holds the currently selected voice

    # A status display to show device info.
    device_status = gr.Markdown(f"**Running on:** {device.upper()}")

    with gr.Tabs():
        # ----- Tab 1: Voice Setup -----
        with gr.Tab("Voice Setup"):
            gr.Markdown("<h2 style='text-align: center; padding-top: 10px;'>Voice Setup</h2>")
            with gr.Column(variant="panel"):
                gr.Markdown("<p style='text-align: center;'>Clone a voice and save it with a custom name. Test TTS using your cloned voices.</p>")
                with gr.Row():
                    text_input = gr.Textbox(label='Text to Clone', placeholder="Enter the text to speak...", elem_classes="full-width")
                with gr.Row():
                    audio_input = gr.Audio(label='Voice Reference Audio', type='filepath')
                with gr.Row():
                    clone_btn = gr.Button("Clone Voice")
                with gr.Row():
                    output_audio = gr.Audio(label='Cloned Voice Output', type='filepath')
                clone_btn.click(fn=clone, inputs=[text_input, audio_input], outputs=output_audio)
                with gr.Row():
                    voice_name_input = gr.Textbox(label="Voice Name", placeholder="Enter a name for this voice clone")
                with gr.Row():
                    save_voice_btn = gr.Button("Save Voice")
                save_voice_btn.click(fn=save_voice, inputs=[voice_name_input, output_audio], outputs=[])
                with gr.Row():
                    refresh_voice_btn_setup = gr.Button("Refresh Voice List")
                    voice_dropdown_setup = gr.Dropdown(choices=get_voice_options(), label="Select Saved Voice", interactive=True)
                    set_voice_btn = gr.Button("Set Selected Voice")
                refresh_voice_btn_setup.click(fn=refresh_voice_list, outputs=voice_dropdown_setup)
                set_voice_btn.click(fn=lambda x: x, inputs=[voice_dropdown_setup], outputs=selected_voice)
                gr.Markdown("<p style='text-align: center;'>(The selected voice will be used for TTS responses in Chat.)</p>")
                gr.Markdown("<hr>")
                gr.Markdown("<h3 style='text-align: center;'>TTS Test</h3>")
                with gr.Row():
                    tts_test_input = gr.Textbox(label="Test Text", placeholder="Enter text to test TTS...", elem_classes="full-width")
                with gr.Row():
                    tts_test_btn = gr.Button("Test TTS")
                    tts_test_output = gr.Audio(label="TTS Output", type="filepath")
                tts_test_btn.click(fn=lambda txt, override, sel: generate_agent_audio(txt, override if override else sel),
                                   inputs=[tts_test_input, audio_input, selected_voice],
                                   outputs=tts_test_output)

        # ----- Tab 2: Chat -----
        with gr.Tab("Chat"):
            gr.Markdown("""
                <div style="text-align: center; padding: 10px;">
                  <h1>DeepSeek Agent Swarm Chat</h1>
                  <p>Multi-round agent with prompt chaining. Ask me anything!</p>
                </div>
            """)
            with gr.Column():
                with gr.Row():
                    mode_selector = gr.Radio(choices=["coding", "math", "writing"], value="coding", label="Select Mode")
                with gr.Row():
                    chat_input = gr.Textbox(label="Your Message", placeholder="Type your message here...", lines=2, elem_id="msg_input")
                with gr.Row():
                    chat_audio_input = gr.Audio(label="Or record/upload your message", type="filepath")
                    transcribe_btn = gr.Button("Transcribe Audio")
                transcribe_btn.click(fn=transcribe_audio, inputs=chat_audio_input, outputs=chat_input)
                with gr.Row():
                    send_btn = gr.Button("Send", variant="primary")
                export_btn = gr.Button("Generate Chat Transcript")
                chatbot = gr.Chatbot(height=450, label="Agent Swarm Output", type="messages")
                with gr.Row():
                    use_tts_checkbox = gr.Checkbox(label="Generate Audio Response using TTS", value=False)
                    chat_voice_dropdown = gr.Dropdown(choices=get_voice_options(), label="Select Voice for TTS", interactive=True)
                    refresh_voice_btn_chat = gr.Button("Refresh Voice List")
                refresh_voice_btn_chat.click(fn=refresh_voice_list, outputs=chat_voice_dropdown)
                agent_audio = gr.Audio(label="Agent Audio Response", type="filepath")

                def chat_wrapper(message, history, param_state, prompt_state, mode):
                    final_history = []
                    history.append(["", "**Generating response...**"])
                    for h in gradio_interface(message, history, param_state, prompt_state, mode):
                        final_history = h
                    return final_history

                send_btn.click(fn=chat_wrapper,
                               inputs=[chat_input, chatbot, param_state, prompt_state, mode_selector],
                               outputs=[chatbot])

                def conditional_tts(latest_text, use_tts, selected_voice_val):
                    if use_tts:
                        return generate_agent_audio(latest_text, selected_voice_val)
                    return None

                def get_latest_text(chat_history):
                    for msg in reversed(chat_history):
                        if msg.get("role") == "assistant" and msg.get("content"):
                            return msg["content"]
                    return ""
                latest_text_state = gr.State(value="")
                gen_audio_btn = gr.Button("Generate Audio from Agent Response")
                gen_audio_btn.click(fn=lambda chat: get_latest_text(chat),
                                    inputs=[chatbot],
                                    outputs=latest_text_state)
                gen_audio_btn.click(fn=conditional_tts,
                                    inputs=[latest_text_state, use_tts_checkbox, chat_voice_dropdown],
                                    outputs=agent_audio)
                def export_transcript(history):
                    transcript = ""
                    for item in history:
                        if isinstance(item, list) and len(item) == 2:
                            transcript += f"User: {item[0]}\nAssistant: {item[1]}\n\n"
                    return transcript
                export_btn.click(fn=export_transcript, inputs=[chatbot], outputs=chatbot)

        # ----- Tab 3: Project Settings -----
        with gr.Tab("Project Settings"):
            gr.Markdown("<h2 style='text-align: center;'>Project Settings</h2>")
            with gr.Tabs():
                with gr.Tab("Generation Parameters"):
                    gr.Markdown("<h3>Generation Parameters</h3>")
                    with gr.Row():
                        temp_slider = gr.Slider(minimum=0.1, maximum=1.0, step=0.1, value=0.5, label="Temperature")
                        top_p_slider = gr.Slider(minimum=0.01, maximum=1.0, step=0.05, value=0.9, label="Top P")
                    with gr.Row():
                        max_tokens_num = gr.Number(value=300, label="Max New Tokens", precision=0)
                        memory_topk_slider = gr.Slider(minimum=1, maximum=5, step=1, value=2, label="Memory Retrieval Top K")
                    with gr.Row():
                        rep_penalty_slider = gr.Slider(minimum=0.5, maximum=2.0, step=0.1, value=1.0, label="Repetition Penalty")
                        num_beams_slider = gr.Slider(minimum=1, maximum=5, step=1, value=1, label="Number of Beams")
                    with gr.Row():
                        show_raw_checkbox = gr.Checkbox(value=False, label="Show Raw Output")
                        use_cpu_checkbox = gr.Checkbox(value=False, label="Force Use CPU")
                    save_params_btn = gr.Button("Save Generation Parameters")
                    def save_params(t, p, m, k, rp, nb, s, use_cpu):
                        global device
                        if use_cpu:
                            device = "cpu"
                        else:
                            device = "cuda" if torch.cuda.is_available() else "cpu"
                        return {
                            "temperature": t,
                            "top_p": p,
                            "max_new_tokens": m,
                            "memory_top_k": k,
                            "repetition_penalty": rp,
                            "num_beams": nb,
                            "show_raw_output": s,
                            "use_cpu": use_cpu
                        }
                    save_params_btn.click(
                        save_params,
                        inputs=[temp_slider, top_p_slider, max_tokens_num, memory_topk_slider, rep_penalty_slider, num_beams_slider, show_raw_checkbox, use_cpu_checkbox],
                        outputs=param_state,
                    )
                    save_params_btn.click(fn=lambda params: f"**Running on:** {device.upper()}", inputs=param_state, outputs=device_status)
                    gr.Markdown("Note: Repetition penalty and number of beams affect generation diversity and quality.")
                with gr.Tab("Prompt Config (Default Modes)"):
                    gr.Markdown("<h3>Prompt Configurations for Default Modes</h3>")
                    with gr.Tabs():
                        with gr.Tab("Coding"):
                            prompt_brainstorm_box_code = gr.Textbox(
                                value=default_prompts["coding"]["brainstorm"],
                                label="Brainstorm Prompt (Coding)",
                                lines=8,
                            )
                            prompt_round2_box_code = gr.Textbox(
                                value=default_prompts["coding"]["round2"],
                                label="Round 2 Prompt (Coding)",
                                lines=8,
                            )
                            prompt_synthesis_box_code = gr.Textbox(
                                value=default_prompts["coding"]["synthesis"],
                                label="Synthesis Prompt (Coding)",
                                lines=8,
                            )
                            prompt_rationale_box_code = gr.Textbox(
                                value=default_prompts["coding"]["rationale"],
                                label="Rationale Prompt (Coding)",
                                lines=8,
                            )
                        with gr.Tab("Math"):
                            prompt_brainstorm_box_math = gr.Textbox(
                                value=default_prompts["math"]["brainstorm"],
                                label="Brainstorm Prompt (Math)",
                                lines=8,
                            )
                            prompt_round2_box_math = gr.Textbox(
                                value=default_prompts["math"]["round2"],
                                label="Round 2 Prompt (Math)",
                                lines=8,
                            )
                            prompt_synthesis_box_math = gr.Textbox(
                                value=default_prompts["math"]["synthesis"],
                                label="Synthesis Prompt (Math)",
                                lines=8,
                            )
                            prompt_rationale_box_math = gr.Textbox(
                                value=default_prompts["math"]["rationale"],
                                label="Rationale Prompt (Math)",
                                lines=8,
                            )
                        with gr.Tab("Writing"):
                            prompt_brainstorm_box_writing = gr.Textbox(
                                value=default_prompts["writing"]["brainstorm"],
                                label="Brainstorm Prompt (Writing)",
                                lines=8,
                            )
                            prompt_round2_box_writing = gr.Textbox(
                                value=default_prompts["writing"]["round2"],
                                label="Round 2 Prompt (Writing)",
                                lines=8,
                            )
                            prompt_synthesis_box_writing = gr.Textbox(
                                value=default_prompts["writing"]["synthesis"],
                                label="Synthesis Prompt (Writing)",
                                lines=8,
                            )
                            prompt_rationale_box_writing = gr.Textbox(
                                value=default_prompts["writing"]["rationale"],
                                label="Rationale Prompt (Writing)",
                                lines=8,
                            )
                    save_prompts_btn = gr.Button("Save Default Prompt Configurations")
                    def save_default_prompts(code_brain, code_r2, code_syn, code_rat, math_brain, math_r2, math_syn, math_rat, writing_brain, writing_r2, writing_syn, writing_rat):
                        return {
                            "default": {
                                "coding": {
                                    "brainstorm": code_brain,
                                    "round2": code_r2,
                                    "synthesis": code_syn,
                                    "rationale": code_rat,
                                },
                                "math": {
                                    "brainstorm": math_brain,
                                    "round2": math_r2,
                                    "synthesis": math_syn,
                                    "rationale": math_rat,
                                },
                                "writing": {
                                    "brainstorm": writing_brain,
                                    "round2": writing_r2,
                                    "synthesis": writing_syn,
                                    "rationale": writing_rat,
                                }
                            },
                            "custom": prompt_state.value.get("custom", {})
                        }
                    save_prompts_btn.click(
                        save_default_prompts,
                        inputs=[prompt_brainstorm_box_code, prompt_round2_box_code, prompt_synthesis_box_code, prompt_rationale_box_code,
                                prompt_brainstorm_box_math, prompt_round2_box_math, prompt_synthesis_box_math, prompt_rationale_box_math,
                                prompt_brainstorm_box_writing, prompt_round2_box_writing, prompt_synthesis_box_writing, prompt_rationale_box_writing],
                        outputs=prompt_state,
                    )
                with gr.Tab("Custom Modes"):
                    gr.Markdown("<h3>Create / Edit Custom Modes</h3>")
                    gr.Markdown(
                        "Define a custom mode by providing a unique mode name, selecting the number of rounds (up to 10), "
                        "and editing the prompt for each round. In custom mode prompts, you can use the placeholders `{user_prompt}` "
                        "(for the first round) and `{prev_response}` (for subsequent rounds)."
                    )
                    with gr.Row():
                        custom_mode_name = gr.Textbox(label="Custom Mode Name", placeholder="Enter a unique mode name")
                        custom_round_count = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Number of Rounds")
                    custom_round1 = gr.Textbox(label="Round 1 Prompt", lines=4, placeholder="e.g., Use {user_prompt} here")
                    custom_round2 = gr.Textbox(label="Round 2 Prompt", lines=4, placeholder="e.g., Use {user_prompt} and {prev_response}")
                    custom_round3 = gr.Textbox(label="Round 3 Prompt", lines=4, placeholder="e.g., Use {user_prompt} and {prev_response}")
                    custom_round4 = gr.Textbox(label="Round 4 Prompt", lines=4, placeholder="Optional")
                    custom_round5 = gr.Textbox(label="Round 5 Prompt", lines=4, placeholder="Optional")
                    custom_round6 = gr.Textbox(label="Round 6 Prompt", lines=4, placeholder="Optional")
                    custom_round7 = gr.Textbox(label="Round 7 Prompt", lines=4, placeholder="Optional")
                    custom_round8 = gr.Textbox(label="Round 8 Prompt", lines=4, placeholder="Optional")
                    custom_round9 = gr.Textbox(label="Round 9 Prompt", lines=4, placeholder="Optional")
                    custom_round10 = gr.Textbox(label="Round 10 Prompt", lines=4, placeholder="Optional")

                    def save_custom_mode(name, round_count, r1, r2, r3, r4, r5, r6, r7, r8, r9, r10, current_prompt_state):
                        if not name:
                            return gr.update(), current_prompt_state
                        rounds = []
                        round_prompts = [r1, r2, r3, r4, r5, r6, r7, r8, r9, r10]
                        for i in range(round_count):
                            if round_prompts[i].strip():
                                rounds.append(round_prompts[i])
                        custom_modes = current_prompt_state.get("custom", {})
                        custom_modes[name] = rounds
                        new_prompt_state = {
                            "default": current_prompt_state.get("default", {}),
                            "custom": custom_modes
                        }
                        return gr.update(value=""), new_prompt_state

                    save_custom_mode_btn = gr.Button("Save Custom Mode")
                    save_custom_mode_btn.click(
                        save_custom_mode,
                        inputs=[custom_mode_name, custom_round_count, custom_round1, custom_round2, custom_round3, custom_round4,
                                custom_round5, custom_round6, custom_round7, custom_round8, custom_round9, custom_round10,
                                prompt_state],
                        outputs=[custom_mode_name, prompt_state]
                    )

                    def update_mode_choices(current_prompt_state):
                        default_modes = list(current_prompt_state.get("default", {}).keys())
                        custom_modes = list(current_prompt_state.get("custom", {}).keys())
                        all_modes = default_modes + custom_modes
                        default_choice = default_modes[0] if default_modes else (custom_modes[0] if custom_modes else "")
                        return gr.update(choices=all_modes, value=default_choice)

                    refresh_mode_selector_btn = gr.Button("Refresh Mode List")
                    refresh_mode_selector_btn.click(fn=update_mode_choices, inputs=prompt_state, outputs=mode_selector)

            gr.Markdown("<hr>")
            gr.Markdown("<p style='text-align: center;'>These settings affect the entire project.</p>")

    gr.Markdown("<hr><p style='text-align: center;'>Agent Chat using DeepSeek Agent Swarm</p>")

if __name__ == "__main__":
    demo.launch(share=True)