g8a9's picture
refactor: streamline dataset and model handling with helper classes
fc63ec6
import gradio as gr
from typing import List, Tuple
import plotly.express as px
from huggingface_hub import snapshot_download
import os
import pdb
import logging
import pandas as pd
from config import LOCAL_RESULTS_DIR, CITATION_BUTTON_TEXT, DatasetHelper, ModelHelper
from parsing import read_all_configs
# Set up logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
handlers=[
# logging.FileHandler("app.log"),
logging.StreamHandler()
],
)
logger = logging.getLogger(__name__)
try:
print("Saving results locally at:", LOCAL_RESULTS_DIR)
snapshot_download(
repo_id="g8a9/fair-asr-results",
local_dir=LOCAL_RESULTS_DIR,
repo_type="dataset",
tqdm_class=None,
etag_timeout=30,
ignore_patterns=["*samples*", "*transcripts*"],
token=os.environ.get("TOKEN"),
)
except Exception as e:
raise e
def format_dataframe(df, times_100=False):
if times_100:
df = df.map(lambda x: (f"{x * 100:.3f}%" if isinstance(x, (int, float)) else x))
else:
df = df.map(lambda x: (f"{x:.4f}" if isinstance(x, (int, float)) else x))
return df
def _build_models_with_nan_md(models_with_nan):
model_markups = [f"*{m}*" for m in models_with_nan]
return f"""
We are currently hiding the results of {', '.join(model_markups)} because they don't support all languages.
"""
def build_components(show_common_langs, selected_datasets: List[str]):
aggregated_df, lang_dfs, barplot_figs, models_with_nan = _populate_components(
show_common_langs, selected_datasets
)
models_with_nan_md = _build_models_with_nan_md(models_with_nan)
return (
gr.DataFrame(format_dataframe(aggregated_df)),
gr.DataFrame(format_dataframe(lang_dfs[0], times_100=True)),
gr.DataFrame(format_dataframe(lang_dfs[1], times_100=True)),
gr.Plot(barplot_figs[0]),
gr.Plot(barplot_figs[1]),
gr.Markdown(models_with_nan_md, visible=len(models_with_nan) > 0),
)
def _populate_components(
show_common_langs: bool, selected_datasets: List[str], contrast_type: str = "F-M"
) -> Tuple[pd.DataFrame, List[pd.DataFrame], List[px.bar], List[str]]:
results = read_all_configs(contrast_type)
if show_common_langs:
common_langs = model_h.get_common_langs()
logger.info(f"Common langs: {common_langs}")
results = results[results["Language"].isin(common_langs)]
missing_langs = (
results[results.isna().any(axis=1)]
.groupby("Model")["Language"]
.apply(list)
.to_dict()
)
for model, langs in missing_langs.items():
logger.info(
f"Model {model} is missing results for languages: {', '.join(langs)}"
)
models_with_nan = results[results.isna().any(axis=1)]["Model"].unique().tolist()
logger.info(f"Models with NaN values: {models_with_nan}")
results = results[~results["Model"].isin(models_with_nan)]
type_dfs = list()
lang_dfs = list()
barplot_figs = list()
for type, type_df in results.groupby("Type"):
# Aggregate main
aggregated_df = type_df.pivot_table(
index="Model",
values="Gap",
aggfunc=lambda x: 100 * x.abs().sum(),
)
aggregated_df = aggregated_df.rename(columns={"Gap": f"Gap ({type})"})
type_dfs.append(aggregated_df)
best_model = aggregated_df.index[0]
top_3_models = aggregated_df.index[:3].tolist()
# Aggregate by language
lang_df = type_df.pivot_table(
index="Model",
values="Gap",
columns="Language",
).reset_index()
lang_dfs.append(lang_df)
# Create plot
type_df["Gap"] = type_df["Gap"] * 100
barplot_fig = px.bar(
type_df.loc[results["Model"].isin(top_3_models)],
x="Language",
y="Gap",
color="Model",
title=f"{type}: Gaps by Language and Model (top 3, sorted by the best model)",
labels={
"Gap": f"{contrast_type} Gap (%)",
"Language": "Language",
"Model": "Model",
},
barmode="group",
)
lang_order = (
lang_df.set_index("Model")
.loc[best_model]
.sort_values(ascending=False)
.index
)
logger.info(f"Lang order: {lang_order}")
barplot_fig.update_layout(
xaxis={"categoryorder": "array", "categoryarray": lang_order}
)
barplot_figs.append(barplot_fig)
# pdb.set_trace()
aggregated_df = pd.concat(type_dfs, axis=1, join="inner")
aggregated_df["Avg"] = aggregated_df.mean(axis=1)
aggregated_df = aggregated_df.sort_values("Avg").reset_index()
# lang_df = results.pivot_table(
# index="Model",
# values="Gap",
# columns="Language",
# ).reset_index()
# results["Gap"] = results["Gap"] * 100
# barplot_fig = px.bar(
# results.loc[results["Model"].isin(top_3_models)],
# x="Language",
# y="Gap",
# color="Model",
# title="Gaps by Language and Model (top 3, sorted by the best model)",
# labels={
# "Gap": "Sum of Absolute Gaps (%)",
# "Language": "Language",
# "Model": "Model",
# },
# barmode="group",
# )
# lang_order = (
# lang_df.set_index("Model").loc[best_model].sort_values(ascending=False).index
# )
# logger.info(f"Lang order: {lang_order}")
# barplot_fig.update_layout(
# xaxis={"categoryorder": "array", "categoryarray": lang_order}
# )
return aggregated_df, lang_dfs, barplot_figs, models_with_nan
dataset_h = DatasetHelper()
model_h = ModelHelper()
with gr.Blocks() as fm_interface:
aggregated_df, lang_dfs, barplot_figs, model_with_nan = _populate_components(
show_common_langs=False, selected_datasets=dataset_h.get_dataset_names()
)
model_with_nans_md = gr.Markdown(_build_models_with_nan_md(model_with_nan))
gr.Markdown("### Sum of Absolute Gaps ⬇️")
aggregated_df_comp = gr.DataFrame(format_dataframe(aggregated_df))
gr.Markdown("#### Read: gaps by language")
lang_df_comp_0 = gr.DataFrame(format_dataframe(lang_dfs[0], times_100=True))
barplot_fig_comp_0 = gr.Plot(barplot_figs[0])
gr.Markdown("#### Spontaneous: gaps by language")
lang_df_comp_1 = gr.DataFrame(format_dataframe(lang_dfs[1], times_100=True))
barplot_fig_comp_1 = gr.Plot(barplot_figs[1])
###################
# LIST MAIN TABS
###################
tabs = [fm_interface]
titles = ["F-M Setup"]
banner = """
<style>
.full-width-image {
width: 100%;
height: auto;
margin: 0;
padding: 0;
}
</style>
<div>
<img src="https://huggingface.co./spaces/g8a9/fair-asr-leaderboard/raw/main/twists_banner.png" alt="Twists Banner" class="full-width-image">
</div>
"""
###################
# MAIN INTERFACE
###################
with gr.Blocks() as demo:
gr.HTML(banner)
with gr.Row() as config_row:
show_common_langs = gr.CheckboxGroup(
choices=["Show only common languages"],
label="Main configuration",
)
datasets_names = dataset_h.get_dataset_names()
include_datasets = gr.CheckboxGroup(
choices=datasets_names,
label="Include datasets",
value=datasets_names,
interactive=False,
)
show_common_langs.input(
build_components,
inputs=[show_common_langs, include_datasets],
outputs=[
aggregated_df_comp,
lang_df_comp_0,
lang_df_comp_1,
barplot_fig_comp_0,
barplot_fig_comp_1,
model_with_nans_md,
],
)
gr.TabbedInterface(tabs, titles)
gr.Markdown(
"""
### Citation
If you find these results useful, please cite the following paper:
"""
)
gr.Markdown(
f"""```
{CITATION_BUTTON_TEXT}"""
)
if __name__ == "__main__":
demo.launch()