File size: 3,751 Bytes
81aaa4e
 
 
 
 
 
 
442f97c
 
81aaa4e
442f97c
 
81aaa4e
442f97c
1779f92
81aaa4e
442f97c
 
 
 
 
 
1779f92
 
 
81aaa4e
 
442f97c
1779f92
442f97c
 
81aaa4e
 
 
 
783de92
81aaa4e
 
036b5da
81aaa4e
 
783de92
81aaa4e
 
d911a9d
 
81aaa4e
 
442f97c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1779f92
 
 
 
442f97c
1779f92
 
 
 
 
 
 
442f97c
 
 
 
 
 
 
f47c911
 
81aaa4e
 
 
1779f92
 
 
81aaa4e
f47c911
 
 
 
 
 
 
 
 
 
1779f92
f47c911
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import asyncio
import os
import time
from typing import Optional
from uuid import uuid4

from fastapi import FastAPI, Form, Header, HTTPException, Request, BackgroundTasks
from fastapi.responses import HTMLResponse
from fastapi.templating import Jinja2Templates
from huggingface_hub import create_discussion, comment_discussion

from build_map import load_dataset_and_metadata, upload_dataset_to_atlas
from models import WebhookPayload

# WEBHOOK_SECRET = os.environ.get("WEBHOOK_SECRET")
HUGGINGFACE_ACCESS_TOKEN = os.environ.get("HUGGINGFACE_ACCESS_TOKEN")

app = FastAPI()
# TODO: use task management queue
tasks = {}
templates = Jinja2Templates(directory="templates")

def upload_atlas_task(task_id: str, 
                      dataset_name: str,
                      atlas_api_token: str,
                      webhook_payload: WebhookPayload = None, 
                      webhook_notify: bool = False):
    dataset_dict = load_dataset_and_metadata(dataset_name)
    map_url = upload_dataset_to_atlas(dataset_dict, atlas_api_token)
    tasks[task_id]['status'] = 'done'
    tasks[task_id]['url'] = map_url
    tasks[task_id]['finish_time'] = time.time()

    if webhook_notify:
        discussion = create_discussion(
            repo_id=webhook_payload.repo.name,
            title="Atlas Maps",
            token=HUGGINGFACE_ACCESS_TOKEN,
            repo_type="dataset"
        )
        comment_discussion(
            repo_id=webhook_payload.repo.name,
            discussion_num=discussion.num,
            comment="Atlas Map: " + map_url,
            token=HUGGINGFACE_ACCESS_TOKEN,
            repo_type="dataset"
        )


@app.on_event("startup")
async def startup_event():
    asyncio.create_task(cleanup_tasks())

async def cleanup_tasks():
    while True:
        current_time = time.time()
        tasks_to_delete = []
        for task_id, task in tasks.items():
            if task['status'] == 'done' and current_time - task.get('finish_time', current_time) > 1800:  # 30 minutes
                tasks_to_delete.append(task_id)
        for task_id in tasks_to_delete:
            del tasks[task_id]
        await asyncio.sleep(1800)  # Wait for 30 minutes

@app.get("/", response_class=HTMLResponse)
async def read_form(request: Request):
    # Render the form.html template
    return templates.TemplateResponse("form.html", {"request": request})

@app.post("/submit_form")
async def form_post(background_tasks: BackgroundTasks, dataset_name: str = Form(...), atlas_api_token: str = Form(...)):
    task_id = str(uuid4())
    tasks[task_id] = {'status': 'running'}
    #form_data = DatasetForm(dataset_name=dataset_name)
    background_tasks.add_task(upload_atlas_task, task_id, dataset_name, atlas_api_token)
    return {'task_id': task_id}

@app.get("/status/{task_id}")
async def read_task(task_id: str):
    if task_id not in tasks:
        return {'status': 'not found'}
    else:
        return tasks[task_id]

@app.post("/webhook")
async def post_webhook(background_tasks: BackgroundTasks, payload: WebhookPayload, x_webhook_secret:  Optional[str] = Header(default=None)):
    if x_webhook_secret is None:
        raise HTTPException(401)

    # if x_webhook_secret != WEBHOOK_SECRET:
    #     raise HTTPException(403)

    if not (
        payload.event.action == "update"
        and payload.event.scope.startswith("repo.content")
        and payload.repo.type == "dataset"
    ):
        return {"processed": False}
    else:
        task_id = str(uuid4())
        tasks[task_id] = {'status': 'running'}
        #form_data = DatasetForm(dataset_name=dataset_name)
        background_tasks.add_task(upload_atlas_task, task_id, payload.repo.name, x_webhook_secret, payload, True)
        return {'task_id': task_id}