File size: 3,503 Bytes
c30b770 4cf98d3 022365f c30b770 10656cf 10c2fec 10656cf 10c2fec 5150064 60283f6 10c2fec c30b770 10c2fec c30b770 10656cf c30b770 60283f6 6ea28ef c30b770 60283f6 c30b770 10c2fec c30b770 60283f6 c30b770 60283f6 c30b770 e8afc94 60283f6 10c2fec c30b770 10656cf c30b770 10656cf c30b770 10c2fec 60283f6 c30b770 022365f 8d0d89f 10656cf 60283f6 10656cf 60283f6 10656cf 60283f6 c30b770 10656cf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
import os
import gc
import time
import shutil
import logging
from pathlib import Path
from huggingface_hub import WebhooksServer, WebhookPayload
from datasets import Dataset, load_dataset, disable_caching
from fastapi import BackgroundTasks, Response, status
def clear_huggingface_cache():
# Path to the Hugging Face cache directory
cache_dir = Path.home() / ".cache" / "huggingface" / "datasets"
# Remove the entire datasets directory
if cache_dir.exists() and cache_dir.is_dir():
shutil.rmtree(cache_dir)
print(f"Removed cache directory: {cache_dir}")
else:
print("Cache directory does not exist.")
# Disable caching globally for Hugging Face datasets
disable_caching()
# Set up the logger
logger = logging.getLogger("basic_logger")
logger.setLevel(logging.INFO)
# Set up the console handler with a simple format
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
# Environment variables
DS_NAME = "amaye15/object-segmentation"
DATA_DIR = Path("data") # Use pathlib for path handling
TARGET_REPO = "amaye15/object-segmentation-processed"
WEBHOOK_SECRET = os.getenv("HF_WEBHOOK_SECRET")
def get_data():
"""
Generator function to stream data from the dataset.
Uses streaming to avoid loading the entire dataset into memory at once,
which is useful for handling large datasets.
"""
ds = load_dataset(
DS_NAME,
streaming=True,
)
for row in ds["train"]:
yield row
def process_and_push_data():
"""
Function to process and push new data to the target repository.
Removes existing data directory if it exists, recreates it, processes
the dataset, and pushes the processed dataset to the hub.
"""
# Process data using the generator and push it to the hub
ds_processed = Dataset.from_generator(get_data)
ds_processed.push_to_hub(TARGET_REPO, max_shard_size="1GB")
logger.info("Data processed and pushed to the hub.")
# Initialize the WebhooksServer with Gradio interface (if needed)
app = WebhooksServer(webhook_secret=WEBHOOK_SECRET)
@app.add_webhook("/dataset_repo")
async def handle_repository_changes(
payload: WebhookPayload, task_queue: BackgroundTasks
):
"""
Webhook endpoint that triggers data processing when the dataset is updated.
Adds a task to the background task queue to process the dataset
asynchronously.
"""
time.sleep(15)
clear_huggingface_cache()
logger.info(
f"Webhook received from {payload.repo.name} indicating a repo {payload.event.action}"
)
task_queue.add_task(_process_webhook)
return Response("Task scheduled.", status_code=status.HTTP_202_ACCEPTED)
def _process_webhook():
"""
Private function to handle the processing of the dataset when a webhook
is triggered.
Loads the dataset, processes it, and pushes the processed data to the hub.
"""
logger.info("Loading new dataset...")
# Dataset loading is handled inside process_and_push_data, no need to load here
logger.info("Loaded new dataset")
logger.info("Processing and updating dataset...")
process_and_push_data()
logger.info("Processing and updating dataset completed!")
if __name__ == "__main__":
app.launch(server_name="0.0.0.0", show_error=True, server_port=7860)
|