davanstrien HF staff commited on
Commit
7cf16e2
·
1 Parent(s): d16e515

switch to chromadb

Browse files
Files changed (1) hide show
  1. main.py +221 -124
main.py CHANGED
@@ -2,13 +2,41 @@ import logging
2
  import os
3
  from typing import List
4
  import sys
5
- import duckdb
6
- from cashews import cache # Add this import
 
7
  from fastapi import FastAPI, HTTPException, Query
8
  from fastapi.middleware.cors import CORSMiddleware
9
  from pydantic import BaseModel
10
- from sentence_transformers import SentenceTransformer
11
  from contextlib import asynccontextmanager
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" # turn on HF_TRANSFER
14
  # Set up logging
@@ -22,15 +50,20 @@ DATA_DIR = "data" if LOCAL else "/data"
22
  # Configure cache
23
  cache.setup("mem://", size_limit="4gb")
24
 
 
 
 
25
 
26
  # Initialize FastAPI app
27
  @asynccontextmanager
28
  async def lifespan(app: FastAPI):
29
- # Startup: nothing special needed here since model and DB are initialized at module level
 
 
30
  yield
 
31
  # Cleanup
32
  await cache.close()
33
- con.close()
34
 
35
 
36
  app = FastAPI(lifespan=lifespan)
@@ -48,62 +81,100 @@ app.add_middleware(
48
  allow_headers=["*"],
49
  )
50
 
51
- # Initialize model and DuckDB
52
- model = SentenceTransformer("nomic-ai/modernbert-embed-base", backend="onnx")
53
- embedding_dim = model.get_sentence_embedding_dimension()
54
-
55
- # Database setup with fallback
56
- db_path = f"{DATA_DIR}/vector_store.db"
57
- try:
58
- # Create directory if it doesn't exist
59
- os.makedirs(os.path.dirname(db_path), exist_ok=True)
60
- con = duckdb.connect(db_path)
61
- logger.info(f"Connected to persistent database at {db_path}")
62
- except (OSError, PermissionError) as e:
63
- logger.warning(
64
- f"Could not create/access {db_path}. Falling back to in-memory database. Error: {e}"
65
- )
66
- con = duckdb.connect(":memory:")
67
 
68
- # Initialize VSS extension
69
- con.sql("INSTALL vss; LOAD vss;")
70
- con.sql("SET hnsw_enable_experimental_persistence=true;")
 
 
71
 
72
 
73
  def setup_database():
74
  try:
75
- # Create table with properly typed embeddings
76
- con.sql(f"""
77
- CREATE TABLE IF NOT EXISTS model_cards AS
78
- SELECT *, embeddings::FLOAT[{embedding_dim}] as embeddings_float
79
- FROM 'hf://datasets/davanstrien/outputs-embeddings/**/*.parquet';
80
- """)
81
-
82
- # Check if index exists
83
- index_exists = (
84
- con.sql("""
85
- SELECT COUNT(*) as count
86
- FROM duckdb_indexes
87
- WHERE index_name = 'my_hnsw_index';
88
- """).fetchone()[0]
89
- > 0
90
- )
91
-
92
- if index_exists:
93
- # Drop existing index
94
- con.sql("DROP INDEX my_hnsw_index;")
95
- logger.info("Dropped existing HNSW index")
96
-
97
- # Create/Recreate HNSW index
98
- con.sql("""
99
- CREATE INDEX my_hnsw_index ON model_cards
100
- USING HNSW (embeddings_float) WITH (metric = 'cosine');
101
- """)
102
- logger.info("Created/Recreated HNSW index")
103
 
104
- # Log the number of rows in the database
105
- row_count = con.sql("SELECT COUNT(*) as count FROM model_cards").fetchone()[0]
106
- logger.info(f"Database initialized with {row_count:,} rows")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
  except Exception as e:
109
  logger.error(f"Setup error: {e}")
@@ -134,39 +205,54 @@ async def redirect_to_docs():
134
 
135
  @app.get("/search/datasets", response_model=QueryResponse)
136
  @cache(ttl="10m")
137
- async def search_datasets(query: str, k: int = Query(default=5, ge=1, le=100)):
 
 
 
 
 
 
 
 
138
  try:
