g8a9 commited on
Commit
fc63ec6
·
1 Parent(s): 9a7f023

refactor: streamline dataset and model handling with helper classes

Browse files
Files changed (4) hide show
  1. app.py +136 -75
  2. config.py +24 -59
  3. parsing.py +20 -25
  4. requirements.txt +2 -1
app.py CHANGED
@@ -1,18 +1,14 @@
1
  import gradio as gr
2
- import pandas as pd
3
- import random
4
  import plotly.express as px
5
  from huggingface_hub import snapshot_download
6
  import os
 
7
  import logging
 
8
 
9
- from config import (
10
- SETUPS,
11
- LOCAL_RESULTS_DIR,
12
- CITATION_BUTTON_TEXT,
13
- CITATION_BUTTON_LABEL,
14
- )
15
- from parsing import read_all_configs, get_common_langs
16
 
17
  # Set up logging
18
  logging.basicConfig(
@@ -57,27 +53,30 @@ We are currently hiding the results of {', '.join(model_markups)} because they d
57
  """
58
 
59
 
60
- def build_components(show_common_langs):
61
- aggregated_df, lang_df, barplot_fig, models_with_nan = _populate_components(
62
- show_common_langs
63
  )
64
  models_with_nan_md = _build_models_with_nan_md(models_with_nan)
65
 
66
  return (
67
  gr.DataFrame(format_dataframe(aggregated_df)),
68
- gr.DataFrame(format_dataframe(lang_df, times_100=True)),
69
- gr.Plot(barplot_fig),
 
 
70
  gr.Markdown(models_with_nan_md, visible=len(models_with_nan) > 0),
71
  )
72
 
73
 
74
- def _populate_components(show_common_langs):
75
- fm = SETUPS[0]
76
- setup = fm["majority_group"] + "_" + fm["minority_group"]
77
- results = read_all_configs(setup)
 
78
 
79
  if show_common_langs:
80
- common_langs = get_common_langs()
81
  logger.info(f"Common langs: {common_langs}")
82
  results = results[results["Language"].isin(common_langs)]
83
 
@@ -96,64 +95,116 @@ def _populate_components(show_common_langs):
96
  logger.info(f"Models with NaN values: {models_with_nan}")
97
  results = results[~results["Model"].isin(models_with_nan)]
98
 
99
- aggregated_df = (
100
- results.pivot_table(
101
- index="Model", values="Gap", aggfunc=lambda x: 100 * x.abs().sum()
102
- )
103
- .reset_index()
104
- .sort_values("Gap")
105
- )
106
- best_model = aggregated_df.iloc[0]["Model"]
107
- top_3_models = aggregated_df["Model"].head(3).tolist()
108
- # main_df = gr.DataFrame(format_dataframe(model_results))
109
-
110
- lang_df = results.pivot_table(
111
- index="Model",
112
- values="Gap",
113
- columns="Language",
114
- ).reset_index()
115
- # lang_df = gr.DataFrame(format_dataframe(lang_results, times_100=True))
116
-
117
- # gr.Plot(fig1)
118
- results["Gap"] = results["Gap"] * 100
119
- barplot_fig = px.bar(
120
- results.loc[results["Model"].isin(top_3_models)],
121
- x="Language",
122
- y="Gap",
123
- color="Model",
124
- title="Gaps by Language and Model (top 3, sorted by the best model)",
125
- labels={
126
- "Gap": "Sum of Absolute Gaps (%)",
127
- "Language": "Language",
128
- "Model": "Model",
129
- },
130
- barmode="group",
131
- )
132
- lang_order = (
133
- lang_df.set_index("Model").loc[best_model].sort_values(ascending=False).index
134
- )
135
- logger.info(f"Lang order: {lang_order}")
136
 
137
- barplot_fig.update_layout(
138
- xaxis={"categoryorder": "array", "categoryarray": lang_order}
139
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
 
141
- return aggregated_df, lang_df, barplot_fig, models_with_nan
 
 
 
 
 
 
142
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
 
144
  with gr.Blocks() as fm_interface:
145
- aggregated_df, lang_df, barplot_fig, model_with_nan = _populate_components(
146
- show_common_langs=False
147
  )
148
  model_with_nans_md = gr.Markdown(_build_models_with_nan_md(model_with_nan))
149
 
150
  gr.Markdown("### Sum of Absolute Gaps ⬇️")
151
  aggregated_df_comp = gr.DataFrame(format_dataframe(aggregated_df))
152
 
153
- gr.Markdown("#### F-M gaps by language")
154
- lang_df_comp = gr.DataFrame(format_dataframe(lang_df, times_100=True))
 
155
 
156
- barplot_fig_comp = gr.Plot(barplot_fig)
 
 
157
 
158
  ###################
159
  # LIST MAIN TABS
@@ -179,6 +230,7 @@ banner = """
179
  # MAIN INTERFACE
180
  ###################
181
  with gr.Blocks() as demo:
 
182
  gr.HTML(banner)
183
 
184
  with gr.Row() as config_row:
@@ -186,31 +238,40 @@ with gr.Blocks() as demo:
186
  choices=["Show only common languages"],
187
  label="Main configuration",
188
  )
 
 
