File size: 1,889 Bytes
c19d1fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import argparse
import uvicorn
from datasets import load_dataset
from app.api import app
from app.engine import PromptSearchEngine
from config import BACKEND_HOST, BACKEND_PORT


def download_and_prepare_dataset(limit: int = 3000) -> list:
    """
    Download and prepare the dataset.
    Args:
        limit (int): Number of captions to include in the dataset.
    Returns:
        list: A list of captions.
    """
    print("Downloading and preparing dataset...")
    try:
        ds = load_dataset("google-research-datasets/conceptual_captions", "unlabeled")
        captions = [item["caption"] for item in ds["train"]][:limit]  # TODO: Remove limit
        print(f"Dataset prepared with {len(captions)} captions.")
        return captions
    except Exception as e:
        print(f"Failed to download or prepare the dataset: {e}")
        raise


def initialize_search_engine() -> PromptSearchEngine:
    """
    Initialize the Prompt Search Engine by loading data and precomputing embeddings.
    Returns:
        PromptSearchEngine: An instance of the search engine initialized with the dataset.
    """
    prompts = download_and_prepare_dataset()
    search_engine = PromptSearchEngine(prompts)
    print("Search engine initialized with precomputed embeddings.")
    return search_engine


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run the Prompt Search Engine server.")
    parser.add_argument(
        "--dev", action="store_true", help="Run the server in development mode with reload enabled."
    )
    args = parser.parse_args()

    engine = initialize_search_engine()
    app.state.search_engine = engine

    # Start the app
    reload_server = args.dev
    print(f"Starting server at {BACKEND_HOST}:{BACKEND_PORT} (debug: {reload_server})...")
    uvicorn.run("app.api:app", host=BACKEND_HOST, port=BACKEND_PORT, reload=reload_server)