139
- query_embedding = model.encode(f"search_query: {query}").tolist()
140
-
141
- # Updated SQL query to include likes and downloads
142
- result = con.sql(f"""
143
- SELECT
144
- datasetId as dataset_id,
145
- 1 - array_cosine_distance(
146
- embeddings_float::FLOAT[{embedding_dim}],
147
- {query_embedding}::FLOAT[{embedding_dim}]
148
- ) as similarity,
149
- summary,
150
- likes,
151
- downloads
152
- FROM model_cards
153
- ORDER BY similarity DESC
154
- LIMIT {k};
155
- """).df()
156
-
157
- # Updated result conversion
158
- results = [
159
- QueryResult(
160
- dataset_id=row["dataset_id"],
161
- similarity=float(row["similarity"]),
162
- summary=row["summary"],
163
- likes=int(row["likes"]),
164
- downloads=int(row["downloads"]),
 
 
 
 
165
  )
166
- for _, row in result.iterrows()
167
- ]
168
 
169
- return QueryResponse(results=results)
 
 
 
 
 
170
 
171
  except Exception as e:
172
  logger.error(f"Search error: {str(e)}")
@@ -176,52 +262,63 @@ async def search_datasets(query: str, k: int = Query(default=5, ge=1, le=100)):
176
  @app.get("/similarity/datasets", response_model=QueryResponse)
177
  @cache(ttl="10m")
178
  async def find_similar_datasets(
179
- dataset_id: str, k: int = Query(default=5, ge=1, le=100)
 
 
 
 
 
 
180
  ):
181
  try:
182
- # First, get the embedding for the input dataset_id
183
- reference_embedding = con.sql(f"""
184
- SELECT embeddings_float
185
- FROM model_cards
186
- WHERE datasetId = '{dataset_id}'
187
- LIMIT 1;
188
- """).df()
189
-
190
- if reference_embedding.empty:
191
  raise HTTPException(
192
  status_code=404, detail=f"Dataset ID '{dataset_id}' not found"
193
  )
194
 
195
- # Updated similarity search query to include likes and downloads
196
- result = con.sql(f"""
197
- SELECT
198
- datasetId as dataset_id,
199
- 1 - array_cosine_distance(
200
- embeddings_float::FLOAT[{embedding_dim}],
201
- (SELECT embeddings_float FROM model_cards WHERE datasetId = '{dataset_id}' LIMIT 1)
202
- ) as similarity,
203
- summary,
204
- likes,
205
- downloads
206
- FROM model_cards
207
- WHERE datasetId != '{dataset_id}'
208
- ORDER BY similarity DESC
209
- LIMIT {k};
210
- """).df()
211
-
212
- # Updated result conversion
213
- results = [
214
- QueryResult(
215
- dataset_id=row["dataset_id"],
216
- similarity=float(row["similarity"]),
217
- summary=row["summary"],
218
- likes=int(row["likes"]),
219
- downloads=int(row["downloads"]),
220
- )
221
- for _, row in result.iterrows()
222
- ]
223
 
224
- return QueryResponse(results=results)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
 
226
  except HTTPException:
227
  raise
 
2
  import os
3
  from typing import List
4
  import sys
5
+ import chromadb
6
+ from chromadb.utils import embedding_functions
7
+ from cashews import cache
8
  from fastapi import FastAPI, HTTPException, Query
9
  from fastapi.middleware.cors import CORSMiddleware
10
  from pydantic import BaseModel
 
11
  from contextlib import asynccontextmanager
12
+ import polars as pl
13
+ from huggingface_hub import hf_hub_url, DatasetCard, ModelCard, HfApi
14
+ from datetime import datetime, timedelta
15
+ from typing import Generator
16
+ from huggingface_hub import ModelInfo, DatasetInfo
17
+ import stamina
18
+ import logging
19
+ import polars as pl
20
+ from huggingface_hub import dataset_info
21
+ from huggingface_hub import InferenceClient
22
+ from transformers import AutoTokenizer
23
+ import stamina
24
+ from tqdm.contrib.concurrent import thread_map
25
+ from datasets import Dataset, Value, Sequence
26
+ import datasets
27
+ import os
28
+ from dotenv import load_dotenv
29
+ from huggingface_hub import get_inference_endpoint
30
+ from huggingface_hub import AsyncInferenceClient
31
+ import asyncio
32
+ from typing import List
33
+
34
+ hf_api = HfApi()
35
+
36
+
37
+ tokenizer = AutoTokenizer.from_pretrained(
38
+ "davanstrien/SmolLM2-360M-tldr-sft-2025-02-12_15-13"
39
+ )
40
 
41
  os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" # turn on HF_TRANSFER
42
  # Set up logging
 
50
  # Configure cache
51
  cache.setup("mem://", size_limit="4gb")
52
 
53
+ # Initialize ChromaDB client
54
+ client = chromadb.PersistentClient(path=f"{DATA_DIR}/chroma")
55
+
56
 
