Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import logging | |
import os | |
from typing import List | |
import sys | |
import chromadb | |
from chromadb.utils import embedding_functions | |
from cashews import cache | |
from fastapi import FastAPI, HTTPException, Query | |
from fastapi.middleware.cors import CORSMiddleware | |
from pydantic import BaseModel | |
from contextlib import asynccontextmanager | |
import polars as pl | |
from huggingface_hub import HfApi | |
from transformers import AutoTokenizer | |
# Configuration constants | |
MODEL_NAME = "davanstrien/SmolLM2-360M-tldr-sft-2025-02-12_15-13" | |
EMBEDDING_MODEL = "nomic-ai/modernbert-embed-base" | |
BATCH_SIZE = 1000 | |
CACHE_TTL = "60" | |
hf_api = HfApi() | |
tokenizer = AutoTokenizer.from_pretrained( | |
"davanstrien/SmolLM2-360M-tldr-sft-2025-02-12_15-13" | |
) | |
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" # turn on HF_TRANSFER | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
LOCAL = False | |
if sys.platform == "darwin": | |
LOCAL = True | |
DATA_DIR = "data" if LOCAL else "/data" | |
# Configure cache | |
cache.setup("mem://", size_limit="4gb") | |
# Initialize ChromaDB client | |
client = chromadb.PersistentClient(path=f"{DATA_DIR}/chroma") | |
# Initialize FastAPI app | |
async def lifespan(app: FastAPI): | |
# Setup | |
setup_database() | |
yield | |
# Cleanup | |
await cache.close() | |
app = FastAPI(lifespan=lifespan) | |
# Add CORS middleware | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=[ | |
"https://*.hf.space", # Allow all Hugging Face Spaces | |
"https://*.huggingface.co", # Allow all Hugging Face domains | |
"http://localhost:5500", # Allow localhost:5500 # TODO remove before prod | |
], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Define the embedding function at module level | |
def get_embedding_function(): | |
return embedding_functions.SentenceTransformerEmbeddingFunction( | |
model_name="nomic-ai/modernbert-embed-base" | |
) | |
def setup_database(): | |
try: | |
embedding_function = get_embedding_function() | |
# Create dataset collection | |
dataset_collection = client.get_or_create_collection( | |
embedding_function=embedding_function, | |
name="dataset_cards", | |
metadata={"hnsw:space": "cosine"}, | |
) | |
# Create model collection | |
model_collection = client.get_or_create_collection( | |
embedding_function=embedding_function, | |
name="model_cards", | |
metadata={"hnsw:space": "cosine"}, | |
) | |
# TODO incremental updates | |
df = pl.scan_parquet( | |
"hf://datasets/davanstrien/datasets_with_metadata_and_summaries/data/train-*.parquet" | |
) | |
df = df.filter( | |
pl.col("datasetId").str.contains_any(["open-llm-leaderboard-old/"]).not_() | |
) | |
row_count = df.select(pl.len()).collect().item() | |
logger.info(f"Row count of new data: {row_count}") | |
if dataset_collection.count() < row_count: | |
# Load parquet files and upsert into ChromaDB | |
df = df.select( | |
["datasetId", "summary", "likes", "downloads", "last_modified"] | |
) | |
df = df.collect() | |
BATCH_SIZE = 1000 | |
total_rows = len(df) | |
for i in range(0, total_rows, BATCH_SIZE): | |
batch_df = df.slice(i, min(BATCH_SIZE, total_rows - i)) | |
dataset_collection.upsert( | |
ids=batch_df.select(["datasetId"]).to_series().to_list(), | |
documents=batch_df.select(["summary"]).to_series().to_list(), | |
metadatas=[ | |
{ | |
"likes": int(likes), | |
"downloads": int(downloads), | |
"last_modified": str(last_modified), | |
} | |
for likes, downloads, last_modified in zip( | |
batch_df.select(["likes"]).to_series().to_list(), | |
batch_df.select(["downloads"]).to_series().to_list(), | |
batch_df.select(["last_modified"]).to_series().to_list(), | |
) | |
], | |
) | |
logger.info(f"Processed {i + len(batch_df):,} / {total_rows:,} rows") | |
logger.info(f"Database initialized with {dataset_collection.count():,} rows") | |
# Load model data | |
model_df = pl.scan_parquet( | |
"hf://datasets/davanstrien/models_with_metadata_and_summaries/data/train-*.parquet" | |
) | |
model_row_count = model_df.select(pl.len()).collect().item() | |
logger.info(f"Row count of new model data: {model_row_count}") | |
if model_collection.count() < model_row_count: | |
model_df = model_df.select( | |
["modelId", "summary", "likes", "downloads", "last_modified"] | |
) | |
model_df = model_df.collect() | |
BATCH_SIZE = 1000 | |
total_rows = len(model_df) | |
for i in range(0, total_rows, BATCH_SIZE): | |
batch_df = model_df.slice(i, min(BATCH_SIZE, total_rows - i)) | |
model_collection.upsert( | |
ids=batch_df.select(["modelId"]).to_series().to_list(), | |
documents=batch_df.select(["summary"]).to_series().to_list(), | |
metadatas=[ | |
{ | |
"likes": int(likes), | |
"downloads": int(downloads), | |
"last_modified": str(last_modified), | |
} | |
for likes, downloads, last_modified in zip( | |
batch_df.select(["likes"]).to_series().to_list(), | |
batch_df.select(["downloads"]).to_series().to_list(), | |
batch_df.select(["last_modified"]).to_series().to_list(), | |
) | |
], | |
) | |
logger.info( | |
f"Processed {i + len(batch_df):,} / {total_rows:,} model rows" | |
) | |
logger.info( | |
f"Model database initialized with {model_collection.count():,} rows" | |
) | |
except Exception as e: | |
logger.error(f"Setup error: {e}") | |
# Run setup on startup | |
setup_database() | |
class QueryResult(BaseModel): | |
dataset_id: str | |
similarity: float | |
summary: str | |
likes: int | |
downloads: int | |
class QueryResponse(BaseModel): | |
results: List[QueryResult] | |
class ModelQueryResult(BaseModel): | |
model_id: str | |
similarity: float | |
summary: str | |
likes: int | |
downloads: int | |
class ModelQueryResponse(BaseModel): | |
results: List[ModelQueryResult] | |
async def redirect_to_docs(): | |
from fastapi.responses import RedirectResponse | |
return RedirectResponse(url="/docs") | |
async def search_datasets( | |
query: str, | |
k: int = Query(default=5, ge=1, le=100), | |
sort_by: str = Query( | |
default="similarity", enum=["similarity", "likes", "downloads"] | |
), | |
min_likes: int = Query(default=0, ge=0), | |
min_downloads: int = Query(default=0, ge=0), | |
): | |
try: | |
# Get collection with proper embedding function | |
collection = client.get_collection( | |
name="dataset_cards", embedding_function=get_embedding_function() | |
) | |
# Query ChromaDB | |
results = collection.query( | |
query_texts=[f"search_query: {query}"], | |
n_results=k * 4 if sort_by != "similarity" else k, | |
where={ | |
"$and": [ | |
{"likes": {"$gte": min_likes}}, | |
{"downloads": {"$gte": min_downloads}}, | |
] | |
} | |
if min_likes > 0 or min_downloads > 0 | |
else None, | |
) | |
# Process results | |
query_results = process_search_results(results, "dataset", k, sort_by) | |
return QueryResponse(results=query_results) | |
except Exception as e: | |
logger.error(f"Search error: {str(e)}") | |
raise HTTPException(status_code=500, detail="Search failed") | |
async def find_similar_datasets( | |
dataset_id: str, | |
k: int = Query(default=5, ge=1, le=100), | |
sort_by: str = Query( | |
default="similarity", enum=["similarity", "likes", "downloads"] | |
), | |
min_likes: int = Query(default=0, ge=0), | |
min_downloads: int = Query(default=0, ge=0), | |
): | |
try: | |
collection = client.get_collection("dataset_cards") | |
# Get the reference document | |
results = collection.get(ids=[dataset_id], include=["embeddings"]) | |
if not results["ids"]: | |
raise HTTPException( | |
status_code=404, detail=f"Dataset ID '{dataset_id}' not found" | |
) | |
# Query using the embedding | |
results = collection.query( | |
query_embeddings=[results["embeddings"][0]], | |
n_results=k * 4 | |
if sort_by != "similarity" | |
else k + 1, # +1 to account for self-match | |
where={ | |
"$and": [ | |
{"likes": {"$gte": min_likes}}, | |
{"downloads": {"$gte": min_downloads}}, | |
] | |
} | |
if min_likes > 0 or min_downloads > 0 | |
else None, | |
) | |
# Process results (excluding the query dataset itself) | |
query_results = process_search_results( | |
results, "dataset", k, sort_by, dataset_id | |
) | |
return QueryResponse(results=query_results) | |
except HTTPException: | |
raise | |
except Exception as e: | |
logger.error(f"Similarity search error: {str(e)}") | |
raise HTTPException(status_code=500, detail="Similarity search failed") | |
async def search_models( | |
query: str, | |
k: int = Query(default=5, ge=1, le=100), | |
sort_by: str = Query( | |
default="similarity", enum=["similarity", "likes", "downloads"] | |
), | |
min_likes: int = Query(default=0, ge=0), | |
min_downloads: int = Query(default=0, ge=0), | |
): | |
try: | |
collection = client.get_collection( | |
name="model_cards", embedding_function=get_embedding_function() | |
) | |
results = collection.query( | |
query_texts=[f"search_query: {query}"], | |
n_results=k * 4 if sort_by != "similarity" else k, | |
where={ | |
"$and": [ | |
{"likes": {"$gte": min_likes}}, | |
{"downloads": {"$gte": min_downloads}}, | |
] | |
} | |
if min_likes > 0 or min_downloads > 0 | |
else None, | |
) | |
query_results = process_search_results(results, "model", k, sort_by) | |
return ModelQueryResponse(results=query_results) | |
except Exception as e: | |
logger.error(f"Model search error: {str(e)}") | |
raise HTTPException(status_code=500, detail="Model search failed") | |
async def find_similar_models( | |
model_id: str, | |
k: int = Query(default=5, ge=1, le=100), | |
sort_by: str = Query( | |
default="similarity", enum=["similarity", "likes", "downloads"] | |
), | |
min_likes: int = Query(default=0, ge=0), | |
min_downloads: int = Query(default=0, ge=0), | |
): | |
try: | |
collection = client.get_collection("model_cards") | |
results = collection.get(ids=[model_id], include=["embeddings"]) | |
if not results["ids"]: | |
raise HTTPException( | |
status_code=404, detail=f"Model ID '{model_id}' not found" | |
) | |
results = collection.query( | |
query_embeddings=[results["embeddings"][0]], | |
n_results=k * 4 if sort_by != "similarity" else k + 1, | |
where={ | |
"$and": [ | |
{"likes": {"$gte": min_likes}}, | |
{"downloads": {"$gte": min_downloads}}, | |
] | |
} | |
if min_likes > 0 or min_downloads > 0 | |
else None, | |
) | |
query_results = process_search_results(results, "model", k, sort_by, model_id) | |
return ModelQueryResponse(results=query_results) | |
except HTTPException: | |
raise | |
except Exception as e: | |
logger.error(f"Model similarity search error: {str(e)}") | |
raise HTTPException(status_code=500, detail="Model similarity search failed") | |
def process_search_results(results, id_field, k, sort_by, exclude_id=None): | |
"""Process search results into a standardized format.""" | |
query_results = [] | |
for i in range(len(results["ids"][0])): | |
current_id = results["ids"][0][i] | |
if exclude_id and current_id == exclude_id: | |
continue | |
result = { | |
f"{id_field}_id": current_id, | |
"similarity": float(results["distances"][0][i]), | |
"summary": results["documents"][0][i], | |
"likes": results["metadatas"][0][i]["likes"], | |
"downloads": results["metadatas"][0][i]["downloads"], | |
} | |
if id_field == "dataset": | |
query_results.append(QueryResult(**result)) | |
else: | |
query_results.append(ModelQueryResult(**result)) | |
if sort_by != "similarity": | |
query_results.sort(key=lambda x: getattr(x, sort_by), reverse=True) | |
query_results = query_results[:k] | |
elif exclude_id: # We fetched extra for similarity + exclude_id case | |
query_results = query_results[:k] | |
return query_results | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=8000) | |