Spaces:
Runtime error
Runtime error
import numpy as np | |
import pandas as pd | |
import transformers | |
import torch | |
import tokenizers | |
import streamlit as st | |
NUM_LABELS = 15 | |
labels_names = { | |
0: 'Astrophysics', | |
1: 'Condensed Matter', | |
2: 'Computer Science', | |
3: 'Economics', | |
4: 'Electrical Engineering and Systems Science', | |
5: 'General Relativity and Quantum Cosmology', | |
6: 'High Energy Physics', | |
7: 'Mathematics', | |
8: 'Nonlinear Sciences', | |
9: 'Nuclear Theory', | |
10: 'General Physics', | |
11: 'Quantitative Biology', | |
12: 'Quantitative Finance', | |
13: 'Quantum Physics', | |
14: 'Statistics', | |
} | |
def get_model(model_name, model_path): | |
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name) | |
model = transformers.AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=NUM_LABELS) | |
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) | |
model.eval() | |
return model, tokenizer | |
def predict(text, tokenizer, model, temperature = 1): | |
tokens = tokenizer.encode(text) | |
with torch.no_grad(): | |
logits = model.cpu()(torch.as_tensor([tokens]))[0] | |
probs = torch.softmax(logits[-1, :] / temperature, dim=-1).data.cpu().numpy() | |
indexes_descending = np.argsort(probs)[::-1] | |
percents = 0 | |
preds = [] | |
pred_probs = [] | |
for index in indexes_descending: | |
preds.append(labels_names[index]) | |
pred_prob = 100 * probs[index] | |
pred_probs.append(f"{pred_prob:.1f}%") | |
percents += pred_prob | |
if percents >= 95: | |
break | |
result = pd.DataFrame({'Probability': pred_probs}) | |
result.index = preds | |
return result | |
model, tokenizer = get_model('distilbert-base-cased', 'distilbert-checkpoint-10983.bin') | |
st.title("Yandex School of Data Analysis. ML course") | |
st.title("Laboratory work 2: classifier of categories of scientific papers") | |
st.markdown("<img width=200px src='https://m.media-amazon.com/images/I/71XOMSKx8NL._AC_SL1500_.jpg'>", unsafe_allow_html=True) | |
st.markdown("\n") | |
st.markdown("Enter the title of the article and its abstract (although, if you really don't want to, you can do with just the title)") | |
title = st.text_area(label='Title of the article', height=100) | |
abstract = st.text_area(label='Abstract of the article', height=200) | |
button = st.button('Go') | |
if button: | |
try: | |
text = ' [ABSTRACT] '.join([title, abstract]) | |
result = predict(text, tokenizer, model) | |
if len(text) > 10: | |
st.subheader('Bumblebee thinks, this paper related to') | |
st.write(result) | |
else: | |
st.error("Enter some more info please") | |
except Exception: | |
st.error("Ooooops, something went wrong. Try again please and report to me, tg: @vladyur") | |