File size: 3,751 Bytes
81aaa4e 442f97c 81aaa4e 442f97c 81aaa4e 442f97c 1779f92 81aaa4e 442f97c 1779f92 81aaa4e 442f97c 1779f92 442f97c 81aaa4e 783de92 81aaa4e 036b5da 81aaa4e 783de92 81aaa4e d911a9d 81aaa4e 442f97c 1779f92 442f97c 1779f92 442f97c f47c911 81aaa4e 1779f92 81aaa4e f47c911 1779f92 f47c911 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
import asyncio
import os
import time
from typing import Optional
from uuid import uuid4
from fastapi import FastAPI, Form, Header, HTTPException, Request, BackgroundTasks
from fastapi.responses import HTMLResponse
from fastapi.templating import Jinja2Templates
from huggingface_hub import create_discussion, comment_discussion
from build_map import load_dataset_and_metadata, upload_dataset_to_atlas
from models import WebhookPayload
# WEBHOOK_SECRET = os.environ.get("WEBHOOK_SECRET")
HUGGINGFACE_ACCESS_TOKEN = os.environ.get("HUGGINGFACE_ACCESS_TOKEN")
app = FastAPI()
# TODO: use task management queue
tasks = {}
templates = Jinja2Templates(directory="templates")
def upload_atlas_task(task_id: str,
dataset_name: str,
atlas_api_token: str,
webhook_payload: WebhookPayload = None,
webhook_notify: bool = False):
dataset_dict = load_dataset_and_metadata(dataset_name)
map_url = upload_dataset_to_atlas(dataset_dict, atlas_api_token)
tasks[task_id]['status'] = 'done'
tasks[task_id]['url'] = map_url
tasks[task_id]['finish_time'] = time.time()
if webhook_notify:
discussion = create_discussion(
repo_id=webhook_payload.repo.name,
title="Atlas Maps",
token=HUGGINGFACE_ACCESS_TOKEN,
repo_type="dataset"
)
comment_discussion(
repo_id=webhook_payload.repo.name,
discussion_num=discussion.num,
comment="Atlas Map: " + map_url,
token=HUGGINGFACE_ACCESS_TOKEN,
repo_type="dataset"
)
@app.on_event("startup")
async def startup_event():
asyncio.create_task(cleanup_tasks())
async def cleanup_tasks():
while True:
current_time = time.time()
tasks_to_delete = []
for task_id, task in tasks.items():
if task['status'] == 'done' and current_time - task.get('finish_time', current_time) > 1800: # 30 minutes
tasks_to_delete.append(task_id)
for task_id in tasks_to_delete:
del tasks[task_id]
await asyncio.sleep(1800) # Wait for 30 minutes
@app.get("/", response_class=HTMLResponse)
async def read_form(request: Request):
# Render the form.html template
return templates.TemplateResponse("form.html", {"request": request})
@app.post("/submit_form")
async def form_post(background_tasks: BackgroundTasks, dataset_name: str = Form(...), atlas_api_token: str = Form(...)):
task_id = str(uuid4())
tasks[task_id] = {'status': 'running'}
#form_data = DatasetForm(dataset_name=dataset_name)
background_tasks.add_task(upload_atlas_task, task_id, dataset_name, atlas_api_token)
return {'task_id': task_id}
@app.get("/status/{task_id}")
async def read_task(task_id: str):
if task_id not in tasks:
return {'status': 'not found'}
else:
return tasks[task_id]
@app.post("/webhook")
async def post_webhook(background_tasks: BackgroundTasks, payload: WebhookPayload, x_webhook_secret: Optional[str] = Header(default=None)):
if x_webhook_secret is None:
raise HTTPException(401)
# if x_webhook_secret != WEBHOOK_SECRET:
# raise HTTPException(403)
if not (
payload.event.action == "update"
and payload.event.scope.startswith("repo.content")
and payload.repo.type == "dataset"
):
return {"processed": False}
else:
task_id = str(uuid4())
tasks[task_id] = {'status': 'running'}
#form_data = DatasetForm(dataset_name=dataset_name)
background_tasks.add_task(upload_atlas_task, task_id, payload.repo.name, x_webhook_secret, payload, True)
return {'task_id': task_id}
|