davanstrien's picture
davanstrien HF staff
longer cache time
b04bcc7
raw
history blame
13.7 kB
import logging
import os
from typing import List
import sys
import chromadb
from chromadb.utils import embedding_functions
from cashews import cache
from fastapi import FastAPI, HTTPException, Query
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from contextlib import asynccontextmanager
import polars as pl
from huggingface_hub import HfApi
from transformers import AutoTokenizer
# Configuration constants
MODEL_NAME = "davanstrien/SmolLM2-360M-tldr-sft-2025-02-12_15-13"
EMBEDDING_MODEL = "nomic-ai/modernbert-embed-base"
BATCH_SIZE = 1000
CACHE_TTL = "60"
hf_api = HfApi()
tokenizer = AutoTokenizer.from_pretrained(
"davanstrien/SmolLM2-360M-tldr-sft-2025-02-12_15-13"
)
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" # turn on HF_TRANSFER
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
LOCAL = False
if sys.platform == "darwin":
LOCAL = True
DATA_DIR = "data" if LOCAL else "/data"
# Configure cache
cache.setup("mem://", size_limit="4gb")
# Initialize ChromaDB client
client = chromadb.PersistentClient(path=f"{DATA_DIR}/chroma")
# Initialize FastAPI app
@asynccontextmanager
async def lifespan(app: FastAPI):
# Setup
setup_database()
yield
# Cleanup
await cache.close()
app = FastAPI(lifespan=lifespan)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=[
"https://*.hf.space", # Allow all Hugging Face Spaces
"https://*.huggingface.co", # Allow all Hugging Face domains
"http://localhost:5500", # Allow localhost:5500 # TODO remove before prod
],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Define the embedding function at module level
def get_embedding_function():
return embedding_functions.SentenceTransformerEmbeddingFunction(
model_name="nomic-ai/modernbert-embed-base"
)
def setup_database():
try:
embedding_function = get_embedding_function()
# Create dataset collection
dataset_collection = client.get_or_create_collection(
embedding_function=embedding_function,
name="dataset_cards",
metadata={"hnsw:space": "cosine"},
)
# Create model collection
model_collection = client.get_or_create_collection(
embedding_function=embedding_function,
name="model_cards",
metadata={"hnsw:space": "cosine"},
)
# TODO incremental updates
df = pl.scan_parquet(
"hf://datasets/davanstrien/datasets_with_metadata_and_summaries/data/train-*.parquet"
)
df = df.filter(
pl.col("datasetId").str.contains_any(["open-llm-leaderboard-old/"]).not_()
)
row_count = df.select(pl.len()).collect().item()
logger.info(f"Row count of new data: {row_count}")
if dataset_collection.count() < row_count:
# Load parquet files and upsert into ChromaDB
df = df.select(
["datasetId", "summary", "likes", "downloads", "last_modified"]
)
df = df.collect()
BATCH_SIZE = 1000
total_rows = len(df)
for i in range(0, total_rows, BATCH_SIZE):
batch_df = df.slice(i, min(BATCH_SIZE, total_rows - i))
dataset_collection.upsert(
ids=batch_df.select(["datasetId"]).to_series().to_list(),
documents=batch_df.select(["summary"]).to_series().to_list(),
metadatas=[
{
"likes": int(likes),
"downloads": int(downloads),
"last_modified": str(last_modified),
}
for likes, downloads, last_modified in zip(
batch_df.select(["likes"]).to_series().to_list(),
batch_df.select(["downloads"]).to_series().to_list(),
batch_df.select(["last_modified"]).to_series().to_list(),
)
],
)
logger.info(f"Processed {i + len(batch_df):,} / {total_rows:,} rows")
logger.info(f"Database initialized with {dataset_collection.count():,} rows")
# Load model data
model_df = pl.scan_parquet(
"hf://datasets/davanstrien/models_with_metadata_and_summaries/data/train-*.parquet"
)
model_row_count = model_df.select(pl.len()).collect().item()
logger.info(f"Row count of new model data: {model_row_count}")
if model_collection.count() < model_row_count:
model_df = model_df.select(
["modelId", "summary", "likes", "downloads", "last_modified"]
)
model_df = model_df.collect()
BATCH_SIZE = 1000
total_rows = len(model_df)
for i in range(0, total_rows, BATCH_SIZE):
batch_df = model_df.slice(i, min(BATCH_SIZE, total_rows - i))
model_collection.upsert(
ids=batch_df.select(["modelId"]).to_series().to_list(),
documents=batch_df.select(["summary"]).to_series().to_list(),
metadatas=[
{
"likes": int(likes),
"downloads": int(downloads),
"last_modified": str(last_modified),
}
for likes, downloads, last_modified in zip(
batch_df.select(["likes"]).to_series().to_list(),
batch_df.select(["downloads"]).to_series().to_list(),
batch_df.select(["last_modified"]).to_series().to_list(),
)
],
)
logger.info(
f"Processed {i + len(batch_df):,} / {total_rows:,} model rows"
)
logger.info(
f"Model database initialized with {model_collection.count():,} rows"
)
except Exception as e:
logger.error(f"Setup error: {e}")
# Run setup on startup
setup_database()
class QueryResult(BaseModel):
dataset_id: str
similarity: float
summary: str
likes: int
downloads: int
class QueryResponse(BaseModel):
results: List[QueryResult]
class ModelQueryResult(BaseModel):
model_id: str
similarity: float
summary: str
likes: int
downloads: int
class ModelQueryResponse(BaseModel):
results: List[ModelQueryResult]
@app.get("/")
async def redirect_to_docs():
from fastapi.responses import RedirectResponse
return RedirectResponse(url="/docs")
@app.get("/search/datasets", response_model=QueryResponse)
@cache(ttl=CACHE_TTL)
async def search_datasets(
query: str,
k: int = Query(default=5, ge=1, le=100),
sort_by: str = Query(
default="similarity", enum=["similarity", "likes", "downloads"]
),
min_likes: int = Query(default=0, ge=0),
min_downloads: int = Query(default=0, ge=0),
):
try:
# Get collection with proper embedding function
collection = client.get_collection(
name="dataset_cards", embedding_function=get_embedding_function()
)
# Query ChromaDB
results = collection.query(
query_texts=[f"search_query: {query}"],
n_results=k * 4 if sort_by != "similarity" else k,
where={
"$and": [
{"likes": {"$gte": min_likes}},
{"downloads": {"$gte": min_downloads}},
]
}
if min_likes > 0 or min_downloads > 0
else None,
)
# Process results
query_results = process_search_results(results, "dataset", k, sort_by)
return QueryResponse(results=query_results)
except Exception as e:
logger.error(f"Search error: {str(e)}")
raise HTTPException(status_code=500, detail="Search failed")
@app.get("/similarity/datasets", response_model=QueryResponse)
@cache(ttl=CACHE_TTL)
async def find_similar_datasets(
dataset_id: str,
k: int = Query(default=5, ge=1, le=100),
sort_by: str = Query(
default="similarity", enum=["similarity", "likes", "downloads"]
),
min_likes: int = Query(default=0, ge=0),
min_downloads: int = Query(default=0, ge=0),
):
try:
collection = client.get_collection("dataset_cards")
# Get the reference document
results = collection.get(ids=[dataset_id], include=["embeddings"])
if not results["ids"]:
raise HTTPException(
status_code=404, detail=f"Dataset ID '{dataset_id}' not found"
)
# Query using the embedding
results = collection.query(
query_embeddings=[results["embeddings"][0]],
n_results=k * 4
if sort_by != "similarity"
else k + 1, # +1 to account for self-match
where={
"$and": [
{"likes": {"$gte": min_likes}},
{"downloads": {"$gte": min_downloads}},
]
}
if min_likes > 0 or min_downloads > 0
else None,
)
# Process results (excluding the query dataset itself)
query_results = process_search_results(
results, "dataset", k, sort_by, dataset_id
)
return QueryResponse(results=query_results)
except HTTPException:
raise
except Exception as e:
logger.error(f"Similarity search error: {str(e)}")
raise HTTPException(status_code=500, detail="Similarity search failed")
@app.get("/search/models", response_model=ModelQueryResponse)
@cache(ttl=CACHE_TTL)
async def search_models(
query: str,
k: int = Query(default=5, ge=1, le=100),
sort_by: str = Query(
default="similarity", enum=["similarity", "likes", "downloads"]
),
min_likes: int = Query(default=0, ge=0),
min_downloads: int = Query(default=0, ge=0),
):
try:
collection = client.get_collection(
name="model_cards", embedding_function=get_embedding_function()
)
results = collection.query(
query_texts=[f"search_query: {query}"],
n_results=k * 4 if sort_by != "similarity" else k,
where={
"$and": [
{"likes": {"$gte": min_likes}},
{"downloads": {"$gte": min_downloads}},
]
}
if min_likes > 0 or min_downloads > 0
else None,
)
query_results = process_search_results(results, "model", k, sort_by)
return ModelQueryResponse(results=query_results)
except Exception as e:
logger.error(f"Model search error: {str(e)}")
raise HTTPException(status_code=500, detail="Model search failed")
@app.get("/similarity/models", response_model=ModelQueryResponse)
@cache(ttl=CACHE_TTL)
async def find_similar_models(
model_id: str,
k: int = Query(default=5, ge=1, le=100),
sort_by: str = Query(
default="similarity", enum=["similarity", "likes", "downloads"]
),
min_likes: int = Query(default=0, ge=0),
min_downloads: int = Query(default=0, ge=0),
):
try:
collection = client.get_collection("model_cards")
results = collection.get(ids=[model_id], include=["embeddings"])
if not results["ids"]:
raise HTTPException(
status_code=404, detail=f"Model ID '{model_id}' not found"
)
results = collection.query(
query_embeddings=[results["embeddings"][0]],
n_results=k * 4 if sort_by != "similarity" else k + 1,
where={
"$and": [
{"likes": {"$gte": min_likes}},
{"downloads": {"$gte": min_downloads}},
]
}
if min_likes > 0 or min_downloads > 0
else None,
)
query_results = process_search_results(results, "model", k, sort_by, model_id)
return ModelQueryResponse(results=query_results)
except HTTPException:
raise
except Exception as e:
logger.error(f"Model similarity search error: {str(e)}")
raise HTTPException(status_code=500, detail="Model similarity search failed")
def process_search_results(results, id_field, k, sort_by, exclude_id=None):
"""Process search results into a standardized format."""
query_results = []
for i in range(len(results["ids"][0])):
current_id = results["ids"][0][i]
if exclude_id and current_id == exclude_id:
continue
result = {
f"{id_field}_id": current_id,
"similarity": float(results["distances"][0][i]),
"summary": results["documents"][0][i],
"likes": results["metadatas"][0][i]["likes"],
"downloads": results["metadatas"][0][i]["downloads"],
}
if id_field == "dataset":
query_results.append(QueryResult(**result))
else:
query_results.append(ModelQueryResult(**result))
if sort_by != "similarity":
query_results.sort(key=lambda x: getattr(x, sort_by), reverse=True)
query_results = query_results[:k]
elif exclude_id: # We fetched extra for similarity + exclude_id case
query_results = query_results[:k]
return query_results
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)