Spaces:
Runtime error
Runtime error
import streamlit as st | |
import pandas as pd | |
import os | |
import torch | |
from transformers import DistilBertTokenizerFast | |
from transformers import DistilBertForSequenceClassification | |
def getTop95(predictions): | |
for i in range(len(predictions)): | |
vals, ids = torch.topk(predictions, i) | |
if torch.sum(vals).item() >= 0.95: | |
return ids | |
st.set_page_config( | |
page_title="ArXiv classificator", | |
page_icon=":book:" | |
) | |
st.header("Theme classification of ArXiv articles") | |
st.markdown(""" | |
""") | |
with st.form(key='input_form'): | |
title = st.text_input(label='Enter title of the article here') | |
summary = st.text_area("Enter summary of the article here") | |
submit = st.form_submit_button(label='Analyze') | |
if submit and (title or summary): | |
with st.spinner(text='Oracul thinks, please wait for his wise prediction'): | |
classes = pd.read_csv('classes.csv') | |
tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-cased") | |
to_predict = title + '|' + summary | |
X = tokenizer(to_predict, truncation=True, padding=True) | |
tokens = torch.tensor(X['input_ids']).unsqueeze(0) | |
mask = torch.tensor(X['attention_mask']).unsqueeze(0) | |
model = DistilBertForSequenceClassification.from_pretrained( | |
os.getcwd(), | |
num_labels=len(classes) | |
) | |
model.eval() | |
logits = model(tokens, mask)[0][0] | |
softmax = torch.nn.Softmax() | |
predictions = softmax(logits) | |
ids = getTop95(predictions) | |
st.markdown("Most likely it is:") | |
for tag in classes.to_numpy()[ids[:5]]: | |
st.markdown(f"- {tag[1]}") | |
st.markdown("Other possible variants:") | |
st.write(', '.join(classes.tag.to_numpy()[ids[5:]])) | |
st.balloons() | |
hide_streamlit_style = """ | |
<style> | |
#MainMenu {visibility: hidden;} | |
footer {visibility: hidden;} | |
</style> | |
""" | |
st.markdown(hide_streamlit_style, unsafe_allow_html=True) | |