Spaces:
Running
Running
refactor: streamline dataset and model handling with helper classes
Browse files- app.py +136 -75
- config.py +24 -59
- parsing.py +20 -25
- requirements.txt +2 -1
app.py
CHANGED
@@ -1,18 +1,14 @@
|
|
1 |
import gradio as gr
|
2 |
-
import
|
3 |
-
import random
|
4 |
import plotly.express as px
|
5 |
from huggingface_hub import snapshot_download
|
6 |
import os
|
|
|
7 |
import logging
|
|
|
8 |
|
9 |
-
from config import
|
10 |
-
|
11 |
-
LOCAL_RESULTS_DIR,
|
12 |
-
CITATION_BUTTON_TEXT,
|
13 |
-
CITATION_BUTTON_LABEL,
|
14 |
-
)
|
15 |
-
from parsing import read_all_configs, get_common_langs
|
16 |
|
17 |
# Set up logging
|
18 |
logging.basicConfig(
|
@@ -57,27 +53,30 @@ We are currently hiding the results of {', '.join(model_markups)} because they d
|
|
57 |
"""
|
58 |
|
59 |
|
60 |
-
def build_components(show_common_langs):
|
61 |
-
aggregated_df,
|
62 |
-
show_common_langs
|
63 |
)
|
64 |
models_with_nan_md = _build_models_with_nan_md(models_with_nan)
|
65 |
|
66 |
return (
|
67 |
gr.DataFrame(format_dataframe(aggregated_df)),
|
68 |
-
gr.DataFrame(format_dataframe(
|
69 |
-
gr.
|
|
|
|
|
70 |
gr.Markdown(models_with_nan_md, visible=len(models_with_nan) > 0),
|
71 |
)
|
72 |
|
73 |
|
74 |
-
def _populate_components(
|
75 |
-
|
76 |
-
|
77 |
-
|
|
|
78 |
|
79 |
if show_common_langs:
|
80 |
-
common_langs = get_common_langs()
|
81 |
logger.info(f"Common langs: {common_langs}")
|
82 |
results = results[results["Language"].isin(common_langs)]
|
83 |
|
@@ -96,64 +95,116 @@ def _populate_components(show_common_langs):
|
|
96 |
logger.info(f"Models with NaN values: {models_with_nan}")
|
97 |
results = results[~results["Model"].isin(models_with_nan)]
|
98 |
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
.reset_index()
|
104 |
-
.sort_values("Gap")
|
105 |
-
)
|
106 |
-
best_model = aggregated_df.iloc[0]["Model"]
|
107 |
-
top_3_models = aggregated_df["Model"].head(3).tolist()
|
108 |
-
# main_df = gr.DataFrame(format_dataframe(model_results))
|
109 |
-
|
110 |
-
lang_df = results.pivot_table(
|
111 |
-
index="Model",
|
112 |
-
values="Gap",
|
113 |
-
columns="Language",
|
114 |
-
).reset_index()
|
115 |
-
# lang_df = gr.DataFrame(format_dataframe(lang_results, times_100=True))
|
116 |
-
|
117 |
-
# gr.Plot(fig1)
|
118 |
-
results["Gap"] = results["Gap"] * 100
|
119 |
-
barplot_fig = px.bar(
|
120 |
-
results.loc[results["Model"].isin(top_3_models)],
|
121 |
-
x="Language",
|
122 |
-
y="Gap",
|
123 |
-
color="Model",
|
124 |
-
title="Gaps by Language and Model (top 3, sorted by the best model)",
|
125 |
-
labels={
|
126 |
-
"Gap": "Sum of Absolute Gaps (%)",
|
127 |
-
"Language": "Language",
|
128 |
-
"Model": "Model",
|
129 |
-
},
|
130 |
-
barmode="group",
|
131 |
-
)
|
132 |
-
lang_order = (
|
133 |
-
lang_df.set_index("Model").loc[best_model].sort_values(ascending=False).index
|
134 |
-
)
|
135 |
-
logger.info(f"Lang order: {lang_order}")
|
136 |
|
137 |
-
|
138 |
-
|
139 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
140 |
|
141 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
142 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
143 |
|
144 |
with gr.Blocks() as fm_interface:
|
145 |
-
aggregated_df,
|
146 |
-
show_common_langs=False
|
147 |
)
|
148 |
model_with_nans_md = gr.Markdown(_build_models_with_nan_md(model_with_nan))
|
149 |
|
150 |
gr.Markdown("### Sum of Absolute Gaps ⬇️")
|
151 |
aggregated_df_comp = gr.DataFrame(format_dataframe(aggregated_df))
|
152 |
|
153 |
-
gr.Markdown("####
|
154 |
-
|
|
|
155 |
|
156 |
-
|
|
|
|
|
157 |
|
158 |
###################
|
159 |
# LIST MAIN TABS
|
@@ -179,6 +230,7 @@ banner = """
|
|
179 |
# MAIN INTERFACE
|
180 |
###################
|
181 |
with gr.Blocks() as demo:
|
|
|
182 |
gr.HTML(banner)
|
183 |
|
184 |
with gr.Row() as config_row:
|
@@ -186,31 +238,40 @@ with gr.Blocks() as demo:
|
|
186 |
choices=["Show only common languages"],
|
187 |
label="Main configuration",
|
188 |
)
|
|
|
|
|
189 |
include_datasets = gr.CheckboxGroup(
|
190 |
-
choices=
|
191 |
label="Include datasets",
|
192 |
-
value=
|
193 |
interactive=False,
|
194 |
)
|
195 |
|
196 |
show_common_langs.input(
|
197 |
build_components,
|
198 |
-
inputs=[show_common_langs],
|
199 |
outputs=[
|
200 |
aggregated_df_comp,
|
201 |
-
|
202 |
-
|
|
|
|
|
203 |
model_with_nans_md,
|
204 |
],
|
205 |
)
|
206 |
|
207 |
gr.TabbedInterface(tabs, titles)
|
208 |
|
209 |
-
gr.
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
|
|
|
|
|
|
|
|
|
|
214 |
)
|
215 |
|
216 |
if __name__ == "__main__":
|
|
|
1 |
import gradio as gr
|
2 |
+
from typing import List, Tuple
|
|
|
3 |
import plotly.express as px
|
4 |
from huggingface_hub import snapshot_download
|
5 |
import os
|
6 |
+
import pdb
|
7 |
import logging
|
8 |
+
import pandas as pd
|
9 |
|
10 |
+
from config import LOCAL_RESULTS_DIR, CITATION_BUTTON_TEXT, DatasetHelper, ModelHelper
|
11 |
+
from parsing import read_all_configs
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
# Set up logging
|
14 |
logging.basicConfig(
|
|
|
53 |
"""
|
54 |
|
55 |
|
56 |
+
def build_components(show_common_langs, selected_datasets: List[str]):
|
57 |
+
aggregated_df, lang_dfs, barplot_figs, models_with_nan = _populate_components(
|
58 |
+
show_common_langs, selected_datasets
|
59 |
)
|
60 |
models_with_nan_md = _build_models_with_nan_md(models_with_nan)
|
61 |
|
62 |
return (
|
63 |
gr.DataFrame(format_dataframe(aggregated_df)),
|
64 |
+
gr.DataFrame(format_dataframe(lang_dfs[0], times_100=True)),
|
65 |
+
gr.DataFrame(format_dataframe(lang_dfs[1], times_100=True)),
|
66 |
+
gr.Plot(barplot_figs[0]),
|
67 |
+
gr.Plot(barplot_figs[1]),
|
68 |
gr.Markdown(models_with_nan_md, visible=len(models_with_nan) > 0),
|
69 |
)
|
70 |
|
71 |
|
72 |
+
def _populate_components(
|
73 |
+
show_common_langs: bool, selected_datasets: List[str], contrast_type: str = "F-M"
|
74 |
+
) -> Tuple[pd.DataFrame, List[pd.DataFrame], List[px.bar], List[str]]:
|
75 |
+
|
76 |
+
results = read_all_configs(contrast_type)
|
77 |
|
78 |
if show_common_langs:
|
79 |
+
common_langs = model_h.get_common_langs()
|
80 |
logger.info(f"Common langs: {common_langs}")
|
81 |
results = results[results["Language"].isin(common_langs)]
|
82 |
|
|
|
95 |
logger.info(f"Models with NaN values: {models_with_nan}")
|
96 |
results = results[~results["Model"].isin(models_with_nan)]
|
97 |
|
98 |
+
type_dfs = list()
|
99 |
+
lang_dfs = list()
|
100 |
+
barplot_figs = list()
|
101 |
+
for type, type_df in results.groupby("Type"):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
|
103 |
+
# Aggregate main
|
104 |
+
aggregated_df = type_df.pivot_table(
|
105 |
+
index="Model",
|
106 |
+
values="Gap",
|
107 |
+
aggfunc=lambda x: 100 * x.abs().sum(),
|
108 |
+
)
|
109 |
+
aggregated_df = aggregated_df.rename(columns={"Gap": f"Gap ({type})"})
|
110 |
+
type_dfs.append(aggregated_df)
|
111 |
+
|
112 |
+
best_model = aggregated_df.index[0]
|
113 |
+
top_3_models = aggregated_df.index[:3].tolist()
|
114 |
+
|
115 |
+
# Aggregate by language
|
116 |
+
lang_df = type_df.pivot_table(
|
117 |
+
index="Model",
|
118 |
+
values="Gap",
|
119 |
+
columns="Language",
|
120 |
+
).reset_index()
|
121 |
+
lang_dfs.append(lang_df)
|
122 |
+
|
123 |
+
# Create plot
|
124 |
+
type_df["Gap"] = type_df["Gap"] * 100
|
125 |
+
barplot_fig = px.bar(
|
126 |
+
type_df.loc[results["Model"].isin(top_3_models)],
|
127 |
+
x="Language",
|
128 |
+
y="Gap",
|
129 |
+
color="Model",
|
130 |
+
title=f"{type}: Gaps by Language and Model (top 3, sorted by the best model)",
|
131 |
+
labels={
|
132 |
+
"Gap": f"{contrast_type} Gap (%)",
|
133 |
+
"Language": "Language",
|
134 |
+
"Model": "Model",
|
135 |
+
},
|
136 |
+
barmode="group",
|
137 |
+
)
|
138 |
|
139 |
+
lang_order = (
|
140 |
+
lang_df.set_index("Model")
|
141 |
+
.loc[best_model]
|
142 |
+
.sort_values(ascending=False)
|
143 |
+
.index
|
144 |
+
)
|
145 |
+
logger.info(f"Lang order: {lang_order}")
|
146 |
|
147 |
+
barplot_fig.update_layout(
|
148 |
+
xaxis={"categoryorder": "array", "categoryarray": lang_order}
|
149 |
+
)
|
150 |
+
barplot_figs.append(barplot_fig)
|
151 |
+
|
152 |
+
# pdb.set_trace()
|
153 |
+
aggregated_df = pd.concat(type_dfs, axis=1, join="inner")
|
154 |
+
aggregated_df["Avg"] = aggregated_df.mean(axis=1)
|
155 |
+
aggregated_df = aggregated_df.sort_values("Avg").reset_index()
|
156 |
+
|
157 |
+
# lang_df = results.pivot_table(
|
158 |
+
# index="Model",
|
159 |
+
# values="Gap",
|
160 |
+
# columns="Language",
|
161 |
+
# ).reset_index()
|
162 |
+
|
163 |
+
# results["Gap"] = results["Gap"] * 100
|
164 |
+
# barplot_fig = px.bar(
|
165 |
+
# results.loc[results["Model"].isin(top_3_models)],
|
166 |
+
# x="Language",
|
167 |
+
# y="Gap",
|
168 |
+
# color="Model",
|
169 |
+
# title="Gaps by Language and Model (top 3, sorted by the best model)",
|
170 |
+
# labels={
|
171 |
+
# "Gap": "Sum of Absolute Gaps (%)",
|
172 |
+
# "Language": "Language",
|
173 |
+
# "Model": "Model",
|
174 |
+
# },
|
175 |
+
# barmode="group",
|
176 |
+
# )
|
177 |
+
# lang_order = (
|
178 |
+
# lang_df.set_index("Model").loc[best_model].sort_values(ascending=False).index
|
179 |
+
# )
|
180 |
+
# logger.info(f"Lang order: {lang_order}")
|
181 |
+
|
182 |
+
# barplot_fig.update_layout(
|
183 |
+
# xaxis={"categoryorder": "array", "categoryarray": lang_order}
|
184 |
+
# )
|
185 |
+
|
186 |
+
return aggregated_df, lang_dfs, barplot_figs, models_with_nan
|
187 |
+
|
188 |
+
|
189 |
+
dataset_h = DatasetHelper()
|
190 |
+
model_h = ModelHelper()
|
191 |
|
192 |
with gr.Blocks() as fm_interface:
|
193 |
+
aggregated_df, lang_dfs, barplot_figs, model_with_nan = _populate_components(
|
194 |
+
show_common_langs=False, selected_datasets=dataset_h.get_dataset_names()
|
195 |
)
|
196 |
model_with_nans_md = gr.Markdown(_build_models_with_nan_md(model_with_nan))
|
197 |
|
198 |
gr.Markdown("### Sum of Absolute Gaps ⬇️")
|
199 |
aggregated_df_comp = gr.DataFrame(format_dataframe(aggregated_df))
|
200 |
|
201 |
+
gr.Markdown("#### Read: gaps by language")
|
202 |
+
lang_df_comp_0 = gr.DataFrame(format_dataframe(lang_dfs[0], times_100=True))
|
203 |
+
barplot_fig_comp_0 = gr.Plot(barplot_figs[0])
|
204 |
|
205 |
+
gr.Markdown("#### Spontaneous: gaps by language")
|
206 |
+
lang_df_comp_1 = gr.DataFrame(format_dataframe(lang_dfs[1], times_100=True))
|
207 |
+
barplot_fig_comp_1 = gr.Plot(barplot_figs[1])
|
208 |
|
209 |
###################
|
210 |
# LIST MAIN TABS
|
|
|
230 |
# MAIN INTERFACE
|
231 |
###################
|
232 |
with gr.Blocks() as demo:
|
233 |
+
|
234 |
gr.HTML(banner)
|
235 |
|
236 |
with gr.Row() as config_row:
|
|
|
238 |
choices=["Show only common languages"],
|
239 |
label="Main configuration",
|
240 |
)
|
241 |
+
|
242 |
+
datasets_names = dataset_h.get_dataset_names()
|
243 |
include_datasets = gr.CheckboxGroup(
|
244 |
+
choices=datasets_names,
|
245 |
label="Include datasets",
|
246 |
+
value=datasets_names,
|
247 |
interactive=False,
|
248 |
)
|
249 |
|
250 |
show_common_langs.input(
|
251 |
build_components,
|
252 |
+
inputs=[show_common_langs, include_datasets],
|
253 |
outputs=[
|
254 |
aggregated_df_comp,
|
255 |
+
lang_df_comp_0,
|
256 |
+
lang_df_comp_1,
|
257 |
+
barplot_fig_comp_0,
|
258 |
+
barplot_fig_comp_1,
|
259 |
model_with_nans_md,
|
260 |
],
|
261 |
)
|
262 |
|
263 |
gr.TabbedInterface(tabs, titles)
|
264 |
|
265 |
+
gr.Markdown(
|
266 |
+
"""
|
267 |
+
### Citation
|
268 |
+
If you find these results useful, please cite the following paper:
|
269 |
+
"""
|
270 |
+
)
|
271 |
+
|
272 |
+
gr.Markdown(
|
273 |
+
f"""```
|
274 |
+
{CITATION_BUTTON_TEXT}"""
|
275 |
)
|
276 |
|
277 |
if __name__ == "__main__":
|
config.py
CHANGED
@@ -4,78 +4,43 @@ to use for a particular datasetm or which language a model should be
|
|
4 |
evaluated on.
|
5 |
"""
|
6 |
|
|
|
|
|
7 |
LOCAL_RESULTS_DIR = "fair-asr-results"
|
8 |
SETUPS = [{"majority_group": "male_masculine", "minority_group": "female_feminine"}]
|
9 |
|
10 |
|
11 |
-
class
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
# fmt: off
|
16 |
-
langs = [
|
17 |
-
"de", "en", "nl", # Germanic
|
18 |
-
"ru", "sr", "cs", "sk", # Slavic
|
19 |
-
"it", "fr", "es", "ca", "pt", "ro", # Romance
|
20 |
-
"sw", # Bantu
|
21 |
-
"yo", # Niger-Congo
|
22 |
-
"ja", # Japonic
|
23 |
-
"hu", "fi", # Uralic
|
24 |
-
"ar" # Semitic
|
25 |
-
]
|
26 |
-
# fmt: on
|
27 |
-
|
28 |
-
|
29 |
-
dataset2info = {"cv_17": CVInfo}
|
30 |
-
|
31 |
|
32 |
-
|
33 |
-
|
34 |
-
langs = [
|
35 |
-
"de", "en", "nl", # Germanic
|
36 |
-
"ru", "sr", "cs", "sk", # Slavic
|
37 |
-
"it", "fr", "es", "ca", "pt", "ro", # Romance
|
38 |
-
"sw", # Bantu
|
39 |
-
"yo", # Niger-Congo
|
40 |
-
"ja", # Japonic
|
41 |
-
"hu", "fi", # Uralic
|
42 |
-
"ar" # Semitic
|
43 |
-
]
|
44 |
-
# fmt: on
|
45 |
|
|
|
|
|
46 |
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
"de", "en", "nl", # Germanic
|
51 |
-
"ru", "sr", "cs", "sk", # Slavic
|
52 |
-
"it", "fr", "es", "ca", "pt", "ro", # Romance
|
53 |
-
"sw", # Bantu
|
54 |
-
"yo", # Niger-Congo
|
55 |
-
"ja", # Japonic
|
56 |
-
"hu", "fi", # Uralic
|
57 |
-
"ar" # Semitic
|
58 |
-
]
|
59 |
-
# fmt: on
|
60 |
|
61 |
|
62 |
-
class
|
63 |
-
|
64 |
-
|
65 |
-
"en", "es", "de", "fr",
|
66 |
-
]
|
67 |
-
# fmt: on
|
68 |
|
|
|
|
|
|
|
69 |
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
}
|
76 |
|
77 |
|
78 |
-
CITATION_BUTTON_LABEL = "Please use this bibtex to cite these results"
|
79 |
CITATION_BUTTON_TEXT = r"""@inproceedings{attanasio-etal-2024-twists,
|
80 |
title = "Twists, Humps, and Pebbles: Multilingual Speech Recognition Models Exhibit Gender Performance Gaps",
|
81 |
author = "Attanasio, Giuseppe and
|
|
|
4 |
evaluated on.
|
5 |
"""
|
6 |
|
7 |
+
from fair_asr_code.config import ALL_DATASET_CONFIGS, MODEL2LANG_SUPPORT
|
8 |
+
|
9 |
LOCAL_RESULTS_DIR = "fair-asr-results"
|
10 |
SETUPS = [{"majority_group": "male_masculine", "minority_group": "female_feminine"}]
|
11 |
|
12 |
|
13 |
+
class DatasetHelper:
|
14 |
+
def __init__(self):
|
15 |
+
self.dataset_configs = ALL_DATASET_CONFIGS
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
+
def get_dataset_names(self):
|
18 |
+
return [config.name for config in self.dataset_configs]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
+
def get_dataset_ids(self):
|
21 |
+
return [config.dataset_id for config in self.dataset_configs]
|
22 |
|
23 |
+
@property
|
24 |
+
def sanitized_dataset_ids(self):
|
25 |
+
return [config.sanitized_id() for config in self.dataset_configs]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
|
27 |
|
28 |
+
class ModelHelper:
|
29 |
+
def __init__(self):
|
30 |
+
self.models = list(MODEL2LANG_SUPPORT.keys())
|
|
|
|
|
|
|
31 |
|
32 |
+
@property
|
33 |
+
def sanitized_model_ids(self):
|
34 |
+
return [model.replace("/", "--") for model in self.models]
|
35 |
|
36 |
+
def get_common_langs(self):
|
37 |
+
common_langs = set(MODEL2LANG_SUPPORT[self.models[0]])
|
38 |
+
for model in self.models:
|
39 |
+
common_langs = common_langs.intersection(set(MODEL2LANG_SUPPORT[model]))
|
40 |
+
return list(common_langs)
|
|
|
41 |
|
42 |
|
43 |
+
# CITATION_BUTTON_LABEL = "Please use this bibtex to cite these results"
|
44 |
CITATION_BUTTON_TEXT = r"""@inproceedings{attanasio-etal-2024-twists,
|
45 |
title = "Twists, Humps, and Pebbles: Multilingual Speech Recognition Models Exhibit Gender Performance Gaps",
|
46 |
author = "Attanasio, Giuseppe and
|
parsing.py
CHANGED
@@ -2,14 +2,14 @@ import pandas as pd
|
|
2 |
from typing import List
|
3 |
from os.path import join as opj
|
4 |
import json
|
5 |
-
from config import dataset2info, model2info, LOCAL_RESULTS_DIR
|
6 |
import logging
|
|
|
7 |
|
8 |
logger = logging.getLogger(__name__)
|
9 |
|
10 |
|
11 |
def load_language_results(
|
12 |
-
model_id: str, dataset_id: str, lang_ids: List[str],
|
13 |
):
|
14 |
lang_gaps = dict()
|
15 |
for lang in lang_ids:
|
@@ -20,7 +20,7 @@ def load_language_results(
|
|
20 |
LOCAL_RESULTS_DIR,
|
21 |
"evaluation",
|
22 |
dataset_id,
|
23 |
-
f"results_{model_id}_{dataset_id}_devtest_{lang}_gender_{
|
24 |
)
|
25 |
) as fp:
|
26 |
data = json.load(fp)
|
@@ -34,26 +34,33 @@ def load_language_results(
|
|
34 |
return lang_gaps
|
35 |
|
36 |
|
37 |
-
def read_all_configs(
|
38 |
-
|
39 |
-
|
40 |
-
print("Parsing results datasets:", all_datasets)
|
41 |
-
all_models = model2info.keys()
|
42 |
-
print("Parsing results models:", all_models)
|
43 |
|
44 |
rows = list()
|
45 |
-
for
|
46 |
-
for model_id in
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
lang_gaps = load_language_results(
|
48 |
-
model_id,
|
|
|
|
|
|
|
49 |
)
|
50 |
|
51 |
rows.extend(
|
52 |
[
|
53 |
{
|
54 |
"Model": model_id,
|
55 |
-
"Dataset":
|
56 |
"Language": lang,
|
|
|
57 |
"Gap": lang_gaps[lang],
|
58 |
}
|
59 |
for lang in lang_gaps
|
@@ -61,16 +68,4 @@ def read_all_configs(setup: str):
|
|
61 |
)
|
62 |
|
63 |
results_df = pd.DataFrame(rows)
|
64 |
-
# results_df = results_df.drop(columns=["Dataset"])
|
65 |
-
# results_df = results_df.sort_values(by="Mean Gap", ascending=True)
|
66 |
-
|
67 |
return results_df
|
68 |
-
|
69 |
-
|
70 |
-
def get_common_langs():
|
71 |
-
"""Return a list of langs that are support by all models"""
|
72 |
-
common_langs = set(model2info[list(model2info.keys())[0]].langs)
|
73 |
-
for model_id in model2info.keys():
|
74 |
-
common_langs = common_langs.intersection(model2info[model_id].langs)
|
75 |
-
|
76 |
-
return list(common_langs)
|
|
|
2 |
from typing import List
|
3 |
from os.path import join as opj
|
4 |
import json
|
|
|
5 |
import logging
|
6 |
+
from config import DatasetHelper, ModelHelper, LOCAL_RESULTS_DIR
|
7 |
|
8 |
logger = logging.getLogger(__name__)
|
9 |
|
10 |
|
11 |
def load_language_results(
|
12 |
+
model_id: str, dataset_id: str, lang_ids: List[str], contrast_string: str
|
13 |
):
|
14 |
lang_gaps = dict()
|
15 |
for lang in lang_ids:
|
|
|
20 |
LOCAL_RESULTS_DIR,
|
21 |
"evaluation",
|
22 |
dataset_id,
|
23 |
+
f"results_{model_id}_{dataset_id}_devtest_{lang}_gender_{contrast_string}.json",
|
24 |
)
|
25 |
) as fp:
|
26 |
data = json.load(fp)
|
|
|
34 |
return lang_gaps
|
35 |
|
36 |
|
37 |
+
def read_all_configs(contrast_type: str):
|
38 |
+
dataset_h = DatasetHelper()
|
39 |
+
model_h = ModelHelper()
|
|
|
|
|
|
|
40 |
|
41 |
rows = list()
|
42 |
+
for dataset_config in dataset_h.dataset_configs:
|
43 |
+
for model_id in model_h.sanitized_model_ids:
|
44 |
+
|
45 |
+
contrast_info = dataset_config.group_contrasts[contrast_type]
|
46 |
+
contrast_string = (
|
47 |
+
f"{contrast_info['majority_group']}_{contrast_info['minority_group']}"
|
48 |
+
)
|
49 |
+
|
50 |
lang_gaps = load_language_results(
|
51 |
+
model_id,
|
52 |
+
dataset_config.sanitized_id(),
|
53 |
+
dataset_config.langs,
|
54 |
+
contrast_string,
|
55 |
)
|
56 |
|
57 |
rows.extend(
|
58 |
[
|
59 |
{
|
60 |
"Model": model_id,
|
61 |
+
"Dataset": dataset_config.sanitized_id(),
|
62 |
"Language": lang,
|
63 |
+
"Type": dataset_config.speaking_condition.capitalize(),
|
64 |
"Gap": lang_gaps[lang],
|
65 |
}
|
66 |
for lang in lang_gaps
|
|
|
68 |
)
|
69 |
|
70 |
results_df = pd.DataFrame(rows)
|
|
|
|
|
|
|
71 |
return results_df
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
@@ -1,3 +1,4 @@
|
|
1 |
gradio
|
2 |
plotly
|
3 |
-
pandas
|
|
|
|
1 |
gradio
|
2 |
plotly
|
3 |
+
pandas
|
4 |
+
-e git+https://github.com/g8a9/fair-asr-code#egg=fair-asr-code
|