File size: 2,188 Bytes
f4e126c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2cd4381
 
37a1625
b9972a2
290938c
2cd4381
f4e126c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import sys
sys.path.append(sys.path[0].replace('scripts', ''))
import os
from typing import List

from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from scripts.run import config, search_engine
from scripts.preprocess import preprocess_text

app = FastAPI(
    title="Prompt Search API",
    description="A RESTful API to find top-n most similar prompts.",
    version="1.0.0"
)

class QueryRequest(BaseModel):
    query: str
    n_results: int = 5

class SimilarQuery(BaseModel):
    prompt: str
    score: float

class QueryResponse(BaseModel):
    query: str
    similar_queries: List[SimilarQuery]

@app.get("/")
def root():
    return {"message": "Welcome to the Prompt Search API. Use '/search' endpoint to find similar prompts."}

def start_api_server():
    import uvicorn
    port = int(os.getenv("PORT", 7861))  # HF Spaces uses 7860 by default
    host = os.getenv("HOST", "0.0.0.0")
    uvicorn.run(app, host=host, port=port, log_level="info")

@app.post("/search", response_model=QueryResponse)
async def search_prompts(query_request: QueryRequest):
    """
    Accepts a query prompt and returns the top n similar prompts.
    Args:
        query_request: JSON input with query prompt and number of results to return.
    Returns:
        A list of top-n similar prompts with similarity scores.
    """

    query = query_request.query
    n_results = query_request.n_results

    if not query.strip():
        raise HTTPException(status_code=400, detail="Query prompt cannot be empty.")
    if n_results <= 0:
        raise HTTPException(status_code=400, detail="Number of results must be greater than zero.")
    
    try:
        q = preprocess_text(query)
        print(q)
        results = search_engine.most_similar(q, n=n_results)

        print("Results:", results)  # Check if results have expected structure
        result_dict = [{"prompt": r['prompt'], "score": float(r['score'])} for r in results]
        return QueryResponse(query=query, similar_queries=result_dict)
        # return [{"query": query, "similar_queries": results}]
    except Exception as e:
        print(e)
        raise HTTPException(status_code=500, detail=str(e))