Jokica17 commited on
Commit
c19d1fd
·
1 Parent(s): 15b0ad3

Added `run.py` script to initialize and run the Prompt Search Engine:

Browse files

- integrated dataset download and preparation logic with `load_dataset` for captions
- initialized `PromptSearchEngine` with precomputed embeddings
- configured server startup with `uvicorn`, including support for development mode (`--dev`) with hot-reloading

Files changed (1) hide show
  1. run.py +53 -0
run.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import uvicorn
3
+ from datasets import load_dataset
4
+ from app.api import app
5
+ from app.engine import PromptSearchEngine
6
+ from config import BACKEND_HOST, BACKEND_PORT
7
+
8
+
9
+ def download_and_prepare_dataset(limit: int = 3000) -> list:
10
+ """
11
+ Download and prepare the dataset.
12
+ Args:
13
+ limit (int): Number of captions to include in the dataset.
14
+ Returns:
15
+ list: A list of captions.
16
+ """
17
+ print("Downloading and preparing dataset...")
18
+ try:
19
+ ds = load_dataset("google-research-datasets/conceptual_captions", "unlabeled")
20
+ captions = [item["caption"] for item in ds["train"]][:limit] # TODO: Remove limit
21
+ print(f"Dataset prepared with {len(captions)} captions.")
22
+ return captions
23
+ except Exception as e:
24
+ print(f"Failed to download or prepare the dataset: {e}")
25
+ raise
26
+
27
+
28
+ def initialize_search_engine() -> PromptSearchEngine:
29
+ """
30
+ Initialize the Prompt Search Engine by loading data and precomputing embeddings.
31
+ Returns:
32
+ PromptSearchEngine: An instance of the search engine initialized with the dataset.
33
+ """
34
+ prompts = download_and_prepare_dataset()
35
+ search_engine = PromptSearchEngine(prompts)
36
+ print("Search engine initialized with precomputed embeddings.")
37
+ return search_engine
38
+
39
+
40
+ if __name__ == "__main__":
41
+ parser = argparse.ArgumentParser(description="Run the Prompt Search Engine server.")
42
+ parser.add_argument(
43
+ "--dev", action="store_true", help="Run the server in development mode with reload enabled."
44
+ )
45
+ args = parser.parse_args()
46
+
47
+ engine = initialize_search_engine()
48
+ app.state.search_engine = engine
49
+
50
+ # Start the app
51
+ reload_server = args.dev
52
+ print(f"Starting server at {BACKEND_HOST}:{BACKEND_PORT} (debug: {reload_server})...")
53
+ uvicorn.run("app.api:app", host=BACKEND_HOST, port=BACKEND_PORT, reload=reload_server)