File size: 4,952 Bytes
f90ad24
0227006
f90ad24
0227006
2102b66
0227006
f90ad24
0227006
f90ad24
 
 
2102b66
 
 
 
f90ad24
2102b66
 
 
 
 
 
 
 
 
07bfeca
 
ea4eff1
2102b66
07bfeca
 
ea4eff1
2102b66
07bfeca
 
ea4eff1
07bfeca
f90ad24
2102b66
f90ad24
 
ea4eff1
f90ad24
2102b66
f90ad24
 
2102b66
 
 
 
 
 
 
b2c063a
 
2102b66
b2c063a
2102b66
f90ad24
2102b66
59c748f
fcb01e3
a460f7a
ffefe11
a460f7a
2102b66
 
 
 
f90ad24
 
 
2102b66
 
f90ad24
2102b66
f90ad24
2102b66
 
f90ad24
 
 
 
 
 
fcb01e3
f90ad24
fcb01e3
2102b66
f90ad24
 
fcb01e3
2102b66
 
 
fcb01e3
2102b66
f90ad24
2102b66
f90ad24
 
2102b66
 
 
 
 
f90ad24
2102b66
 
db6f218
2102b66
 
 
db6f218
2102b66
 
 
 
 
 
 
 
 
f90ad24
2102b66
f90ad24
 
 
 
 
 
2102b66
 
 
f90ad24
2102b66
 
db6f218
 
2102b66
f90ad24
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import glob
import json
from dataclasses import dataclass
from typing import Dict, List, Tuple

import numpy as np

# clone / pull the lmeh eval data
METRICS = ["acc_norm", "acc_norm", "acc_norm", "mc2"]
BENCHMARKS = ["arc_challenge", "hellaswag", "hendrycks", "truthfulqa_mc"]
BENCH_TO_NAME = {
    "arc_challenge": "ARC (25-shot) ⬆️",
    "hellaswag": "HellaSwag (10-shot) ⬆️",
    "hendrycks": "MMLU (5-shot) ⬆️",
    "truthfulqa_mc": "TruthfulQA (0-shot) ⬆️",
}


def make_clickable_model(model_name):
    LLAMAS = [
        "huggingface/llama-7b",
        "huggingface/llama-13b",
        "huggingface/llama-30b",
        "huggingface/llama-65b",
    ]
    if model_name in LLAMAS:
        model = model_name.split("/")[1]
        return f'<a target="_blank" href="https://ai.facebook.com/blog/large-language-model-llama-meta-ai/" style="color: var(--link-text-color); text-decoration: underline;text-decoration-style: dotted;">{model}</a>'

    if model_name == "HuggingFaceH4/stable-vicuna-13b-2904":
        link = "https://huggingface.co./" + "CarperAI/stable-vicuna-13b-delta"
        return f'<a target="_blank" href="{link}" style="color: var(--link-text-color); text-decoration: underline;text-decoration-style: dotted;">stable-vicuna-13b</a>'

    if model_name == "HuggingFaceH4/llama-7b-ift-alpaca":
        link = "https://crfm.stanford.edu/2023/03/13/alpaca.html"
        return f'<a target="_blank" href="{link}" style="color: var(--link-text-color); text-decoration: underline;text-decoration-style: dotted;">alpaca-13b</a>'

    # remove user from model name
    # model_name_show = ' '.join(model_name.split('/')[1:])

    link = "https://huggingface.co./" + model_name
    return f'<a target="_blank" href="{link}" style="color: var(--link-text-color); text-decoration: underline;text-decoration-style: dotted;">{model_name}</a>'


@dataclass
class EvalResult:
    eval_name: str
    org: str
    model: str
    revision: str
    is_8bit: bool
    results: dict

    def to_dict(self):
        if self.org is not None:
            base_model = f"{self.org}/{self.model}"
        else:
            base_model = f"{self.model}"
        data_dict = {}

        data_dict["eval_name"] = self.eval_name
        data_dict["8bit"] = self.is_8bit
        data_dict["Model"] = make_clickable_model(base_model)
        data_dict["model_name_for_query"] = base_model
        data_dict["Revision"] = self.revision
        data_dict["Average ⬆️"] = round(
            sum([v for k, v in self.results.items()]) / 4.0, 1
        )

        for benchmark in BENCHMARKS:
            if not benchmark in self.results.keys():
                self.results[benchmark] = None

        for k, v in BENCH_TO_NAME.items():
            data_dict[v] = self.results[k]

        return data_dict


def parse_eval_result(json_filepath: str) -> Tuple[str, dict]:
    with open(json_filepath) as fp:
        data = json.load(fp)

    path_split = json_filepath.split("/")
    org = None
    model = path_split[-4]
    is_8bit = path_split[-2] == "8bit"
    revision = path_split[-3]
    if len(path_split) == 7:
        # handles gpt2 type models that don't have an org
        result_key = f"{path_split[-4]}_{path_split[-3]}_{path_split[-2]}"
    else:
        result_key = (
            f"{path_split[-5]}_{path_split[-4]}_{path_split[-3]}_{path_split[-2]}"
        )
        org = path_split[-5]

    eval_result = None
    for benchmark, metric in zip(BENCHMARKS, METRICS):
        if benchmark in json_filepath:
            accs = np.array([v[metric] for k, v in data["results"].items()])
            mean_acc = round(np.mean(accs) * 100.0, 1)
            eval_result = EvalResult(
                result_key, org, model, revision, is_8bit, {benchmark: mean_acc}
            )

    return result_key, eval_result


def get_eval_results(is_public) -> List[EvalResult]:
    json_filepaths = glob.glob(
        "evals/eval_results/public/**/16bit/*.json", recursive=True
    )
    if not is_public:
        json_filepaths += glob.glob(
            "evals/eval_results/private/**/*.json", recursive=True
        )
        json_filepaths += glob.glob(
            "evals/eval_results/private/**/*.json", recursive=True
        )
        json_filepaths += glob.glob(
            "evals/eval_results/public/**/8bit/*.json", recursive=True
        )  # include the 8bit evals of public models
    eval_results = {}

    for json_filepath in json_filepaths:
        result_key, eval_result = parse_eval_result(json_filepath)
        if result_key in eval_results.keys():
            eval_results[result_key].results.update(eval_result.results)
        else:
            eval_results[result_key] = eval_result

    eval_results = [v for k, v in eval_results.items()]

    return eval_results


def get_eval_results_dicts(is_public=True) -> List[Dict]:
    eval_results = get_eval_results(is_public)

    return [e.to_dict() for e in eval_results]