alexpantex's picture
Upload scripts/api.py with huggingface_hub
37a1625 verified
import sys
sys.path.append(sys.path[0].replace('scripts', ''))
import os
from typing import List
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from scripts.run import config, search_engine
from scripts.preprocess import preprocess_text
app = FastAPI(
title="Prompt Search API",
description="A RESTful API to find top-n most similar prompts.",
version="1.0.0"
)
class QueryRequest(BaseModel):
query: str
n_results: int = 5
class SimilarQuery(BaseModel):
prompt: str
score: float
class QueryResponse(BaseModel):
query: str
similar_queries: List[SimilarQuery]
@app.get("/")
def root():
return {"message": "Welcome to the Prompt Search API. Use '/search' endpoint to find similar prompts."}
def start_api_server():
import uvicorn
port = int(os.getenv("PORT", 7861)) # HF Spaces uses 7860 by default
host = os.getenv("HOST", "0.0.0.0")
uvicorn.run(app, host=host, port=port, log_level="info")
@app.post("/search", response_model=QueryResponse)
async def search_prompts(query_request: QueryRequest):
"""
Accepts a query prompt and returns the top n similar prompts.
Args:
query_request: JSON input with query prompt and number of results to return.
Returns:
A list of top-n similar prompts with similarity scores.
"""
query = query_request.query
n_results = query_request.n_results
if not query.strip():
raise HTTPException(status_code=400, detail="Query prompt cannot be empty.")
if n_results <= 0:
raise HTTPException(status_code=400, detail="Number of results must be greater than zero.")
try:
q = preprocess_text(query)
print(q)
results = search_engine.most_similar(q, n=n_results)
print("Results:", results) # Check if results have expected structure
result_dict = [{"prompt": r['prompt'], "score": float(r['score'])} for r in results]
return QueryResponse(query=query, similar_queries=result_dict)
# return [{"query": query, "similar_queries": results}]
except Exception as e:
print(e)
raise HTTPException(status_code=500, detail=str(e))