FinGPT_Forecaster / utils.py
humanist96's picture
Upload 12 files
e248cd9
import re
import os
import datasets
from sklearn.metrics import accuracy_score, mean_squared_error
from collections import defaultdict
from rouge_score import rouge_scorer
lora_module_dict = {
'chatglm2': ['query_key_value'],
'llama2': [
'q_proj', 'k_proj', 'v_proj',
'o_proj', 'gate_proj', 'up_proj', 'down_proj',
# 'embed_tokens', 'lm_head',
],
}
def tokenize(args, tokenizer, feature):
prompt_ids = tokenizer.encode(
feature['prompt'].strip(), padding=False,
max_length=args.max_length, truncation=True
)
target_ids = tokenizer.encode(
feature['answer'].strip(), padding=False,
max_length=args.max_length, truncation=True, add_special_tokens=False
)
input_ids = prompt_ids + target_ids
exceed_max_length = len(input_ids) >= args.max_length
# Add EOS Token
if input_ids[-1] != tokenizer.eos_token_id and not exceed_max_length:
input_ids.append(tokenizer.eos_token_id)
label_ids = [tokenizer.pad_token_id] * len(prompt_ids) + input_ids[len(prompt_ids):]
return {
"input_ids": input_ids,
"labels": label_ids,
"exceed_max_length": exceed_max_length
}
def parse_model_name(name, from_remote=False):
if name == 'chatglm2':
return 'THUDM/chatglm2-6b' if from_remote else 'base_models/chatglm2-6b'
elif name == 'llama2':
return 'meta-llama/Llama-2-7b-chat-hf' if from_remote else 'base_models/Llama-2-7b-chat-hf'
else:
raise ValueError(f"Undefined base model {name}")
def load_dataset(names, from_remote=False):
dataset_names = [d for d in names.split(',')]
dataset_list = []
for name in dataset_names:
rep = 1
if not os.path.exists(name):
rep = int(name.split('*')[1]) if '*' in name else 1
name = ('FinGPT/fingpt-forecaster-' if from_remote else 'data/fingpt-forecaster-') + name.split('*')[0]
tmp_dataset = datasets.load_dataset(name) if from_remote else datasets.load_from_disk(name)
if 'test' not in tmp_dataset:
tmp_dataset = tmp_dataset.train_test_split(0.2, shuffle=True, seed=42)
dataset_list.extend([tmp_dataset] * rep)
return dataset_list
def parse_answer(answer):
match_res = re.match(r"^\s*\[Positive Developments\]:\s*(.*)\s*\[Potential Concerns\]:\s*(.*)\s*\[Prediction & Analysis\]:\s*(.*)\s*$", answer, flags=re.DOTALL)
if not match_res:
return None
pros, cons, pna = match_res.group(1), match_res.group(2), match_res.group(3)
match_res = re.match(r'^Prediction:\s*(.*)\s*Analysis:\s*(.*)\s*$', pna, flags=re.DOTALL)
if not match_res:
return None
pred, anal = match_res.group(1), match_res.group(2)
if re.search(r'up|increase', pred.lower()):
pred_bin = 1
elif re.search(r'down|decrease|decline', pred.lower()):
pred_bin = -1
else:
pred_bin = 0
match_res = re.search(r'(\d)-(\d)%', pred)
if not match_res:
match_res = re.search(r'(?:more than )?(\d)+?%', pred)
pred_margin = pred_bin * (int(match_res.group(1)) + 0.5) if match_res else 0.
return {
"positive developments": pros,
"potential concerns": cons,
"prediction": pred_margin,
"prediction_binary": pred_bin,
"analysis": anal
}
def calc_rouge_score(references, answers):
scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
scores_per_pair = [scorer.score(ref, ans) for ref, ans in zip(references, answers)]
rouge1 = sum(score['rouge1'].fmeasure for score in scores_per_pair) / len(scores_per_pair)
rouge2 = sum(score['rouge2'].fmeasure for score in scores_per_pair) / len(scores_per_pair)
rougeL = sum(score['rougeL'].fmeasure for score in scores_per_pair) / len(scores_per_pair)
return {'rouge1': rouge1, 'rouge2': rouge2, 'rougeL': rougeL}
def calc_metrics(answers, gts):
answers_dict = defaultdict(list)
gts_dict = defaultdict(list)
for answer, gt in zip(answers, gts):
answer_dict = parse_answer(answer)
gt_dict = parse_answer(gt)
if answer_dict and gt_dict:
for k in answer_dict.keys():
answers_dict[k].append(answer_dict[k])
gts_dict[k].append(gt_dict[k])
if not answers_dict['prediction']:
return {}
bin_acc = accuracy_score(gts_dict['prediction_binary'], answers_dict['prediction_binary'])
mse = mean_squared_error(gts_dict['prediction'], answers_dict['prediction'])
pros_rouge_scores = calc_rouge_score(gts_dict['positive developments'], answers_dict['positive developments'])
cons_rouge_scores = calc_rouge_score(gts_dict['potential concerns'], answers_dict['potential concerns'])
anal_rouge_scores = calc_rouge_score(gts_dict['analysis'], answers_dict['analysis'])
print(f"\nBinary Accuracy: {bin_acc:.2f} | Mean Square Error: {mse:.2f}")
print(f"\nRouge Score of Positive Developments: {pros_rouge_scores}")
print(f"\nRouge Score of Potential Concerns: {cons_rouge_scores}")
print(f"\nRouge Score of Summary Analysis: {anal_rouge_scores}")
return {
"valid_count": len(answers_dict['prediction']),
"bin_acc": bin_acc,
"mse": mse,
"pros_rouge_scores": pros_rouge_scores,
"cons_rouge_scores": cons_rouge_scores,
"anal_rouge_scores": anal_rouge_scores
}