ANGO-Leaderboard / assets /evaluation.py
ango
upload all files
f1e1ac2
raw
history blame
9.3 kB
import os
import json
import re
import argparse
import torch
import gc
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoModel, AutoTokenizer
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
def parse_args():
parser = argparse.ArgumentParser(description='Validation')
parser.add_argument('--model_path', dest="model_path")
parser.add_argument('--dataset_path', dest="dataset_path")
parser.add_argument('--save_path', dest='save_path')
parser.add_argument('--max_length', dest="max_length", default=2000)
parser.add_argument('--max_new_tokens', dest="max_new_tokens", default=48)
parser.add_argument('--input_template', dest="input_template",
default="材料:{material}\n问题:{question}\n{options}\n答案:{response}")
parser.add_argument('--query_template', dest="query_template", default="问题:{question}\n{options}\n答案:{response}")
parser.add_argument('--system_prompt', dest="system_prompt",
default="你是一名考生,请回答问题的正确选项,比如C。如果有多个正确选项,请按顺序回答所有正确选项,比如ABD。")
parser.add_argument('--level_delimiter', dest="level_delimiter", default="|")
args = parser.parse_args()
return args
class Validator:
def __init__(self, args):
self.load_model(args.model_path)
self.load_dataset(args.dataset_path)
self.save_dir = os.path.join(args.save_path, os.path.split(model_path)[-1])
if not os.path.exists(self.save_dir):
os.makedirs(self.save_dir)
self.max_length = args.max_length
self.max_new_tokens = args.max_new_tokens
self.input_template = args.input_template
self.query_template = args.query_template
self.system_prompt = args.system_prompt
self.level_delimiter = args.level_delimiter
def load_model(self, model_path):
self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
self.model = AutoModelForCausalLM.from_pretrained(model_path, device_map={"": 0}, trust_remote_code=True)
self.model.eval()
def load_dataset(self, dataset_path):
self.dataset = json.load(open(dataset_path, encoding="utf-8"))
def format_prompt(self, material, question, options, response):
if material:
return self.input_template.format(material=material, question=question, options=options,
response=response).strip()
return self.query_template.format(question=question, options=options, response=response).strip()
def build_prompt(self, item):
query_prompt = self.format_prompt(item['material'], item['question'], item['options'], "")
history_prompts = []
for sub in item['history']:
history_prompts.append(self.format_prompt(sub['material'], sub['question'], sub['options'], sub['choice']))
final_prompt = self.system_prompt + "\n" + "\n".join(history_prompts) + "\n" + query_prompt
if len(self.tokenizer.tokenize(final_prompt)) > self.max_length:
history_prompts.pop()
break
return self.system_prompt + "\n" + "\n".join(history_prompts) + "\n" + query_prompt
def get_predict(self, prompt):
gen_kwargs = {"do_sample": False, "max_new_tokens": self.max_new_tokens}
inputs = self.tokenizer([prompt], return_tensors="pt")
inputs = inputs.to(self.model.device)
with torch.no_grad():
outputs = self.model.generate(**inputs, return_dict_in_generate=True, **gen_kwargs)
predict = self.tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)[len(prompt):]
return predict
def extract_answer(self, results):
predict_result = {"acc": 0, "wrong_value": 0, "human_acc": 0, "hit": 0, "total": 0, "wrong_hit": 0,
"wrong_total": 0, 'detail': []}
for result in results:
answer = result['item']['choice']
most_wrong = result['item']['most_wrong']
human_acc = result['item']['human_acc'] * 0.01
predict = ""
for e in re.split(r'[、,.,和\s]\s*', result['predict']):
if not e.isascii():
break
predict += e
result['predict'] = predict
predict_result['detail'].append(
{"answer": answer, "most_wrong": most_wrong, "predict": predict, "human_acc": human_acc})
predict_result['hit'] += 1 if answer == predict else 0
predict_result['wrong_hit'] += 1 if most_wrong == predict else 0
predict_result['wrong_value'] += 1 - human_acc if most_wrong == predict else 0
predict_result['human_acc'] += human_acc
predict_result['total'] += 1
predict_result['acc'] = predict_result['hit'] / predict_result['total']
predict_result['wrong_total'] = predict_result['total'] - predict_result['hit']
predict_result['wrong_value'] = predict_result['wrong_value'] / predict_result['wrong_total']
predict_result['human_acc'] = predict_result['human_acc'] / len(results)
json.dump(predict_result, open(os.path.join(self.save_dir, "acc_result.json"), "w", encoding="utf-8"),
ensure_ascii=False)
def category_summary(self, results):
category_result = {"总计": {"hit": 0, "all": 0, "difficulty": {}, "human_acc": 0}}
for result in results:
hit = 1 if result['item']['choice'] == result['predict'] else 0
categories_list = result['item']['categories']
difficulty = result['item']['difficulty']
human_acc = result['item']['human_acc']
for categories in categories_list:
if difficulty not in category_result["总计"]["difficulty"]:
category_result["总计"]["difficulty"][difficulty] = {"hit": 0, "all": 0}
category_result["总计"]["difficulty"][difficulty]['hit'] += hit
category_result["总计"]["difficulty"][difficulty]['all'] += 1
category_result["总计"]['hit'] += hit
category_result["总计"]['all'] += 1
category_result["总计"]['human_acc'] += human_acc
category_subset = []
for category in categories:
category_subset.append(category)
category_name = self.level_delimiter.join(category_subset)
if not category_name:
category_name = "未分类"
if category_name not in category_result:
category_result[category_name] = {"hit": 0, "all": 0, "difficulty": {}, "human_acc": 0}
if difficulty not in category_result[category_name]["difficulty"]:
category_result[category_name]["difficulty"][difficulty] = {"hit": 0, "all": 0}
category_result[category_name]["difficulty"][difficulty]['hit'] += hit
category_result[category_name]["difficulty"][difficulty]['all'] += 1
category_result[category_name]['hit'] += hit
category_result[category_name]['all'] += 1
category_result[category_name]['human_acc'] += human_acc
for k, v in category_result.items():
v['acc'] = v['hit'] / v['all']
v['human_acc'] = v['human_acc'] / v['all']
for d, sub_v in v['difficulty'].items():
sub_v['acc'] = sub_v['hit'] / sub_v['all']
json.dump(category_result, open(os.path.join(self.save_dir, "category_result.json"), "w", encoding="utf-8"),
ensure_ascii=False)
def difficulty_summary(self, results):
difficulty_result = {"总计": {"hit": 0, "all": 0}}
for result in results:
hit = 1 if result['item']['choice'] == result['predict'] else 0
difficulty = result['item']['difficulty']
if difficulty not in difficulty_result:
difficulty_result[difficulty] = {"hit": 0, "all": 0}
difficulty_result[difficulty]['hit'] += hit
difficulty_result[difficulty]['all'] += 1
difficulty_result["总计"]['hit'] += hit
difficulty_result["总计"]['all'] += 1
for k in difficulty_result:
difficulty_result[k]['acc'] = difficulty_result[k]['hit'] / difficulty_result[k]['all']
json.dump(difficulty_result, open(os.path.join(self.save_dir, "difficulty_result.json"), "w", encoding="utf-8"),
ensure_ascii=False)
def __call__(self):
results = []
for item in tqdm(self.dataset):
prompt = self.build_prompt(item)
predict = self.get_predict(prompt)
results.append({"item": item, "predict": predict})
gc.collect()
json.dump(results, open(os.path.join(self.save_dir, "raw.json"), "w", encoding="utf-8"), ensure_ascii=False)
self.extract_answer(results)
self.category_summary(results)
self.difficulty_summary(results)
if __name__ == "__main__":
args = parse_args()
Validator(args)()