Spaces:
AIR-Bench
/
Running on CPU Upgrade

nan commited on
Commit
b671337
·
1 Parent(s): 6b461df

test: add unit tests for utils

Browse files
Files changed (3) hide show
  1. src/utils.py +20 -13
  2. tests/src/test_utils.py +151 -20
  3. tests/test_utils.py +1 -12
src/utils.py CHANGED
@@ -98,14 +98,7 @@ def get_default_cols(task: TaskType, version_slug, add_fix_cols: bool = True) ->
98
  return cols, types
99
 
100
 
101
- def select_columns(
102
- df: pd.DataFrame,
103
- domain_query: list,
104
- language_query: list,
105
- task: TaskType = TaskType.qa,
106
- reset_ranking: bool = True,
107
- version_slug: str = None,
108
- ) -> pd.DataFrame:
109
  cols, _ = get_default_cols(task=task, version_slug=version_slug, add_fix_cols=False)
110
  selected_cols = []
111
  for c in cols:
@@ -115,21 +108,35 @@ def select_columns(
115
  eval_col = LongDocBenchmarks[version_slug].value[c].value
116
  else:
117
  raise NotImplementedError
118
- if eval_col.domain not in domain_query:
119
  continue
120
- if eval_col.lang not in language_query:
121
  continue
122
  selected_cols.append(c)
123
  # We use COLS to maintain sorting
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  fixed_cols, _ = get_fixed_col_names_and_types()
125
  filtered_df = df[fixed_cols + selected_cols]
126
  filtered_df.replace({"": pd.NA}, inplace=True)
127
  if reset_ranking:
128
- filtered_df[COL_NAME_AVG] = filtered_df[selected_cols].apply(calculate_mean, axis=1).round(decimals=2)
129
- filtered_df.sort_values(by=[COL_NAME_AVG], ascending=False, inplace=True)
 
 
130
  filtered_df.reset_index(inplace=True, drop=True)
131
  filtered_df = reset_rank(filtered_df)
132
-
133
  return filtered_df
134
 
135
 
 
98
  return cols, types
99
 
100
 
101
+ def get_selected_cols(task, version_slug, domains, languages):
 
 
 
 
 
 
 
102
  cols, _ = get_default_cols(task=task, version_slug=version_slug, add_fix_cols=False)
103
  selected_cols = []
104
  for c in cols:
 
108
  eval_col = LongDocBenchmarks[version_slug].value[c].value
109
  else:
110
  raise NotImplementedError
111
+ if eval_col.domain not in domains:
112
  continue
113
+ if eval_col.lang not in languages:
114
  continue
115
  selected_cols.append(c)
116
  # We use COLS to maintain sorting
117
+ return selected_cols
118
+
119
+
120
+ def select_columns(
121
+ df: pd.DataFrame,
122
+ domains: list,
123
+ languages: list,
124
+ task: TaskType = TaskType.qa,
125
+ reset_ranking: bool = True,
126
+ version_slug: str = None,
127
+ ) -> pd.DataFrame:
128
+ selected_cols = get_selected_cols(
129
+ task, version_slug, domains, languages)
130
  fixed_cols, _ = get_fixed_col_names_and_types()
131
  filtered_df = df[fixed_cols + selected_cols]
132
  filtered_df.replace({"": pd.NA}, inplace=True)
133
  if reset_ranking:
134
+ filtered_df[COL_NAME_AVG] = \
135
+ filtered_df[selected_cols].apply(calculate_mean, axis=1).round(decimals=2)
136
+ filtered_df.sort_values(
137
+ by=[COL_NAME_AVG], ascending=False, inplace=True)
138
  filtered_df.reset_index(inplace=True, drop=True)
139
  filtered_df = reset_rank(filtered_df)
 
140
  return filtered_df
141
 
142
 
tests/src/test_utils.py CHANGED
@@ -1,26 +1,157 @@
1
- from src.display.utils import (
2
- COLS_LONG_DOC,
3
- COLS_QA,
4
- TYPES_LONG_DOC,
5
- TYPES_QA,
6
- AutoEvalColumnQA,
7
- fields,
8
- get_default_auto_eval_column_dict,
9
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
 
12
- def test_fields():
13
- for c in fields(AutoEvalColumnQA):
14
- print(c)
 
 
 
 
15
 
16
 
17
- def test_macro_variables():
18
- print(f"COLS_QA: {COLS_QA}")
19
- print(f"COLS_LONG_DOC: {COLS_LONG_DOC}")
20
- print(f"TYPES_QA: {TYPES_QA}")
21
- print(f"TYPES_LONG_DOC: {TYPES_LONG_DOC}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
 
24
- def test_get_default_auto_eval_column_dict():
25
- auto_eval_column_dict_list = get_default_auto_eval_column_dict()
26
- assert len(auto_eval_column_dict_list) == 9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ import pandas as pd
3
+
4
+ from src.utils import remove_html, calculate_mean, filter_models, filter_queries, get_default_cols, select_columns, get_selected_cols
5
+ from src.models import model_hyperlink, TaskType
6
+ from src.columns import COL_NAME_RERANKING_MODEL, COL_NAME_RETRIEVAL_MODEL
7
+
8
+
9
+ NUM_QA_BENCHMARKS_24_05 = 53
10
+ NUM_DOC_BENCHMARKS_24_05 = 11
11
+ NUM_QA_BENCHMARKS_24_04 = 13
12
+ NUM_DOC_BENCHMARKS_24_04 = 15
13
+
14
+
15
+ @pytest.fixture
16
+ def toy_df():
17
+ return pd.DataFrame(
18
+ {
19
+ "Retrieval Method": [
20
+ "bge-m3",
21
+ "bge-m3",
22
+ "jina-embeddings-v2-base",
23
+ "jina-embeddings-v2-base"
24
+ ],
25
+ "Reranking Model": [
26
+ "bge-reranker-v2-m3",
27
+ "NoReranker",
28
+ "bge-reranker-v2-m3",
29
+ "NoReranker"
30
+ ],
31
+ "Rank 🏆": [1, 2, 3, 4],
32
+ "Revision": ["", "", "", ""],
33
+ "Submission Date": ["", "", "", ""],
34
+ "Average ⬆️": [0.6, 0.4, 0.3, 0.2],
35
+ "wiki_en": [0.8, 0.7, 0.2, 0.1],
36
+ "wiki_zh": [0.4, 0.1, 0.4, 0.3],
37
+ "news_en": [0.8, 0.7, 0.2, 0.1],
38
+ "news_zh": [0.4, 0.1, 0.4, 0.3],
39
+ }
40
+ )
41
+
42
+
43
+ def test_remove_html():
44
+ model_name = "jina-embeddings-v3"
45
+ html_str = model_hyperlink(
46
+ "https://jina.ai", model_name)
47
+ output_str = remove_html(html_str)
48
+ assert output_str == model_name
49
 
50
 
51
+ def test_calculate_mean():
52
+ valid_row = [1, 3]
53
+ invalid_row = [2, pd.NA]
54
+ df = pd.DataFrame([valid_row, invalid_row], columns=["a", "b"])
55
+ result = list(df.apply(calculate_mean, axis=1))
56
+ assert result[0] == sum(valid_row) / 2
57
+ assert result[1] == -1
58
 
59
 
60
+ @pytest.mark.parametrize("models, expected", [
61
+ (["model1", "model3"], 2),
62
+ (["model1", "model_missing"], 1),
63
+ (["model1", "model2", "model3"], 3),
64
+ (["model1", ], 1),
65
+ ([], 3),
66
+ ])
67
+ def test_filter_models(models, expected):
68
+ df = pd.DataFrame(
69
+ {
70
+ COL_NAME_RERANKING_MODEL: ["model1", "model2", "model3", ],
71
+ "col2": [1, 2, 3],
72
+ }
73
+ )
74
+ output_df = filter_models(df, models)
75
+ assert len(output_df) == expected
76
+
77
+
78
+ @pytest.mark.parametrize("query, expected", [
79
+ ("model1;model3", 2),
80
+ ("model1;model4", 1),
81
+ ("model1;model2;model3", 3),
82
+ ("model1", 1),
83
+ ("", 3),
84
+ ])
85
+ def test_filter_queries(query, expected):
86
+ df = pd.DataFrame(
87
+ {
88
+ COL_NAME_RETRIEVAL_MODEL: ["model1", "model2", "model3", ],
89
+ COL_NAME_RERANKING_MODEL: ["model4", "model5", "model6", ],
90
+ }
91
+ )
92
+ output_df = filter_queries(query, df)
93
+ assert len(output_df) == expected
94
+
95
+
96
+ @pytest.mark.parametrize(
97
+ "task_type, slug, expected",
98
+ [
99
+ (TaskType.qa, "air_bench_2404", NUM_QA_BENCHMARKS_24_04),
100
+ (TaskType.long_doc, "air_bench_2404", NUM_DOC_BENCHMARKS_24_04),
101
+ (TaskType.qa, "air_bench_2405", NUM_QA_BENCHMARKS_24_05),
102
+ (TaskType.long_doc, "air_bench_2405", NUM_DOC_BENCHMARKS_24_05),
103
+ ]
104
+ )
105
+ def test_get_default_cols(task_type, slug, expected):
106
+ attr_cols = ['Rank 🏆', 'Retrieval Method', 'Reranking Model', 'Revision', 'Submission Date', 'Average ⬆️']
107
+ cols, types = get_default_cols(task_type, slug)
108
+ benchmark_cols = list(frozenset(cols).difference(frozenset(attr_cols)))
109
+ assert len(benchmark_cols) == expected
110
+
111
+
112
+ @pytest.mark.parametrize(
113
+ "task_type, domains, languages, expected",
114
+ [
115
+ (TaskType.qa, ["wiki", "news"], ["zh",], ["wiki_zh", "news_zh"]),
116
+ (TaskType.qa, ["law",], ["zh", "en"], ["law_en"]),
117
+ (
118
+ TaskType.long_doc,
119
+ ["healthcare"],
120
+ ["zh", "en"],
121
+ [
122
+ 'healthcare_en_pubmed_100k_200k_1',
123
+ 'healthcare_en_pubmed_100k_200k_2',
124
+ 'healthcare_en_pubmed_100k_200k_3',
125
+ 'healthcare_en_pubmed_40k_50k_5_merged',
126
+ 'healthcare_en_pubmed_30k_40k_10_merged'
127
+ ]
128
+ )
129
+ ]
130
+ )
131
+ def test_get_selected_cols(task_type, domains, languages, expected):
132
+ slug = "air_bench_2404"
133
+ cols = get_selected_cols(task_type, slug, domains, languages)
134
+ assert sorted(cols) == sorted(expected)
135
 
136
 
137
+ def test_select_columns(toy_df):
138
+ expected = [
139
+ 'Rank 🏆',
140
+ 'Retrieval Method',
141
+ 'Reranking Model',
142
+ 'Revision',
143
+ 'Submission Date',
144
+ 'Average ⬆️',
145
+ 'news_zh']
146
+ df_result = select_columns(
147
+ toy_df,
148
+ [
149
+ "news",
150
+ ],
151
+ [
152
+ "zh",
153
+ ],
154
+ version_slug="air_bench_2404",
155
+ )
156
+ assert len(df_result.columns) == len(expected)
157
+ assert df_result["Average ⬆️"].equals(df_result["news_zh"])
tests/test_utils.py CHANGED
@@ -75,18 +75,7 @@ def test_filter_queries(toy_df):
75
  assert df_result.iloc[0]["Retrieval Model"] == "jina-embeddings-v2-base"
76
 
77
 
78
- def test_select_columns(toy_df):
79
- df_result = select_columns(
80
- toy_df,
81
- [
82
- "news",
83
- ],
84
- [
85
- "zh",
86
- ],
87
- )
88
- assert len(df_result.columns) == 4
89
- assert df_result["Average ⬆️"].equals(df_result["news_zh"])
90
 
91
 
92
  def test_update_table_long_doc(toy_df_long_doc):
 
75
  assert df_result.iloc[0]["Retrieval Model"] == "jina-embeddings-v2-base"
76
 
77
 
78
+
 
 
 
 
 
 
 
 
 
 
 
79
 
80
 
81
  def test_update_table_long_doc(toy_df_long_doc):