|
import os |
|
import shutil |
|
import logging |
|
from pathlib import Path |
|
from huggingface_hub import WebhooksServer, WebhookPayload |
|
from datasets import Dataset, load_dataset, disable_caching |
|
from fastapi import BackgroundTasks, Response, status |
|
from huggingface_hub.utils import build_hf_headers, get_session |
|
|
|
disable_caching() |
|
|
|
|
|
logger = logging.getLogger("basic_logger") |
|
logger.setLevel(logging.INFO) |
|
|
|
|
|
console_handler = logging.StreamHandler() |
|
console_handler.setLevel(logging.INFO) |
|
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") |
|
console_handler.setFormatter(formatter) |
|
logger.addHandler(console_handler) |
|
|
|
|
|
DS_NAME = "amaye15/object-segmentation" |
|
DATA_DIR = "data" |
|
TARGET_REPO = "amaye15/tmp" |
|
WEBHOOK_SECRET = os.getenv("HF_WEBHOOK_SECRET", "my_secret_key") |
|
|
|
|
|
def get_data(): |
|
""" |
|
Generator function to stream data from the dataset. |
|
""" |
|
ds = load_dataset( |
|
DS_NAME, |
|
cache_dir=os.path.join(os.getcwd(), DATA_DIR), |
|
streaming=True, |
|
download_mode="force_redownload", |
|
) |
|
for row in ds["train"]: |
|
yield row |
|
|
|
|
|
def process_and_push_data(): |
|
""" |
|
Function to process and push new data to the target repository. |
|
""" |
|
p = os.path.join(os.getcwd(), DATA_DIR) |
|
|
|
if os.path.exists(p): |
|
shutil.rmtree(p) |
|
|
|
os.mkdir(p) |
|
|
|
ds_processed = Dataset.from_generator(get_data) |
|
ds_processed.push_to_hub(TARGET_REPO) |
|
logger.info("Data processed and pushed to the hub.") |
|
|
|
|
|
|
|
app = WebhooksServer(webhook_secret=WEBHOOK_SECRET) |
|
|
|
|
|
@app.add_webhook("/dataset_repo") |
|
async def handle_repository_changes( |
|
payload: WebhookPayload, task_queue: BackgroundTasks |
|
): |
|
""" |
|
Webhook endpoint that triggers data processing when the dataset is updated. |
|
""" |
|
if not payload.event.scope.startswith("repo"): |
|
return Response("No task scheduled", status_code=status.HTTP_200_OK) |
|
|
|
|
|
try: |
|
if payload.updatedRefs[0].ref != "refs/heads/main": |
|
response_content = "No task scheduled: Change not on main branch" |
|
logger.info(response_content) |
|
return Response(response_content, status_code=status.HTTP_200_OK) |
|
except Exception as e: |
|
logger.error(f"Error checking branch: {str(e)}") |
|
return Response("No task scheduled", status_code=status.HTTP_200_OK) |
|
|
|
|
|
try: |
|
commit_files_url = f"{payload.repo.url.api}/compare/{payload.updatedRefs[0].oldSha}..{payload.updatedRefs[0].newSha}?raw=true" |
|
response_text = ( |
|
get_session().get(commit_files_url, headers=build_hf_headers()).text |
|
) |
|
logger.info(f"Git Compare URL: {commit_files_url}") |
|
|
|
|
|
file_lines = response_text.split("\n") |
|
|
|
|
|
changed_files = [line.split("\t")[-1] for line in file_lines if line.strip()] |
|
logger.info(f"Changed files: {changed_files}") |
|
|
|
|
|
if all("README.md" in file for file in changed_files): |
|
response_content = "No task scheduled: It's a README only update." |
|
logger.info(response_content) |
|
return Response(response_content, status_code=status.HTTP_200_OK) |
|
except Exception as e: |
|
logger.error(f"Error checking files: {str(e)}") |
|
return Response("Unexpected issue", status_code=status.HTTP_501_NOT_IMPLEMENTED) |
|
|
|
logger.info( |
|
f"Webhook received from {payload.repo.name} indicating a repo {payload.event.action}" |
|
) |
|
task_queue.add_task(_process_webhook) |
|
return Response("Task scheduled.", status_code=status.HTTP_202_ACCEPTED) |
|
|
|
|
|
def _process_webhook(): |
|
logger.info(f"Loading new dataset...") |
|
dataset = load_dataset(DS_NAME) |
|
logger.info(f"Loaded new dataset") |
|
|
|
logger.info(f"Processing and updating dataset...") |
|
process_and_push_data() |
|
logger.info(f"Processing and updating dataset completed!") |
|
|
|
|
|
if __name__ == "__main__": |
|
app.launch(server_name="0.0.0.0", show_error=True, server_port=7860) |
|
|