Update main.py
Browse files
main.py
CHANGED
@@ -1,61 +1,94 @@
|
|
1 |
-
import
|
2 |
-
from
|
3 |
from pydantic import BaseModel
|
4 |
-
from
|
|
|
|
|
|
|
|
|
5 |
|
6 |
-
# Pydantic object for request validation
|
7 |
-
class Validation(BaseModel):
|
8 |
-
inputs: str
|
9 |
-
temperature: float = 0.0
|
10 |
-
max_new_tokens: int = 1048
|
11 |
-
top_p: float = 0.15
|
12 |
-
repetition_penalty: float = 1.0
|
13 |
-
|
14 |
-
# Initialize FastAPI app
|
15 |
app = FastAPI()
|
16 |
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
-
|
|
|
|
|
|
|
|
|
24 |
|
25 |
-
|
26 |
-
raise RuntimeError("No models found in the models directory")
|
27 |
|
28 |
-
|
29 |
|
30 |
-
|
31 |
-
for model_name in model_dirs:
|
32 |
-
model_path = os.path.join(model_base_path, model_name)
|
33 |
try:
|
34 |
-
|
35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
except Exception as e:
|
37 |
-
print(f"
|
38 |
-
|
39 |
-
|
40 |
-
# Function to get model dependency
|
41 |
-
def get_model(model_name: str):
|
42 |
-
if model_name not in models:
|
43 |
-
raise HTTPException(status_code=404, detail="Model not found")
|
44 |
-
return models[model_name]
|
45 |
-
|
46 |
-
# Create an endpoint for each model
|
47 |
-
for model_name in model_dirs:
|
48 |
-
@app.post(f"/{model_name}")
|
49 |
-
async def generate_response(item: Validation, model=Depends(lambda: get_model(model_name))):
|
50 |
-
try:
|
51 |
-
response = model(item.inputs,
|
52 |
-
temperature=item.temperature,
|
53 |
-
max_new_tokens=item.max_new_tokens,
|
54 |
-
top_p=item.top_p,
|
55 |
-
repetition_penalty=item.repetition_penalty)
|
56 |
-
return response
|
57 |
-
except Exception as e:
|
58 |
-
raise HTTPException(status_code=500, detail=str(e))
|
59 |
-
|
60 |
-
# Setup endpoints
|
61 |
-
setup_endpoints(app)
|
|
|
1 |
+
from fastapi import FastAPI, HTTPException
|
2 |
+
from typing import Any
|
3 |
from pydantic import BaseModel
|
4 |
+
from os import getenv
|
5 |
+
from huggingface_hub import InferenceClient
|
6 |
+
import random
|
7 |
+
from json_repair import repair_json
|
8 |
+
import nltk
|
9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
app = FastAPI()
|
11 |
|
12 |
+
nltk.download('punkt')
|
13 |
+
|
14 |
+
tokenizer = nltk.data.load('tokenizers/punkt/english.pickle')
|
15 |
+
|
16 |
+
HF_TOKEN = getenv("HF_TOKEN")
|
17 |
+
MODEL = "mistralai/Mistral-7B-Instruct-v0.2"
|
18 |
+
FALLBACK_MODELS = [
|
19 |
+
"mistralai/Mixtral-8x7B-Instruct-v0.1",
|
20 |
+
"mistralai/Mistral-7B-Instruct-v0.2", "mistralai/Mistral-7B-Instruct-v0.1"
|
21 |
+
]
|
22 |
+
|
23 |
+
class InputData(BaseModel):
|
24 |
+
model: str
|
25 |
+
system_prompt_template: str
|
26 |
+
prompt_template: str
|
27 |
+
system_prompt: str
|
28 |
+
user_input: str
|
29 |
+
json_prompt: str
|
30 |
+
history: str = ""
|
31 |
+
|
32 |
+
@app.post("/generate-response/")
|
33 |
+
async def generate_response(data: InputData) -> Any:
|
34 |
+
client = InferenceClient(model=data.model, token=HF_TOKEN)
|
35 |
+
|
36 |
+
sentences = tokenizer.tokenize(data.user_input)
|
37 |
+
data_dict = {'###New response###': [], '###Sentence count###': 0}
|
38 |
+
for i, sentence in enumerate(sentences):
|
39 |
+
data_dict["###New response###"].append(sentence)
|
40 |
+
data_dict["###Sentence count###"] = i + 1
|
41 |
+
|
42 |
+
data.history += data.prompt_template.replace("{Prompt}", str(data_dict))
|
43 |
|
44 |
+
inputs = (
|
45 |
+
data.system_prompt_template.replace("{SystemPrompt}",
|
46 |
+
data.system_prompt) +
|
47 |
+
data.system_prompt_template.replace("{SystemPrompt}", data.json_prompt) +
|
48 |
+
data.history)
|
49 |
|
50 |
+
seed = random.randint(0, 2**32 - 1)
|
|
|
51 |
|
52 |
+
models_to_try = [data.model] + FALLBACK_MODELS
|
53 |
|
54 |
+
for model in models_to_try:
|
|
|
|
|
55 |
try:
|
56 |
+
response = client.text_generation(inputs,
|
57 |
+
temperature=1.0,
|
58 |
+
max_new_tokens=1000,
|
59 |
+
seed=seed)
|
60 |
+
|
61 |
+
strict_response = str(response)
|
62 |
+
|
63 |
+
repaired_response = repair_json(strict_response,
|
64 |
+
return_objects=True)
|
65 |
+
|
66 |
+
if isinstance(repaired_response, str):
|
67 |
+
raise HTTPException(status_code=500, detail="Invalid response from model")
|
68 |
+
else:
|
69 |
+
cleaned_response = {}
|
70 |
+
for key, value in repaired_response.items():
|
71 |
+
cleaned_key = key.replace("###", "")
|
72 |
+
cleaned_response[cleaned_key] = value
|
73 |
+
|
74 |
+
for i, text in enumerate(cleaned_response["New response"]):
|
75 |
+
if i <= 2:
|
76 |
+
sentences = tokenizer.tokenize(text)
|
77 |
+
if sentences:
|
78 |
+
cleaned_response["New response"][i] = sentences[0]
|
79 |
+
else:
|
80 |
+
del cleaned_response["New response"][i]
|
81 |
+
if cleaned_response.get("Sentence count"):
|
82 |
+
if cleaned_response["Sentence count"] > 3:
|
83 |
+
cleaned_response["Sentence count"] = 3
|
84 |
+
else:
|
85 |
+
cleaned_response["Sentence count"] = len(cleaned_response["New response"])
|
86 |
+
|
87 |
+
data.history += str(cleaned_response)
|
88 |
+
|
89 |
+
return cleaned_response
|
90 |
+
|
91 |
except Exception as e:
|
92 |
+
print(f"Model {model} failed with error: {e}")
|
93 |
+
|
94 |
+
raise HTTPException(status_code=500, detail="All models failed to generate response")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|