Jokica17's picture
Added `run.py` script to initialize and run the Prompt Search Engine:
c19d1fd
raw
history blame
1.89 kB
import argparse
import uvicorn
from datasets import load_dataset
from app.api import app
from app.engine import PromptSearchEngine
from config import BACKEND_HOST, BACKEND_PORT
def download_and_prepare_dataset(limit: int = 3000) -> list:
"""
Download and prepare the dataset.
Args:
limit (int): Number of captions to include in the dataset.
Returns:
list: A list of captions.
"""
print("Downloading and preparing dataset...")
try:
ds = load_dataset("google-research-datasets/conceptual_captions", "unlabeled")
captions = [item["caption"] for item in ds["train"]][:limit] # TODO: Remove limit
print(f"Dataset prepared with {len(captions)} captions.")
return captions
except Exception as e:
print(f"Failed to download or prepare the dataset: {e}")
raise
def initialize_search_engine() -> PromptSearchEngine:
"""
Initialize the Prompt Search Engine by loading data and precomputing embeddings.
Returns:
PromptSearchEngine: An instance of the search engine initialized with the dataset.
"""
prompts = download_and_prepare_dataset()
search_engine = PromptSearchEngine(prompts)
print("Search engine initialized with precomputed embeddings.")
return search_engine
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run the Prompt Search Engine server.")
parser.add_argument(
"--dev", action="store_true", help="Run the server in development mode with reload enabled."
)
args = parser.parse_args()
engine = initialize_search_engine()
app.state.search_engine = engine
# Start the app
reload_server = args.dev
print(f"Starting server at {BACKEND_HOST}:{BACKEND_PORT} (debug: {reload_server})...")
uvicorn.run("app.api:app", host=BACKEND_HOST, port=BACKEND_PORT, reload=reload_server)