Shield's picture
Update app.py
4d87157
raw
history blame
1.95 kB
import streamlit as st
import pandas as pd
import os
import torch
from transformers import DistilBertTokenizerFast
from transformers import DistilBertForSequenceClassification
def getTop95(predictions):
for i in range(len(predictions)):
vals, ids = torch.topk(predictions, i)
if torch.sum(vals).item() >= 0.95:
return ids
st.set_page_config(
page_title="ArXiv classificator",
page_icon=":book:"
)
st.header("Theme classification of ArXiv articles")
st.markdown("""
""")
with st.form(key='input_form'):
title = st.text_input(label='Enter title of the article here')
summary = st.text_area("Enter summary of the article here")
submit = st.form_submit_button(label='Analyze')
if submit and (title or summary):
with st.spinner(text='Oracul thinks, please wait for his wise prediction'):
classes = pd.read_csv('classes.csv')
tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-cased")
to_predict = title + '|' + summary
X = tokenizer(to_predict, truncation=True, padding=True)
tokens = torch.tensor(X['input_ids']).unsqueeze(0)
mask = torch.tensor(X['attention_mask']).unsqueeze(0)
model = DistilBertForSequenceClassification.from_pretrained(
os.getcwd(),
num_labels=len(classes)
)
model.eval()
logits = model(tokens, mask)[0][0]
softmax = torch.nn.Softmax()
predictions = softmax(logits)
ids = getTop95(predictions)
st.markdown("Most likely it is:")
for tag in classes.to_numpy()[ids[:5]]:
st.markdown(f"- {tag[1]}")
st.markdown("Other possible variants:")
st.write(', '.join(classes.tag.to_numpy()[ids[5:]]))
st.balloons()
hide_streamlit_style = """
<style>
#MainMenu {visibility: hidden;}
footer {visibility: hidden;}
</style>
"""
st.markdown(hide_streamlit_style, unsafe_allow_html=True)