Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
test: add unit tests for utils
Browse files- src/utils.py +20 -13
- tests/src/test_utils.py +151 -20
- tests/test_utils.py +1 -12
src/utils.py
CHANGED
@@ -98,14 +98,7 @@ def get_default_cols(task: TaskType, version_slug, add_fix_cols: bool = True) ->
|
|
98 |
return cols, types
|
99 |
|
100 |
|
101 |
-
def
|
102 |
-
df: pd.DataFrame,
|
103 |
-
domain_query: list,
|
104 |
-
language_query: list,
|
105 |
-
task: TaskType = TaskType.qa,
|
106 |
-
reset_ranking: bool = True,
|
107 |
-
version_slug: str = None,
|
108 |
-
) -> pd.DataFrame:
|
109 |
cols, _ = get_default_cols(task=task, version_slug=version_slug, add_fix_cols=False)
|
110 |
selected_cols = []
|
111 |
for c in cols:
|
@@ -115,21 +108,35 @@ def select_columns(
|
|
115 |
eval_col = LongDocBenchmarks[version_slug].value[c].value
|
116 |
else:
|
117 |
raise NotImplementedError
|
118 |
-
if eval_col.domain not in
|
119 |
continue
|
120 |
-
if eval_col.lang not in
|
121 |
continue
|
122 |
selected_cols.append(c)
|
123 |
# We use COLS to maintain sorting
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
fixed_cols, _ = get_fixed_col_names_and_types()
|
125 |
filtered_df = df[fixed_cols + selected_cols]
|
126 |
filtered_df.replace({"": pd.NA}, inplace=True)
|
127 |
if reset_ranking:
|
128 |
-
filtered_df[COL_NAME_AVG] =
|
129 |
-
|
|
|
|
|
130 |
filtered_df.reset_index(inplace=True, drop=True)
|
131 |
filtered_df = reset_rank(filtered_df)
|
132 |
-
|
133 |
return filtered_df
|
134 |
|
135 |
|
|
|
98 |
return cols, types
|
99 |
|
100 |
|
101 |
+
def get_selected_cols(task, version_slug, domains, languages):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
cols, _ = get_default_cols(task=task, version_slug=version_slug, add_fix_cols=False)
|
103 |
selected_cols = []
|
104 |
for c in cols:
|
|
|
108 |
eval_col = LongDocBenchmarks[version_slug].value[c].value
|
109 |
else:
|
110 |
raise NotImplementedError
|
111 |
+
if eval_col.domain not in domains:
|
112 |
continue
|
113 |
+
if eval_col.lang not in languages:
|
114 |
continue
|
115 |
selected_cols.append(c)
|
116 |
# We use COLS to maintain sorting
|
117 |
+
return selected_cols
|
118 |
+
|
119 |
+
|
120 |
+
def select_columns(
|
121 |
+
df: pd.DataFrame,
|
122 |
+
domains: list,
|
123 |
+
languages: list,
|
124 |
+
task: TaskType = TaskType.qa,
|
125 |
+
reset_ranking: bool = True,
|
126 |
+
version_slug: str = None,
|
127 |
+
) -> pd.DataFrame:
|
128 |
+
selected_cols = get_selected_cols(
|
129 |
+
task, version_slug, domains, languages)
|
130 |
fixed_cols, _ = get_fixed_col_names_and_types()
|
131 |
filtered_df = df[fixed_cols + selected_cols]
|
132 |
filtered_df.replace({"": pd.NA}, inplace=True)
|
133 |
if reset_ranking:
|
134 |
+
filtered_df[COL_NAME_AVG] = \
|
135 |
+
filtered_df[selected_cols].apply(calculate_mean, axis=1).round(decimals=2)
|
136 |
+
filtered_df.sort_values(
|
137 |
+
by=[COL_NAME_AVG], ascending=False, inplace=True)
|
138 |
filtered_df.reset_index(inplace=True, drop=True)
|
139 |
filtered_df = reset_rank(filtered_df)
|
|
|
140 |
return filtered_df
|
141 |
|
142 |
|
tests/src/test_utils.py
CHANGED
@@ -1,26 +1,157 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
|
12 |
-
def
|
13 |
-
|
14 |
-
|
|
|
|
|
|
|
|
|
15 |
|
16 |
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
|
24 |
-
def
|
25 |
-
|
26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pytest
|
2 |
+
import pandas as pd
|
3 |
+
|
4 |
+
from src.utils import remove_html, calculate_mean, filter_models, filter_queries, get_default_cols, select_columns, get_selected_cols
|
5 |
+
from src.models import model_hyperlink, TaskType
|
6 |
+
from src.columns import COL_NAME_RERANKING_MODEL, COL_NAME_RETRIEVAL_MODEL
|
7 |
+
|
8 |
+
|
9 |
+
NUM_QA_BENCHMARKS_24_05 = 53
|
10 |
+
NUM_DOC_BENCHMARKS_24_05 = 11
|
11 |
+
NUM_QA_BENCHMARKS_24_04 = 13
|
12 |
+
NUM_DOC_BENCHMARKS_24_04 = 15
|
13 |
+
|
14 |
+
|
15 |
+
@pytest.fixture
|
16 |
+
def toy_df():
|
17 |
+
return pd.DataFrame(
|
18 |
+
{
|
19 |
+
"Retrieval Method": [
|
20 |
+
"bge-m3",
|
21 |
+
"bge-m3",
|
22 |
+
"jina-embeddings-v2-base",
|
23 |
+
"jina-embeddings-v2-base"
|
24 |
+
],
|
25 |
+
"Reranking Model": [
|
26 |
+
"bge-reranker-v2-m3",
|
27 |
+
"NoReranker",
|
28 |
+
"bge-reranker-v2-m3",
|
29 |
+
"NoReranker"
|
30 |
+
],
|
31 |
+
"Rank 🏆": [1, 2, 3, 4],
|
32 |
+
"Revision": ["", "", "", ""],
|
33 |
+
"Submission Date": ["", "", "", ""],
|
34 |
+
"Average ⬆️": [0.6, 0.4, 0.3, 0.2],
|
35 |
+
"wiki_en": [0.8, 0.7, 0.2, 0.1],
|
36 |
+
"wiki_zh": [0.4, 0.1, 0.4, 0.3],
|
37 |
+
"news_en": [0.8, 0.7, 0.2, 0.1],
|
38 |
+
"news_zh": [0.4, 0.1, 0.4, 0.3],
|
39 |
+
}
|
40 |
+
)
|
41 |
+
|
42 |
+
|
43 |
+
def test_remove_html():
|
44 |
+
model_name = "jina-embeddings-v3"
|
45 |
+
html_str = model_hyperlink(
|
46 |
+
"https://jina.ai", model_name)
|
47 |
+
output_str = remove_html(html_str)
|
48 |
+
assert output_str == model_name
|
49 |
|
50 |
|
51 |
+
def test_calculate_mean():
|
52 |
+
valid_row = [1, 3]
|
53 |
+
invalid_row = [2, pd.NA]
|
54 |
+
df = pd.DataFrame([valid_row, invalid_row], columns=["a", "b"])
|
55 |
+
result = list(df.apply(calculate_mean, axis=1))
|
56 |
+
assert result[0] == sum(valid_row) / 2
|
57 |
+
assert result[1] == -1
|
58 |
|
59 |
|
60 |
+
@pytest.mark.parametrize("models, expected", [
|
61 |
+
(["model1", "model3"], 2),
|
62 |
+
(["model1", "model_missing"], 1),
|
63 |
+
(["model1", "model2", "model3"], 3),
|
64 |
+
(["model1", ], 1),
|
65 |
+
([], 3),
|
66 |
+
])
|
67 |
+
def test_filter_models(models, expected):
|
68 |
+
df = pd.DataFrame(
|
69 |
+
{
|
70 |
+
COL_NAME_RERANKING_MODEL: ["model1", "model2", "model3", ],
|
71 |
+
"col2": [1, 2, 3],
|
72 |
+
}
|
73 |
+
)
|
74 |
+
output_df = filter_models(df, models)
|
75 |
+
assert len(output_df) == expected
|
76 |
+
|
77 |
+
|
78 |
+
@pytest.mark.parametrize("query, expected", [
|
79 |
+
("model1;model3", 2),
|
80 |
+
("model1;model4", 1),
|
81 |
+
("model1;model2;model3", 3),
|
82 |
+
("model1", 1),
|
83 |
+
("", 3),
|
84 |
+
])
|
85 |
+
def test_filter_queries(query, expected):
|
86 |
+
df = pd.DataFrame(
|
87 |
+
{
|
88 |
+
COL_NAME_RETRIEVAL_MODEL: ["model1", "model2", "model3", ],
|
89 |
+
COL_NAME_RERANKING_MODEL: ["model4", "model5", "model6", ],
|
90 |
+
}
|
91 |
+
)
|
92 |
+
output_df = filter_queries(query, df)
|
93 |
+
assert len(output_df) == expected
|
94 |
+
|
95 |
+
|
96 |
+
@pytest.mark.parametrize(
|
97 |
+
"task_type, slug, expected",
|
98 |
+
[
|
99 |
+
(TaskType.qa, "air_bench_2404", NUM_QA_BENCHMARKS_24_04),
|
100 |
+
(TaskType.long_doc, "air_bench_2404", NUM_DOC_BENCHMARKS_24_04),
|
101 |
+
(TaskType.qa, "air_bench_2405", NUM_QA_BENCHMARKS_24_05),
|
102 |
+
(TaskType.long_doc, "air_bench_2405", NUM_DOC_BENCHMARKS_24_05),
|
103 |
+
]
|
104 |
+
)
|
105 |
+
def test_get_default_cols(task_type, slug, expected):
|
106 |
+
attr_cols = ['Rank 🏆', 'Retrieval Method', 'Reranking Model', 'Revision', 'Submission Date', 'Average ⬆️']
|
107 |
+
cols, types = get_default_cols(task_type, slug)
|
108 |
+
benchmark_cols = list(frozenset(cols).difference(frozenset(attr_cols)))
|
109 |
+
assert len(benchmark_cols) == expected
|
110 |
+
|
111 |
+
|
112 |
+
@pytest.mark.parametrize(
|
113 |
+
"task_type, domains, languages, expected",
|
114 |
+
[
|
115 |
+
(TaskType.qa, ["wiki", "news"], ["zh",], ["wiki_zh", "news_zh"]),
|
116 |
+
(TaskType.qa, ["law",], ["zh", "en"], ["law_en"]),
|
117 |
+
(
|
118 |
+
TaskType.long_doc,
|
119 |
+
["healthcare"],
|
120 |
+
["zh", "en"],
|
121 |
+
[
|
122 |
+
'healthcare_en_pubmed_100k_200k_1',
|
123 |
+
'healthcare_en_pubmed_100k_200k_2',
|
124 |
+
'healthcare_en_pubmed_100k_200k_3',
|
125 |
+
'healthcare_en_pubmed_40k_50k_5_merged',
|
126 |
+
'healthcare_en_pubmed_30k_40k_10_merged'
|
127 |
+
]
|
128 |
+
)
|
129 |
+
]
|
130 |
+
)
|
131 |
+
def test_get_selected_cols(task_type, domains, languages, expected):
|
132 |
+
slug = "air_bench_2404"
|
133 |
+
cols = get_selected_cols(task_type, slug, domains, languages)
|
134 |
+
assert sorted(cols) == sorted(expected)
|
135 |
|
136 |
|
137 |
+
def test_select_columns(toy_df):
|
138 |
+
expected = [
|
139 |
+
'Rank 🏆',
|
140 |
+
'Retrieval Method',
|
141 |
+
'Reranking Model',
|
142 |
+
'Revision',
|
143 |
+
'Submission Date',
|
144 |
+
'Average ⬆️',
|
145 |
+
'news_zh']
|
146 |
+
df_result = select_columns(
|
147 |
+
toy_df,
|
148 |
+
[
|
149 |
+
"news",
|
150 |
+
],
|
151 |
+
[
|
152 |
+
"zh",
|
153 |
+
],
|
154 |
+
version_slug="air_bench_2404",
|
155 |
+
)
|
156 |
+
assert len(df_result.columns) == len(expected)
|
157 |
+
assert df_result["Average ⬆️"].equals(df_result["news_zh"])
|
tests/test_utils.py
CHANGED
@@ -75,18 +75,7 @@ def test_filter_queries(toy_df):
|
|
75 |
assert df_result.iloc[0]["Retrieval Model"] == "jina-embeddings-v2-base"
|
76 |
|
77 |
|
78 |
-
|
79 |
-
df_result = select_columns(
|
80 |
-
toy_df,
|
81 |
-
[
|
82 |
-
"news",
|
83 |
-
],
|
84 |
-
[
|
85 |
-
"zh",
|
86 |
-
],
|
87 |
-
)
|
88 |
-
assert len(df_result.columns) == 4
|
89 |
-
assert df_result["Average ⬆️"].equals(df_result["news_zh"])
|
90 |
|
91 |
|
92 |
def test_update_table_long_doc(toy_df_long_doc):
|
|
|
75 |
assert df_result.iloc[0]["Retrieval Model"] == "jina-embeddings-v2-base"
|
76 |
|
77 |
|
78 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
|
80 |
|
81 |
def test_update_table_long_doc(toy_df_long_doc):
|