artnitolog's picture
add app.py
3e4f0f9
raw
history blame
1.99 kB
import json
import torch
from transformers import AutoTokenizer
from transformers import AutoModelForSequenceClassification
import gradio as gr
with open('tag_to_name.json', 'r') as fin:
tag_to_name = json.load(fin)
id_to_name = dict(zip(range(len(tag_to_name)), tag_to_name.values()))
tokenizer = AutoTokenizer.from_pretrained('checkpoints/checkpoint-5000/')
model = AutoModelForSequenceClassification.from_pretrained(
'checkpoints/checkpoint-5000/')
model.eval()
TITLE_DEFAULT = "Attention Is All You Need"
@torch.no_grad()
def predict(title, abstract, top_p, top_k):
if abstract:
text = title + "[SEP]" + abstract
else:
text = title
tokenized = tokenizer(text, truncation=True,
return_tensors='pt', max_length=512)
probs = model(**tokenized).logits[0].softmax(0)
top_probs, top_inds = probs.sort(descending=True)
mask = top_probs.cumsum(0) <= top_p
if not mask.all():
mask[mask.sum()] = True
mask[top_k:] = False
mask[0] = True
predicted_ids = top_inds[mask].tolist()
predicted_probs = top_probs[mask].tolist()
predicted_names = [id_to_name[id_] for id_ in predicted_ids]
return {name: prob for name, prob in zip(predicted_names, predicted_probs)}
def inference(
title,
abstract,
top_p,
top_k,
):
if not title:
title = TITLE_DEFAULT
return predict(title, abstract, top_p, top_k)
g = gr.Interface(
fn=inference,
inputs=[
gr.components.Textbox(
lines=2, label="Title", placeholder=TITLE_DEFAULT
),
gr.components.Textbox(lines=4, label="Abstract", placeholder=""),
gr.components.Slider(minimum=0, maximum=1, value=0.95, label="Top p"),
gr.components.Slider(minimum=1, maximum=len(tag_to_name),
step=1, value=10, label="Top n"),
],
outputs=gr.outputs.Label(label="Predicted categories"),
title="πŸͺ„ arXiv classifier πŸͺ„",
)
g.launch()