|
import os |
|
import gc |
|
import time |
|
import shutil |
|
import logging |
|
from pathlib import Path |
|
from huggingface_hub import WebhooksServer, WebhookPayload |
|
from datasets import Dataset, load_dataset, disable_caching |
|
from fastapi import BackgroundTasks, Response, status |
|
|
|
|
|
def clear_huggingface_cache(): |
|
|
|
cache_dir = Path.home() / ".cache" / "huggingface" / "datasets" |
|
|
|
|
|
if cache_dir.exists() and cache_dir.is_dir(): |
|
shutil.rmtree(cache_dir) |
|
print(f"Removed cache directory: {cache_dir}") |
|
else: |
|
print("Cache directory does not exist.") |
|
|
|
|
|
|
|
disable_caching() |
|
|
|
|
|
logger = logging.getLogger("basic_logger") |
|
logger.setLevel(logging.INFO) |
|
|
|
|
|
console_handler = logging.StreamHandler() |
|
console_handler.setLevel(logging.INFO) |
|
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") |
|
console_handler.setFormatter(formatter) |
|
logger.addHandler(console_handler) |
|
|
|
|
|
DS_NAME = "amaye15/object-segmentation" |
|
DATA_DIR = Path("data") |
|
TARGET_REPO = "amaye15/object-segmentation-processed" |
|
WEBHOOK_SECRET = os.getenv("HF_WEBHOOK_SECRET") |
|
|
|
|
|
def get_data(): |
|
""" |
|
Generator function to stream data from the dataset. |
|
|
|
Uses streaming to avoid loading the entire dataset into memory at once, |
|
which is useful for handling large datasets. |
|
""" |
|
ds = load_dataset( |
|
DS_NAME, |
|
streaming=True, |
|
) |
|
for row in ds["train"]: |
|
yield row |
|
|
|
|
|
def process_and_push_data(): |
|
""" |
|
Function to process and push new data to the target repository. |
|
|
|
Removes existing data directory if it exists, recreates it, processes |
|
the dataset, and pushes the processed dataset to the hub. |
|
""" |
|
|
|
|
|
ds_processed = Dataset.from_generator(get_data) |
|
ds_processed.push_to_hub(TARGET_REPO, max_shard_size="1GB") |
|
|
|
logger.info("Data processed and pushed to the hub.") |
|
|
|
|
|
|
|
app = WebhooksServer(webhook_secret=WEBHOOK_SECRET) |
|
|
|
|
|
@app.add_webhook("/dataset_repo") |
|
async def handle_repository_changes( |
|
payload: WebhookPayload, task_queue: BackgroundTasks |
|
): |
|
""" |
|
Webhook endpoint that triggers data processing when the dataset is updated. |
|
|
|
Adds a task to the background task queue to process the dataset |
|
asynchronously. |
|
""" |
|
time.sleep(15) |
|
clear_huggingface_cache() |
|
logger.info( |
|
f"Webhook received from {payload.repo.name} indicating a repo {payload.event.action}" |
|
) |
|
task_queue.add_task(_process_webhook) |
|
return Response("Task scheduled.", status_code=status.HTTP_202_ACCEPTED) |
|
|
|
|
|
def _process_webhook(): |
|
""" |
|
Private function to handle the processing of the dataset when a webhook |
|
is triggered. |
|
|
|
Loads the dataset, processes it, and pushes the processed data to the hub. |
|
""" |
|
logger.info("Loading new dataset...") |
|
|
|
logger.info("Loaded new dataset") |
|
|
|
logger.info("Processing and updating dataset...") |
|
process_and_push_data() |
|
logger.info("Processing and updating dataset completed!") |
|
|
|
|
|
if __name__ == "__main__": |
|
app.launch(server_name="0.0.0.0", show_error=True, server_port=7860) |
|
|