|
"""This page contains all misclassified examples and allows filtering by specific error types.""" |
|
from collections import defaultdict |
|
|
|
import pandas as pd |
|
import streamlit as st |
|
from sklearn.metrics import confusion_matrix |
|
|
|
from src.subpages.page import Context, Page |
|
from src.utils import htmlify_labeled_example |
|
|
|
|
|
class MisclassifiedPage(Page): |
|
name = "Misclassified" |
|
icon = "x-octagon" |
|
|
|
def render(self, context: Context): |
|
st.title(self.name) |
|
with st.expander("💡", expanded=True): |
|
st.write( |
|
"This page contains all misclassified examples and allows filtering by specific error types." |
|
) |
|
|
|
misclassified_indices = context.df_tokens_merged.query("labels != preds").index.unique() |
|
misclassified_samples = context.df_tokens_merged.loc[misclassified_indices] |
|
cm = confusion_matrix( |
|
misclassified_samples.labels, |
|
misclassified_samples.preds, |
|
labels=context.labels, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
df = pd.DataFrame(cm, index=context.labels, columns=context.labels).astype(str) |
|
import numpy as np |
|
|
|
np.fill_diagonal(df.values, "") |
|
st.dataframe(df.applymap(lambda x: x if x != "0" else "")) |
|
|
|
|
|
|
|
|
|
|
|
|
|
confusions = defaultdict(int) |
|
for i, row in enumerate(cm): |
|
for j, _ in enumerate(row): |
|
if i == j or cm[i][j] == 0: |
|
continue |
|
confusions[(context.labels[i], context.labels[j])] += cm[i][j] |
|
|
|
def format_func(item): |
|
return ( |
|
f"true: {item[0][0]} <> pred: {item[0][1]} ||| count: {item[1]}" if item else "All" |
|
) |
|
|
|
conf = st.radio( |
|
"Filter by Class Confusion", |
|
options=list(zip(confusions.keys(), confusions.values())), |
|
format_func=format_func, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
filtered_indices = misclassified_samples.query( |
|
f"labels == '{conf[0][0]}' and preds == '{conf[0][1]}'" |
|
).index |
|
for i, idx in enumerate(filtered_indices): |
|
sample = context.df_tokens_merged.loc[idx] |
|
st.write( |
|
htmlify_labeled_example(sample), |
|
unsafe_allow_html=True, |
|
) |
|
st.write("---") |
|
|