57
  # Initialize FastAPI app
58
  @asynccontextmanager
59
  async def lifespan(app: FastAPI):
60
+ # Setup
61
+ setup_database()
62
+
63
  yield
64
+
65
  # Cleanup
66
  await cache.close()
 
67
 
68
 
69
  app = FastAPI(lifespan=lifespan)
 
81
  allow_headers=["*"],
82
  )
83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
+ # Define the embedding function at module level
86
+ def get_embedding_function():
87
+ return embedding_functions.SentenceTransformerEmbeddingFunction(
88
+ model_name="nomic-ai/modernbert-embed-base"
89
+ )
90
 
91
 
92
  def setup_database():
93
  try:
94
+ embedding_function = get_embedding_function()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
+ # Create collection with embedding function
97
+ dataset_collection = client.get_or_create_collection(
98
+ embedding_function=embedding_function,
99
+ name="dataset_cards",
100
+ metadata={"hnsw:space": "cosine"},
101
+ )
102
+ # TODO incremental updates
103
+ df = pl.scan_parquet(
104
+ "hf://datasets/davanstrien/datasets_with_metadata_and_summaries/data/train-*.parquet"
105
+ )
106
+ df = df.filter(
107
+ pl.col("datasetId").str.contains_any(["open-llm-leaderboard-old/"]).not_()
108
+ )
109
+ row_count = df.select(pl.len()).collect().item()
110
+ logger.info(f"Row count of new data: {row_count}")
111
+ if dataset_collection.count() < row_count:
112
+ # Load parquet files and upsert into ChromaDB
113
+ df = df.select(
114
+ ["datasetId", "summary", "likes", "downloads", "last_modified"]
115
+ )
116
+ df = df.collect()
117
+ BATCH_SIZE = 1000
118
+ total_rows = len(df)
119
+
120
+ for i in range(0, total_rows, BATCH_SIZE):
121
+ batch_df = df.slice(i, min(BATCH_SIZE, total_rows - i))
122
+
123
+ dataset_collection.upsert(
124
+ ids=batch_df.select(["datasetId"]).to_series().to_list(),
125
+ documents=batch_df.select(["summary"]).to_series().to_list(),
126
+ metadatas=[
127
+ {
128
+ "likes": int(likes),
129
+ "downloads": int(downloads),
130
+ "last_modified": str(last_modified),
131
+ }
132
+ for likes, downloads, last_modified in zip(
133
+ batch_df.select(["likes"]).to_series().to_list(),
134
+ batch_df.select(["downloads"]).to_series().to_list(),
135
+ batch_df.select(["last_modified"]).to_series().to_list(),
136
+ )
137
+ ],
138
+ )
139
+ logger.info(f"Processed {i + len(batch_df):,} / {total_rows:,} rows")
140
+
141
+ logger.info(f"Database initialized with {dataset_collection.count():,} rows")
142
+ # model_collection = client.get_or_create_collection(
143
+ # embedding_function=embedding_function,
144
+ # name="model_cards",
145
+ # metadata={"hnsw:space": "cosine"},
146
+ # )
147
+
148
+ # # If collection is empty, load data from parquet files
149
+ # if model_collection.count() == 0:
150
+ # # Load parquet files and insert into ChromaDB
151
+ # df = pl.scan_parquet(
152
+ # "hf://datasets/librarian-bots/model_cards_with_metadata/data/train-*.parquet"
153
+ # )
154
+ # df = df.select(["modelId", "likes", "downloads"])
155
+ # df = df.collect()
156
+ # df = df.sample(n=1000) # TODO remove for prod
157
+ # # Process in batches of 1000
158
+ # BATCH_SIZE = 1000
159
+ # total_rows = len(df)
160
+
161
+ # for i in range(0, total_rows, BATCH_SIZE):
162
+ # batch_df = df.slice(i, min(BATCH_SIZE, total_rows - i))
163
+
164
+ # model_collection.add(
165
+ # ids=batch_df.select(["modelId"]).to_series().to_list(),
166
+ # documents=batch_df.select(["summary"]).to_series().to_list(),
167
+ # metadatas=[
168
+ # {"likes": int(likes), "downloads": int(downloads)}
169
+ # for likes, downloads in zip(
170
+ # batch_df.select(["likes"]).to_series().to_list(),
171
+ # batch_df.select(["downloads"]).to_series().to_list(),
172
+ # )
173
+ # ],
174
+ # )
175
+ # logger.info(f"Processed {i + len(batch_df):,} / {total_rows:,} rows")
176
+
177
+ # logger.info(f"Database initialized with {model_collection.count():,} rows")
178
 
