ANGO-Leaderboard / assets /evaluation.py
ango
upload all files
f1e1ac2
import os
import json
import re
import argparse
import torch
import gc
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoModel, AutoTokenizer
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
def parse_args():
parser = argparse.ArgumentParser(description='Validation')
parser.add_argument('--model_path', dest="model_path")
parser.add_argument('--dataset_path', dest="dataset_path")
parser.add_argument('--save_path', dest='save_path')
parser.add_argument('--max_length', dest="max_length", default=2000)
parser.add_argument('--max_new_tokens', dest="max_new_tokens", default=48)
parser.add_argument('--input_template', dest="input_template",
default="材料:{material}\n问题:{question}\n{options}\n答案:{response}")
parser.add_argument('--query_template', dest="query_template", default="问题:{question}\n{options}\n答案:{response}")
parser.add_argument('--system_prompt', dest="system_prompt",
default="你是一名考生,请回答问题的正确选项,比如C。如果有多个正确选项,请按顺序回答所有正确选项,比如ABD。")
parser.add_argument('--level_delimiter', dest="level_delimiter", default="|")
args = parser.parse_args()
return args
class Validator:
def __init__(self, args):
self.load_model(args.model_path)
self.load_dataset(args.dataset_path)
self.save_dir = os.path.join(args.save_path, os.path.split(model_path)[-1])
if not os.path.exists(self.save_dir):
os.makedirs(self.save_dir)
self.max_length = args.max_length
self.max_new_tokens = args.max_new_tokens
self.input_template = args.input_template
self.query_template = args.query_template
self.system_prompt = args.system_prompt
self.level_delimiter = args.level_delimiter
def load_model(self, model_path):
self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
self.model = AutoModelForCausalLM.from_pretrained(model_path, device_map={"": 0}, trust_remote_code=True)
self.model.eval()
def load_dataset(self, dataset_path):
self.dataset = json.load(open(dataset_path, encoding="utf-8"))
def format_prompt(self, material, question, options, response):
if material:
return self.input_template.format(material=material, question=question, options=options,
response=response).strip()
return self.query_template.format(question=question, options=options, response=response).strip()
def build_prompt(self, item):
query_prompt = self.format_prompt(item['material'], item['question'], item['options'], "")
history_prompts = []
for sub in item['history']:
history_prompts.append(self.format_prompt(sub['material'], sub['question'], sub['options'], sub['choice']))
final_prompt = self.system_prompt + "\n" + "\n".join(history_prompts) + "\n" + query_prompt
if len(self.tokenizer.tokenize(final_prompt)) > self.max_length:
history_prompts.pop()
break
return self.system_prompt + "\n" + "\n".join(history_prompts) + "\n" + query_prompt
def get_predict(self, prompt):
gen_kwargs = {"do_sample": False, "max_new_tokens": self.max_new_tokens}
inputs = self.tokenizer([prompt], return_tensors="pt")
inputs = inputs.to(self.model.device)
with torch.no_grad():
outputs = self.model.generate(**inputs, return_dict_in_generate=True, **gen_kwargs)
predict = self.tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)[len(prompt):]
return predict
def extract_answer(self, results):
predict_result = {"acc": 0, "wrong_value": 0, "human_acc": 0, "hit": 0, "total": 0, "wrong_hit": 0,
"wrong_total": 0, 'detail': []}
for result in results:
answer = result['item']['choice']
most_wrong = result['item']['most_wrong']
human_acc = result['item']['human_acc'] * 0.01
predict = ""
for e in re.split(r'[、,.,和\s]\s*', result['predict']):
if not e.isascii():
break
predict += e
result['predict'] = predict
predict_result['detail'].append(
{"answer": answer, "most_wrong": most_wrong, "predict": predict, "human_acc": human_acc})
predict_result['hit'] += 1 if answer == predict else 0
predict_result['wrong_hit'] += 1 if most_wrong == predict else 0
predict_result['wrong_value'] += 1 - human_acc if most_wrong == predict else 0
predict_result['human_acc'] += human_acc
predict_result['total'] += 1
predict_result['acc'] = predict_result['hit'] / predict_result['total']
predict_result['wrong_total'] = predict_result['total'] - predict_result['hit']
predict_result['wrong_value'] = predict_result['wrong_value'] / predict_result['wrong_total']
predict_result['human_acc'] = predict_result['human_acc'] / len(results)
json.dump(predict_result, open(os.path.join(self.save_dir, "acc_result.json"), "w", encoding="utf-8"),
ensure_ascii=False)
def category_summary(self, results):
category_result = {"总计": {"hit": 0, "all": 0, "difficulty": {}, "human_acc": 0}}
for result in results:
hit = 1 if result['item']['choice'] == result['predict'] else 0
categories_list = result['item']['categories']
difficulty = result['item']['difficulty']
human_acc = result['item']['human_acc']
for categories in categories_list:
if difficulty not in category_result["总计"]["difficulty"]:
category_result["总计"]["difficulty"][difficulty] = {"hit": 0, "all": 0}
category_result["总计"]["difficulty"][difficulty]['hit'] += hit
category_result["总计"]["difficulty"][difficulty]['all'] += 1
category_result["总计"]['hit'] += hit
category_result["总计"]['all'] += 1
category_result["总计"]['human_acc'] += human_acc
category_subset = []
for category in categories:
category_subset.append(category)
category_name = self.level_delimiter.join(category_subset)
if not category_name:
category_name = "未分类"
if category_name not in category_result:
category_result[category_name] = {"hit": 0, "all": 0, "difficulty": {}, "human_acc": 0}
if difficulty not in category_result[category_name]["difficulty"]:
category_result[category_name]["difficulty"][difficulty] = {"hit": 0, "all": 0}
category_result[category_name]["difficulty"][difficulty]['hit'] += hit
category_result[category_name]["difficulty"][difficulty]['all'] += 1
category_result[category_name]['hit'] += hit
category_result[category_name]['all'] += 1
category_result[category_name]['human_acc'] += human_acc
for k, v in category_result.items():
v['acc'] = v['hit'] / v['all']
v['human_acc'] = v['human_acc'] / v['all']
for d, sub_v in v['difficulty'].items():
sub_v['acc'] = sub_v['hit'] / sub_v['all']
json.dump(category_result, open(os.path.join(self.save_dir, "category_result.json"), "w", encoding="utf-8"),
ensure_ascii=False)
def difficulty_summary(self, results):
difficulty_result = {"总计": {"hit": 0, "all": 0}}
for result in results:
hit = 1 if result['item']['choice'] == result['predict'] else 0
difficulty = result['item']['difficulty']
if difficulty not in difficulty_result:
difficulty_result[difficulty] = {"hit": 0, "all": 0}
difficulty_result[difficulty]['hit'] += hit
difficulty_result[difficulty]['all'] += 1
difficulty_result["总计"]['hit'] += hit
difficulty_result["总计"]['all'] += 1
for k in difficulty_result:
difficulty_result[k]['acc'] = difficulty_result[k]['hit'] / difficulty_result[k]['all']
json.dump(difficulty_result, open(os.path.join(self.save_dir, "difficulty_result.json"), "w", encoding="utf-8"),
ensure_ascii=False)
def __call__(self):
results = []
for item in tqdm(self.dataset):
prompt = self.build_prompt(item)
predict = self.get_predict(prompt)
results.append({"item": item, "predict": predict})
gc.collect()
json.dump(results, open(os.path.join(self.save_dir, "raw.json"), "w", encoding="utf-8"), ensure_ascii=False)
self.extract_answer(results)
self.category_summary(results)
self.difficulty_summary(results)
if __name__ == "__main__":
args = parse_args()
Validator(args)()