189
  include_datasets = gr.CheckboxGroup(
190
- choices=["Mozilla CV 17"],
191
  label="Include datasets",
192
- value=["Mozilla CV 17"],
193
  interactive=False,
194
  )
195
 
196
  show_common_langs.input(
197
  build_components,
198
- inputs=[show_common_langs],
199
  outputs=[
200
  aggregated_df_comp,
201
- lang_df_comp,
202
- barplot_fig_comp,
 
 
203
  model_with_nans_md,
204
  ],
205
  )
206
 
207
  gr.TabbedInterface(tabs, titles)
208
 
209
- gr.Textbox(
210
- value=CITATION_BUTTON_TEXT,
211
- label=CITATION_BUTTON_LABEL,
212
- max_lines=6,
213
- show_copy_button=True,
 
 
 
 
 
214
  )
215
 
216
  if __name__ == "__main__":
 
1
  import gradio as gr
2
+ from typing import List, Tuple
 
3
  import plotly.express as px
4
  from huggingface_hub import snapshot_download
5
  import os
6
+ import pdb
7
  import logging
8
+ import pandas as pd
9
 
10
+ from config import LOCAL_RESULTS_DIR, CITATION_BUTTON_TEXT, DatasetHelper, ModelHelper
11
+ from parsing import read_all_configs
 
 
 
 
 
12
 
13
  # Set up logging
