File size: 5,318 Bytes
c069edf
6133a63
f8544e9
7fa4c88
05c34a8
9f559e5
fadc2ea
b9d94dc
fadc2ea
6133a63
f8544e9
 
e0f7998
c9eef99
f8544e9
 
c069edf
b9d94dc
c9eef99
e0f7998
7fa4c88
 
b9d94dc
f8544e9
 
 
 
 
 
 
 
 
 
 
 
 
7fa4c88
 
 
e4165c8
4f21ff8
9b9825b
05c34a8
9b9825b
d9f52bb
9b9825b
 
 
 
e4165c8
 
b9d94dc
 
4f21ff8
f8544e9
6133a63
d9f52bb
9fd4d92
6133a63
 
05c34a8
f8544e9
05c34a8
 
7fa4c88
 
 
 
 
f8544e9
f8c3935
 
6133a63
f8c3935
e372f0d
7fa4c88
f8544e9
 
 
 
6133a63
b9d94dc
9533a0b
cf4d675
e4165c8
db2e73b
5cd11f8
f8544e9
 
 
 
 
 
 
5cd11f8
c069edf
b9d94dc
f8544e9
 
 
 
 
 
 
b9d94dc
e033c91
b9d94dc
 
 
fadc2ea
 
db2e73b
b9d94dc
 
fadc2ea
 
b9d94dc
 
9f559e5
506360c
 
 
b9d94dc
506360c
 
b9d94dc
506360c
9f559e5
fc968b1
 
 
 
 
 
 
 
365f24d
 
c3629a1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import os
import gc
import io
from llama_cpp import Llama
from concurrent.futures import ThreadPoolExecutor, as_completed
import gradio as gr
from fastapi import FastAPI, Request, HTTPException
from fastapi.responses import JSONResponse
from tqdm import tqdm
from dotenv import load_dotenv
from pydantic import BaseModel
import asyncio
from huggingface_hub import login

load_dotenv()
os.system("pip install --upgrade llama-cpp-python")

app = FastAPI()
HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
login(token=HUGGINGFACE_TOKEN)

global_data = {
    'model_configs': [
        {"repo_id": "Ffftdtd5dtft/gpt2-xl-Q2_K-GGUF", "name": "GPT-2 XL"},
        {"repo_id": "Ffftdtd5dtft/gemma-2-27b-Q2_K-GGUF", "name": "Gemma 2-27B"},
        {"repo_id": "Ffftdtd5dtft/Phi-3-mini-128k-instruct-Q2_K-GGUF", "name": "Phi-3 Mini 128K Instruct"},
        {"repo_id": "Ffftdtd5dtft/starcoder2-3b-Q2_K-GGUF", "name": "Starcoder2 3B"},
        {"repo_id": "Ffftdtd5dtft/Qwen2-1.5B-Instruct-Q2_K-GGUF", "name": "Qwen2 1.5B Instruct"},
        {"repo_id": "Ffftdtd5dtft/Mistral-Nemo-Instruct-2407-Q2_K-GGUF", "name": "Mistral Nemo Instruct 2407"},
        {"repo_id": "Ffftdtd5dtft/Phi-3-mini-128k-instruct-IQ2_XXS-GGUF", "name": "Phi 3 Mini 128K Instruct XXS"},
        {"repo_id": "Ffftdtd5dtft/TinyLlama-1.1B-Chat-v1.0-IQ1_S-GGUF", "name": "TinyLlama 1.1B Chat"},
        {"repo_id": "Ffftdtd5dtft/Meta-Llama-3.1-8B-Q2_K-GGUF", "name": "Meta Llama 3.1-8B"},
        {"repo_id": "Ffftdtd5dtft/codegemma-2b-IQ1_S-GGUF", "name": "Codegemma 2B"},
    ],
    'training_data': io.StringIO(),
    'auto_train_threshold': 10
}

class ModelManager:
    def __init__(self):
        self.models = {}
        self.load_models_once()

    def load_models_once(self):
        if not self.models:
            with ThreadPoolExecutor(max_workers=len(global_data['model_configs'])) as executor:
                futures = [executor.submit(self._load_model, config) for config in tqdm(global_data['model_configs'], desc="Loading models")]
                for future in tqdm(as_completed(futures), total=len(global_data['model_configs']), desc="Loading models complete"):
                    future.result()

    def _load_model(self, model_config):
        model_name = model_config['name']
        if model_name not in self.models:
            try:
                model = Llama.from_pretrained(repo_id=model_config['repo_id'], use_auth_token=HUGGINGFACE_TOKEN)
                self.models[model_name] = model
            except Exception:
                self.models[model_name] = None
            finally:
                gc.collect()

    def get_model(self, model_name: str):
        return self.models.get(model_name)

model_manager = ModelManager()

class ChatRequest(BaseModel):
    message: str

async def generate_model_response(model, inputs: str) -> str:
    try:
        response = model(inputs, max_tokens=150)
        return response['choices'][0]['text']
    except Exception as e:
        return f"Error: Could not generate a response. Details: {e}"

interaction_count = 0

async def process_message(message: str) -> str:
    global interaction_count
    inputs = message.strip()
    responses = {}

    with ThreadPoolExecutor(max_workers=len(global_data['model_configs'])) as executor:
        futures = [executor.submit(generate_model_response, model_manager.get_model(config['name']), inputs) for config in global_data['model_configs'] if model_manager.get_model(config['name'])]
        for i, future in enumerate(tqdm(as_completed(futures), total=len([f for f in futures]), desc="Generating responses")):
            model_name = global_data['model_configs'][i]['name']
            responses[model_name] = await future

    interaction_count += 1

    if interaction_count >= global_data['auto_train_threshold']:
        await auto_train_model()
        interaction_count = 0

    return "\n\n".join([f"**{model}:**\n{response}" for model, response in responses.items()])

async def auto_train_model():
    training_data_content = global_data['training_data'].getvalue()
    if training_data_content:
        print("Auto training model with the following data:")
        print(training_data_content)
        await asyncio.sleep(1)

@app.post("/generate_multimodel")
# @spaces.GPU()  # Eliminar temporalmente o comentar si causa error
async def api_generate_multimodel(request: Request):
    try:
        data = await request.json()
        message = data.get("message")
        if not message:
            raise HTTPException(status_code=400, detail="Missing message")
        response = await process_message(message)
        return JSONResponse({"response": response})
    except HTTPException as e:
        raise e
    except Exception as e:
        return JSONResponse({"error": str(e)}, status_code=500)

iface = gr.Interface(
    fn=process_message,
    inputs=gr.Textbox(lines=2, placeholder="Enter your message here..."),
    outputs=gr.Markdown(),
    title="Multi-Model LLM API",
    description="Enter a message and get responses from multiple LLMs.",
    live=False
)

@app.on_event("startup")
async def startup_event():
    pass

@app.on_event("shutdown")
async def shutdown_event():
    gc.collect()

if __name__ == "__main__":
    port = int(os.environ.get("PORT", 7860))
    iface.launch(server_port=port)