import os import gc import time import shutil import logging from pathlib import Path from huggingface_hub import WebhooksServer, WebhookPayload from datasets import Dataset, load_dataset, disable_caching from fastapi import BackgroundTasks, Response, status def clear_huggingface_cache(): # Path to the Hugging Face cache directory cache_dir = Path.home() / ".cache" / "huggingface" / "datasets" # Remove the entire datasets directory if cache_dir.exists() and cache_dir.is_dir(): shutil.rmtree(cache_dir) print(f"Removed cache directory: {cache_dir}") else: print("Cache directory does not exist.") # Disable caching globally for Hugging Face datasets disable_caching() # Set up the logger logger = logging.getLogger("basic_logger") logger.setLevel(logging.INFO) # Set up the console handler with a simple format console_handler = logging.StreamHandler() console_handler.setLevel(logging.INFO) formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") console_handler.setFormatter(formatter) logger.addHandler(console_handler) # Environment variables DS_NAME = "amaye15/object-segmentation" DATA_DIR = Path("data") # Use pathlib for path handling TARGET_REPO = "amaye15/object-segmentation-processed" WEBHOOK_SECRET = os.getenv("HF_WEBHOOK_SECRET") def get_data(): """ Generator function to stream data from the dataset. Uses streaming to avoid loading the entire dataset into memory at once, which is useful for handling large datasets. """ ds = load_dataset( DS_NAME, streaming=True, ) for row in ds["train"]: yield row def process_and_push_data(): """ Function to process and push new data to the target repository. Removes existing data directory if it exists, recreates it, processes the dataset, and pushes the processed dataset to the hub. """ # Process data using the generator and push it to the hub ds_processed = Dataset.from_generator(get_data) ds_processed.push_to_hub(TARGET_REPO, max_shard_size="1GB") logger.info("Data processed and pushed to the hub.") # Initialize the WebhooksServer with Gradio interface (if needed) app = WebhooksServer(webhook_secret=WEBHOOK_SECRET) @app.add_webhook("/dataset_repo") async def handle_repository_changes( payload: WebhookPayload, task_queue: BackgroundTasks ): """ Webhook endpoint that triggers data processing when the dataset is updated. Adds a task to the background task queue to process the dataset asynchronously. """ time.sleep(15) clear_huggingface_cache() logger.info( f"Webhook received from {payload.repo.name} indicating a repo {payload.event.action}" ) task_queue.add_task(_process_webhook) return Response("Task scheduled.", status_code=status.HTTP_202_ACCEPTED) def _process_webhook(): """ Private function to handle the processing of the dataset when a webhook is triggered. Loads the dataset, processes it, and pushes the processed data to the hub. """ logger.info("Loading new dataset...") # Dataset loading is handled inside process_and_push_data, no need to load here logger.info("Loaded new dataset") logger.info("Processing and updating dataset...") process_and_push_data() logger.info("Processing and updating dataset completed!") if __name__ == "__main__": app.launch(server_name="0.0.0.0", show_error=True, server_port=7860)