Spaces:
Running
Running
import gradio as gr | |
from typing import List, Tuple | |
import plotly.express as px | |
from huggingface_hub import snapshot_download | |
import os | |
import pdb | |
import logging | |
import pandas as pd | |
from config import LOCAL_RESULTS_DIR, CITATION_BUTTON_TEXT, DatasetHelper, ModelHelper | |
from parsing import read_all_configs | |
# Set up logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", | |
handlers=[ | |
# logging.FileHandler("app.log"), | |
logging.StreamHandler() | |
], | |
) | |
logger = logging.getLogger(__name__) | |
try: | |
print("Saving results locally at:", LOCAL_RESULTS_DIR) | |
snapshot_download( | |
repo_id="g8a9/fair-asr-results", | |
local_dir=LOCAL_RESULTS_DIR, | |
repo_type="dataset", | |
tqdm_class=None, | |
etag_timeout=30, | |
ignore_patterns=["*samples*", "*transcripts*"], | |
token=os.environ.get("TOKEN"), | |
) | |
except Exception as e: | |
raise e | |
def format_dataframe(df, times_100=False): | |
if times_100: | |
df = df.map(lambda x: (f"{x * 100:.3f}%" if isinstance(x, (int, float)) else x)) | |
else: | |
df = df.map(lambda x: (f"{x:.4f}" if isinstance(x, (int, float)) else x)) | |
return df | |
def _build_models_with_nan_md(models_with_nan): | |
model_markups = [f"*{m}*" for m in models_with_nan] | |
return f""" | |
We are currently hiding the results of {', '.join(model_markups)} because they don't support all languages. | |
""" | |
def build_components(show_common_langs, selected_datasets: List[str]): | |
aggregated_df, lang_dfs, barplot_figs, models_with_nan = _populate_components( | |
show_common_langs, selected_datasets | |
) | |
models_with_nan_md = _build_models_with_nan_md(models_with_nan) | |
return ( | |
gr.DataFrame(format_dataframe(aggregated_df)), | |
gr.DataFrame(format_dataframe(lang_dfs[0], times_100=True)), | |
gr.DataFrame(format_dataframe(lang_dfs[1], times_100=True)), | |
gr.Plot(barplot_figs[0]), | |
gr.Plot(barplot_figs[1]), | |
gr.Markdown(models_with_nan_md, visible=len(models_with_nan) > 0), | |
) | |
def _populate_components( | |
show_common_langs: bool, selected_datasets: List[str], contrast_type: str = "F-M" | |
) -> Tuple[pd.DataFrame, List[pd.DataFrame], List[px.bar], List[str]]: | |
results = read_all_configs(contrast_type) | |
if show_common_langs: | |
common_langs = model_h.get_common_langs() | |
logger.info(f"Common langs: {common_langs}") | |
results = results[results["Language"].isin(common_langs)] | |
missing_langs = ( | |
results[results.isna().any(axis=1)] | |
.groupby("Model")["Language"] | |
.apply(list) | |
.to_dict() | |
) | |
for model, langs in missing_langs.items(): | |
logger.info( | |
f"Model {model} is missing results for languages: {', '.join(langs)}" | |
) | |
models_with_nan = results[results.isna().any(axis=1)]["Model"].unique().tolist() | |
logger.info(f"Models with NaN values: {models_with_nan}") | |
results = results[~results["Model"].isin(models_with_nan)] | |
type_dfs = list() | |
lang_dfs = list() | |
barplot_figs = list() | |
for type, type_df in results.groupby("Type"): | |
# Aggregate main | |
aggregated_df = type_df.pivot_table( | |
index="Model", | |
values="Gap", | |
aggfunc=lambda x: 100 * x.abs().sum(), | |
) | |
aggregated_df = aggregated_df.rename(columns={"Gap": f"Gap ({type})"}) | |
type_dfs.append(aggregated_df) | |
best_model = aggregated_df.index[0] | |
top_3_models = aggregated_df.index[:3].tolist() | |
# Aggregate by language | |
lang_df = type_df.pivot_table( | |
index="Model", | |
values="Gap", | |
columns="Language", | |
).reset_index() | |
lang_dfs.append(lang_df) | |
# Create plot | |
type_df["Gap"] = type_df["Gap"] * 100 | |
barplot_fig = px.bar( | |
type_df.loc[results["Model"].isin(top_3_models)], | |
x="Language", | |
y="Gap", | |
color="Model", | |
title=f"{type}: Gaps by Language and Model (top 3, sorted by the best model)", | |
labels={ | |
"Gap": f"{contrast_type} Gap (%)", | |
"Language": "Language", | |
"Model": "Model", | |
}, | |
barmode="group", | |
) | |
lang_order = ( | |
lang_df.set_index("Model") | |
.loc[best_model] | |
.sort_values(ascending=False) | |
.index | |
) | |
logger.info(f"Lang order: {lang_order}") | |
barplot_fig.update_layout( | |
xaxis={"categoryorder": "array", "categoryarray": lang_order} | |
) | |
barplot_figs.append(barplot_fig) | |
# pdb.set_trace() | |
aggregated_df = pd.concat(type_dfs, axis=1, join="inner") | |
aggregated_df["Avg"] = aggregated_df.mean(axis=1) | |
aggregated_df = aggregated_df.sort_values("Avg").reset_index() | |
# lang_df = results.pivot_table( | |
# index="Model", | |
# values="Gap", | |
# columns="Language", | |
# ).reset_index() | |
# results["Gap"] = results["Gap"] * 100 | |
# barplot_fig = px.bar( | |
# results.loc[results["Model"].isin(top_3_models)], | |
# x="Language", | |
# y="Gap", | |
# color="Model", | |
# title="Gaps by Language and Model (top 3, sorted by the best model)", | |
# labels={ | |
# "Gap": "Sum of Absolute Gaps (%)", | |
# "Language": "Language", | |
# "Model": "Model", | |
# }, | |
# barmode="group", | |
# ) | |
# lang_order = ( | |
# lang_df.set_index("Model").loc[best_model].sort_values(ascending=False).index | |
# ) | |
# logger.info(f"Lang order: {lang_order}") | |
# barplot_fig.update_layout( | |
# xaxis={"categoryorder": "array", "categoryarray": lang_order} | |
# ) | |
return aggregated_df, lang_dfs, barplot_figs, models_with_nan | |
dataset_h = DatasetHelper() | |
model_h = ModelHelper() | |
with gr.Blocks() as fm_interface: | |
aggregated_df, lang_dfs, barplot_figs, model_with_nan = _populate_components( | |
show_common_langs=False, selected_datasets=dataset_h.get_dataset_names() | |
) | |
model_with_nans_md = gr.Markdown(_build_models_with_nan_md(model_with_nan)) | |
gr.Markdown("### Sum of Absolute Gaps ⬇️") | |
aggregated_df_comp = gr.DataFrame(format_dataframe(aggregated_df)) | |
gr.Markdown("#### Read: gaps by language") | |
lang_df_comp_0 = gr.DataFrame(format_dataframe(lang_dfs[0], times_100=True)) | |
barplot_fig_comp_0 = gr.Plot(barplot_figs[0]) | |
gr.Markdown("#### Spontaneous: gaps by language") | |
lang_df_comp_1 = gr.DataFrame(format_dataframe(lang_dfs[1], times_100=True)) | |
barplot_fig_comp_1 = gr.Plot(barplot_figs[1]) | |
################### | |
# LIST MAIN TABS | |
################### | |
tabs = [fm_interface] | |
titles = ["F-M Setup"] | |
banner = """ | |
<style> | |
.full-width-image { | |
width: 100%; | |
height: auto; | |
margin: 0; | |
padding: 0; | |
} | |
</style> | |
<div> | |
<img src="https://huggingface.co./spaces/g8a9/fair-asr-leaderboard/raw/main/twists_banner.png" alt="Twists Banner" class="full-width-image"> | |
</div> | |
""" | |
################### | |
# MAIN INTERFACE | |
################### | |
with gr.Blocks() as demo: | |
gr.HTML(banner) | |
with gr.Row() as config_row: | |
show_common_langs = gr.CheckboxGroup( | |
choices=["Show only common languages"], | |
label="Main configuration", | |
) | |
datasets_names = dataset_h.get_dataset_names() | |
include_datasets = gr.CheckboxGroup( | |
choices=datasets_names, | |
label="Include datasets", | |
value=datasets_names, | |
interactive=False, | |
) | |
show_common_langs.input( | |
build_components, | |
inputs=[show_common_langs, include_datasets], | |
outputs=[ | |
aggregated_df_comp, | |
lang_df_comp_0, | |
lang_df_comp_1, | |
barplot_fig_comp_0, | |
barplot_fig_comp_1, | |
model_with_nans_md, | |
], | |
) | |
gr.TabbedInterface(tabs, titles) | |
gr.Markdown( | |
""" | |
### Citation | |
If you find these results useful, please cite the following paper: | |
""" | |
) | |
gr.Markdown( | |
f"""``` | |
{CITATION_BUTTON_TEXT}""" | |
) | |
if __name__ == "__main__": | |
demo.launch() | |