Yhhxhfh commited on
Commit
c069edf
1 Parent(s): 6838103

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -26
app.py CHANGED
@@ -1,27 +1,27 @@
 
1
  from pydantic import BaseModel
2
  from llama_cpp import Llama
3
  from concurrent.futures import ThreadPoolExecutor, as_completed
4
  import re
5
  import gradio as gr
6
- import os
7
- import urllib3
8
- import pickle
9
- from functools import lru_cache
10
  from dotenv import load_dotenv
11
  from fastapi import FastAPI, Request, HTTPException
12
  from fastapi.responses import JSONResponse
13
  from tqdm import tqdm
 
 
14
 
15
  urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
16
 
 
 
 
17
  app = FastAPI()
18
  load_dotenv()
19
  HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
20
 
 
21
  global_data = {
22
- 'tokens': {'eos': 'eos_token', 'pad': 'pad_token', 'padding': 'padding_token',
23
- 'unk': 'unk_token', 'bos': 'bos_token', 'sep': 'sep_token',
24
- 'cls': 'cls_token', 'mask': 'mask_token'},
25
  'model_configs': [
26
  {"repo_id": "Ffftdtd5dtft/gpt2-xl-Q2_K-GGUF", "filename": "gpt2-xl-q2_k.gguf", "name": "GPT-2 XL"},
27
  {"repo_id": "Ffftdtd5dtft/gemma-2-27b-Q2_K-GGUF", "filename": "gemma-2-27b-q2_k.gguf", "name": "Gemma 2-27B"},
@@ -36,14 +36,10 @@ global_data = {
36
  ]
37
  }
38
 
39
- response_cache = {}
40
- model_cache_dir = "model_cache"
41
- os.makedirs(model_cache_dir, exist_ok=True)
42
-
43
  class ModelManager:
44
  def __init__(self):
45
  self.models = {}
46
- self.model_cache_dir = model_cache_dir
47
  self.load_all_models()
48
 
49
  def load_all_models(self):
@@ -54,16 +50,9 @@ class ModelManager:
54
 
55
  def _load_model(self, model_config):
56
  model_name = model_config['name']
57
- cache_file = os.path.join(self.model_cache_dir, f"{model_name}.pkl")
58
  if model_name not in self.models:
59
  try:
60
- if os.path.exists(cache_file):
61
- with open(cache_file, "rb") as f:
62
- self.models[model_name] = pickle.load(f)
63
- else:
64
- self.models[model_name] = Llama.from_pretrained(repo_id=model_config['repo_id'], filename=model_config['filename'], use_auth_token=HUGGINGFACE_TOKEN)
65
- with open(cache_file, "wb") as f:
66
- pickle.dump(self.models[model_name], f)
67
  except Exception as e:
68
  print(f"Error loading {model_name}: {e}")
69
  self.models[model_name] = None
@@ -76,9 +65,11 @@ model_manager = ModelManager()
76
  class ChatRequest(BaseModel):
77
  message: str
78
 
 
79
  def normalize_input(input_text):
80
  return input_text.strip()
81
 
 
82
  def remove_duplicates(text):
83
  text = re.sub(r'(Hello there, how are you\? \[/INST\]){2,}', 'Hello there, how are you?', text)
84
  text = re.sub(r'(How are you\? \[/INST\]){2,}', 'How are you?', text)
@@ -92,6 +83,7 @@ def remove_duplicates(text):
92
  seen_lines.add(line)
93
  return '\n'.join(unique_lines)
94
 
 
95
  @lru_cache(maxsize=128)
96
  def generate_model_response(model, inputs):
97
  try:
@@ -101,11 +93,9 @@ def generate_model_response(model, inputs):
101
  print(f"Error generating response: {e}")
102
  return f"Error: Could not generate a response. Details: {e}"
103
 
 
104
  async def process_message(message):
105
  inputs = normalize_input(message)
106
- if inputs in response_cache:
107
- return response_cache[inputs]
108
-
109
  responses = {}
110
 
111
  with ThreadPoolExecutor(max_workers=len(global_data['model_configs'])) as executor:
@@ -114,9 +104,7 @@ async def process_message(message):
114
  model_name = global_data['model_configs'][i]['name']
115
  responses[model_name] = future.result()
116
 
117
- formatted_response = "\n\n".join([f"**{model}:**\n{response}" for model, response in responses.items()])
118
- response_cache[inputs] = formatted_response
119
- return formatted_response
120
 
121
  @app.post("/generate_multimodel")
122
  async def api_generate_multimodel(request: Request):
@@ -132,6 +120,7 @@ async def api_generate_multimodel(request: Request):
132
  except Exception as e:
133
  return JSONResponse({"error": str(e)}, status_code=500)
134
 
 
135
  iface = gr.Interface(
136
  fn=process_message,
137
  inputs=gr.Textbox(lines=2, placeholder="Enter your message here..."),
@@ -141,6 +130,7 @@ iface = gr.Interface(
141
  live=False
142
  )
143
 
 
144
  if __name__ == "__main__":
145
  port = int(os.environ.get("PORT", 7860))
146
  iface.launch(server_port=port)
 
1
+ import os
2
  from pydantic import BaseModel
3
  from llama_cpp import Llama
4
  from concurrent.futures import ThreadPoolExecutor, as_completed
5
  import re
6
  import gradio as gr
 
 
 
 
7
  from dotenv import load_dotenv
8
  from fastapi import FastAPI, Request, HTTPException
9
  from fastapi.responses import JSONResponse
10
  from tqdm import tqdm
11
+ from functools import lru_cache
12
+ import urllib3
13
 
14
  urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
15
 
16
+ # Instalar la librería llama-cpp-python
17
+ os.system("pip install llama-cpp-python")
18
+
19
  app = FastAPI()
20
  load_dotenv()
21
  HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
22
 
23
+ # Configuración de modelos globales
24
  global_data = {
 
 
 
25
  'model_configs': [
26
  {"repo_id": "Ffftdtd5dtft/gpt2-xl-Q2_K-GGUF", "filename": "gpt2-xl-q2_k.gguf", "name": "GPT-2 XL"},
27
  {"repo_id": "Ffftdtd5dtft/gemma-2-27b-Q2_K-GGUF", "filename": "gemma-2-27b-q2_k.gguf", "name": "Gemma 2-27B"},
 
36
  ]
37
  }
38
 
39
+ # Manejo de la carga de modelos
 
 
 
40
  class ModelManager:
41
  def __init__(self):
42
  self.models = {}
 
43
  self.load_all_models()
44
 
45
  def load_all_models(self):
 
50
 
51
  def _load_model(self, model_config):
52
  model_name = model_config['name']
 
53
  if model_name not in self.models:
54
  try:
55
+ self.models[model_name] = Llama.from_pretrained(repo_id=model_config['repo_id'], filename=model_config['filename'], use_auth_token=HUGGINGFACE_TOKEN)
 
 
 
 
 
 
56
  except Exception as e:
57
  print(f"Error loading {model_name}: {e}")
58
  self.models[model_name] = None
 
65
  class ChatRequest(BaseModel):
66
  message: str
67
 
68
+ # Normalización de entrada
69
  def normalize_input(input_text):
70
  return input_text.strip()
71
 
72
+ # Eliminación de duplicados en la respuesta
73
  def remove_duplicates(text):
74
  text = re.sub(r'(Hello there, how are you\? \[/INST\]){2,}', 'Hello there, how are you?', text)
75
  text = re.sub(r'(How are you\? \[/INST\]){2,}', 'How are you?', text)
 
83
  seen_lines.add(line)
84
  return '\n'.join(unique_lines)
85
 
86
+ # Generación de respuesta de modelo
87
  @lru_cache(maxsize=128)
88
  def generate_model_response(model, inputs):
89
  try:
 
93
  print(f"Error generating response: {e}")
94
  return f"Error: Could not generate a response. Details: {e}"
95
 
96
+ # Procesamiento del mensaje
97
  async def process_message(message):
98
  inputs = normalize_input(message)
 
 
 
99
  responses = {}
100
 
101
  with ThreadPoolExecutor(max_workers=len(global_data['model_configs'])) as executor:
 
104
  model_name = global_data['model_configs'][i]['name']
105
  responses[model_name] = future.result()
106
 
107
+ return "\n\n".join([f"**{model}:**\n{response}" for model, response in responses.items()])
 
 
108
 
109
  @app.post("/generate_multimodel")
110
  async def api_generate_multimodel(request: Request):
 
120
  except Exception as e:
121
  return JSONResponse({"error": str(e)}, status_code=500)
122
 
123
+ # Interfaz Gradio
124
  iface = gr.Interface(
125
  fn=process_message,
126
  inputs=gr.Textbox(lines=2, placeholder="Enter your message here..."),
 
130
  live=False
131
  )
132
 
133
+ # Lanzar servidor
134
  if __name__ == "__main__":
135
  port = int(os.environ.get("PORT", 7860))
136
  iface.launch(server_port=port)