rajistics's picture
Update app.py
4ed6a65
import os
import pandas as pd
import streamlit as st
from PIL import Image
from streamlit import components
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from transformers_interpret import SequenceClassificationExplainer
@st.cache(allow_output_mutation=True, suppress_st_warning=True, max_entries=1)
def load_model(model_name):
return (
AutoModelForSequenceClassification.from_pretrained(model_name),
AutoTokenizer.from_pretrained(model_name),
)
st.title("Transformers Interpet Demo App")
image = Image.open("./images/tight@1920x_transparent.png")
st.sidebar.image(image, use_column_width=True)
st.sidebar.markdown(
"Check out the package on [Github](https://github.com/cdpierse/transformers-interpret)"
)
# uncomment the options below to test out the app with a variety of classification models.
models = {
"mrm8488/bert-mini-finetuned-age_news-classification": "BERT-Mini finetuned on AG News dataset. Predicts news class (sports/tech/business/world) of text.",
"nateraw/bert-base-uncased-ag-news": "BERT finetuned on AG News dataset. Predicts news class (sports/tech/business/world) of text.",
"distilbert-base-uncased-finetuned-sst-2-english": "DistilBERT model finetuned on SST-2 sentiment analysis task. Predicts positive/negative sentiment.",
"ProsusAI/finbert": "BERT model finetuned to predict sentiment of financial text. Finetuned on Financial PhraseBank data. Predicts positive/negative/neutral.",
"sampathkethineedi/industry-classification": "DistilBERT Model to classify a business description into one of 62 industry tags.",
"MoritzLaurer/policy-distilbert-7d": "DistilBERT model finetuned to classify text into one of seven political categories.",
# # "MoritzLaurer/covid-policy-roberta-21": "(Under active development ) RoBERTA model finetuned to identify COVID policy measure classes ",
"mrm8488/bert-tiny-finetuned-sms-spam-detection": "Tiny bert model finetuned for spam detection. 0 == not spam, 1 == spam",
}
model_name = st.sidebar.selectbox(
"Choose a classification model", list(models.keys())
)
model, tokenizer = load_model(model_name)
model.eval()
cls_explainer = SequenceClassificationExplainer(model=model, tokenizer=tokenizer)
if cls_explainer.accepts_position_ids:
emb_type_name = st.sidebar.selectbox(
"Choose embedding type for attribution.", ["word", "position"]
)
if emb_type_name == "word":
emb_type_num = 0
if emb_type_name == "position":
emb_type_num = 1
else:
emb_type_num = 0
explanation_classes = ["predicted"] + list(model.config.label2id.keys())
explanation_class_choice = st.sidebar.selectbox(
"Explanation class: The class you would like to explain output with respect to.",
explanation_classes,
)
my_expander = st.expander(
"Click here for a description of models and their tasks"
)
with my_expander:
st.json(models)
# st.info("Max char limit of 350 (memory management)")
text = st.text_area(
"Enter text to be interpreted",
"I like you, I love you",
height=400,
max_chars=850,
)
if st.button("Interpret Text"):
st.text("Output")
with st.spinner("Interpreting your text (This may take some time)"):
print ("Interpreting text")
if explanation_class_choice != "predicted":
word_attributions = cls_explainer(
text,
class_name=explanation_class_choice,
embedding_type=emb_type_num,
internal_batch_size=2,
)
else:
word_attributions = cls_explainer(
text, embedding_type=emb_type_num, internal_batch_size=2
)
if word_attributions:
print ("Word Attributions")
word_attributions_expander = st.expander(
"Click here for raw word attributions"
)
with word_attributions_expander:
st.json(word_attributions)
components.v1.html(
cls_explainer.visualize()._repr_html_(), scrolling=True, height=350
)