artnitolog commited on
Commit
3e4f0f9
·
1 Parent(s): 34f19e6

add app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -0
app.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import torch
3
+ from transformers import AutoTokenizer
4
+ from transformers import AutoModelForSequenceClassification
5
+ import gradio as gr
6
+
7
+ with open('tag_to_name.json', 'r') as fin:
8
+ tag_to_name = json.load(fin)
9
+ id_to_name = dict(zip(range(len(tag_to_name)), tag_to_name.values()))
10
+
11
+ tokenizer = AutoTokenizer.from_pretrained('checkpoints/checkpoint-5000/')
12
+ model = AutoModelForSequenceClassification.from_pretrained(
13
+ 'checkpoints/checkpoint-5000/')
14
+ model.eval()
15
+
16
+ TITLE_DEFAULT = "Attention Is All You Need"
17
+
18
+
19
+ @torch.no_grad()
20
+ def predict(title, abstract, top_p, top_k):
21
+ if abstract:
22
+ text = title + "[SEP]" + abstract
23
+ else:
24
+ text = title
25
+ tokenized = tokenizer(text, truncation=True,
26
+ return_tensors='pt', max_length=512)
27
+ probs = model(**tokenized).logits[0].softmax(0)
28
+ top_probs, top_inds = probs.sort(descending=True)
29
+ mask = top_probs.cumsum(0) <= top_p
30
+ if not mask.all():
31
+ mask[mask.sum()] = True
32
+ mask[top_k:] = False
33
+ mask[0] = True
34
+ predicted_ids = top_inds[mask].tolist()
35
+ predicted_probs = top_probs[mask].tolist()
36
+ predicted_names = [id_to_name[id_] for id_ in predicted_ids]
37
+ return {name: prob for name, prob in zip(predicted_names, predicted_probs)}
38
+
39
+
40
+ def inference(
41
+ title,
42
+ abstract,
43
+ top_p,
44
+ top_k,
45
+ ):
46
+ if not title:
47
+ title = TITLE_DEFAULT
48
+ return predict(title, abstract, top_p, top_k)
49
+
50
+
51
+ g = gr.Interface(
52
+ fn=inference,
53
+ inputs=[
54
+ gr.components.Textbox(
55
+ lines=2, label="Title", placeholder=TITLE_DEFAULT
56
+ ),
57
+ gr.components.Textbox(lines=4, label="Abstract", placeholder=""),
58
+ gr.components.Slider(minimum=0, maximum=1, value=0.95, label="Top p"),
59
+ gr.components.Slider(minimum=1, maximum=len(tag_to_name),
60
+ step=1, value=10, label="Top n"),
61
+ ],
62
+ outputs=gr.outputs.Label(label="Predicted categories"),
63
+ title="🪄 arXiv classifier 🪄",
64
+ )
65
+ g.launch()