ura-llama-evaluation / data_loader.py
naot97's picture
update evaluation result 11 November
443afb0
raw
history blame
7.21 kB
import pandas as pd
import numpy as np
RESULT_FILE = 'evaluation_results.xlsx'
metric_ud = {
"Accuracy": 1,
"Average Exact Match": 1,
"Exact Match": 1,
"F1 Score": 1,
"AUC ROC": 1,
"AUC PR": 1,
"Precision": 1,
"Recall": 1,
"Equivalent": 1,
"Bias": -1,
"Demographic representation (race)": -1,
"Demographic representation (gender)": -1,
"Stereotypical associations (race, profession)": -1,
"Stereotypical associations (gender, profession)": -1,
"Toxicity": -1,
"ROUGE-1": 1,
"ROUGE-2": 1,
"ROUGE-L": 1,
"BLEU": 1,
"SummaC": 1,
"BERTScore": 1,
"Coverage": 1,
"Density": 1,
"Compression": 1,
"hLEPOR": 1,
"Character Error Rate": -1,
"Word Error Rate": -1,
"Character Edit Distance": -1,
"Word Edit Distance": -1,
"Perplexity": -1,
"Expected Calibration Error": -1,
"acc@10": 1,
"MRR@10 (Top 30)": 1,
"NDCG@10 (Top 30)": 1,
"MRR@10": 1,
"NDCG@10": 1,
}
tasks = {
"Information Retrieval": "informationretrieval",
"Knowledge": "knowledge",
"Language Modelling": "language-modelling",
"Question Answering": "question-answering",
"Reasoning": "reasoning",
"Summarization": "summarization",
"Text Classification": "text-classification",
"Toxicity Detection": "toxicity-detection",
"Translation": "translation",
"Sentiment Analysis": "sentiment-analysis",
}
settings = {
"Normal": "",
"Few-shot Leanring": "fs",
"Prompt Strategy 0": "pt0",
"Prompt Strategy 1": "pt1",
"Prompt Strategy 2": "pt2",
"Chain-of-Thought": "cot",
"Fairness": "fairness",
"Robustness": "robustness",
}
task_w_settings = {
"Information Retrieval": ["Normal", "Few-shot Leanring", "Robustness", "Fairness"],
"Knowledge": ["Normal", "Few-shot Leanring", "Robustness"],
"Language Modelling": ["Normal", "Few-shot Leanring", "Fairness"],
"Question Answering": ["Prompt Strategy 0", "Prompt Strategy 1", "Prompt Strategy 2", "Robustness", "Fairness"],
"Reasoning": ["Few-shot Leanring", "Chain-of-Thought"],
"Summarization": ["Prompt Strategy 0", "Prompt Strategy 1", "Prompt Strategy 2", "Robustness"],
"Text Classification": ["Normal", "Few-shot Leanring", "Robustness", "Fairness"],
"Toxicity Detection": ["Normal", "Few-shot Leanring", "Robustness", "Fairness"],
"Translation": ["Few-shot Leanring", "Robustness"],
"Sentiment Analysis": ["Normal", "Few-shot Leanring", "Robustness", "Fairness"],
}
datasets = {
"question-answering": {
"xquad_xtreme": "xQUAD EXTREME",
"mlqa": "MLQA",
},
"summarization": {
"vietnews": "VietNews",
"wikilingua": "WikiLingua",
},
"text-classification": {
"vsmec": "VSMEC",
"phoatis": "PhoATIS",
},
"toxicity-detection": {
"victsd": "UIT-ViCTSD",
"vihsd": "UIT-ViHSD",
},
"translation": {
"phomt-envi": "PhoMT English-Vietnamese",
"phomt-vien": "PhoMT Vietnamese-English",
"opus100-envi": "OPUS-100 English-Vietnamese",
"opus100-vien": "OPUS-100 Vietnamese-English",
},
"sentiment-analysis": {
"vlsp": "VLSP 2016",
"vsfc": "UIT-VSFC",
},
"informationretrieval": {
"mmarco": "mMARCO",
"mrobust": "mRobust",
},
"knowledge": {
"zaloe2e": "ZaloE2E",
"vimmrc": "ViMMRC",
},
"language-modelling": {
"mlqa-mlm": "MLQA",
"vsec": "VSEC",
},
"reasoning": {
"srnatural-azr": "Synthetic Reasoning (Natural) - Azure",
"srnatural-gcp": "Synthetic Reasoning (Natural) - Google Cloud",
"srabstract-azr": "Synthetic Reasoning (Abstract Symbol)- Azure",
"srabstract-gcp": "Synthetic Reasoning (Abstract Symbol)- Google Cloud",
"srinduction-azr": "Synthetic Reasoning (Induction) - Azure",
"srinduction-gcp": "Synthetic Reasoning (Induction) - Google Cloud",
"srpattern-azr": "Synthetic Introduction (Pattern Match) - Azure",
"srpattern-gcp": "Synthetic Introduction (Pattern Match) - Google Cloud",
"srsubstitution-azr": "Synthetic Introduction (Variable Substitution) - Azure",
"srsubstitution-gcp": "Synthetic Introduction (Variable Substitution) - Google Cloud",
"math-azr-Algebra": "MATH Level 1 (Algebra) - Azure",
"math-azr-Counting&Probability": "MATH Level 1 (Counting&Probability) - Azure",
"math-azr-Geometry": "MATH Level 1 (Geometry) - Azure",
"math-azr-IntermediateAlgebra": "MATH Level 1 (IntermediateAlgebra) - Azure",
"math-azr-NumberTheory": "MATH Level 1 (NumberTheory) - Azure",
"math-azr-Prealgebra": "MATH Level 1 (Prealgebra) - Azure",
"math-azr-Precalculus": "MATH Level 1 (Precalculus) - Azure",
"math-gcp-Algebra": "MATH Level 1 (Algebra) - Google Cloud",
"math-gcp-Counting&Probability": "MATH Level 1 (Counting&Probability) - Google Cloud",
"math-gcp-Geometry": "MATH Level 1 (Geometry) - Google Cloud",
"math-gcp-IntermediateAlgebra": "MATH Level 1 (IntermediateAlgebra) - Google Cloud",
"math-gcp-NumberTheory": "MATH Level 1 (NumberTheory) - Google Cloud",
"math-gcp-Prealgebra": "MATH Level 1 (Prealgebra) - Google Cloud",
"math-gcp-Precalculus": "MATH Level 1 (Precalculus) - Google Cloud",
},
}
def load_data(file_name):
"""
Load the data from the csv file
"""
data = pd.read_excel(
file_name,
sheet_name=None,
header=None
)
results = {}
for task_name, task_id in tasks.items():
for setting_name in task_w_settings[task_name]:
setting_id = settings[setting_name]
sheet_name = f"{task_id}-{setting_id}" if setting_id else task_id
sheet_data = data[sheet_name]
results_by_dataset = {}
# Find the rows that contain the dataset ids
# dataset_ids = datasets[task_id].keys()
row_ids = []
for i, row in sheet_data.iterrows():
if "Models/" in row[0]:
row_ids.append(i)
row_ids.append(len(sheet_data))
# Get the data for each dataset
for i in range(len(row_ids) - 1):
dataset_id = sheet_data.iloc[row_ids[i]][0].split('/')[-1]
dataset_name = datasets[task_id][dataset_id]
dataset_data = sheet_data.iloc[row_ids[i] + 1: row_ids[i + 1]]
dataset_data = dataset_data.fillna(f'-')
header = sheet_data.iloc[0]
header[0] = "Models"
# Create new pandas dataframe
dataset_data = pd.DataFrame(
dataset_data.values, columns=header)
# column_dtypes = {'Models': 'string'}
# for column in header[1:]:
# column_dtypes[column] = 'float'
# dataset_data = dataset_data.astype(column_dtypes)
results_by_dataset[dataset_name] = dataset_data
results[f"{task_id}-{setting_id}"] = results_by_dataset
return results
resutls = load_data(RESULT_FILE)