File size: 4,330 Bytes
c30b770
4cf98d3
c30b770
 
10656cf
 
10c2fec
10656cf
 
10c2fec
 
c30b770
 
 
 
 
10c2fec
c30b770
 
 
 
 
 
10656cf
c30b770
 
10c2fec
10656cf
c30b770
 
 
 
 
 
 
 
 
 
 
 
 
 
4cf98d3
c30b770
 
 
 
10c2fec
c30b770
 
 
 
 
 
 
 
 
4cf98d3
10c2fec
4cf98d3
c30b770
 
10656cf
 
c30b770
 
10656cf
 
 
 
c30b770
10c2fec
c30b770
10656cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4cf98d3
10656cf
 
 
 
 
c30b770
 
10656cf
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import os
import gc
import shutil
import logging
from pathlib import Path
from huggingface_hub import WebhooksServer, WebhookPayload
from datasets import Dataset, load_dataset, disable_caching
from fastapi import BackgroundTasks, Response, status
from huggingface_hub.utils import build_hf_headers, get_session

disable_caching()

# Set up the logger
logger = logging.getLogger("basic_logger")
logger.setLevel(logging.INFO)

# Set up the console handler with a simple format
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)

# Environment variables
DS_NAME = "amaye15/object-segmentation"
DATA_DIR = "data"
TARGET_REPO = "amaye15/tmp"
WEBHOOK_SECRET = os.getenv("HF_WEBHOOK_SECRET", "my_secret_key")


def get_data():
    """
    Generator function to stream data from the dataset.
    """
    ds = load_dataset(
        DS_NAME,
        cache_dir=os.path.join(os.getcwd(), DATA_DIR),
        streaming=True,
        download_mode="force_redownload",
    )
    for row in ds["train"]:
        yield row
    gc.collect()


def process_and_push_data():
    """
    Function to process and push new data to the target repository.
    """
    p = os.path.join(os.getcwd(), DATA_DIR)

    if os.path.exists(p):
        shutil.rmtree(p)

    os.mkdir(p)

    ds_processed = Dataset.from_generator(get_data)
    ds_processed.push_to_hub(TARGET_REPO)
    logger.info("Data processed and pushed to the hub.")
    gc.collect()


# Initialize the WebhooksServer with Gradio interface (if needed)
app = WebhooksServer(webhook_secret=WEBHOOK_SECRET)


@app.add_webhook("/dataset_repo")
async def handle_repository_changes(
    payload: WebhookPayload, task_queue: BackgroundTasks
):
    """
    Webhook endpoint that triggers data processing when the dataset is updated.
    """
    if not payload.event.scope.startswith("repo"):
        return Response("No task scheduled", status_code=status.HTTP_200_OK)

    # Only run if change is on main branch
    try:
        if payload.updatedRefs[0].ref != "refs/heads/main":
            response_content = "No task scheduled: Change not on main branch"
            logger.info(response_content)
            return Response(response_content, status_code=status.HTTP_200_OK)
    except Exception as e:
        logger.error(f"Error checking branch: {str(e)}")
        return Response("No task scheduled", status_code=status.HTTP_200_OK)

    # No need to run for README only updates
    try:
        commit_files_url = f"{payload.repo.url.api}/compare/{payload.updatedRefs[0].oldSha}..{payload.updatedRefs[0].newSha}?raw=true"
        response_text = (
            get_session().get(commit_files_url, headers=build_hf_headers()).text
        )
        logger.info(f"Git Compare URL: {commit_files_url}")

        # Splitting the output into lines
        file_lines = response_text.split("\n")

        # Filtering the lines to find file changes
        changed_files = [line.split("\t")[-1] for line in file_lines if line.strip()]
        logger.info(f"Changed files: {changed_files}")

        # Checking if only README.md has been changed
        if all("README.md" in file for file in changed_files):
            response_content = "No task scheduled: It's a README only update."
            logger.info(response_content)
            return Response(response_content, status_code=status.HTTP_200_OK)
    except Exception as e:
        logger.error(f"Error checking files: {str(e)}")
        return Response("Unexpected issue", status_code=status.HTTP_501_NOT_IMPLEMENTED)

    logger.info(
        f"Webhook received from {payload.repo.name} indicating a repo {payload.event.action}"
    )
    task_queue.add_task(_process_webhook)
    return Response("Task scheduled.", status_code=status.HTTP_202_ACCEPTED)


def _process_webhook():
    logger.info(f"Loading new dataset...")
    # dataset = load_dataset(DS_NAME)
    logger.info(f"Loaded new dataset")

    logger.info(f"Processing and updating dataset...")
    process_and_push_data()
    logger.info(f"Processing and updating dataset completed!")


if __name__ == "__main__":
    app.launch(server_name="0.0.0.0", show_error=True, server_port=7860)