14
  logging.basicConfig(
 
53
  """
54
 
55
 
56
+ def build_components(show_common_langs, selected_datasets: List[str]):
57
+ aggregated_df, lang_dfs, barplot_figs, models_with_nan = _populate_components(
58
+ show_common_langs, selected_datasets
59
  )
60
  models_with_nan_md = _build_models_with_nan_md(models_with_nan)
61
 
62
  return (
63
  gr.DataFrame(format_dataframe(aggregated_df)),
64
+ gr.DataFrame(format_dataframe(lang_dfs[0], times_100=True)),
65
+ gr.DataFrame(format_dataframe(lang_dfs[1], times_100=True)),
66
+ gr.Plot(barplot_figs[0]),
67
+ gr.Plot(barplot_figs[1]),
68
  gr.Markdown(models_with_nan_md, visible=len(models_with_nan) > 0),
69
  )
70
 
71
 
72
+ def _populate_components(
73
+ show_common_langs: bool, selected_datasets: List[str], contrast_type: str = "F-M"
74
+ ) -> Tuple[pd.DataFrame, List[pd.DataFrame], List[px.bar], List[str]]:
75
+
76
+ results = read_all_configs(contrast_type)
77
 
78
  if show_common_langs:
79
+ common_langs = model_h.get_common_langs()
80
  logger.info(f"Common langs: {common_langs}")
81
  results = results[results["Language"].isin(common_langs)]
82
 
 
95
  logger.info(f"Models with NaN values: {models_with_nan}")
96
  results = results[~results["Model"].isin(models_with_nan)]
97
 
98
+ type_dfs = list()
99
+ lang_dfs = list()
100
+ barplot_figs = list()
101
+ for type, type_df in results.groupby("Type"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
+ # Aggregate main
104
+ aggregated_df = type_df.pivot_table(
105
+ index="Model",
106
+ values="Gap",
107
+ aggfunc=lambda x: 100 * x.abs().sum(),
108
+ )
109
+ aggregated_df = aggregated_df.rename(columns={"Gap": f"Gap ({type})"})
110
+ type_dfs.append(aggregated_df)
111
+
112
+ best_model = aggregated_df.index[0]
113
+ top_3_models = aggregated_df.index[:3].tolist()
114
+
115
+ # Aggregate by language
116
+ lang_df = type_df.pivot_table(
117
+ index="Model",
118
+ values="Gap",
119
+ columns="Language",
120
+ ).reset_index()
121
+ lang_dfs.append(lang_df)
122
+
123
+ # Create plot
124
+ type_df["Gap"] = type_df["Gap"] * 100
125
+ barplot_fig = px.bar(
126
+ type_df.loc[results["Model"].isin(top_3_models)],
127
+ x="Language",
128
+ y="Gap",
129
+ color="Model",
130
+ title=f"{type}: Gaps by Language and Model (top 3, sorted by the best model)",
131
+ labels={
132
+ "Gap": f"{contrast_type} Gap (%)",
133
+ "Language": "Language",
134
+ "Model": "Model",
135
+ },
136
+ barmode="group",
137
+ )
138
 
139
+ lang_order = (
140
+ lang_df.set_index("Model")
141
+ .loc[best_model]
142
+ .sort_values(ascending=False)
143
+ .index
144
+ )
145
+ logger.info(f"Lang order: {lang_order}")
146
 
147
+ barplot_fig.update_layout(
148
+ xaxis={"categoryorder": "array", "categoryarray": lang_order}
149
+ )
150
+ barplot_figs.append(barplot_fig)
151
+
152
+ # pdb.set_trace()
153
+ aggregated_df = pd.concat(type_dfs, axis=1, join="inner")
154
+ aggregated_df["Avg"] = aggregated_df.mean(axis=1)
155
+ aggregated_df = aggregated_df.sort_values("Avg").reset_index()
156
+
157
+ # lang_df = results.pivot_table(
158
+ # index="Model",
159
+ # values="Gap",
160
+ # columns="Language",
161
+ # ).reset_index()
162
+
163
+ # results["Gap"] = results["Gap"] * 100
164
+ # barplot_fig = px.bar(
165
+ # results.loc[results["Model"].isin(top_3_models)],
166
+ # x="Language",
167
+ # y="Gap",
168
+ # color="Model",
169
+ # title="Gaps by Language and Model (top 3, sorted by the best model)",
170
+ # labels={
171
+ # "Gap": "Sum of Absolute Gaps (%)",
172
+ # "Language": "Language",
173
+ # "Model": "Model",
174
+ # },
175
+ # barmode="group",
176
+ # )
177
+ # lang_order = (
178
+ # lang_df.set_index("Model").loc[best_model].sort_values(ascending=False).index
179
+ # )
180
+ # logger.info(f"Lang order: {lang_order}")
181
+
182
+ # barplot_fig.update_layout(
183
+ # xaxis={"categoryorder": "array", "categoryarray": lang_order}
184
+ # )
185
+
186
+ return aggregated_df, lang_dfs, barplot_figs, models_with_nan
187
+
188
+
189
+ dataset_h = DatasetHelper()
190
+ model_h = ModelHelper()
191
 
192
  with gr.Blocks() as fm_interface:
193
+ aggregated_df, lang_dfs, barplot_figs, model_with_nan = _populate_components(
194
+ show_common_langs=False, selected_datasets=dataset_h.get_dataset_names()
195
  )
196
  model_with_nans_md = gr.Markdown(_build_models_with_nan_md(model_with_nan))
197
 
198
  gr.Markdown("### Sum of Absolute Gaps ⬇️")
199
  aggregated_df_comp = gr.DataFrame(format_dataframe(aggregated_df))
200
 
201
+ gr.Markdown("#### Read: gaps by language")
202
+ lang_df_comp_0 = gr.DataFrame(format_dataframe(lang_dfs[0], times_100=True))
203
+ barplot_fig_comp_0 = gr.Plot(barplot_figs[0])
204
 
205
+ gr.Markdown("#### Spontaneous: gaps by language")
206
+ lang_df_comp_1 = gr.DataFrame(format_dataframe(lang_dfs[1], times_100=True))
207
+ barplot_fig_comp_1 = gr.Plot(barplot_figs[1])
208
 
209
  ###################
210
  # LIST MAIN TABS
 
230
  # MAIN INTERFACE
231
  ###################
232
  with gr.Blocks() as demo:
233
+
234
  gr.HTML(banner)
235
 
236
  with gr.Row() as config_row:
 
238
  choices=["Show only common languages"],
239
  label="Main configuration",
240
  )
241
+
242
+ datasets_names = dataset_h.get_dataset_names()
243
  include_datasets = gr.CheckboxGroup(
244
+ choices=datasets_names,
245
  label="Include datasets",
246
+ value=datasets_names,
247
  interactive=False,
248
  )
249
 
250
  show_common_langs.input(
251
  build_components,
252
+ inputs=[show_common_langs, include_datasets],
253
  outputs=[
254
  aggregated_df_comp,
255
+ lang_df_comp_0,
256
+ lang_df_comp_1,
257
+ barplot_fig_comp_0,
258
+ barplot_fig_comp_1,
259
  model_with_nans_md,
260
  ],
261
  )
262
 
263
  gr.TabbedInterface(tabs, titles)
264
 
265
+ gr.Markdown(
266
+ """
267
+ ### Citation
268
+ If you find these results useful, please cite the following paper:
269
+ """
270
+ )
271
+
272
+ gr.Markdown(
273
+ f"""```
274
+ {CITATION_BUTTON_TEXT}"""
275
  )
276
 
277
  if __name__ == "__main__":
config.py CHANGED
@@ -4,78 +4,43 @@ to use for a particular datasetm or which language a model should be
4
  evaluated on.
5
  """
6
 
 
 
7
  LOCAL_RESULTS_DIR = "fair-asr-results"
8
  SETUPS = [{"majority_group": "male_masculine", "minority_group": "female_feminine"}]
9
 
10
 
11
- class CVInfo:
12
- dataset_id: str = "cv_17"
13
- full_name: str = "Mozilla Common Voice v17"
14
-
15
- # fmt: off
16
- langs = [
17
- "de", "en", "nl", # Germanic
18
- "ru", "sr", "cs", "sk", # Slavic
19
- "it", "fr", "es", "ca", "pt", "ro", # Romance
20
- "sw", # Bantu
21
- "yo", # Niger-Congo
22
- "ja", # Japonic
23
- "hu", "fi", # Uralic
24
- "ar" # Semitic
25
- ]
26
- # fmt: on
27
-
28
-
29
- dataset2info = {"cv_17": CVInfo}
30
-
31
 
32
- class WhisperInfo:
33
- # fmt: off
34
- langs = [
35
- "de", "en", "nl", # Germanic
36
- "ru", "sr", "cs", "sk", # Slavic
37
- "it", "fr", "es", "ca", "pt", "ro", # Romance
38
- "sw", # Bantu
39
- "yo", # Niger-Congo
40
- "ja", # Japonic
41
- "hu", "fi", # Uralic
42
- "ar" # Semitic
43
- ]
44
- # fmt: on
45
 
 
 
46
 
47
- class SeamlessInfo:
48
- # fmt: off
49
- langs = [
50
- "de", "en", "nl", # Germanic
51
- "ru", "sr", "cs", "sk", # Slavic
52
- "it", "fr", "es", "ca", "pt", "ro", # Romance
53
- "sw", # Bantu
54
- "yo", # Niger-Congo
55
- "ja", # Japonic
56
- "hu", "fi", # Uralic
57
- "ar" # Semitic
58
- ]
59
- # fmt: on
60
 
61
 
62
- class CanaryInfo:
63
- # fmt: off
64
- langs = [
65
- "en", "es", "de", "fr",
66
- ]
67
- # fmt: on
68
 
 
 
 
69
 
70
- model2info = {
71
- "openai--whisper-large-v3": WhisperInfo,
72
- "openai--whisper-large-v3-turbo": WhisperInfo,
73
- "facebook--seamless-m4t-v2-large": SeamlessInfo,
74
- "nvidia--canary-1b": CanaryInfo,
75
- }
76
 
77
 
78
- CITATION_BUTTON_LABEL = "Please use this bibtex to cite these results"
79
  CITATION_BUTTON_TEXT = r"""@inproceedings{attanasio-etal-2024-twists,
80
  title = "Twists, Humps, and Pebbles: Multilingual Speech Recognition Models Exhibit Gender Performance Gaps",
81
  author = "Attanasio, Giuseppe and
 
4
  evaluated on.
5
  """
6
 
7
+ from fair_asr_code.config import ALL_DATASET_CONFIGS, MODEL2LANG_SUPPORT
8
+
9
  LOCAL_RESULTS_DIR = "fair-asr-results"
10
  SETUPS = [{"majority_group": "male_masculine", "minority_group": "female_feminine"}]
11
 
12
 
13
+ class DatasetHelper:
14
+ def __init__(self):
15
+ self.dataset_configs = ALL_DATASET_CONFIGS
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
+ def get_dataset_names(self):
18
+ return [config.name for config in self.dataset_configs]
 
 
 
 
 
 
 
 
 
 
 
19
 
20
+ def get_dataset_ids(self):
21
+ return [config.dataset_id for config in self.dataset_configs]
22
 
23
+ @property
24
+ def sanitized_dataset_ids(self):
25
+ return [config.sanitized_id() for config in self.dataset_configs]
 
 
 
 
 
 
 
 
 
 
26
 
27
 
28
+ class ModelHelper:
29
+ def __init__(self):
30
+ self.models = list(MODEL2LANG_SUPPORT.keys())
 
 
 
31
 
32
+ @property
33
+ def sanitized_model_ids(self):
34
+ return [model.replace("/", "--") for model in self.models]
35
 
36
+ def get_common_langs(self):
37
+ common_langs = set(MODEL2LANG_SUPPORT[self.models[0]])
38
+ for model in self.models:
39
+ common_langs = common_langs.intersection(set(MODEL2LANG_SUPPORT[model]))
40
+ return list(common_langs)
 
41
 
42
 
43
+ # CITATION_BUTTON_LABEL = "Please use this bibtex to cite these results"
44
  CITATION_BUTTON_TEXT = r"""@inproceedings{attanasio-etal-2024-twists,
45
  title = "Twists, Humps, and Pebbles: Multilingual Speech Recognition Models Exhibit Gender Performance Gaps",
46
  author = "Attanasio, Giuseppe and
parsing.py CHANGED
@@ -2,14 +2,14 @@ import pandas as pd
2
  from typing import List
3
  from os.path import join as opj
4
  import json
5
- from config import dataset2info, model2info, LOCAL_RESULTS_DIR
6
  import logging
 
7
 
8
  logger = logging.getLogger(__name__)
9
 
10
 
11
  def load_language_results(
12
- model_id: str, dataset_id: str, lang_ids: List[str], setup: str
13
  ):
14
  lang_gaps = dict()
15
  for lang in lang_ids:
@@ -20,7 +20,7 @@ def load_language_results(
20
  LOCAL_RESULTS_DIR,
21
  "evaluation",
22
  dataset_id,
23
- f"results_{model_id}_{dataset_id}_devtest_{lang}_gender_{setup}.json",
24
  )
25
  ) as fp:
26
  data = json.load(fp)
@@ -34,26 +34,33 @@ def load_language_results(
34
  return lang_gaps
35
 
36
 
37
- def read_all_configs(setup: str):
38
-
39
- all_datasets = dataset2info.keys()
40
- print("Parsing results datasets:", all_datasets)
41
- all_models = model2info.keys()
42
- print("Parsing results models:", all_models)
43
 
44
  rows = list()
45
- for dataset_id in all_datasets:
46
- for model_id in all_models:
 
 
 
 
 
 
47
  lang_gaps = load_language_results(
48
- model_id, dataset_id, dataset2info[dataset_id].langs, setup
 
 
 
49
  )
50
 
51
  rows.extend(
52
  [
53
  {
54
  "Model": model_id,
55
- "Dataset": dataset_id,
56
  "Language": lang,
 
57
  "Gap": lang_gaps[lang],
58
  }
59
  for lang in lang_gaps
@@ -61,16 +68,4 @@ def read_all_configs(setup: str):
61
  )
62
 
63
  results_df = pd.DataFrame(rows)
64
- # results_df = results_df.drop(columns=["Dataset"])
65
- # results_df = results_df.sort_values(by="Mean Gap", ascending=True)
66
-
67
  return results_df
68
-
69
-
70
- def get_common_langs():
71
- """Return a list of langs that are support by all models"""
72
- common_langs = set(model2info[list(model2info.keys())[0]].langs)
73
- for model_id in model2info.keys():
74
- common_langs = common_langs.intersection(model2info[model_id].langs)
75
-
76
- return list(common_langs)
 
2
  from typing import List
3
  from os.path import join as opj
4
  import json
 
5
  import logging
6
+ from config import DatasetHelper, ModelHelper, LOCAL_RESULTS_DIR
7
 
8
  logger = logging.getLogger(__name__)
9
 
10
 
11
  def load_language_results(
12
+ model_id: str, dataset_id: str, lang_ids: List[str], contrast_string: str
13
  ):
14
  lang_gaps = dict()
15
  for lang in lang_ids:
 
20
  LOCAL_RESULTS_DIR,
21
  "evaluation",
22
  dataset_id,
23
+ f"results_{model_id}_{dataset_id}_devtest_{lang}_gender_{contrast_string}.json",
24
  )
25
  ) as fp:
26
  data = json.load(fp)
 
34
  return lang_gaps
35
 
36
 
37
+ def read_all_configs(contrast_type: str):
38
+ dataset_h = DatasetHelper()
39
+ model_h = ModelHelper()
 
 
 
40
 
41
  rows = list()
42
+ for dataset_config in dataset_h.dataset_configs:
43
+ for model_id in model_h.sanitized_model_ids:
44
+
45
+ contrast_info = dataset_config.group_contrasts[contrast_type]
46
+ contrast_string = (
47
+ f"{contrast_info['majority_group']}_{contrast_info['minority_group']}"
48
+ )
49
+
50
  lang_gaps = load_language_results(
51
+ model_id,
52
+ dataset_config.sanitized_id(),
53
+ dataset_config.langs,
54
+ contrast_string,
55
  )
56
 
57
  rows.extend(
58
  [
59
  {
60
  "Model": model_id,
61
+ "Dataset": dataset_config.sanitized_id(),
62
  "Language": lang,
63
+ "Type": dataset_config.speaking_condition.capitalize(),
64
  "Gap": lang_gaps[lang],
65
  }
66
  for lang in lang_gaps
 
68
  )
69
 
70
  results_df = pd.DataFrame(rows)
 
 
 
71
  return results_df
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,3 +1,4 @@
1
  gradio
2
  plotly
3
- pandas
 
 
1
  gradio
2
  plotly
3
+ pandas
4
+ -e git+https://github.com/g8a9/fair-asr-code#egg=fair-asr-code