File size: 3,503 Bytes
c30b770
4cf98d3
022365f
c30b770
 
10656cf
 
10c2fec
10656cf
10c2fec
5150064
 
 
 
 
 
 
 
 
 
 
 
 
60283f6
10c2fec
c30b770
 
 
 
 
10c2fec
c30b770
 
 
 
 
 
10656cf
c30b770
60283f6
6ea28ef
 
c30b770
 
 
 
 
60283f6
 
 
c30b770
 
 
 
 
 
 
 
 
 
 
10c2fec
c30b770
60283f6
 
 
c30b770
60283f6
c30b770
e8afc94
60283f6
10c2fec
c30b770
 
10656cf
 
c30b770
 
10656cf
 
 
 
c30b770
10c2fec
60283f6
 
 
c30b770
022365f
8d0d89f
10656cf
 
 
 
 
 
 
 
60283f6
 
 
 
 
 
 
 
 
10656cf
60283f6
10656cf
60283f6
c30b770
 
10656cf
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import os
import gc
import time
import shutil
import logging
from pathlib import Path
from huggingface_hub import WebhooksServer, WebhookPayload
from datasets import Dataset, load_dataset, disable_caching
from fastapi import BackgroundTasks, Response, status


def clear_huggingface_cache():
    # Path to the Hugging Face cache directory
    cache_dir = Path.home() / ".cache" / "huggingface" / "datasets"

    # Remove the entire datasets directory
    if cache_dir.exists() and cache_dir.is_dir():
        shutil.rmtree(cache_dir)
        print(f"Removed cache directory: {cache_dir}")
    else:
        print("Cache directory does not exist.")


# Disable caching globally for Hugging Face datasets
disable_caching()

# Set up the logger
logger = logging.getLogger("basic_logger")
logger.setLevel(logging.INFO)

# Set up the console handler with a simple format
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)

# Environment variables
DS_NAME = "amaye15/object-segmentation"
DATA_DIR = Path("data")  # Use pathlib for path handling
TARGET_REPO = "amaye15/object-segmentation-processed"
WEBHOOK_SECRET = os.getenv("HF_WEBHOOK_SECRET")


def get_data():
    """
    Generator function to stream data from the dataset.

    Uses streaming to avoid loading the entire dataset into memory at once,
    which is useful for handling large datasets.
    """
    ds = load_dataset(
        DS_NAME,
        streaming=True,
    )
    for row in ds["train"]:
        yield row


def process_and_push_data():
    """
    Function to process and push new data to the target repository.

    Removes existing data directory if it exists, recreates it, processes
    the dataset, and pushes the processed dataset to the hub.
    """

    # Process data using the generator and push it to the hub
    ds_processed = Dataset.from_generator(get_data)
    ds_processed.push_to_hub(TARGET_REPO, max_shard_size="1GB")

    logger.info("Data processed and pushed to the hub.")


# Initialize the WebhooksServer with Gradio interface (if needed)
app = WebhooksServer(webhook_secret=WEBHOOK_SECRET)


@app.add_webhook("/dataset_repo")
async def handle_repository_changes(
    payload: WebhookPayload, task_queue: BackgroundTasks
):
    """
    Webhook endpoint that triggers data processing when the dataset is updated.

    Adds a task to the background task queue to process the dataset
    asynchronously.
    """
    time.sleep(15)
    clear_huggingface_cache()
    logger.info(
        f"Webhook received from {payload.repo.name} indicating a repo {payload.event.action}"
    )
    task_queue.add_task(_process_webhook)
    return Response("Task scheduled.", status_code=status.HTTP_202_ACCEPTED)


def _process_webhook():
    """
    Private function to handle the processing of the dataset when a webhook
    is triggered.

    Loads the dataset, processes it, and pushes the processed data to the hub.
    """
    logger.info("Loading new dataset...")
    # Dataset loading is handled inside process_and_push_data, no need to load here
    logger.info("Loaded new dataset")

    logger.info("Processing and updating dataset...")
    process_and_push_data()
    logger.info("Processing and updating dataset completed!")


if __name__ == "__main__":
    app.launch(server_name="0.0.0.0", show_error=True, server_port=7860)