Yhhxhfh commited on
Commit
6133a63
1 Parent(s): c069edf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -32
app.py CHANGED
@@ -1,26 +1,24 @@
1
  import os
2
- from pydantic import BaseModel
 
3
  from llama_cpp import Llama
4
  from concurrent.futures import ThreadPoolExecutor, as_completed
5
- import re
6
  import gradio as gr
7
- from dotenv import load_dotenv
8
  from fastapi import FastAPI, Request, HTTPException
9
  from fastapi.responses import JSONResponse
10
  from tqdm import tqdm
 
11
  from functools import lru_cache
12
  import urllib3
13
 
14
  urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
15
 
16
- # Instalar la librería llama-cpp-python
17
  os.system("pip install llama-cpp-python")
18
 
19
  app = FastAPI()
20
  load_dotenv()
21
  HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
22
 
23
- # Configuración de modelos globales
24
  global_data = {
25
  'model_configs': [
26
  {"repo_id": "Ffftdtd5dtft/gpt2-xl-Q2_K-GGUF", "filename": "gpt2-xl-q2_k.gguf", "name": "GPT-2 XL"},
@@ -36,7 +34,6 @@ global_data = {
36
  ]
37
  }
38
 
39
- # Manejo de la carga de modelos
40
  class ModelManager:
41
  def __init__(self):
42
  self.models = {}
@@ -52,10 +49,16 @@ class ModelManager:
52
  model_name = model_config['name']
53
  if model_name not in self.models:
54
  try:
55
- self.models[model_name] = Llama.from_pretrained(repo_id=model_config['repo_id'], filename=model_config['filename'], use_auth_token=HUGGINGFACE_TOKEN)
 
 
 
 
56
  except Exception as e:
57
  print(f"Error loading {model_name}: {e}")
58
  self.models[model_name] = None
 
 
59
 
60
  def get_model(self, model_name):
61
  return self.models.get(model_name)
@@ -65,37 +68,16 @@ model_manager = ModelManager()
65
  class ChatRequest(BaseModel):
66
  message: str
67
 
68
- # Normalización de entrada
69
- def normalize_input(input_text):
70
- return input_text.strip()
71
-
72
- # Eliminación de duplicados en la respuesta
73
- def remove_duplicates(text):
74
- text = re.sub(r'(Hello there, how are you\? \[/INST\]){2,}', 'Hello there, how are you?', text)
75
- text = re.sub(r'(How are you\? \[/INST\]){2,}', 'How are you?', text)
76
- text = text.replace('[/INST]', '')
77
- lines = text.split('\n')
78
- unique_lines = []
79
- seen_lines = set()
80
- for line in lines:
81
- if line not in seen_lines:
82
- unique_lines.append(line)
83
- seen_lines.add(line)
84
- return '\n'.join(unique_lines)
85
-
86
- # Generación de respuesta de modelo
87
  @lru_cache(maxsize=128)
88
  def generate_model_response(model, inputs):
89
  try:
90
  response = model(inputs, max_tokens=150)
91
- return remove_duplicates(response['choices'][0]['text'])
92
  except Exception as e:
93
- print(f"Error generating response: {e}")
94
  return f"Error: Could not generate a response. Details: {e}"
95
 
96
- # Procesamiento del mensaje
97
  async def process_message(message):
98
- inputs = normalize_input(message)
99
  responses = {}
100
 
101
  with ThreadPoolExecutor(max_workers=len(global_data['model_configs'])) as executor:
@@ -120,7 +102,6 @@ async def api_generate_multimodel(request: Request):
120
  except Exception as e:
121
  return JSONResponse({"error": str(e)}, status_code=500)
122
 
123
- # Interfaz Gradio
124
  iface = gr.Interface(
125
  fn=process_message,
126
  inputs=gr.Textbox(lines=2, placeholder="Enter your message here..."),
@@ -130,7 +111,6 @@ iface = gr.Interface(
130
  live=False
131
  )
132
 
133
- # Lanzar servidor
134
  if __name__ == "__main__":
135
  port = int(os.environ.get("PORT", 7860))
136
  iface.launch(server_port=port)
 
1
  import os
2
+ import gc
3
+ import tempfile
4
  from llama_cpp import Llama
5
  from concurrent.futures import ThreadPoolExecutor, as_completed
 
6
  import gradio as gr
 
7
  from fastapi import FastAPI, Request, HTTPException
8
  from fastapi.responses import JSONResponse
9
  from tqdm import tqdm
10
+ from dotenv import load_dotenv
11
  from functools import lru_cache
12
  import urllib3
13
 
14
  urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
15
 
 
16
  os.system("pip install llama-cpp-python")
17
 
18
  app = FastAPI()
19
  load_dotenv()
20
  HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
21
 
 
22
  global_data = {
23
  'model_configs': [
24
  {"repo_id": "Ffftdtd5dtft/gpt2-xl-Q2_K-GGUF", "filename": "gpt2-xl-q2_k.gguf", "name": "GPT-2 XL"},
 
34
  ]
35
  }
36
 
 
37
  class ModelManager:
38
  def __init__(self):
39
  self.models = {}
 
49
  model_name = model_config['name']
50
  if model_name not in self.models:
51
  try:
52
+ tempdir = tempfile.TemporaryDirectory()
53
+ filepath = os.path.join(tempdir.name, model_config['filename'])
54
+ model = Llama.from_pretrained(repo_id=model_config['repo_id'], filename=model_config['filename'], use_auth_token=HUGGINGFACE_TOKEN)
55
+ self.models[model_name] = model
56
+ model.model.model_path = filepath
57
  except Exception as e:
58
  print(f"Error loading {model_name}: {e}")
59
  self.models[model_name] = None
60
+ finally:
61
+ gc.collect()
62
 
63
  def get_model(self, model_name):
64
  return self.models.get(model_name)
 
68
  class ChatRequest(BaseModel):
69
  message: str
70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  @lru_cache(maxsize=128)
72
  def generate_model_response(model, inputs):
73
  try:
74
  response = model(inputs, max_tokens=150)
75
+ return response['choices'][0]['text']
76
  except Exception as e:
 
77
  return f"Error: Could not generate a response. Details: {e}"
78
 
 
79
  async def process_message(message):
80
+ inputs = message.strip()
81
  responses = {}
82
 
83
  with ThreadPoolExecutor(max_workers=len(global_data['model_configs'])) as executor:
 
102
  except Exception as e:
103
  return JSONResponse({"error": str(e)}, status_code=500)
104
 
 
105
  iface = gr.Interface(
106
  fn=process_message,
107
  inputs=gr.Textbox(lines=2, placeholder="Enter your message here..."),
 
111
  live=False
112
  )
113
 
 
114
  if __name__ == "__main__":
115
  port = int(os.environ.get("PORT", 7860))
116
  iface.launch(server_port=port)