179
  except Exception as e:
180
  logger.error(f"Setup error: {e}")
 
205
 
206
  @app.get("/search/datasets", response_model=QueryResponse)
207
  @cache(ttl="10m")
208
+ async def search_datasets(
209
+ query: str,
210
+ k: int = Query(default=5, ge=1, le=100),
211
+ sort_by: str = Query(
212
+ default="similarity", enum=["similarity", "likes", "downloads"]
213
+ ),
214
+ min_likes: int = Query(default=0, ge=0),
215
+ min_downloads: int = Query(default=0, ge=0),
216
+ ):
217
  try:
218
+ # Get collection with proper embedding function
219
+ collection = client.get_collection(
220
+ name="dataset_cards", embedding_function=get_embedding_function()
221
+ )
222
+
223
+ # Query ChromaDB
224
+ results = collection.query(
225
+ query_texts=[f"search_query: {query}"],
226
+ n_results=k * 4 if sort_by != "similarity" else k,
227
+ where={
228
+ "$and": [
229
+ {"likes": {"$gte": min_likes}},
230
+ {"downloads": {"$gte": min_downloads}},
231
+ ]
232
+ }
233
+ if min_likes > 0 or min_downloads > 0
234
+ else None,
235
+ )
236
+
237
+ # Process results
238
+ query_results = []
239
+ for i in range(len(results["ids"][0])):
240
+ query_results.append(
241
+ QueryResult(
242
+ dataset_id=results["ids"][0][i],
243
+ similarity=float(results["distances"][0][i]),
244
+ summary=results["documents"][0][i],
245
+ likes=results["metadatas"][0][i]["likes"],
246
+ downloads=results["metadatas"][0][i]["downloads"],
247
+ )
248
  )
 
 
249
 
250
+ # Sort results if needed
251
+ if sort_by != "similarity":
252
+ query_results.sort(key=lambda x: getattr(x, sort_by), reverse=True)
253
+ query_results = query_results[:k]
254
+
255
+ return QueryResponse(results=query_results)
256
 
257
  except Exception as e:
258
  logger.error(f"Search error: {str(e)}")
 
262
  @app.get("/similarity/datasets", response_model=QueryResponse)
263
  @cache(ttl="10m")
264
  async def find_similar_datasets(
265
+ dataset_id: str,
266
+ k: int = Query(default=5, ge=1, le=100),
267
+ sort_by: str = Query(
268
+ default="similarity", enum=["similarity", "likes", "downloads"]
269
+ ),
270
+ min_likes: int = Query(default=0, ge=0),
271
+ min_downloads: int = Query(default=0, ge=0),
272
  ):
273
  try:
274
+ collection = client.get_collection("dataset_cards")
275
+
276
+ # Get the reference document
277
+ results = collection.get(ids=[dataset_id], include=["embeddings"])
278
+
279
+ if not results["ids"]:
 
 
 
280
  raise HTTPException(
281
  status_code=404, detail=f"Dataset ID '{dataset_id}' not found"
282
  )
283
 
284
+ # Query using the embedding
285
+ results = collection.query(
286
+ query_embeddings=[results["embeddings"][0]],
287
+ n_results=k * 4
288
+ if sort_by != "similarity"
289
+ else k + 1, # +1 to account for self-match
290
+ where={
291
+ "$and": [
292
+ {"likes": {"$gte": min_likes}},
293
+ {"downloads": {"$gte": min_downloads}},
294
+ ]
295
+ }
296
+ if min_likes > 0 or min_downloads > 0
297
+ else None,
298
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
299
 
300
+ # Process results (excluding the query dataset itself)
301
+ query_results = []
302
+ for i in range(len(results["ids"][0])):
303
+ if results["ids"][0][i] != dataset_id:
304
+ query_results.append(
305
+ QueryResult(
306
+ dataset_id=results["ids"][0][i],
307
+ similarity=float(results["distances"][0][i]),
308
+ summary=results["documents"][0][i],
309
+ likes=results["metadatas"][0][i]["likes"],
310
+ downloads=results["metadatas"][0][i]["downloads"],
311
+ )
312
+ )
313
+
314
+ # Sort results if needed
315
+ if sort_by != "similarity":
316
+ query_results.sort(key=lambda x: getattr(x, sort_by), reverse=True)
317
+ query_results = query_results[:k]
318
+ else:
319
+ query_results = query_results[:k]
320
+
321
+ return QueryResponse(results=query_results)
322
 
323
  except HTTPException:
324
  raise