import re import os import datasets from sklearn.metrics import accuracy_score, mean_squared_error from collections import defaultdict from rouge_score import rouge_scorer lora_module_dict = { 'chatglm2': ['query_key_value'], 'llama2': [ 'q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj', # 'embed_tokens', 'lm_head', ], } def tokenize(args, tokenizer, feature): prompt_ids = tokenizer.encode( feature['prompt'].strip(), padding=False, max_length=args.max_length, truncation=True ) target_ids = tokenizer.encode( feature['answer'].strip(), padding=False, max_length=args.max_length, truncation=True, add_special_tokens=False ) input_ids = prompt_ids + target_ids exceed_max_length = len(input_ids) >= args.max_length # Add EOS Token if input_ids[-1] != tokenizer.eos_token_id and not exceed_max_length: input_ids.append(tokenizer.eos_token_id) label_ids = [tokenizer.pad_token_id] * len(prompt_ids) + input_ids[len(prompt_ids):] return { "input_ids": input_ids, "labels": label_ids, "exceed_max_length": exceed_max_length } def parse_model_name(name, from_remote=False): if name == 'chatglm2': return 'THUDM/chatglm2-6b' if from_remote else 'base_models/chatglm2-6b' elif name == 'llama2': return 'meta-llama/Llama-2-7b-chat-hf' if from_remote else 'base_models/Llama-2-7b-chat-hf' else: raise ValueError(f"Undefined base model {name}") def load_dataset(names, from_remote=False): dataset_names = [d for d in names.split(',')] dataset_list = [] for name in dataset_names: rep = 1 if not os.path.exists(name): rep = int(name.split('*')[1]) if '*' in name else 1 name = ('FinGPT/fingpt-forecaster-' if from_remote else 'data/fingpt-forecaster-') + name.split('*')[0] tmp_dataset = datasets.load_dataset(name) if from_remote else datasets.load_from_disk(name) if 'test' not in tmp_dataset: tmp_dataset = tmp_dataset.train_test_split(0.2, shuffle=True, seed=42) dataset_list.extend([tmp_dataset] * rep) return dataset_list def parse_answer(answer): match_res = re.match(r"^\s*\[Positive Developments\]:\s*(.*)\s*\[Potential Concerns\]:\s*(.*)\s*\[Prediction & Analysis\]:\s*(.*)\s*$", answer, flags=re.DOTALL) if not match_res: return None pros, cons, pna = match_res.group(1), match_res.group(2), match_res.group(3) match_res = re.match(r'^Prediction:\s*(.*)\s*Analysis:\s*(.*)\s*$', pna, flags=re.DOTALL) if not match_res: return None pred, anal = match_res.group(1), match_res.group(2) if re.search(r'up|increase', pred.lower()): pred_bin = 1 elif re.search(r'down|decrease|decline', pred.lower()): pred_bin = -1 else: pred_bin = 0 match_res = re.search(r'(\d)-(\d)%', pred) if not match_res: match_res = re.search(r'(?:more than )?(\d)+?%', pred) pred_margin = pred_bin * (int(match_res.group(1)) + 0.5) if match_res else 0. return { "positive developments": pros, "potential concerns": cons, "prediction": pred_margin, "prediction_binary": pred_bin, "analysis": anal } def calc_rouge_score(references, answers): scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True) scores_per_pair = [scorer.score(ref, ans) for ref, ans in zip(references, answers)] rouge1 = sum(score['rouge1'].fmeasure for score in scores_per_pair) / len(scores_per_pair) rouge2 = sum(score['rouge2'].fmeasure for score in scores_per_pair) / len(scores_per_pair) rougeL = sum(score['rougeL'].fmeasure for score in scores_per_pair) / len(scores_per_pair) return {'rouge1': rouge1, 'rouge2': rouge2, 'rougeL': rougeL} def calc_metrics(answers, gts): answers_dict = defaultdict(list) gts_dict = defaultdict(list) for answer, gt in zip(answers, gts): answer_dict = parse_answer(answer) gt_dict = parse_answer(gt) if answer_dict and gt_dict: for k in answer_dict.keys(): answers_dict[k].append(answer_dict[k]) gts_dict[k].append(gt_dict[k]) if not answers_dict['prediction']: return {} bin_acc = accuracy_score(gts_dict['prediction_binary'], answers_dict['prediction_binary']) mse = mean_squared_error(gts_dict['prediction'], answers_dict['prediction']) pros_rouge_scores = calc_rouge_score(gts_dict['positive developments'], answers_dict['positive developments']) cons_rouge_scores = calc_rouge_score(gts_dict['potential concerns'], answers_dict['potential concerns']) anal_rouge_scores = calc_rouge_score(gts_dict['analysis'], answers_dict['analysis']) print(f"\nBinary Accuracy: {bin_acc:.2f} | Mean Square Error: {mse:.2f}") print(f"\nRouge Score of Positive Developments: {pros_rouge_scores}") print(f"\nRouge Score of Potential Concerns: {cons_rouge_scores}") print(f"\nRouge Score of Summary Analysis: {anal_rouge_scores}") return { "valid_count": len(answers_dict['prediction']), "bin_acc": bin_acc, "mse": mse, "pros_rouge_scores": pros_rouge_scores, "cons_rouge_scores": cons_rouge_scores, "anal_rouge_scores": anal_rouge_scores }