davanstrien's picture
davanstrien HF staff
increase batch size from 1000 to 2000
5bf5966
raw
history blame
14.2 kB
import logging
import os
from typing import List
import sys
import chromadb
from chromadb.utils import embedding_functions
from cashews import cache
from fastapi import FastAPI, HTTPException, Query
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from contextlib import asynccontextmanager
import polars as pl
from huggingface_hub import HfApi
from transformers import AutoTokenizer
import torch
# Configuration constants
MODEL_NAME = "davanstrien/SmolLM2-360M-tldr-sft-2025-02-12_15-13"
EMBEDDING_MODEL = "nomic-ai/modernbert-embed-base"
BATCH_SIZE = 2000
CACHE_TTL = "60"
if torch.cuda.is_available():
DEVICE = "cuda"
elif torch.backends.mps.is_available():
DEVICE = "mps"
else:
DEVICE = "cpu"
hf_api = HfApi()
tokenizer = AutoTokenizer.from_pretrained(
"davanstrien/SmolLM2-360M-tldr-sft-2025-02-12_15-13"
)
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" # turn on HF_TRANSFER
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
LOCAL = False
if sys.platform == "darwin":
LOCAL = True
DATA_DIR = "data" if LOCAL else "/data"
# Configure cache
cache.setup("mem://", size_limit="5gb")
# Initialize ChromaDB client
client = chromadb.PersistentClient(path=f"{DATA_DIR}/chroma")
# Initialize FastAPI app
@asynccontextmanager
async def lifespan(app: FastAPI):
# Setup
setup_database()
yield
# Cleanup
await cache.close()
app = FastAPI(lifespan=lifespan)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=[
"https://*.hf.space", # Allow all Hugging Face Spaces
"https://*.huggingface.co", # Allow all Hugging Face domains
"http://localhost:5500", # Allow localhost:5500 # TODO remove before prod
],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Define the embedding function at module level
def get_embedding_function():
logger.info(f"Using device: {DEVICE}")
return embedding_functions.SentenceTransformerEmbeddingFunction(
model_name="nomic-ai/modernbert-embed-base", device=DEVICE
)
def setup_database():
try:
embedding_function = get_embedding_function()
# Create dataset collection
dataset_collection = client.get_or_create_collection(
embedding_function=embedding_function,
name="dataset_cards",
metadata={"hnsw:space": "cosine"},
)
# Create model collection
model_collection = client.get_or_create_collection(
embedding_function=embedding_function,
name="model_cards",
metadata={"hnsw:space": "cosine"},
)
# Load dataset data
df = pl.scan_parquet(
"hf://datasets/davanstrien/datasets_with_metadata_and_summaries/data/train-*.parquet"
)
df = df.filter(
pl.col("datasetId").str.contains_any(["open-llm-leaderboard-old/"]).not_()
)
row_count = df.select(pl.len()).collect().item()
logger.info(f"Row count of dataset data: {row_count}")
# Check if we need to update the collection
current_count = dataset_collection.count()
logger.info(f"Current dataset collection count: {current_count}")
if current_count < row_count:
logger.info(
f"Updating dataset collection with {row_count - current_count} new records"
)
# Load parquet files and upsert into ChromaDB
df = df.select(
["datasetId", "summary", "likes", "downloads", "last_modified"]
)
df = df.collect()
total_rows = len(df)
for i in range(0, total_rows, BATCH_SIZE):
batch_df = df.slice(i, min(BATCH_SIZE, total_rows - i))
dataset_collection.upsert(
ids=batch_df.select(["datasetId"]).to_series().to_list(),
documents=batch_df.select(["summary"]).to_series().to_list(),
metadatas=[
{
"likes": int(likes),
"downloads": int(downloads),
"last_modified": str(last_modified),
}
for likes, downloads, last_modified in zip(
batch_df.select(["likes"]).to_series().to_list(),
batch_df.select(["downloads"]).to_series().to_list(),
batch_df.select(["last_modified"]).to_series().to_list(),
)
],
)
logger.info(f"Processed {i + len(batch_df):,} / {total_rows:,} rows")
logger.info(f"Database initialized with {dataset_collection.count():,} rows")
# Load model data
model_df = pl.scan_parquet(
"hf://datasets/davanstrien/models_with_metadata_and_summaries/data/train-*.parquet"
)
model_row_count = model_df.select(pl.len()).collect().item()
logger.info(f"Row count of new model data: {model_row_count}")
if model_collection.count() < model_row_count:
model_df = model_df.select(
["modelId", "summary", "likes", "downloads", "last_modified"]
)
model_df = model_df.collect()
total_rows = len(model_df)
for i in range(0, total_rows, BATCH_SIZE):
batch_df = model_df.slice(i, min(BATCH_SIZE, total_rows - i))
model_collection.upsert(
ids=batch_df.select(["modelId"]).to_series().to_list(),
documents=batch_df.select(["summary"]).to_series().to_list(),
metadatas=[
{
"likes": int(likes),
"downloads": int(downloads),
"last_modified": str(last_modified),
}
for likes, downloads, last_modified in zip(
batch_df.select(["likes"]).to_series().to_list(),
batch_df.select(["downloads"]).to_series().to_list(),
batch_df.select(["last_modified"]).to_series().to_list(),
)
],
)
logger.info(
f"Processed {i + len(batch_df):,} / {total_rows:,} model rows"
)
logger.info(
f"Model database initialized with {model_collection.count():,} rows"
)
except Exception as e:
logger.error(f"Setup error: {e}")
# Run setup on startup
setup_database()
class QueryResult(BaseModel):
dataset_id: str
similarity: float
summary: str
likes: int
downloads: int
class QueryResponse(BaseModel):
results: List[QueryResult]
class ModelQueryResult(BaseModel):
model_id: str
similarity: float
summary: str
likes: int
downloads: int
class ModelQueryResponse(BaseModel):
results: List[ModelQueryResult]
@app.get("/")
async def redirect_to_docs():
from fastapi.responses import RedirectResponse
return RedirectResponse(url="/docs")
@app.get("/search/datasets", response_model=QueryResponse)
@cache(ttl=CACHE_TTL)
async def search_datasets(
query: str,
k: int = Query(default=5, ge=1, le=100),
sort_by: str = Query(
default="similarity", enum=["similarity", "likes", "downloads"]
),
min_likes: int = Query(default=0, ge=0),
min_downloads: int = Query(default=0, ge=0),
):
try:
# Get collection with proper embedding function
collection = client.get_collection(
name="dataset_cards", embedding_function=get_embedding_function()
)
# Query ChromaDB
results = collection.query(
query_texts=[f"search_query: {query}"],
n_results=k * 4 if sort_by != "similarity" else k,
where={
"$and": [
{"likes": {"$gte": min_likes}},
{"downloads": {"$gte": min_downloads}},
]
}
if min_likes > 0 or min_downloads > 0
else None,
)
# Process results
query_results = process_search_results(results, "dataset", k, sort_by)
return QueryResponse(results=query_results)
except Exception as e:
logger.error(f"Search error: {str(e)}")
raise HTTPException(status_code=500, detail="Search failed")
@app.get("/similarity/datasets", response_model=QueryResponse)
@cache(ttl=CACHE_TTL)
async def find_similar_datasets(
dataset_id: str,
k: int = Query(default=5, ge=1, le=100),
sort_by: str = Query(
default="similarity", enum=["similarity", "likes", "downloads"]
),
min_likes: int = Query(default=0, ge=0),
min_downloads: int = Query(default=0, ge=0),
):
try:
collection = client.get_collection("dataset_cards")
# Get the reference document
results = collection.get(ids=[dataset_id], include=["embeddings"])
if not results["ids"]:
raise HTTPException(
status_code=404, detail=f"Dataset ID '{dataset_id}' not found"
)
# Query using the embedding
results = collection.query(
query_embeddings=[results["embeddings"][0]],
n_results=k * 4
if sort_by != "similarity"
else k + 1, # +1 to account for self-match
where={
"$and": [
{"likes": {"$gte": min_likes}},
{"downloads": {"$gte": min_downloads}},
]
}
if min_likes > 0 or min_downloads > 0
else None,
)
# Process results (excluding the query dataset itself)
query_results = process_search_results(
results, "dataset", k, sort_by, dataset_id
)
return QueryResponse(results=query_results)
except HTTPException:
raise
except Exception as e:
logger.error(f"Similarity search error: {str(e)}")
raise HTTPException(status_code=500, detail="Similarity search failed")
@app.get("/search/models", response_model=ModelQueryResponse)
@cache(ttl=CACHE_TTL)
async def search_models(
query: str,
k: int = Query(default=5, ge=1, le=100),
sort_by: str = Query(
default="similarity", enum=["similarity", "likes", "downloads"]
),
min_likes: int = Query(default=0, ge=0),
min_downloads: int = Query(default=0, ge=0),
):
try:
collection = client.get_collection(
name="model_cards", embedding_function=get_embedding_function()
)
results = collection.query(
query_texts=[f"search_query: {query}"],
n_results=k * 4 if sort_by != "similarity" else k,
where={
"$and": [
{"likes": {"$gte": min_likes}},
{"downloads": {"$gte": min_downloads}},
]
}
if min_likes > 0 or min_downloads > 0
else None,
)
query_results = process_search_results(results, "model", k, sort_by)
return ModelQueryResponse(results=query_results)
except Exception as e:
logger.error(f"Model search error: {str(e)}")
raise HTTPException(status_code=500, detail="Model search failed")
@app.get("/similarity/models", response_model=ModelQueryResponse)
@cache(ttl=CACHE_TTL)
async def find_similar_models(
model_id: str,
k: int = Query(default=5, ge=1, le=100),
sort_by: str = Query(
default="similarity", enum=["similarity", "likes", "downloads"]
),
min_likes: int = Query(default=0, ge=0),
min_downloads: int = Query(default=0, ge=0),
):
try:
collection = client.get_collection("model_cards")
results = collection.get(ids=[model_id], include=["embeddings"])
if not results["ids"]:
raise HTTPException(
status_code=404, detail=f"Model ID '{model_id}' not found"
)
results = collection.query(
query_embeddings=[results["embeddings"][0]],
n_results=k * 4 if sort_by != "similarity" else k + 1,
where={
"$and": [
{"likes": {"$gte": min_likes}},
{"downloads": {"$gte": min_downloads}},
]
}
if min_likes > 0 or min_downloads > 0
else None,
)
query_results = process_search_results(results, "model", k, sort_by, model_id)
return ModelQueryResponse(results=query_results)
except HTTPException:
raise
except Exception as e:
logger.error(f"Model similarity search error: {str(e)}")
raise HTTPException(status_code=500, detail="Model similarity search failed")
def process_search_results(results, id_field, k, sort_by, exclude_id=None):
"""Process search results into a standardized format."""
query_results = []
for i in range(len(results["ids"][0])):
current_id = results["ids"][0][i]
if exclude_id and current_id == exclude_id:
continue
result = {
f"{id_field}_id": current_id,
"similarity": float(results["distances"][0][i]),
"summary": results["documents"][0][i],
"likes": results["metadatas"][0][i]["likes"],
"downloads": results["metadatas"][0][i]["downloads"],
}
if id_field == "dataset":
query_results.append(QueryResult(**result))
else:
query_results.append(ModelQueryResult(**result))
if sort_by != "similarity":
query_results.sort(key=lambda x: getattr(x, sort_by), reverse=True)
query_results = query_results[:k]
elif exclude_id: # We fetched extra for similarity + exclude_id case
query_results = query_results[:k]
return query_results
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)