Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import asyncio | |
import logging | |
import chromadb | |
import requests | |
import stamina | |
from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction | |
from huggingface_hub import InferenceClient | |
from tqdm.auto import tqdm | |
from tqdm.contrib.concurrent import thread_map | |
from prep_viewer_data import prep_data | |
from utils import get_chroma_client | |
# Set up logging | |
logger = logging.getLogger(__name__) | |
logger.setLevel(logging.INFO) | |
EMBEDDING_MODEL_NAME = "davanstrien/query-to-dataset-viewer-descriptions" | |
EMBEDDING_MODEL_REVISION = "07c71d97861a73695f0c53cd6b4b32980007d908" | |
INFERENCE_MODEL_URL = ( | |
"https://ecg0by60w2vo9j8h.us-east-1.aws.endpoints.huggingface.cloud" | |
) | |
def initialize_clients(): | |
logger.info("Initializing clients") | |
chroma_client = get_chroma_client() | |
inference_client = InferenceClient( | |
INFERENCE_MODEL_URL, | |
) | |
return chroma_client, inference_client | |
def create_collection(chroma_client): | |
logger.info("Creating or getting collection") | |
embedding_function = SentenceTransformerEmbeddingFunction( | |
model_name=EMBEDDING_MODEL_NAME, | |
trust_remote_code=True, | |
revision=EMBEDDING_MODEL_REVISION, | |
) | |
logger.info(f"Embedding function: {embedding_function}") | |
logger.info(f"Embedding model name: {EMBEDDING_MODEL_NAME}") | |
logger.info(f"Embedding model revision: {EMBEDDING_MODEL_REVISION}") | |
return chroma_client.create_collection( | |
name="dataset-viewer-descriptions", | |
get_or_create=True, | |
embedding_function=embedding_function, | |
metadata={"hnsw:space": "cosine"}, | |
) | |
def embed_card(text, client): | |
text = text[:8192] | |
return client.feature_extraction(text) | |
def embed_and_upsert_datasets( | |
dataset_rows_and_ids: list[dict[str, str]], | |
collection: chromadb.Collection, | |
inference_client: InferenceClient, | |
batch_size: int = 100, | |
): | |
logger.info( | |
f"Embedding and upserting {len(dataset_rows_and_ids)} datasets for viewer data" | |
) | |
for i in tqdm(range(0, len(dataset_rows_and_ids), batch_size)): | |
batch = dataset_rows_and_ids[i : i + batch_size] | |
ids = [] | |
documents = [] | |
for item in batch: | |
ids.append(item["dataset_id"]) | |
documents.append(f"HUB_DATASET_PREVIEW: {item['formatted_prompt']}") | |
results = thread_map( | |
lambda doc: embed_card(doc, inference_client), documents, leave=False | |
) | |
logger.info(f"Results: {len(results)}") | |
collection.upsert( | |
ids=ids, | |
embeddings=[embedding.tolist()[0] for embedding in results], | |
) | |
logger.debug(f"Processed batch {i//batch_size + 1}") | |
async def refresh_viewer_data(sample_size=200_000, min_likes=2): | |
logger.info( | |
f"Refreshing viewer data with sample_size={sample_size} and min_likes={min_likes}" | |
) | |
chroma_client, inference_client = initialize_clients() | |
collection = create_collection(chroma_client) | |
logger.info("Collection created successfully") | |
logger.info("Preparing data") | |
df = await prep_data(sample_size=sample_size, min_likes=min_likes) | |
df.write_parquet("viewer_data.parquet") | |
if df is not None: | |
logger.info("Data prepared successfully") | |
logger.info(f"Data: {df}") | |
dataset_rows_and_ids = df.to_dicts() | |
logger.info(f"Embedding and upserting {len(dataset_rows_and_ids)} datasets") | |
embed_and_upsert_datasets(dataset_rows_and_ids, collection, inference_client) | |
logger.info("Refresh completed successfully") | |
if __name__ == "__main__": | |
logging.basicConfig(level=logging.INFO) | |
asyncio.run(refresh_viewer_data()) | |