Spaces:
Running
Running
import os | |
import json | |
import re | |
import argparse | |
import torch | |
import gc | |
from tqdm import tqdm | |
from transformers import AutoModelForCausalLM, AutoModel, AutoTokenizer | |
os.environ['CUDA_VISIBLE_DEVICES'] = '1' | |
def parse_args(): | |
parser = argparse.ArgumentParser(description='Validation') | |
parser.add_argument('--model_path', dest="model_path") | |
parser.add_argument('--dataset_path', dest="dataset_path") | |
parser.add_argument('--save_path', dest='save_path') | |
parser.add_argument('--max_length', dest="max_length", default=2000) | |
parser.add_argument('--max_new_tokens', dest="max_new_tokens", default=48) | |
parser.add_argument('--input_template', dest="input_template", | |
default="材料:{material}\n问题:{question}\n{options}\n答案:{response}") | |
parser.add_argument('--query_template', dest="query_template", default="问题:{question}\n{options}\n答案:{response}") | |
parser.add_argument('--system_prompt', dest="system_prompt", | |
default="你是一名考生,请回答问题的正确选项,比如C。如果有多个正确选项,请按顺序回答所有正确选项,比如ABD。") | |
parser.add_argument('--level_delimiter', dest="level_delimiter", default="|") | |
args = parser.parse_args() | |
return args | |
class Validator: | |
def __init__(self, args): | |
self.load_model(args.model_path) | |
self.load_dataset(args.dataset_path) | |
self.save_dir = os.path.join(args.save_path, os.path.split(model_path)[-1]) | |
if not os.path.exists(self.save_dir): | |
os.makedirs(self.save_dir) | |
self.max_length = args.max_length | |
self.max_new_tokens = args.max_new_tokens | |
self.input_template = args.input_template | |
self.query_template = args.query_template | |
self.system_prompt = args.system_prompt | |
self.level_delimiter = args.level_delimiter | |
def load_model(self, model_path): | |
self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) | |
self.model = AutoModelForCausalLM.from_pretrained(model_path, device_map={"": 0}, trust_remote_code=True) | |
self.model.eval() | |
def load_dataset(self, dataset_path): | |
self.dataset = json.load(open(dataset_path, encoding="utf-8")) | |
def format_prompt(self, material, question, options, response): | |
if material: | |
return self.input_template.format(material=material, question=question, options=options, | |
response=response).strip() | |
return self.query_template.format(question=question, options=options, response=response).strip() | |
def build_prompt(self, item): | |
query_prompt = self.format_prompt(item['material'], item['question'], item['options'], "") | |
history_prompts = [] | |
for sub in item['history']: | |
history_prompts.append(self.format_prompt(sub['material'], sub['question'], sub['options'], sub['choice'])) | |
final_prompt = self.system_prompt + "\n" + "\n".join(history_prompts) + "\n" + query_prompt | |
if len(self.tokenizer.tokenize(final_prompt)) > self.max_length: | |
history_prompts.pop() | |
break | |
return self.system_prompt + "\n" + "\n".join(history_prompts) + "\n" + query_prompt | |
def get_predict(self, prompt): | |
gen_kwargs = {"do_sample": False, "max_new_tokens": self.max_new_tokens} | |
inputs = self.tokenizer([prompt], return_tensors="pt") | |
inputs = inputs.to(self.model.device) | |
with torch.no_grad(): | |
outputs = self.model.generate(**inputs, return_dict_in_generate=True, **gen_kwargs) | |
predict = self.tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)[len(prompt):] | |
return predict | |
def extract_answer(self, results): | |
predict_result = {"acc": 0, "wrong_value": 0, "human_acc": 0, "hit": 0, "total": 0, "wrong_hit": 0, | |
"wrong_total": 0, 'detail': []} | |
for result in results: | |
answer = result['item']['choice'] | |
most_wrong = result['item']['most_wrong'] | |
human_acc = result['item']['human_acc'] * 0.01 | |
predict = "" | |
for e in re.split(r'[、,.,和\s]\s*', result['predict']): | |
if not e.isascii(): | |
break | |
predict += e | |
result['predict'] = predict | |
predict_result['detail'].append( | |
{"answer": answer, "most_wrong": most_wrong, "predict": predict, "human_acc": human_acc}) | |
predict_result['hit'] += 1 if answer == predict else 0 | |
predict_result['wrong_hit'] += 1 if most_wrong == predict else 0 | |
predict_result['wrong_value'] += 1 - human_acc if most_wrong == predict else 0 | |
predict_result['human_acc'] += human_acc | |
predict_result['total'] += 1 | |
predict_result['acc'] = predict_result['hit'] / predict_result['total'] | |
predict_result['wrong_total'] = predict_result['total'] - predict_result['hit'] | |
predict_result['wrong_value'] = predict_result['wrong_value'] / predict_result['wrong_total'] | |
predict_result['human_acc'] = predict_result['human_acc'] / len(results) | |
json.dump(predict_result, open(os.path.join(self.save_dir, "acc_result.json"), "w", encoding="utf-8"), | |
ensure_ascii=False) | |
def category_summary(self, results): | |
category_result = {"总计": {"hit": 0, "all": 0, "difficulty": {}, "human_acc": 0}} | |
for result in results: | |
hit = 1 if result['item']['choice'] == result['predict'] else 0 | |
categories_list = result['item']['categories'] | |
difficulty = result['item']['difficulty'] | |
human_acc = result['item']['human_acc'] | |
for categories in categories_list: | |
if difficulty not in category_result["总计"]["difficulty"]: | |
category_result["总计"]["difficulty"][difficulty] = {"hit": 0, "all": 0} | |
category_result["总计"]["difficulty"][difficulty]['hit'] += hit | |
category_result["总计"]["difficulty"][difficulty]['all'] += 1 | |
category_result["总计"]['hit'] += hit | |
category_result["总计"]['all'] += 1 | |
category_result["总计"]['human_acc'] += human_acc | |
category_subset = [] | |
for category in categories: | |
category_subset.append(category) | |
category_name = self.level_delimiter.join(category_subset) | |
if not category_name: | |
category_name = "未分类" | |
if category_name not in category_result: | |
category_result[category_name] = {"hit": 0, "all": 0, "difficulty": {}, "human_acc": 0} | |
if difficulty not in category_result[category_name]["difficulty"]: | |
category_result[category_name]["difficulty"][difficulty] = {"hit": 0, "all": 0} | |
category_result[category_name]["difficulty"][difficulty]['hit'] += hit | |
category_result[category_name]["difficulty"][difficulty]['all'] += 1 | |
category_result[category_name]['hit'] += hit | |
category_result[category_name]['all'] += 1 | |
category_result[category_name]['human_acc'] += human_acc | |
for k, v in category_result.items(): | |
v['acc'] = v['hit'] / v['all'] | |
v['human_acc'] = v['human_acc'] / v['all'] | |
for d, sub_v in v['difficulty'].items(): | |
sub_v['acc'] = sub_v['hit'] / sub_v['all'] | |
json.dump(category_result, open(os.path.join(self.save_dir, "category_result.json"), "w", encoding="utf-8"), | |
ensure_ascii=False) | |
def difficulty_summary(self, results): | |
difficulty_result = {"总计": {"hit": 0, "all": 0}} | |
for result in results: | |
hit = 1 if result['item']['choice'] == result['predict'] else 0 | |
difficulty = result['item']['difficulty'] | |
if difficulty not in difficulty_result: | |
difficulty_result[difficulty] = {"hit": 0, "all": 0} | |
difficulty_result[difficulty]['hit'] += hit | |
difficulty_result[difficulty]['all'] += 1 | |
difficulty_result["总计"]['hit'] += hit | |
difficulty_result["总计"]['all'] += 1 | |
for k in difficulty_result: | |
difficulty_result[k]['acc'] = difficulty_result[k]['hit'] / difficulty_result[k]['all'] | |
json.dump(difficulty_result, open(os.path.join(self.save_dir, "difficulty_result.json"), "w", encoding="utf-8"), | |
ensure_ascii=False) | |
def __call__(self): | |
results = [] | |
for item in tqdm(self.dataset): | |
prompt = self.build_prompt(item) | |
predict = self.get_predict(prompt) | |
results.append({"item": item, "predict": predict}) | |
gc.collect() | |
json.dump(results, open(os.path.join(self.save_dir, "raw.json"), "w", encoding="utf-8"), ensure_ascii=False) | |
self.extract_answer(results) | |
self.category_summary(results) | |
self.difficulty_summary(results) | |
if __name__ == "__main__": | |
args = parse_args() | |
Validator(args)() | |