Spaces:
AIR-Bench
/
Running on CPU Upgrade

File size: 3,005 Bytes
8b7a945
 
0785fe4
8b7a945
 
 
 
 
 
 
 
 
 
 
32ebf18
8b7a945
 
 
 
 
 
 
 
 
 
 
 
 
 
 
443f557
8b7a945
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9c49811
8b7a945
 
9c49811
 
 
8b7a945
a96f80a
9134169
 
0785fe4
8b7a945
 
 
9c49811
8b7a945
f30cbcc
8b7a945
9c49811
1a2dba5
8b7a945
f30cbcc
 
 
32ebf18
a96f80a
 
8b7a945
9134169
 
9c49811
e8879cc
f30cbcc
e8879cc
 
 
f30cbcc
 
 
32ebf18
973bd2a
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from dataclasses import dataclass
from enum import Enum
from air_benchmark.tasks.tasks import BenchmarkTable


def get_safe_name(name: str):
    """Get RFC 1123 compatible safe name"""
    name = name.replace('-', '_')
    return ''.join(
        character.lower()
        for character in name
        if (character.isalnum() or character == '_'))


METRIC_LIST = [
    "ndcg_at_1",
    "ndcg_at_3",
    "ndcg_at_5",
    "ndcg_at_10",
    "ndcg_at_100",
    "ndcg_at_1000",
    "map_at_1",
    "map_at_3",
    "map_at_5",
    "map_at_10",
    "map_at_100",
    "map_at_1000",
    "recall_at_1",
    "recall_at_3",
    "recall_at_5",
    "recall_at_10",
    "recall_at_100",
    "recall_at_1000",
    "precision_at_1",
    "precision_at_3",
    "precision_at_5",
    "precision_at_10",
    "precision_at_100",
    "precision_at_1000",
    "mrr_at_1",
    "mrr_at_3",
    "mrr_at_5",
    "mrr_at_10",
    "mrr_at_100",
    "mrr_at_1000"
]


@dataclass
class Benchmark:
    name: str  # [domain]_[language]_[metric], task_key in the json file,
    metric: str  # ndcg_at_1 ,metric_key in the json file
    col_name: str  # [domain]_[language], name to display in the leaderboard
    domain: str
    lang: str
    task: str


qa_benchmark_dict = {}
long_doc_benchmark_dict = {}
for task, domain_dict in BenchmarkTable['AIR-Bench_24.04'].items():
    for domain, lang_dict in domain_dict.items():
        for lang, dataset_list in lang_dict.items():
            if task == "qa":
                benchmark_name = f"{domain}_{lang}"
                benchmark_name = get_safe_name(benchmark_name)
                col_name = benchmark_name
                for metric in dataset_list:
                    qa_benchmark_dict[benchmark_name] = Benchmark(benchmark_name, metric, col_name, domain, lang, task)
            elif task == "long-doc":
                for dataset in dataset_list:
                    benchmark_name = f"{domain}_{lang}_{dataset}"
                    benchmark_name = get_safe_name(benchmark_name)
                    col_name = benchmark_name
                    for metric in METRIC_LIST:
                        long_doc_benchmark_dict[benchmark_name] = Benchmark(benchmark_name, metric, col_name, domain,
                                                                            lang, task)

BenchmarksQA = Enum('BenchmarksQA', qa_benchmark_dict)
BenchmarksLongDoc = Enum('BenchmarksLongDoc', long_doc_benchmark_dict)

BENCHMARK_COLS_QA = [c.col_name for c in qa_benchmark_dict.values()]
BENCHMARK_COLS_LONG_DOC = [c.col_name for c in long_doc_benchmark_dict.values()]

DOMAIN_COLS_QA = list(frozenset([c.domain for c in qa_benchmark_dict.values()]))
LANG_COLS_QA = list(frozenset([c.lang for c in qa_benchmark_dict.values()]))

DOMAIN_COLS_LONG_DOC = list(frozenset([c.domain for c in long_doc_benchmark_dict.values()]))
LANG_COLS_LONG_DOC = list(frozenset([c.lang for c in long_doc_benchmark_dict.values()]))

DEFAULT_METRIC_QA = "ndcg_at_10"
DEFAULT_METRIC_LONG_DOC = "recall_at_10"