Spaces:
AIR-Bench
/
Running on CPU Upgrade

nan commited on
Commit
f997dd6
·
1 Parent(s): 9fcf267

test: add unit tests for models

Browse files
Files changed (2) hide show
  1. src/models.py +10 -9
  2. tests/src/test_models.py +33 -4
src/models.py CHANGED
@@ -109,18 +109,19 @@ class FullEvalResult:
109
  continue
110
  if eval_result.task != task:
111
  continue
112
- results[eval_result.eval_name]["eval_name"] = eval_result.eval_name
113
- results[eval_result.eval_name][COL_NAME_RETRIEVAL_MODEL] = make_clickable_model(
 
114
  self.retrieval_model, self.retrieval_model_link
115
  )
116
- results[eval_result.eval_name][COL_NAME_RERANKING_MODEL] = make_clickable_model(
117
  self.reranking_model, self.reranking_model_link
118
  )
119
- results[eval_result.eval_name][COL_NAME_RETRIEVAL_MODEL_LINK] = self.retrieval_model_link
120
- results[eval_result.eval_name][COL_NAME_RERANKING_MODEL_LINK] = self.reranking_model_link
121
- results[eval_result.eval_name][COL_NAME_REVISION] = self.revision
122
- results[eval_result.eval_name][COL_NAME_TIMESTAMP] = self.timestamp
123
- results[eval_result.eval_name][COL_NAME_IS_ANONYMOUS] = self.is_anonymous
124
 
125
  for result in eval_result.results:
126
  # add result for each domain, language, and dataset
@@ -132,7 +133,7 @@ class FullEvalResult:
132
  benchmark_name = f"{domain}_{lang}"
133
  else:
134
  benchmark_name = f"{domain}_{lang}_{dataset}"
135
- results[eval_result.eval_name][get_safe_name(benchmark_name)] = value
136
  return [v for v in results.values()]
137
 
138
 
 
109
  continue
110
  if eval_result.task != task:
111
  continue
112
+ eval_name = eval_result.eval_name
113
+ results[eval_name]["eval_name"] = eval_name
114
+ results[eval_name][COL_NAME_RETRIEVAL_MODEL] = make_clickable_model(
115
  self.retrieval_model, self.retrieval_model_link
116
  )
117
+ results[eval_name][COL_NAME_RERANKING_MODEL] = make_clickable_model(
118
  self.reranking_model, self.reranking_model_link
119
  )
120
+ results[eval_name][COL_NAME_RETRIEVAL_MODEL_LINK] = self.retrieval_model_link
121
+ results[eval_name][COL_NAME_RERANKING_MODEL_LINK] = self.reranking_model_link
122
+ results[eval_name][COL_NAME_REVISION] = self.revision
123
+ results[eval_name][COL_NAME_TIMESTAMP] = self.timestamp
124
+ results[eval_name][COL_NAME_IS_ANONYMOUS] = self.is_anonymous
125
 
126
  for result in eval_result.results:
127
  # add result for each domain, language, and dataset
 
133
  benchmark_name = f"{domain}_{lang}"
134
  else:
135
  benchmark_name = f"{domain}_{lang}_{dataset}"
136
+ results[eval_name][get_safe_name(benchmark_name)] = value
137
  return [v for v in results.values()]
138
 
139
 
tests/src/test_models.py CHANGED
@@ -6,6 +6,23 @@ from src.models import EvalResult, FullEvalResult
6
  cur_fp = Path(__file__)
7
 
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  def test_eval_result():
10
  eval_result = EvalResult(
11
  eval_name="eval_name",
@@ -41,9 +58,21 @@ def test_full_eval_result_init_from_json_file(file_path):
41
  assert len(full_eval_result.results) == 70
42
 
43
 
44
- def test_full_eval_result_to_dict():
45
- json_fp = cur_fp.parents[1] / "toydata/eval_results/" / "AIR-Bench_24.05/bge-m3/NoReranker/results.json"
 
 
 
 
 
 
 
 
46
  full_eval_result = FullEvalResult.init_from_json_file(json_fp)
47
- result_dict_list = full_eval_result.to_dict()
48
  assert len(result_dict_list) == 1
49
- print(len(result_dict_list[0]))
 
 
 
 
 
6
  cur_fp = Path(__file__)
7
 
8
 
9
+ # Ref: https://github.com/AIR-Bench/AIR-Bench/blob/4b27b8a8f2047a963805fcf6fb9d74be51ec440c/docs/available_tasks.md
10
+ # 24.05
11
+ # | Task | dev | test |
12
+ # | ---- | --- | ---- |
13
+ # | Long-Doc | 4 | 11 |
14
+ # | QA | 54 | 53 |
15
+ #
16
+ # 24.04
17
+ # | Task | test |
18
+ # | ---- | ---- |
19
+ # | Long-Doc | 15 |
20
+ # | QA | 13 |
21
+ NUM_QA_BENCHMARKS_24_05 = 53
22
+ NUM_DOC_BENCHMARKS_24_05 = 11
23
+ NUM_QA_BENCHMARKS_24_04 = 13
24
+ NUM_DOC_BENCHMARKS_24_04 = 15
25
+
26
  def test_eval_result():
27
  eval_result = EvalResult(
28
  eval_name="eval_name",
 
58
  assert len(full_eval_result.results) == 70
59
 
60
 
61
+ @pytest.mark.parametrize(
62
+ 'file_path, task, expected_num_results',
63
+ [
64
+ ("AIR-Bench_24.04/bge-m3/jina-reranker-v2-base-multilingual/results.json", "qa", NUM_QA_BENCHMARKS_24_04),
65
+ ("AIR-Bench_24.04/bge-m3/jina-reranker-v2-base-multilingual/results.json", "long-doc", NUM_DOC_BENCHMARKS_24_04),
66
+ ("AIR-Bench_24.05/bge-m3/NoReranker/results.json", "qa", NUM_QA_BENCHMARKS_24_05),
67
+ ("AIR-Bench_24.05/bge-m3/NoReranker/results.json", "long-doc", NUM_DOC_BENCHMARKS_24_05),
68
+ ])
69
+ def test_full_eval_result_to_dict(file_path, task, expected_num_results):
70
+ json_fp = cur_fp.parents[1] / "toydata/eval_results/" / file_path
71
  full_eval_result = FullEvalResult.init_from_json_file(json_fp)
72
+ result_dict_list = full_eval_result.to_dict(task)
73
  assert len(result_dict_list) == 1
74
+ result = result_dict_list[0]
75
+ attr_list = frozenset([
76
+ 'eval_name', 'Retrieval Method', 'Reranking Model', 'Retrieval Model LINK', 'Reranking Model LINK', 'Revision', 'Submission Date', 'Anonymous Submission'])
77
+ result_cols = list(result.keys())
78
+ assert len(result_cols) == (expected_num_results + len(attr_list))