import sys sys.path.append(sys.path[0].replace('scripts', '')) import os from typing import List from fastapi import FastAPI, HTTPException from pydantic import BaseModel from scripts.run import config, search_engine from scripts.preprocess import preprocess_text app = FastAPI( title="Prompt Search API", description="A RESTful API to find top-n most similar prompts.", version="1.0.0" ) class QueryRequest(BaseModel): query: str n_results: int = 5 class SimilarQuery(BaseModel): prompt: str score: float class QueryResponse(BaseModel): query: str similar_queries: List[SimilarQuery] @app.get("/") def root(): return {"message": "Welcome to the Prompt Search API. Use '/search' endpoint to find similar prompts."} def start_api_server(): import uvicorn port = int(os.getenv("PORT", 7861)) # HF Spaces uses 7860 by default host = os.getenv("HOST", "0.0.0.0") uvicorn.run(app, host=host, port=port, log_level="info") @app.post("/search", response_model=QueryResponse) async def search_prompts(query_request: QueryRequest): """ Accepts a query prompt and returns the top n similar prompts. Args: query_request: JSON input with query prompt and number of results to return. Returns: A list of top-n similar prompts with similarity scores. """ query = query_request.query n_results = query_request.n_results if not query.strip(): raise HTTPException(status_code=400, detail="Query prompt cannot be empty.") if n_results <= 0: raise HTTPException(status_code=400, detail="Number of results must be greater than zero.") try: q = preprocess_text(query) print(q) results = search_engine.most_similar(q, n=n_results) print("Results:", results) # Check if results have expected structure result_dict = [{"prompt": r['prompt'], "score": float(r['score'])} for r in results] return QueryResponse(query=query, similar_queries=result_dict) # return [{"query": query, "similar_queries": results}] except Exception as e: print(e) raise HTTPException(status_code=500, detail=str(e))