Spaces:
Runtime error
Runtime error
import re | |
import os | |
import datasets | |
from sklearn.metrics import accuracy_score, mean_squared_error | |
from collections import defaultdict | |
from rouge_score import rouge_scorer | |
lora_module_dict = { | |
'chatglm2': ['query_key_value'], | |
'llama2': [ | |
'q_proj', 'k_proj', 'v_proj', | |
'o_proj', 'gate_proj', 'up_proj', 'down_proj', | |
# 'embed_tokens', 'lm_head', | |
], | |
} | |
def tokenize(args, tokenizer, feature): | |
prompt_ids = tokenizer.encode( | |
feature['prompt'].strip(), padding=False, | |
max_length=args.max_length, truncation=True | |
) | |
target_ids = tokenizer.encode( | |
feature['answer'].strip(), padding=False, | |
max_length=args.max_length, truncation=True, add_special_tokens=False | |
) | |
input_ids = prompt_ids + target_ids | |
exceed_max_length = len(input_ids) >= args.max_length | |
# Add EOS Token | |
if input_ids[-1] != tokenizer.eos_token_id and not exceed_max_length: | |
input_ids.append(tokenizer.eos_token_id) | |
label_ids = [tokenizer.pad_token_id] * len(prompt_ids) + input_ids[len(prompt_ids):] | |
return { | |
"input_ids": input_ids, | |
"labels": label_ids, | |
"exceed_max_length": exceed_max_length | |
} | |
def parse_model_name(name, from_remote=False): | |
if name == 'chatglm2': | |
return 'THUDM/chatglm2-6b' if from_remote else 'base_models/chatglm2-6b' | |
elif name == 'llama2': | |
return 'meta-llama/Llama-2-7b-chat-hf' if from_remote else 'base_models/Llama-2-7b-chat-hf' | |
else: | |
raise ValueError(f"Undefined base model {name}") | |
def load_dataset(names, from_remote=False): | |
dataset_names = [d for d in names.split(',')] | |
dataset_list = [] | |
for name in dataset_names: | |
rep = 1 | |
if not os.path.exists(name): | |
rep = int(name.split('*')[1]) if '*' in name else 1 | |
name = ('FinGPT/fingpt-forecaster-' if from_remote else 'data/fingpt-forecaster-') + name.split('*')[0] | |
tmp_dataset = datasets.load_dataset(name) if from_remote else datasets.load_from_disk(name) | |
if 'test' not in tmp_dataset: | |
tmp_dataset = tmp_dataset.train_test_split(0.2, shuffle=True, seed=42) | |
dataset_list.extend([tmp_dataset] * rep) | |
return dataset_list | |
def parse_answer(answer): | |
match_res = re.match(r"^\s*\[Positive Developments\]:\s*(.*)\s*\[Potential Concerns\]:\s*(.*)\s*\[Prediction & Analysis\]:\s*(.*)\s*$", answer, flags=re.DOTALL) | |
if not match_res: | |
return None | |
pros, cons, pna = match_res.group(1), match_res.group(2), match_res.group(3) | |
match_res = re.match(r'^Prediction:\s*(.*)\s*Analysis:\s*(.*)\s*$', pna, flags=re.DOTALL) | |
if not match_res: | |
return None | |
pred, anal = match_res.group(1), match_res.group(2) | |
if re.search(r'up|increase', pred.lower()): | |
pred_bin = 1 | |
elif re.search(r'down|decrease|decline', pred.lower()): | |
pred_bin = -1 | |
else: | |
pred_bin = 0 | |
match_res = re.search(r'(\d)-(\d)%', pred) | |
if not match_res: | |
match_res = re.search(r'(?:more than )?(\d)+?%', pred) | |
pred_margin = pred_bin * (int(match_res.group(1)) + 0.5) if match_res else 0. | |
return { | |
"positive developments": pros, | |
"potential concerns": cons, | |
"prediction": pred_margin, | |
"prediction_binary": pred_bin, | |
"analysis": anal | |
} | |
def calc_rouge_score(references, answers): | |
scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True) | |
scores_per_pair = [scorer.score(ref, ans) for ref, ans in zip(references, answers)] | |
rouge1 = sum(score['rouge1'].fmeasure for score in scores_per_pair) / len(scores_per_pair) | |
rouge2 = sum(score['rouge2'].fmeasure for score in scores_per_pair) / len(scores_per_pair) | |
rougeL = sum(score['rougeL'].fmeasure for score in scores_per_pair) / len(scores_per_pair) | |
return {'rouge1': rouge1, 'rouge2': rouge2, 'rougeL': rougeL} | |
def calc_metrics(answers, gts): | |
answers_dict = defaultdict(list) | |
gts_dict = defaultdict(list) | |
for answer, gt in zip(answers, gts): | |
answer_dict = parse_answer(answer) | |
gt_dict = parse_answer(gt) | |
if answer_dict and gt_dict: | |
for k in answer_dict.keys(): | |
answers_dict[k].append(answer_dict[k]) | |
gts_dict[k].append(gt_dict[k]) | |
if not answers_dict['prediction']: | |
return {} | |
bin_acc = accuracy_score(gts_dict['prediction_binary'], answers_dict['prediction_binary']) | |
mse = mean_squared_error(gts_dict['prediction'], answers_dict['prediction']) | |
pros_rouge_scores = calc_rouge_score(gts_dict['positive developments'], answers_dict['positive developments']) | |
cons_rouge_scores = calc_rouge_score(gts_dict['potential concerns'], answers_dict['potential concerns']) | |
anal_rouge_scores = calc_rouge_score(gts_dict['analysis'], answers_dict['analysis']) | |
print(f"\nBinary Accuracy: {bin_acc:.2f} | Mean Square Error: {mse:.2f}") | |
print(f"\nRouge Score of Positive Developments: {pros_rouge_scores}") | |
print(f"\nRouge Score of Potential Concerns: {cons_rouge_scores}") | |
print(f"\nRouge Score of Summary Analysis: {anal_rouge_scores}") | |
return { | |
"valid_count": len(answers_dict['prediction']), | |
"bin_acc": bin_acc, | |
"mse": mse, | |
"pros_rouge_scores": pros_rouge_scores, | |
"cons_rouge_scores": cons_rouge_scores, | |
"anal_rouge_scores": anal_rouge_scores | |
} | |