Spaces:
Running
Running
import argparse | |
import uvicorn | |
from datasets import load_dataset | |
from app.api import app | |
from app.engine import PromptSearchEngine | |
from config import BACKEND_HOST, BACKEND_PORT | |
def download_and_prepare_dataset(limit: int = 3000) -> list: | |
""" | |
Download and prepare the dataset. | |
Args: | |
limit (int): Number of captions to include in the dataset. | |
Returns: | |
list: A list of captions. | |
""" | |
print("Downloading and preparing dataset...") | |
try: | |
ds = load_dataset("google-research-datasets/conceptual_captions", "unlabeled") | |
captions = [item["caption"] for item in ds["train"]][:limit] # TODO: Remove limit | |
print(f"Dataset prepared with {len(captions)} captions.") | |
return captions | |
except Exception as e: | |
print(f"Failed to download or prepare the dataset: {e}") | |
raise | |
def initialize_search_engine() -> PromptSearchEngine: | |
""" | |
Initialize the Prompt Search Engine by loading data and precomputing embeddings. | |
Returns: | |
PromptSearchEngine: An instance of the search engine initialized with the dataset. | |
""" | |
prompts = download_and_prepare_dataset() | |
search_engine = PromptSearchEngine(prompts) | |
print("Search engine initialized with precomputed embeddings.") | |
return search_engine | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description="Run the Prompt Search Engine server.") | |
parser.add_argument( | |
"--dev", action="store_true", help="Run the server in development mode with reload enabled." | |
) | |
args = parser.parse_args() | |
engine = initialize_search_engine() | |
app.state.search_engine = engine | |
# Start the app | |
reload_server = args.dev | |
print(f"Starting server at {BACKEND_HOST}:{BACKEND_PORT} (debug: {reload_server})...") | |
uvicorn.run("app.api:app", host=BACKEND_HOST, port=BACKEND_PORT, reload=reload_server) | |