|
import fastapi |
|
from fastapi.middleware.cors import CORSMiddleware |
|
import chromadb |
|
from chromadb.utils import embedding_functions |
|
from utils import GAMES_DICT |
|
|
|
import os |
|
os.environ['TRANSFORMERS_CACHE'] = 'cache/' |
|
|
|
app = fastapi.FastAPI() |
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
@app.get("/") |
|
def hello_world(): |
|
return {"message": "Hello, World!"} |
|
|
|
@app.get("/{name}/{query}") |
|
def hello_world(name: str, query: str): |
|
|
|
if name in GAMES_DICT.keys(): |
|
chroma_client = chromadb.PersistentClient(path="Chromadb/") |
|
SentenceTransformerEmbeddings= embedding_functions.SentenceTransformerEmbeddingFunction(model_name="all-mpnet-base-v2") |
|
collection= chroma_client.get_collection("GameMaster", embedding_function=SentenceTransformerEmbeddings) |
|
|
|
results= collection.query( |
|
query_texts=[query], |
|
n_results=10, |
|
where= {"source": GAMES_DICT[name]}, |
|
include= [ "documents" ] |
|
) |
|
|
|
answer= results["documents"][0][0] |
|
|
|
return {"message": answer} |
|
|
|
else: |
|
return {"message": "Game not found"} |
|
|