ferret / corpus.py
g8a9's picture
Update corpus.py (#2)
bb4d707
raw
history blame
5.94 kB
from ctypes import DEFAULT_MODE
import streamlit as st
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig
from ferret import Benchmark
from torch.nn.functional import softmax
from copy import deepcopy
DEFAULT_MODEL = "Hate-speech-CNERG/bert-base-uncased-hatexplain"
DEFAULT_SAMPLES = "3,5,8,13,15,17,18,25,27,28"
@st.cache()
def get_model(model_name):
return AutoModelForSequenceClassification.from_pretrained(model_name)
@st.cache()
def get_config(model_name):
return AutoConfig.from_pretrained(model_name)
def get_tokenizer(tokenizer_name):
return AutoTokenizer.from_pretrained(tokenizer_name, use_fast=True)
def body():
st.title("Evaluate explanations on dataset samples")
st.markdown(
"""
Let's test how our built-in explainers behave on state-of-the-art datasets for explanability.
*ferret* exposes an extensible Dataset API. We currently implement [MovieReviews](https://huggingface.co./datasets/movie_rationales) and [HateXPlain](https://huggingface.co./datasets/hatexplain).
In this demo, you let you experiment with HateXPlain.
You just need to choose a prediction model and a set of samples to test.
We will trigger *ferret* to:
1. download the model;
2. explain every sample you did choose;
3. average all faithfulness and plausibility metrics we support 📊
"""
)
col1, col2 = st.columns([3, 1])
with col1:
model_name = st.text_input("HF Model", DEFAULT_MODEL)
config = AutoConfig.from_pretrained(model_name)
with col2:
class_labels = list(config.id2label.values())
target = st.selectbox(
"Target",
options=class_labels,
index=0,
help="Class label you want to explain.",
)
samples_string = st.text_input(
"List of samples",
DEFAULT_SAMPLES,
help="List of indices in the dataset, comma-separated.",
)
compute = st.button("Run")
samples = list(map(int, samples_string.replace(" ", "").split(",")))
if compute and model_name:
with st.spinner("Preparing the magic. Hang in there..."):
model = get_model(model_name)
tokenizer = get_tokenizer(model_name)
bench = Benchmark(model, tokenizer)
with st.spinner("Explaining sample (this might take a while)..."):
@st.cache(allow_output_mutation=True)
def compute_table(samples):
data = bench.load_dataset("hatexplain")
sample_evaluations = bench.evaluate_samples(
data, samples, target=class_labels.index(target)
)
table = bench.show_samples_evaluation_table(sample_evaluations).format(
"{:.2f}"
)
return table
table = compute_table(samples)
st.markdown("### Averaged metrics")
st.dataframe(table)
st.caption("Darker colors mean better performance.")
# scores = bench.score(text)
# scores_str = ", ".join(
# [f"{config.id2label[l]}: {s:.2f}" for l, s in enumerate(scores)]
# )
# st.text(scores_str)
# with st.spinner("Computing Explanations.."):
# explanations = bench.explain(text, target=class_labels.index(target))
# st.markdown("### Explanations")
# st.dataframe(bench.show_table(explanations))
# st.caption("Darker red (blue) means higher (lower) contribution.")
# with st.spinner("Evaluating Explanations..."):
# evaluations = bench.evaluate_explanations(
# explanations, target=class_labels.index(target), apply_style=False
# )
# st.markdown("### Faithfulness Metrics")
# st.dataframe(bench.show_evaluation_table(evaluations))
# st.caption("Darker colors mean better performance.")
st.markdown(
"""
**Legend**
**Faithfulness**
- **AOPC Comprehensiveness** (aopc_compr) measures *comprehensiveness*, i.e., if the explanation captures all the tokens needed to make the prediction. Higher is better.
- **AOPC Sufficiency** (aopc_suff) measures *sufficiency*, i.e., if the relevant tokens in the explanation are sufficient to make the prediction. Lower is better.
- **Leave-On-Out TAU Correlation** (taucorr_loo) measures the Kendall rank correlation coefficient τ between the explanation and leave-one-out importances. Closer to 1 is better.
**Plausibility**
- **AUPRC plausibility** (auprc_plau) is the area under the precision-recall curve (AUPRC) of the explanation and the rationale as ground truth. Higher is better.
- **Intersection-Over-Union (IOU)** (token_iou_plau) is the size of the overlap of the most relevant tokens of the explanation and the human rationale divided by the size of their union. Higher is better.
- **Token-level F1 score** (token_f1_plau) measures the F1 score among the most relevant tokens and the human rationale. Higher is better.
See the paper for details.
"""
)
st.markdown(
"""
**In code, it would be as simple as**
"""
)
st.code(
f"""
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from ferret import Benchmark
model = AutoModelForSequenceClassification.from_pretrained("{model_name}")
tokenizer = AutoTokenizer.from_pretrained("{model_name}")
bench = Benchmark(model, tokenizer)
data = bench.load_dataset("hatexplain")
evaluations = bench.evaluate_samples(data, {samples})
bench.show_samples_evaluation_table(evaluations)
""",
language="python",
)