|
import asyncio |
|
import os |
|
import time |
|
from typing import Optional |
|
from uuid import uuid4 |
|
|
|
from fastapi import FastAPI, Form, Header, HTTPException, Request, BackgroundTasks |
|
from fastapi.responses import HTMLResponse |
|
from fastapi.templating import Jinja2Templates |
|
from huggingface_hub import create_discussion, comment_discussion |
|
|
|
from build_map import load_dataset_and_metadata, upload_dataset_to_atlas |
|
from models import WebhookPayload |
|
|
|
|
|
HUGGINGFACE_ACCESS_TOKEN = os.environ.get("HUGGINGFACE_ACCESS_TOKEN") |
|
|
|
app = FastAPI() |
|
|
|
tasks = {} |
|
templates = Jinja2Templates(directory="templates") |
|
|
|
def upload_atlas_task(task_id: str, |
|
dataset_name: str, |
|
atlas_api_token: str, |
|
webhook_payload: WebhookPayload = None, |
|
webhook_notify: bool = False): |
|
dataset_dict = load_dataset_and_metadata(dataset_name) |
|
map_url = upload_dataset_to_atlas(dataset_dict, atlas_api_token) |
|
tasks[task_id]['status'] = 'done' |
|
tasks[task_id]['url'] = map_url |
|
tasks[task_id]['finish_time'] = time.time() |
|
|
|
if webhook_notify: |
|
discussion = create_discussion( |
|
repo_id=webhook_payload.repo.name, |
|
title="Atlas Maps", |
|
token=HUGGINGFACE_ACCESS_TOKEN, |
|
repo_type="dataset" |
|
) |
|
comment_discussion( |
|
repo_id=webhook_payload.repo.name, |
|
discussion_num=discussion.num, |
|
comment="Atlas Map: " + map_url, |
|
token=HUGGINGFACE_ACCESS_TOKEN, |
|
repo_type="dataset" |
|
) |
|
|
|
|
|
@app.on_event("startup") |
|
async def startup_event(): |
|
asyncio.create_task(cleanup_tasks()) |
|
|
|
async def cleanup_tasks(): |
|
while True: |
|
current_time = time.time() |
|
tasks_to_delete = [] |
|
for task_id, task in tasks.items(): |
|
if task['status'] == 'done' and current_time - task.get('finish_time', current_time) > 1800: |
|
tasks_to_delete.append(task_id) |
|
for task_id in tasks_to_delete: |
|
del tasks[task_id] |
|
await asyncio.sleep(1800) |
|
|
|
@app.get("/", response_class=HTMLResponse) |
|
async def read_form(request: Request): |
|
|
|
return templates.TemplateResponse("form.html", {"request": request}) |
|
|
|
@app.post("/submit_form") |
|
async def form_post(background_tasks: BackgroundTasks, dataset_name: str = Form(...), atlas_api_token: str = Form(...)): |
|
task_id = str(uuid4()) |
|
tasks[task_id] = {'status': 'running'} |
|
|
|
background_tasks.add_task(upload_atlas_task, task_id, dataset_name, atlas_api_token) |
|
return {'task_id': task_id} |
|
|
|
@app.get("/status/{task_id}") |
|
async def read_task(task_id: str): |
|
if task_id not in tasks: |
|
return {'status': 'not found'} |
|
else: |
|
return tasks[task_id] |
|
|
|
@app.post("/webhook") |
|
async def post_webhook(background_tasks: BackgroundTasks, payload: WebhookPayload, x_webhook_secret: Optional[str] = Header(default=None)): |
|
if x_webhook_secret is None: |
|
raise HTTPException(401) |
|
|
|
|
|
|
|
|
|
if not ( |
|
payload.event.action == "update" |
|
and payload.event.scope.startswith("repo.content") |
|
and payload.repo.type == "dataset" |
|
): |
|
return {"processed": False} |
|
else: |
|
task_id = str(uuid4()) |
|
tasks[task_id] = {'status': 'running'} |
|
|
|
background_tasks.add_task(upload_atlas_task, task_id, payload.repo.name, x_webhook_secret, payload, True) |
|
return {'task_id': task_id} |
|
|