amaye15
define size of push
e8afc94
raw
history blame contribute delete
No virus
3.5 kB
import os
import gc
import time
import shutil
import logging
from pathlib import Path
from huggingface_hub import WebhooksServer, WebhookPayload
from datasets import Dataset, load_dataset, disable_caching
from fastapi import BackgroundTasks, Response, status
def clear_huggingface_cache():
# Path to the Hugging Face cache directory
cache_dir = Path.home() / ".cache" / "huggingface" / "datasets"
# Remove the entire datasets directory
if cache_dir.exists() and cache_dir.is_dir():
shutil.rmtree(cache_dir)
print(f"Removed cache directory: {cache_dir}")
else:
print("Cache directory does not exist.")
# Disable caching globally for Hugging Face datasets
disable_caching()
# Set up the logger
logger = logging.getLogger("basic_logger")
logger.setLevel(logging.INFO)
# Set up the console handler with a simple format
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
# Environment variables
DS_NAME = "amaye15/object-segmentation"
DATA_DIR = Path("data") # Use pathlib for path handling
TARGET_REPO = "amaye15/object-segmentation-processed"
WEBHOOK_SECRET = os.getenv("HF_WEBHOOK_SECRET")
def get_data():
"""
Generator function to stream data from the dataset.
Uses streaming to avoid loading the entire dataset into memory at once,
which is useful for handling large datasets.
"""
ds = load_dataset(
DS_NAME,
streaming=True,
)
for row in ds["train"]:
yield row
def process_and_push_data():
"""
Function to process and push new data to the target repository.
Removes existing data directory if it exists, recreates it, processes
the dataset, and pushes the processed dataset to the hub.
"""
# Process data using the generator and push it to the hub
ds_processed = Dataset.from_generator(get_data)
ds_processed.push_to_hub(TARGET_REPO, max_shard_size="1GB")
logger.info("Data processed and pushed to the hub.")
# Initialize the WebhooksServer with Gradio interface (if needed)
app = WebhooksServer(webhook_secret=WEBHOOK_SECRET)
@app.add_webhook("/dataset_repo")
async def handle_repository_changes(
payload: WebhookPayload, task_queue: BackgroundTasks
):
"""
Webhook endpoint that triggers data processing when the dataset is updated.
Adds a task to the background task queue to process the dataset
asynchronously.
"""
time.sleep(15)
clear_huggingface_cache()
logger.info(
f"Webhook received from {payload.repo.name} indicating a repo {payload.event.action}"
)
task_queue.add_task(_process_webhook)
return Response("Task scheduled.", status_code=status.HTTP_202_ACCEPTED)
def _process_webhook():
"""
Private function to handle the processing of the dataset when a webhook
is triggered.
Loads the dataset, processes it, and pushes the processed data to the hub.
"""
logger.info("Loading new dataset...")
# Dataset loading is handled inside process_and_push_data, no need to load here
logger.info("Loaded new dataset")
logger.info("Processing and updating dataset...")
process_and_push_data()
logger.info("Processing and updating dataset completed!")
if __name__ == "__main__":
app.launch(server_name="0.0.0.0", show_error=True, server_port=7860)