davanstrien HF staff commited on
Commit
b5c8d5a
·
1 Parent(s): 7cf16e2

chroma and models

Browse files
Files changed (1) hide show
  1. main.py +190 -94
main.py CHANGED
@@ -10,26 +10,14 @@ from fastapi.middleware.cors import CORSMiddleware
10
  from pydantic import BaseModel
11
  from contextlib import asynccontextmanager
12
  import polars as pl
13
- from huggingface_hub import hf_hub_url, DatasetCard, ModelCard, HfApi
14
- from datetime import datetime, timedelta
15
- from typing import Generator
16
- from huggingface_hub import ModelInfo, DatasetInfo
17
- import stamina
18
- import logging
19
- import polars as pl
20
- from huggingface_hub import dataset_info
21
- from huggingface_hub import InferenceClient
22
  from transformers import AutoTokenizer
23
- import stamina
24
- from tqdm.contrib.concurrent import thread_map
25
- from datasets import Dataset, Value, Sequence
26
- import datasets
27
- import os
28
- from dotenv import load_dotenv
29
- from huggingface_hub import get_inference_endpoint
30
- from huggingface_hub import AsyncInferenceClient
31
- import asyncio
32
- from typing import List
33
 
34
  hf_api = HfApi()
35
 
@@ -74,7 +62,7 @@ app.add_middleware(
74
  allow_origins=[
75
  "https://*.hf.space", # Allow all Hugging Face Spaces
76
  "https://*.huggingface.co", # Allow all Hugging Face domains
77
- # "http://localhost:5500", # Allow localhost:5500 # TODO remove before prod
78
  ],
79
  allow_credentials=True,
80
  allow_methods=["*"],
@@ -93,12 +81,20 @@ def setup_database():
93
  try:
94
  embedding_function = get_embedding_function()
95
 
96
- # Create collection with embedding function
97
  dataset_collection = client.get_or_create_collection(
98
  embedding_function=embedding_function,
99
  name="dataset_cards",
100
  metadata={"hnsw:space": "cosine"},
101
  )
 
 
 
 
 
 
 
 
102
  # TODO incremental updates
103
  df = pl.scan_parquet(
104
  "hf://datasets/davanstrien/datasets_with_metadata_and_summaries/data/train-*.parquet"
@@ -139,42 +135,48 @@ def setup_database():
139
  logger.info(f"Processed {i + len(batch_df):,} / {total_rows:,} rows")
140
 
141
  logger.info(f"Database initialized with {dataset_collection.count():,} rows")
142
- # model_collection = client.get_or_create_collection(
143
- # embedding_function=embedding_function,
144
- # name="model_cards",
145
- # metadata={"hnsw:space": "cosine"},
146
- # )
147
-
148
- # # If collection is empty, load data from parquet files
149
- # if model_collection.count() == 0:
150
- # # Load parquet files and insert into ChromaDB
151
- # df = pl.scan_parquet(
152
- # "hf://datasets/librarian-bots/model_cards_with_metadata/data/train-*.parquet"
153
- # )
154
- # df = df.select(["modelId", "likes", "downloads"])
155
- # df = df.collect()
156
- # df = df.sample(n=1000) # TODO remove for prod
157
- # # Process in batches of 1000
158
- # BATCH_SIZE = 1000
159
- # total_rows = len(df)
160
-
161
- # for i in range(0, total_rows, BATCH_SIZE):
162
- # batch_df = df.slice(i, min(BATCH_SIZE, total_rows - i))
163
-
164
- # model_collection.add(
165
- # ids=batch_df.select(["modelId"]).to_series().to_list(),
166
- # documents=batch_df.select(["summary"]).to_series().to_list(),
167
- # metadatas=[
168
- # {"likes": int(likes), "downloads": int(downloads)}
169
- # for likes, downloads in zip(
170
- # batch_df.select(["likes"]).to_series().to_list(),
171
- # batch_df.select(["downloads"]).to_series().to_list(),
172
- # )
173
- # ],
174
- # )
175
- # logger.info(f"Processed {i + len(batch_df):,} / {total_rows:,} rows")
176
-
177
- # logger.info(f"Database initialized with {model_collection.count():,} rows")
 
 
 
 
 
 
178
 
179
  except Exception as e:
180
  logger.error(f"Setup error: {e}")
@@ -196,6 +198,18 @@ class QueryResponse(BaseModel):
196
  results: List[QueryResult]
197
 
198
 
 
 
 
 
 
 
 
 
 
 
 
 
199
  @app.get("/")
200
  async def redirect_to_docs():
201
  from fastapi.responses import RedirectResponse
@@ -204,7 +218,7 @@ async def redirect_to_docs():
204
 
205
 
206
  @app.get("/search/datasets", response_model=QueryResponse)
207
- @cache(ttl="10m")
208
  async def search_datasets(
209
  query: str,
210
  k: int = Query(default=5, ge=1, le=100),
@@ -235,22 +249,7 @@ async def search_datasets(
235
  )
236
 
237
  # Process results
238
- query_results = []
239
- for i in range(len(results["ids"][0])):
240
- query_results.append(
241
- QueryResult(
242
- dataset_id=results["ids"][0][i],
243
- similarity=float(results["distances"][0][i]),
244
- summary=results["documents"][0][i],
245
- likes=results["metadatas"][0][i]["likes"],
246
- downloads=results["metadatas"][0][i]["downloads"],
247
- )
248
- )
249
-
250
- # Sort results if needed
251
- if sort_by != "similarity":
252
- query_results.sort(key=lambda x: getattr(x, sort_by), reverse=True)
253
- query_results = query_results[:k]
254
 
255
  return QueryResponse(results=query_results)
256
 
@@ -260,7 +259,7 @@ async def search_datasets(
260
 
261
 
262
  @app.get("/similarity/datasets", response_model=QueryResponse)
263
- @cache(ttl="10m")
264
  async def find_similar_datasets(
265
  dataset_id: str,
266
  k: int = Query(default=5, ge=1, le=100),
@@ -298,25 +297,9 @@ async def find_similar_datasets(
298
  )
299
 
300
  # Process results (excluding the query dataset itself)
301
- query_results = []
302
- for i in range(len(results["ids"][0])):
303
- if results["ids"][0][i] != dataset_id:
304
- query_results.append(
305
- QueryResult(
306
- dataset_id=results["ids"][0][i],
307
- similarity=float(results["distances"][0][i]),
308
- summary=results["documents"][0][i],
309
- likes=results["metadatas"][0][i]["likes"],
310
- downloads=results["metadatas"][0][i]["downloads"],
311
- )
312
- )
313
-
314
- # Sort results if needed
315
- if sort_by != "similarity":
316
- query_results.sort(key=lambda x: getattr(x, sort_by), reverse=True)
317
- query_results = query_results[:k]
318
- else:
319
- query_results = query_results[:k]
320
 
321
  return QueryResponse(results=query_results)
322
 
@@ -327,6 +310,119 @@ async def find_similar_datasets(
327
  raise HTTPException(status_code=500, detail="Similarity search failed")
328
 
329
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
330
  if __name__ == "__main__":
331
  import uvicorn
332
 
 
10
  from pydantic import BaseModel
11
  from contextlib import asynccontextmanager
12
  import polars as pl
13
+ from huggingface_hub import HfApi
 
 
 
 
 
 
 
 
14
  from transformers import AutoTokenizer
15
+
16
+ # Configuration constants
17
+ MODEL_NAME = "davanstrien/SmolLM2-360M-tldr-sft-2025-02-12_15-13"
18
+ EMBEDDING_MODEL = "nomic-ai/modernbert-embed-base"
19
+ BATCH_SIZE = 1000
20
+ CACHE_TTL = "30"
 
 
 
 
21
 
22
  hf_api = HfApi()
23
 
 
62
  allow_origins=[
63
  "https://*.hf.space", # Allow all Hugging Face Spaces
64
  "https://*.huggingface.co", # Allow all Hugging Face domains
65
+ "http://localhost:5500", # Allow localhost:5500 # TODO remove before prod
66
  ],
67
  allow_credentials=True,
68
  allow_methods=["*"],
 
81
  try:
82
  embedding_function = get_embedding_function()
83
 
84
+ # Create dataset collection
85
  dataset_collection = client.get_or_create_collection(
86
  embedding_function=embedding_function,
87
  name="dataset_cards",
88
  metadata={"hnsw:space": "cosine"},
89
  )
90
+
91
+ # Create model collection
92
+ model_collection = client.get_or_create_collection(
93
+ embedding_function=embedding_function,
94
+ name="model_cards",
95
+ metadata={"hnsw:space": "cosine"},
96
+ )
97
+
98
  # TODO incremental updates
99
  df = pl.scan_parquet(
100
  "hf://datasets/davanstrien/datasets_with_metadata_and_summaries/data/train-*.parquet"
 
135
  logger.info(f"Processed {i + len(batch_df):,} / {total_rows:,} rows")
136
 
137
  logger.info(f"Database initialized with {dataset_collection.count():,} rows")
138
+
139
+ # Load model data
140
+ model_df = pl.scan_parquet(
141
+ "hf://datasets/davanstrien/models_with_metadata_and_summaries/data/train-*.parquet"
142
+ )
143
+ model_row_count = model_df.select(pl.len()).collect().item()
144
+ logger.info(f"Row count of new model data: {model_row_count}")
145
+
146
+ if model_collection.count() < model_row_count:
147
+ model_df = model_df.select(
148
+ ["modelId", "summary", "likes", "downloads", "last_modified"]
149
+ )
150
+ model_df = model_df.collect()
151
+ BATCH_SIZE = 1000
152
+ total_rows = len(model_df)
153
+
154
+ for i in range(0, total_rows, BATCH_SIZE):
155
+ batch_df = model_df.slice(i, min(BATCH_SIZE, total_rows - i))
156
+
157
+ model_collection.upsert(
158
+ ids=batch_df.select(["modelId"]).to_series().to_list(),
159
+ documents=batch_df.select(["summary"]).to_series().to_list(),
160
+ metadatas=[
161
+ {
162
+ "likes": int(likes),
163
+ "downloads": int(downloads),
164
+ "last_modified": str(last_modified),
165
+ }
166
+ for likes, downloads, last_modified in zip(
167
+ batch_df.select(["likes"]).to_series().to_list(),
168
+ batch_df.select(["downloads"]).to_series().to_list(),
169
+ batch_df.select(["last_modified"]).to_series().to_list(),
170
+ )
171
+ ],
172
+ )
173
+ logger.info(
174
+ f"Processed {i + len(batch_df):,} / {total_rows:,} model rows"
175
+ )
176
+
177
+ logger.info(
178
+ f"Model database initialized with {model_collection.count():,} rows"
179
+ )
180
 
181
  except Exception as e:
182
  logger.error(f"Setup error: {e}")
 
198
  results: List[QueryResult]
199
 
200
 
201
+ class ModelQueryResult(BaseModel):
202
+ model_id: str
203
+ similarity: float
204
+ summary: str
205
+ likes: int
206
+ downloads: int
207
+
208
+
209
+ class ModelQueryResponse(BaseModel):
210
+ results: List[ModelQueryResult]
211
+
212
+
213
  @app.get("/")
214
  async def redirect_to_docs():
215
  from fastapi.responses import RedirectResponse
 
218
 
219
 
220
  @app.get("/search/datasets", response_model=QueryResponse)
221
+ @cache(ttl=CACHE_TTL)
222
  async def search_datasets(
223
  query: str,
224
  k: int = Query(default=5, ge=1, le=100),
 
249
  )
250
 
251
  # Process results
252
+ query_results = process_search_results(results, "dataset", k, sort_by)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253
 
254
  return QueryResponse(results=query_results)
255
 
 
259
 
260
 
261
  @app.get("/similarity/datasets", response_model=QueryResponse)
262
+ @cache(ttl=CACHE_TTL)
263
  async def find_similar_datasets(
264
  dataset_id: str,
265
  k: int = Query(default=5, ge=1, le=100),
 
297
  )
298
 
299
  # Process results (excluding the query dataset itself)
300
+ query_results = process_search_results(
301
+ results, "dataset", k, sort_by, dataset_id
302
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
303
 
304
  return QueryResponse(results=query_results)
305
 
 
310
  raise HTTPException(status_code=500, detail="Similarity search failed")
311
 
312
 
313
+ @app.get("/search/models", response_model=ModelQueryResponse)
314
+ @cache(ttl=CACHE_TTL)
315
+ async def search_models(
316
+ query: str,
317
+ k: int = Query(default=5, ge=1, le=100),
318
+ sort_by: str = Query(
319
+ default="similarity", enum=["similarity", "likes", "downloads"]
320
+ ),
321
+ min_likes: int = Query(default=0, ge=0),
322
+ min_downloads: int = Query(default=0, ge=0),
323
+ ):
324
+ try:
325
+ collection = client.get_collection(
326
+ name="model_cards", embedding_function=get_embedding_function()
327
+ )
328
+
329
+ results = collection.query(
330
+ query_texts=[f"search_query: {query}"],
331
+ n_results=k * 4 if sort_by != "similarity" else k,
332
+ where={
333
+ "$and": [
334
+ {"likes": {"$gte": min_likes}},
335
+ {"downloads": {"$gte": min_downloads}},
336
+ ]
337
+ }
338
+ if min_likes > 0 or min_downloads > 0
339
+ else None,
340
+ )
341
+
342
+ query_results = process_search_results(results, "model", k, sort_by)
343
+
344
+ return ModelQueryResponse(results=query_results)
345
+
346
+ except Exception as e:
347
+ logger.error(f"Model search error: {str(e)}")
348
+ raise HTTPException(status_code=500, detail="Model search failed")
349
+
350
+
351
+ @app.get("/similarity/models", response_model=ModelQueryResponse)
352
+ @cache(ttl=CACHE_TTL)
353
+ async def find_similar_models(
354
+ model_id: str,
355
+ k: int = Query(default=5, ge=1, le=100),
356
+ sort_by: str = Query(
357
+ default="similarity", enum=["similarity", "likes", "downloads"]
358
+ ),
359
+ min_likes: int = Query(default=0, ge=0),
360
+ min_downloads: int = Query(default=0, ge=0),
361
+ ):
362
+ try:
363
+ collection = client.get_collection("model_cards")
364
+
365
+ results = collection.get(ids=[model_id], include=["embeddings"])
366
+
367
+ if not results["ids"]:
368
+ raise HTTPException(
369
+ status_code=404, detail=f"Model ID '{model_id}' not found"
370
+ )
371
+
372
+ results = collection.query(
373
+ query_embeddings=[results["embeddings"][0]],
374
+ n_results=k * 4 if sort_by != "similarity" else k + 1,
375
+ where={
376
+ "$and": [
377
+ {"likes": {"$gte": min_likes}},
378
+ {"downloads": {"$gte": min_downloads}},
379
+ ]
380
+ }
381
+ if min_likes > 0 or min_downloads > 0
382
+ else None,
383
+ )
384
+
385
+ query_results = process_search_results(results, "model", k, sort_by, model_id)
386
+
387
+ return ModelQueryResponse(results=query_results)
388
+
389
+ except HTTPException:
390
+ raise
391
+ except Exception as e:
392
+ logger.error(f"Model similarity search error: {str(e)}")
393
+ raise HTTPException(status_code=500, detail="Model similarity search failed")
394
+
395
+
396
+ def process_search_results(results, id_field, k, sort_by, exclude_id=None):
397
+ """Process search results into a standardized format."""
398
+ query_results = []
399
+ for i in range(len(results["ids"][0])):
400
+ current_id = results["ids"][0][i]
401
+ if exclude_id and current_id == exclude_id:
402
+ continue
403
+
404
+ result = {
405
+ f"{id_field}_id": current_id,
406
+ "similarity": float(results["distances"][0][i]),
407
+ "summary": results["documents"][0][i],
408
+ "likes": results["metadatas"][0][i]["likes"],
409
+ "downloads": results["metadatas"][0][i]["downloads"],
410
+ }
411
+
412
+ if id_field == "dataset":
413
+ query_results.append(QueryResult(**result))
414
+ else:
415
+ query_results.append(ModelQueryResult(**result))
416
+
417
+ if sort_by != "similarity":
418
+ query_results.sort(key=lambda x: getattr(x, sort_by), reverse=True)
419
+ query_results = query_results[:k]
420
+ elif exclude_id: # We fetched extra for similarity + exclude_id case
421
+ query_results = query_results[:k]
422
+
423
+ return query_results
424
+
425
+
426
  if __name__ == "__main__":
427
  import uvicorn
428