import os import json import re import argparse import torch import gc from tqdm import tqdm from transformers import AutoModelForCausalLM, AutoModel, AutoTokenizer os.environ['CUDA_VISIBLE_DEVICES'] = '1' def parse_args(): parser = argparse.ArgumentParser(description='Validation') parser.add_argument('--model_path', dest="model_path") parser.add_argument('--dataset_path', dest="dataset_path") parser.add_argument('--save_path', dest='save_path') parser.add_argument('--max_length', dest="max_length", default=2000) parser.add_argument('--max_new_tokens', dest="max_new_tokens", default=48) parser.add_argument('--input_template', dest="input_template", default="材料:{material}\n问题:{question}\n{options}\n答案:{response}") parser.add_argument('--query_template', dest="query_template", default="问题:{question}\n{options}\n答案:{response}") parser.add_argument('--system_prompt', dest="system_prompt", default="你是一名考生,请回答问题的正确选项,比如C。如果有多个正确选项,请按顺序回答所有正确选项,比如ABD。") parser.add_argument('--level_delimiter', dest="level_delimiter", default="|") args = parser.parse_args() return args class Validator: def __init__(self, args): self.load_model(args.model_path) self.load_dataset(args.dataset_path) self.save_dir = os.path.join(args.save_path, os.path.split(model_path)[-1]) if not os.path.exists(self.save_dir): os.makedirs(self.save_dir) self.max_length = args.max_length self.max_new_tokens = args.max_new_tokens self.input_template = args.input_template self.query_template = args.query_template self.system_prompt = args.system_prompt self.level_delimiter = args.level_delimiter def load_model(self, model_path): self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) self.model = AutoModelForCausalLM.from_pretrained(model_path, device_map={"": 0}, trust_remote_code=True) self.model.eval() def load_dataset(self, dataset_path): self.dataset = json.load(open(dataset_path, encoding="utf-8")) def format_prompt(self, material, question, options, response): if material: return self.input_template.format(material=material, question=question, options=options, response=response).strip() return self.query_template.format(question=question, options=options, response=response).strip() def build_prompt(self, item): query_prompt = self.format_prompt(item['material'], item['question'], item['options'], "") history_prompts = [] for sub in item['history']: history_prompts.append(self.format_prompt(sub['material'], sub['question'], sub['options'], sub['choice'])) final_prompt = self.system_prompt + "\n" + "\n".join(history_prompts) + "\n" + query_prompt if len(self.tokenizer.tokenize(final_prompt)) > self.max_length: history_prompts.pop() break return self.system_prompt + "\n" + "\n".join(history_prompts) + "\n" + query_prompt def get_predict(self, prompt): gen_kwargs = {"do_sample": False, "max_new_tokens": self.max_new_tokens} inputs = self.tokenizer([prompt], return_tensors="pt") inputs = inputs.to(self.model.device) with torch.no_grad(): outputs = self.model.generate(**inputs, return_dict_in_generate=True, **gen_kwargs) predict = self.tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)[len(prompt):] return predict def extract_answer(self, results): predict_result = {"acc": 0, "wrong_value": 0, "human_acc": 0, "hit": 0, "total": 0, "wrong_hit": 0, "wrong_total": 0, 'detail': []} for result in results: answer = result['item']['choice'] most_wrong = result['item']['most_wrong'] human_acc = result['item']['human_acc'] * 0.01 predict = "" for e in re.split(r'[、,.,和\s]\s*', result['predict']): if not e.isascii(): break predict += e result['predict'] = predict predict_result['detail'].append( {"answer": answer, "most_wrong": most_wrong, "predict": predict, "human_acc": human_acc}) predict_result['hit'] += 1 if answer == predict else 0 predict_result['wrong_hit'] += 1 if most_wrong == predict else 0 predict_result['wrong_value'] += 1 - human_acc if most_wrong == predict else 0 predict_result['human_acc'] += human_acc predict_result['total'] += 1 predict_result['acc'] = predict_result['hit'] / predict_result['total'] predict_result['wrong_total'] = predict_result['total'] - predict_result['hit'] predict_result['wrong_value'] = predict_result['wrong_value'] / predict_result['wrong_total'] predict_result['human_acc'] = predict_result['human_acc'] / len(results) json.dump(predict_result, open(os.path.join(self.save_dir, "acc_result.json"), "w", encoding="utf-8"), ensure_ascii=False) def category_summary(self, results): category_result = {"总计": {"hit": 0, "all": 0, "difficulty": {}, "human_acc": 0}} for result in results: hit = 1 if result['item']['choice'] == result['predict'] else 0 categories_list = result['item']['categories'] difficulty = result['item']['difficulty'] human_acc = result['item']['human_acc'] for categories in categories_list: if difficulty not in category_result["总计"]["difficulty"]: category_result["总计"]["difficulty"][difficulty] = {"hit": 0, "all": 0} category_result["总计"]["difficulty"][difficulty]['hit'] += hit category_result["总计"]["difficulty"][difficulty]['all'] += 1 category_result["总计"]['hit'] += hit category_result["总计"]['all'] += 1 category_result["总计"]['human_acc'] += human_acc category_subset = [] for category in categories: category_subset.append(category) category_name = self.level_delimiter.join(category_subset) if not category_name: category_name = "未分类" if category_name not in category_result: category_result[category_name] = {"hit": 0, "all": 0, "difficulty": {}, "human_acc": 0} if difficulty not in category_result[category_name]["difficulty"]: category_result[category_name]["difficulty"][difficulty] = {"hit": 0, "all": 0} category_result[category_name]["difficulty"][difficulty]['hit'] += hit category_result[category_name]["difficulty"][difficulty]['all'] += 1 category_result[category_name]['hit'] += hit category_result[category_name]['all'] += 1 category_result[category_name]['human_acc'] += human_acc for k, v in category_result.items(): v['acc'] = v['hit'] / v['all'] v['human_acc'] = v['human_acc'] / v['all'] for d, sub_v in v['difficulty'].items(): sub_v['acc'] = sub_v['hit'] / sub_v['all'] json.dump(category_result, open(os.path.join(self.save_dir, "category_result.json"), "w", encoding="utf-8"), ensure_ascii=False) def difficulty_summary(self, results): difficulty_result = {"总计": {"hit": 0, "all": 0}} for result in results: hit = 1 if result['item']['choice'] == result['predict'] else 0 difficulty = result['item']['difficulty'] if difficulty not in difficulty_result: difficulty_result[difficulty] = {"hit": 0, "all": 0} difficulty_result[difficulty]['hit'] += hit difficulty_result[difficulty]['all'] += 1 difficulty_result["总计"]['hit'] += hit difficulty_result["总计"]['all'] += 1 for k in difficulty_result: difficulty_result[k]['acc'] = difficulty_result[k]['hit'] / difficulty_result[k]['all'] json.dump(difficulty_result, open(os.path.join(self.save_dir, "difficulty_result.json"), "w", encoding="utf-8"), ensure_ascii=False) def __call__(self): results = [] for item in tqdm(self.dataset): prompt = self.build_prompt(item) predict = self.get_predict(prompt) results.append({"item": item, "predict": predict}) gc.collect() json.dump(results, open(os.path.join(self.save_dir, "raw.json"), "w", encoding="utf-8"), ensure_ascii=False) self.extract_answer(results) self.category_summary(results) self.difficulty_summary(results) if __name__ == "__main__": args = parse_args() Validator(args)()