File size: 9,301 Bytes
f1e1ac2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
import os
import json
import re
import argparse
import torch
import gc
from tqdm import tqdm

from transformers import AutoModelForCausalLM, AutoModel, AutoTokenizer

os.environ['CUDA_VISIBLE_DEVICES'] = '1'


def parse_args():
    parser = argparse.ArgumentParser(description='Validation')
    parser.add_argument('--model_path', dest="model_path")
    parser.add_argument('--dataset_path', dest="dataset_path")
    parser.add_argument('--save_path', dest='save_path')
    parser.add_argument('--max_length', dest="max_length", default=2000)
    parser.add_argument('--max_new_tokens', dest="max_new_tokens", default=48)
    parser.add_argument('--input_template', dest="input_template",
                        default="材料:{material}\n问题:{question}\n{options}\n答案:{response}")
    parser.add_argument('--query_template', dest="query_template", default="问题:{question}\n{options}\n答案:{response}")
    parser.add_argument('--system_prompt', dest="system_prompt",
                        default="你是一名考生,请回答问题的正确选项,比如C。如果有多个正确选项,请按顺序回答所有正确选项,比如ABD。")
    parser.add_argument('--level_delimiter', dest="level_delimiter", default="|")
    args = parser.parse_args()
    return args


class Validator:
    def __init__(self, args):
        self.load_model(args.model_path)
        self.load_dataset(args.dataset_path)
        self.save_dir = os.path.join(args.save_path, os.path.split(model_path)[-1])
        if not os.path.exists(self.save_dir):
            os.makedirs(self.save_dir)

        self.max_length = args.max_length
        self.max_new_tokens = args.max_new_tokens
        self.input_template = args.input_template
        self.query_template = args.query_template
        self.system_prompt = args.system_prompt
        self.level_delimiter = args.level_delimiter

    def load_model(self, model_path):
        self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

        self.model = AutoModelForCausalLM.from_pretrained(model_path, device_map={"": 0}, trust_remote_code=True)
        self.model.eval()

    def load_dataset(self, dataset_path):
        self.dataset = json.load(open(dataset_path, encoding="utf-8"))

    def format_prompt(self, material, question, options, response):
        if material:
            return self.input_template.format(material=material, question=question, options=options,
                                              response=response).strip()
        return self.query_template.format(question=question, options=options, response=response).strip()

    def build_prompt(self, item):
        query_prompt = self.format_prompt(item['material'], item['question'], item['options'], "")
        history_prompts = []
        for sub in item['history']:
            history_prompts.append(self.format_prompt(sub['material'], sub['question'], sub['options'], sub['choice']))

            final_prompt = self.system_prompt + "\n" + "\n".join(history_prompts) + "\n" + query_prompt

            if len(self.tokenizer.tokenize(final_prompt)) > self.max_length:
                history_prompts.pop()
                break
        return self.system_prompt + "\n" + "\n".join(history_prompts) + "\n" + query_prompt

    def get_predict(self, prompt):
        gen_kwargs = {"do_sample": False, "max_new_tokens": self.max_new_tokens}
        inputs = self.tokenizer([prompt], return_tensors="pt")
        inputs = inputs.to(self.model.device)
        with torch.no_grad():
            outputs = self.model.generate(**inputs, return_dict_in_generate=True, **gen_kwargs)
        predict = self.tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)[len(prompt):]
        return predict

    def extract_answer(self, results):
        predict_result = {"acc": 0, "wrong_value": 0, "human_acc": 0, "hit": 0, "total": 0, "wrong_hit": 0,
                          "wrong_total": 0, 'detail': []}
        for result in results:
            answer = result['item']['choice']
            most_wrong = result['item']['most_wrong']
            human_acc = result['item']['human_acc'] * 0.01

            predict = ""
            for e in re.split(r'[、,.,和\s]\s*', result['predict']):
                if not e.isascii():
                    break
                predict += e
            result['predict'] = predict
            predict_result['detail'].append(
                {"answer": answer, "most_wrong": most_wrong, "predict": predict, "human_acc": human_acc})
            predict_result['hit'] += 1 if answer == predict else 0
            predict_result['wrong_hit'] += 1 if most_wrong == predict else 0
            predict_result['wrong_value'] += 1 - human_acc if most_wrong == predict else 0
            predict_result['human_acc'] += human_acc
            predict_result['total'] += 1
        predict_result['acc'] = predict_result['hit'] / predict_result['total']
        predict_result['wrong_total'] = predict_result['total'] - predict_result['hit']
        predict_result['wrong_value'] = predict_result['wrong_value'] / predict_result['wrong_total']
        predict_result['human_acc'] = predict_result['human_acc'] / len(results)

        json.dump(predict_result, open(os.path.join(self.save_dir, "acc_result.json"), "w", encoding="utf-8"),
                  ensure_ascii=False)

    def category_summary(self, results):
        category_result = {"总计": {"hit": 0, "all": 0, "difficulty": {}, "human_acc": 0}}
        for result in results:
            hit = 1 if result['item']['choice'] == result['predict'] else 0
            categories_list = result['item']['categories']
            difficulty = result['item']['difficulty']
            human_acc = result['item']['human_acc']
            for categories in categories_list:
                if difficulty not in category_result["总计"]["difficulty"]:
                    category_result["总计"]["difficulty"][difficulty] = {"hit": 0, "all": 0}
                category_result["总计"]["difficulty"][difficulty]['hit'] += hit
                category_result["总计"]["difficulty"][difficulty]['all'] += 1
                category_result["总计"]['hit'] += hit
                category_result["总计"]['all'] += 1
                category_result["总计"]['human_acc'] += human_acc
                category_subset = []
                for category in categories:
                    category_subset.append(category)
                    category_name = self.level_delimiter.join(category_subset)
                    if not category_name:
                        category_name = "未分类"
                    if category_name not in category_result:
                        category_result[category_name] = {"hit": 0, "all": 0, "difficulty": {}, "human_acc": 0}
                    if difficulty not in category_result[category_name]["difficulty"]:
                        category_result[category_name]["difficulty"][difficulty] = {"hit": 0, "all": 0}
                    category_result[category_name]["difficulty"][difficulty]['hit'] += hit
                    category_result[category_name]["difficulty"][difficulty]['all'] += 1
                    category_result[category_name]['hit'] += hit
                    category_result[category_name]['all'] += 1
                    category_result[category_name]['human_acc'] += human_acc

        for k, v in category_result.items():
            v['acc'] = v['hit'] / v['all']
            v['human_acc'] = v['human_acc'] / v['all']
            for d, sub_v in v['difficulty'].items():
                sub_v['acc'] = sub_v['hit'] / sub_v['all']
        json.dump(category_result, open(os.path.join(self.save_dir, "category_result.json"), "w", encoding="utf-8"),
                  ensure_ascii=False)

    def difficulty_summary(self, results):
        difficulty_result = {"总计": {"hit": 0, "all": 0}}
        for result in results:
            hit = 1 if result['item']['choice'] == result['predict'] else 0
            difficulty = result['item']['difficulty']
            if difficulty not in difficulty_result:
                difficulty_result[difficulty] = {"hit": 0, "all": 0}
            difficulty_result[difficulty]['hit'] += hit
            difficulty_result[difficulty]['all'] += 1
            difficulty_result["总计"]['hit'] += hit
            difficulty_result["总计"]['all'] += 1

        for k in difficulty_result:
            difficulty_result[k]['acc'] = difficulty_result[k]['hit'] / difficulty_result[k]['all']

        json.dump(difficulty_result, open(os.path.join(self.save_dir, "difficulty_result.json"), "w", encoding="utf-8"),
                  ensure_ascii=False)

    def __call__(self):
        results = []
        for item in tqdm(self.dataset):
            prompt = self.build_prompt(item)
            predict = self.get_predict(prompt)
            results.append({"item": item, "predict": predict})
            gc.collect()

        json.dump(results, open(os.path.join(self.save_dir, "raw.json"), "w", encoding="utf-8"), ensure_ascii=False)
        self.extract_answer(results)
        self.category_summary(results)
        self.difficulty_summary(results)


if __name__ == "__main__":
    args = parse_args()
    Validator(args)()