Spaces:
Runtime error
Runtime error
from pydantic import BaseModel | |
from llama_cpp import Llama | |
import os | |
import gradio as gr # Not suitable for production | |
from dotenv import load_dotenv | |
from fastapi import FastAPI, Request | |
from fastapi.responses import StreamingResponse | |
import spaces | |
import asyncio | |
import random | |
#from llama_cpp.tokenizers import LlamaTokenizer | |
from peft import PeftModel, LoraConfig, get_peft_model | |
import torch | |
from multiprocessing import Process, Queue | |
from google.cloud import storage | |
import json | |
app = FastAPI() | |
load_dotenv() | |
HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN") | |
GOOGLE_CLOUD_BUCKET = os.getenv("GOOGLE_CLOUD_BUCKET") | |
GOOGLE_CLOUD_CREDENTIALS = os.getenv("GOOGLE_CLOUD_CREDENTIALS") | |
gcp_credentials = json.loads(GOOGLE_CLOUD_CREDENTIALS) | |
storage_client = storage.Client.from_service_account_info(gcp_credentials) | |
bucket = storage_client.bucket(GOOGLE_CLOUD_BUCKET) | |
MODEL_NAMES = { | |
"starcoder": "starcoder2-3b-q2_k.gguf", | |
"gemma_2b_it": "gemma-2-2b-it-q2_k.gguf", | |
"llama_3_2_1b": "Llama-3.2-1B.Q2_K.gguf", | |
"gemma_2b_imat": "gemma-2-2b-iq1_s-imat.gguf", | |
"phi_3_mini": "phi-3-mini-128k-instruct-iq2_xxs-imat.gguf", | |
"qwen2_0_5b": "qwen2-0.5b-iq1_s-imat.gguf", | |
} | |
class ModelManager: | |
def __init__(self): | |
self.params = {"n_ctx": 2048, "n_batch": 512, "n_predict": 512, "repeat_penalty": 1.1, "n_threads": 1, "seed": -1, "stop": ["</s>"], "tokens": []} | |
# self.tokenizer = LlamaTokenizer.from_pretrained("decapoda-research/llama-7b-hf") # Load from GCS for production | |
self.request_queue = Queue() | |
self.response_queue = Queue() | |
self.models = {} # Dictionary to hold multiple models | |
self.load_models() | |
self.start_processing_processes() | |
def load_model_from_bucket(self, bucket_path): | |
blob = bucket.blob(bucket_path) | |
try: | |
model = Llama(model_path=blob.download_as_string(), **self.params) | |
return model | |
except Exception as e: | |
print(f"Error loading model: {e}") | |
return None | |
def load_models(self): | |
for name, path in MODEL_NAMES.items(): | |
model = self.load_model_from_bucket(path) | |
if model: | |
self.models[name] = model | |
def save_model_to_bucket(self, model, bucket_path): | |
blob = bucket.blob(bucket_path) | |
try: | |
blob.upload_from_string(model.save_pretrained(), content_type='application/octet-stream') | |
except Exception as e: | |
print(f"Error saving model: {e}") | |
def train_model(self): #This function needs a complete overhaul for production use. This is a placeholder. | |
config = LoraConfig(r=8, lora_alpha=32, target_modules=["q_proj", "v_proj"], lora_dropout=0.05, bias="none", task_type="CAUSAL_LM") | |
base_model_path = "llama-2-7b-chat/llama-2-7b-chat.Q4_K_M.gguf" | |
try: | |
base_model = self.load_model_from_bucket(base_model_path) | |
if base_model: | |
model = get_peft_model(base_model, config) | |
# Placeholder training data - needs a robust data loading mechanism | |
for batch in [{"question": ["a"], "answer":["b"]}, {"question":["c"], "answer":["d"]}]: | |
inputs = self.tokenizer(batch["question"], return_tensors="pt", padding=True, truncation=True) | |
labels = self.tokenizer(batch["answer"], return_tensors="pt", padding=True, truncation=True) | |
outputs = model(**inputs, labels=labels.input_ids) | |
loss = outputs.loss | |
loss.backward() | |
self.save_model_to_bucket(model, "llama_finetuned/llama_finetuned.gguf") | |
del model | |
del base_model | |
except Exception as e: | |
print(f"Error during training: {e}") | |
def generate_text(self, prompt, model_name): | |
if model_name in self.models: | |
model = self.models[model_name] | |
inputs = self.tokenizer(prompt, return_tensors="pt") | |
outputs = model.generate(**inputs, max_new_tokens=100) | |
generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return generated_text | |
else: | |
return "Error: Model not found." | |
def start_processing_processes(self): | |
p = Process(target=self.process_requests) | |
p.start() | |
def process_requests(self): | |
while True: | |
request_data = self.request_queue.get() | |
if request_data is None: | |
break | |
inputs, model_name, top_p, top_k, temperature, max_tokens = request_data | |
try: | |
response = self.generate_text(inputs, model_name) | |
self.response_queue.put(response) | |
except Exception as e: | |
print(f"Error during inference: {e}") | |
self.response_queue.put("Error generating text.") | |
model_manager = ModelManager() | |
class ChatRequest(BaseModel): | |
message: str | |
model_name: str | |
async def generate_streaming_response(inputs, model_name): | |
top_p = 0.9 | |
top_k = 50 | |
temperature = 0.7 | |
max_tokens = model_manager.params["n_ctx"] - len(model_manager.tokenizer.encode(inputs)) | |
model_manager.request_queue.put((inputs, model_name, top_p, top_k, temperature, max_tokens)) | |
full_text = model_manager.response_queue.get() | |
async def stream_response(): | |
yield full_text | |
return StreamingResponse(stream_response()) | |
async def process_message(message, model_name): | |
inputs = message.strip() | |
return await generate_streaming_response(inputs, model_name) | |
async def api_generate_multimodel(request: Request): | |
data = await request.json() | |
message = data["message"] | |
model_name = data.get("model_name", list(MODEL_NAMES.keys())[0]) | |
if model_name not in MODEL_NAMES: | |
return {"error": "Invalid model name"} | |
return await process_message(message, model_name) | |
iface = gr.Interface(fn=process_message, inputs=[gr.Textbox(lines=2, placeholder="Enter your message here..."), gr.Dropdown(list(MODEL_NAMES.keys()), label="Select Model")], outputs=gr.Markdown(stream=True), title="Unified Multi-Model API", description="Enter a message to get responses from a unified model.") #gradio is not suitable for production | |
if __name__ == "__main__": | |
iface.launch() |