Richard Guo
commited on
Commit
·
81aaa4e
1
Parent(s):
f47c911
huggingface cli requirement and webhook route
Browse files- main.py +38 -11
- requirements.txt +1 -0
main.py
CHANGED
@@ -1,25 +1,48 @@
|
|
1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
from fastapi.responses import HTMLResponse
|
3 |
from fastapi.templating import Jinja2Templates
|
4 |
-
|
5 |
-
from uuid import uuid4
|
6 |
-
import time
|
7 |
-
import asyncio
|
8 |
|
9 |
from build_map import load_dataset_and_metadata, upload_dataset_to_atlas
|
10 |
-
from models import WebhookPayload
|
11 |
|
|
|
|
|
12 |
|
13 |
app = FastAPI()
|
14 |
# TODO: use task management queue
|
15 |
tasks = {}
|
16 |
templates = Jinja2Templates(directory="templates")
|
17 |
|
18 |
-
def upload_atlas_task(task_id,
|
|
|
|
|
|
|
19 |
dataset_dict = load_dataset_and_metadata(dataset_name)
|
20 |
-
map_url = upload_dataset_to_atlas(dataset_dict
|
21 |
tasks[task_id]['status'] = 'done'
|
22 |
tasks[task_id]['url'] = map_url
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
@app.on_event("startup")
|
25 |
async def startup_event():
|
@@ -47,7 +70,6 @@ async def form_post(background_tasks: BackgroundTasks, dataset_name: str = Form(
|
|
47 |
tasks[task_id] = {'status': 'running'}
|
48 |
#form_data = DatasetForm(dataset_name=dataset_name)
|
49 |
background_tasks.add_task(upload_atlas_task, task_id, dataset_name)
|
50 |
-
|
51 |
return {'task_id': task_id}
|
52 |
|
53 |
@app.get("/status/{task_id}")
|
@@ -58,7 +80,12 @@ async def read_task(task_id: str):
|
|
58 |
return tasks[task_id]
|
59 |
|
60 |
@app.post("/webhook")
|
61 |
-
async def post_webhook(background_tasks: BackgroundTasks, payload: WebhookPayload):
|
|
|
|
|
|
|
|
|
|
|
62 |
if not (
|
63 |
payload.event.action == "update"
|
64 |
and payload.event.scope.startswith("repo.content")
|
@@ -69,5 +96,5 @@ async def post_webhook(background_tasks: BackgroundTasks, payload: WebhookPayloa
|
|
69 |
task_id = str(uuid4())
|
70 |
tasks[task_id] = {'status': 'running'}
|
71 |
#form_data = DatasetForm(dataset_name=dataset_name)
|
72 |
-
background_tasks.add_task(upload_atlas_task, task_id, payload.repo.name)
|
73 |
return {'task_id': task_id}
|
|
|
1 |
+
import asyncio
|
2 |
+
import os
|
3 |
+
import time
|
4 |
+
from typing import Optional
|
5 |
+
from uuid import uuid4
|
6 |
+
|
7 |
+
from fastapi import FastAPI, Form, Header, HTTPException, Request, BackgroundTasks
|
8 |
from fastapi.responses import HTMLResponse
|
9 |
from fastapi.templating import Jinja2Templates
|
10 |
+
from huggingface_hub import create_discussion, comment_discussion
|
|
|
|
|
|
|
11 |
|
12 |
from build_map import load_dataset_and_metadata, upload_dataset_to_atlas
|
13 |
+
from models import WebhookPayload
|
14 |
|
15 |
+
WEBHOOK_SECRET = os.environ.get("WEBHOOK_SECRET")
|
16 |
+
HUGGINGFACE_ACCESS_TOKEN = os.environ.get("HUGGINGFACE_ACCESS_TOKEN")
|
17 |
|
18 |
app = FastAPI()
|
19 |
# TODO: use task management queue
|
20 |
tasks = {}
|
21 |
templates = Jinja2Templates(directory="templates")
|
22 |
|
23 |
+
def upload_atlas_task(task_id,
|
24 |
+
dataset_name,
|
25 |
+
webhook_payload: WebhookPayload = None,
|
26 |
+
webhook_notify: bool = False):
|
27 |
dataset_dict = load_dataset_and_metadata(dataset_name)
|
28 |
+
map_url = upload_dataset_to_atlas(dataset_dict)
|
29 |
tasks[task_id]['status'] = 'done'
|
30 |
tasks[task_id]['url'] = map_url
|
31 |
+
tasks[task_id]['finish_time'] = time.time()
|
32 |
+
|
33 |
+
if webhook_notify:
|
34 |
+
discussion = create_discussion(
|
35 |
+
repo_id=webhook_payload.repo.id,
|
36 |
+
title="Atlas Maps",
|
37 |
+
token=HUGGINGFACE_ACCESS_TOKEN,
|
38 |
+
)
|
39 |
+
comment_discussion(
|
40 |
+
repo_id=webhook_payload.repo.id,
|
41 |
+
discussion_num=discussion.num,
|
42 |
+
comment="Atlas Map: " + map_url,
|
43 |
+
token=HUGGINGFACE_ACCESS_TOKEN
|
44 |
+
)
|
45 |
+
|
46 |
|
47 |
@app.on_event("startup")
|
48 |
async def startup_event():
|
|
|
70 |
tasks[task_id] = {'status': 'running'}
|
71 |
#form_data = DatasetForm(dataset_name=dataset_name)
|
72 |
background_tasks.add_task(upload_atlas_task, task_id, dataset_name)
|
|
|
73 |
return {'task_id': task_id}
|
74 |
|
75 |
@app.get("/status/{task_id}")
|
|
|
80 |
return tasks[task_id]
|
81 |
|
82 |
@app.post("/webhook")
|
83 |
+
async def post_webhook(background_tasks: BackgroundTasks, payload: WebhookPayload, x_webhook_secret: Optional[str] = Header(default=None)):
|
84 |
+
if x_webhook_secret is None:
|
85 |
+
raise HTTPException(401)
|
86 |
+
if x_webhook_secret != WEBHOOK_SECRET:
|
87 |
+
raise HTTPException(403)
|
88 |
+
|
89 |
if not (
|
90 |
payload.event.action == "update"
|
91 |
and payload.event.scope.startswith("repo.content")
|
|
|
96 |
task_id = str(uuid4())
|
97 |
tasks[task_id] = {'status': 'running'}
|
98 |
#form_data = DatasetForm(dataset_name=dataset_name)
|
99 |
+
background_tasks.add_task(upload_atlas_task, task_id, payload.repo.name, payload, True)
|
100 |
return {'task_id': task_id}
|
requirements.txt
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
datasets==2.13.0
|
2 |
fastapi[all]
|
|
|
3 |
nomic==2.0.3
|
4 |
pandas==1.5.3
|
5 |
pyarrow==12.0.1
|
|
|
1 |
datasets==2.13.0
|
2 |
fastapi[all]
|
3 |
+
huggingface-hub==0.16.4
|
4 |
nomic==2.0.3
|
5 |
pandas==1.5.3
|
6 |
pyarrow==12.0.1
|