huggingface-datasets-search-v2 / load_viewer_data.py
davanstrien's picture
davanstrien HF staff
update embedding model
6a4b44c
raw
history blame
3.71 kB
import asyncio
import logging
import chromadb
import requests
import stamina
from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction
from huggingface_hub import InferenceClient
from tqdm.auto import tqdm
from tqdm.contrib.concurrent import thread_map
from prep_viewer_data import prep_data
from utils import get_chroma_client
# Set up logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
EMBEDDING_MODEL_NAME = "davanstrien/query-to-dataset-viewer-descriptions"
EMBEDDING_MODEL_REVISION = "07c71d97861a73695f0c53cd6b4b32980007d908"
INFERENCE_MODEL_URL = (
"https://ecg0by60w2vo9j8h.us-east-1.aws.endpoints.huggingface.cloud"
)
def initialize_clients():
logger.info("Initializing clients")
chroma_client = get_chroma_client()
inference_client = InferenceClient(
INFERENCE_MODEL_URL,
)
return chroma_client, inference_client
def create_collection(chroma_client):
logger.info("Creating or getting collection")
embedding_function = SentenceTransformerEmbeddingFunction(
model_name=EMBEDDING_MODEL_NAME,
trust_remote_code=True,
revision=EMBEDDING_MODEL_REVISION,
)
logger.info(f"Embedding function: {embedding_function}")
logger.info(f"Embedding model name: {EMBEDDING_MODEL_NAME}")
logger.info(f"Embedding model revision: {EMBEDDING_MODEL_REVISION}")
return chroma_client.create_collection(
name="dataset-viewer-descriptions",
get_or_create=True,
embedding_function=embedding_function,
metadata={"hnsw:space": "cosine"},
)
@stamina.retry(on=requests.HTTPError, attempts=3, wait_initial=10)
def embed_card(text, client):
text = text[:8192]
return client.feature_extraction(text)
def embed_and_upsert_datasets(
dataset_rows_and_ids: list[dict[str, str]],
collection: chromadb.Collection,
inference_client: InferenceClient,
batch_size: int = 100,
):
logger.info(
f"Embedding and upserting {len(dataset_rows_and_ids)} datasets for viewer data"
)
for i in tqdm(range(0, len(dataset_rows_and_ids), batch_size)):
batch = dataset_rows_and_ids[i : i + batch_size]
ids = []
documents = []
for item in batch:
ids.append(item["dataset_id"])
documents.append(f"HUB_DATASET_PREVIEW: {item['formatted_prompt']}")
results = thread_map(
lambda doc: embed_card(doc, inference_client), documents, leave=False
)
logger.info(f"Results: {len(results)}")
collection.upsert(
ids=ids,
embeddings=[embedding.tolist()[0] for embedding in results],
)
logger.debug(f"Processed batch {i//batch_size + 1}")
async def refresh_viewer_data(sample_size=200_000, min_likes=2):
logger.info(
f"Refreshing viewer data with sample_size={sample_size} and min_likes={min_likes}"
)
chroma_client, inference_client = initialize_clients()
collection = create_collection(chroma_client)
logger.info("Collection created successfully")
logger.info("Preparing data")
df = await prep_data(sample_size=sample_size, min_likes=min_likes)
df.write_parquet("viewer_data.parquet")
if df is not None:
logger.info("Data prepared successfully")
logger.info(f"Data: {df}")
dataset_rows_and_ids = df.to_dicts()
logger.info(f"Embedding and upserting {len(dataset_rows_and_ids)} datasets")
embed_and_upsert_datasets(dataset_rows_and_ids, collection, inference_client)
logger.info("Refresh completed successfully")
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
asyncio.run(refresh_viewer_data())