File size: 4,330 Bytes
c30b770 4cf98d3 c30b770 10656cf 10c2fec 10656cf 10c2fec c30b770 10c2fec c30b770 10656cf c30b770 10c2fec 10656cf c30b770 4cf98d3 c30b770 10c2fec c30b770 4cf98d3 10c2fec 4cf98d3 c30b770 10656cf c30b770 10656cf c30b770 10c2fec c30b770 10656cf 4cf98d3 10656cf c30b770 10656cf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import os
import gc
import shutil
import logging
from pathlib import Path
from huggingface_hub import WebhooksServer, WebhookPayload
from datasets import Dataset, load_dataset, disable_caching
from fastapi import BackgroundTasks, Response, status
from huggingface_hub.utils import build_hf_headers, get_session
disable_caching()
# Set up the logger
logger = logging.getLogger("basic_logger")
logger.setLevel(logging.INFO)
# Set up the console handler with a simple format
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
# Environment variables
DS_NAME = "amaye15/object-segmentation"
DATA_DIR = "data"
TARGET_REPO = "amaye15/tmp"
WEBHOOK_SECRET = os.getenv("HF_WEBHOOK_SECRET", "my_secret_key")
def get_data():
"""
Generator function to stream data from the dataset.
"""
ds = load_dataset(
DS_NAME,
cache_dir=os.path.join(os.getcwd(), DATA_DIR),
streaming=True,
download_mode="force_redownload",
)
for row in ds["train"]:
yield row
gc.collect()
def process_and_push_data():
"""
Function to process and push new data to the target repository.
"""
p = os.path.join(os.getcwd(), DATA_DIR)
if os.path.exists(p):
shutil.rmtree(p)
os.mkdir(p)
ds_processed = Dataset.from_generator(get_data)
ds_processed.push_to_hub(TARGET_REPO)
logger.info("Data processed and pushed to the hub.")
gc.collect()
# Initialize the WebhooksServer with Gradio interface (if needed)
app = WebhooksServer(webhook_secret=WEBHOOK_SECRET)
@app.add_webhook("/dataset_repo")
async def handle_repository_changes(
payload: WebhookPayload, task_queue: BackgroundTasks
):
"""
Webhook endpoint that triggers data processing when the dataset is updated.
"""
if not payload.event.scope.startswith("repo"):
return Response("No task scheduled", status_code=status.HTTP_200_OK)
# Only run if change is on main branch
try:
if payload.updatedRefs[0].ref != "refs/heads/main":
response_content = "No task scheduled: Change not on main branch"
logger.info(response_content)
return Response(response_content, status_code=status.HTTP_200_OK)
except Exception as e:
logger.error(f"Error checking branch: {str(e)}")
return Response("No task scheduled", status_code=status.HTTP_200_OK)
# No need to run for README only updates
try:
commit_files_url = f"{payload.repo.url.api}/compare/{payload.updatedRefs[0].oldSha}..{payload.updatedRefs[0].newSha}?raw=true"
response_text = (
get_session().get(commit_files_url, headers=build_hf_headers()).text
)
logger.info(f"Git Compare URL: {commit_files_url}")
# Splitting the output into lines
file_lines = response_text.split("\n")
# Filtering the lines to find file changes
changed_files = [line.split("\t")[-1] for line in file_lines if line.strip()]
logger.info(f"Changed files: {changed_files}")
# Checking if only README.md has been changed
if all("README.md" in file for file in changed_files):
response_content = "No task scheduled: It's a README only update."
logger.info(response_content)
return Response(response_content, status_code=status.HTTP_200_OK)
except Exception as e:
logger.error(f"Error checking files: {str(e)}")
return Response("Unexpected issue", status_code=status.HTTP_501_NOT_IMPLEMENTED)
logger.info(
f"Webhook received from {payload.repo.name} indicating a repo {payload.event.action}"
)
task_queue.add_task(_process_webhook)
return Response("Task scheduled.", status_code=status.HTTP_202_ACCEPTED)
def _process_webhook():
logger.info(f"Loading new dataset...")
# dataset = load_dataset(DS_NAME)
logger.info(f"Loaded new dataset")
logger.info(f"Processing and updating dataset...")
process_and_push_data()
logger.info(f"Processing and updating dataset completed!")
if __name__ == "__main__":
app.launch(server_name="0.0.0.0", show_error=True, server_port=7860)
|