New-Place / main.py
oflakne26's picture
Update main.py
d83f45c verified
raw
history blame
3.83 kB
from fastapi import FastAPI, HTTPException
from typing import Any, Dict
from pydantic import BaseModel
from os import getenv
from huggingface_hub import InferenceClient
import random
from json_repair import repair_json
import nltk
import json
import re
from word_forms.word_forms import get_word_forms
app = FastAPI()
nltk.download('punkt')
tokenizer = nltk.data.load('tokenizers/punkt/english.pickle')
HF_TOKEN = getenv("HF_TOKEN")
class InputData(BaseModel):
model: str
system_prompt_template: str
prompt_template: str
end_token: str
system_prompt: str
user_input: str
json_prompt: str
history: str = ""
class WordCheckData(BaseModel):
string: str
word: str
@app.post("/generate-response/")
async def generate_response(data: InputData) -> Dict[str, Any]:
client = InferenceClient(model=data.model, token=HF_TOKEN)
sentences = tokenizer.tokenize(data.user_input)
data_dict = {'###New response###': [], '###Sentence count###': 0}
for i, sentence in enumerate(sentences):
data_dict["###New response###"].append(sentence)
data_dict["###Sentence count###"] = i + 1
data.history += data.prompt_template.replace("{Prompt}", str(data_dict))
inputs = (
data.system_prompt_template.replace("{SystemPrompt}", data.system_prompt) +
data.system_prompt_template.replace("{SystemPrompt}", data.json_prompt) +
data.history
)
seed = random.randint(0, 2**32 - 1)
try:
response = client.text_generation(
inputs,
temperature=1.0,
max_new_tokens=1000,
seed=seed
)
strict_response = str(response)
repaired_response = repair_json(strict_response, return_objects=True)
if isinstance(repaired_response, str):
raise HTTPException(status_code=500, detail="Invalid response from model")
else:
cleaned_response = {}
for key, value in repaired_response.items():
cleaned_key = key.replace("###", "")
cleaned_response[cleaned_key] = value
strings = ""
for i, text in enumerate(cleaned_response["New response"]):
if i != len(cleaned_response["New response"]) - 1:
strings += text + " "
else:
strings += text
sentences = tokenizer.tokenize(strings)
cleaned_response["New response"] = sentences
if cleaned_response.get("Sentence count"):
if cleaned_response["Sentence count"] > 3:
cleaned_response["Sentence count"] = 3
else:
cleaned_response["Sentence count"] = len(cleaned_response["New response"])
data.history += str(cleaned_response)
return {
"response": cleaned_response,
"history": data.history + data.end_token
}
except Exception as e:
print(f"Model {data.model} failed with error: {e}")
raise HTTPException(status_code=500, detail=f"Model {data.model} failed to generate response")
@app.post("/check-word/")
async def check_word(data: WordCheckData) -> Dict[str, Any]:
input_string = data.string.lower()
word = data.word.lower()
forms = get_word_forms(word)
all_forms = set()
for words in forms.values():
all_forms.update(words)
# Split the input string into words using regular expression to handle spaces and punctuation
words_in_string = re.findall(r'\b\w+\b', input_string)
found = False
for word_in_string in words_in_string:
if word_in_string in all_forms:
found = True
break
result = {
"found": found
}
return result