Spaces:
Sleeping
Sleeping
import sys | |
sys.path.append(sys.path[0].replace('scripts', '')) | |
import os | |
from typing import List | |
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
from scripts.run import config, search_engine | |
from scripts.preprocess import preprocess_text | |
app = FastAPI( | |
title="Prompt Search API", | |
description="A RESTful API to find top-n most similar prompts.", | |
version="1.0.0" | |
) | |
class QueryRequest(BaseModel): | |
query: str | |
n_results: int = 5 | |
class SimilarQuery(BaseModel): | |
prompt: str | |
score: float | |
class QueryResponse(BaseModel): | |
query: str | |
similar_queries: List[SimilarQuery] | |
def root(): | |
return {"message": "Welcome to the Prompt Search API. Use '/search' endpoint to find similar prompts."} | |
def start_api_server(): | |
import uvicorn | |
port = int(os.getenv("PORT", 7861)) # HF Spaces uses 7860 by default | |
host = os.getenv("HOST", "0.0.0.0") | |
uvicorn.run(app, host=host, port=port, log_level="info") | |
async def search_prompts(query_request: QueryRequest): | |
""" | |
Accepts a query prompt and returns the top n similar prompts. | |
Args: | |
query_request: JSON input with query prompt and number of results to return. | |
Returns: | |
A list of top-n similar prompts with similarity scores. | |
""" | |
query = query_request.query | |
n_results = query_request.n_results | |
if not query.strip(): | |
raise HTTPException(status_code=400, detail="Query prompt cannot be empty.") | |
if n_results <= 0: | |
raise HTTPException(status_code=400, detail="Number of results must be greater than zero.") | |
try: | |
q = preprocess_text(query) | |
print(q) | |
results = search_engine.most_similar(q, n=n_results) | |
print("Results:", results) # Check if results have expected structure | |
result_dict = [{"prompt": r['prompt'], "score": float(r['score'])} for r in results] | |
return QueryResponse(query=query, similar_queries=result_dict) | |
# return [{"query": query, "similar_queries": results}] | |
except Exception as e: | |
print(e) | |
raise HTTPException(status_code=500, detail=str(e)) | |