|
import streamlit as st |
|
import pandas as pd |
|
import numpy as np |
|
import torch |
|
|
|
from datasets import load_dataset |
|
from evaluate import load as load_metric |
|
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer |
|
from sklearn.metrics import accuracy_score, f1_score |
|
from tqdm.auto import tqdm |
|
from torch.utils.data import DataLoader |
|
|
|
select = st.selectbox('Which model would you like to evaluate?', |
|
('Bart', 'mBart')) |
|
|
|
def get_datasets(): |
|
if select == 'Bart': |
|
all_datasets = ["Communication Networks: unseen questions", "Communication Networks: unseen answers"] |
|
if select == 'mBart': |
|
all_datasets = ["Micro Job: unseen questions", "Micro Job: unseen answers", "Legal Domain: unseen questions", "Legal Domain: unseen answers"] |
|
return all_datasets |
|
|
|
all_datasets = get_datasets() |
|
|
|
def get_split(dataset_name): |
|
if dataset_name == "Communication Networks: unseen questions": |
|
split = load_dataset("Short-Answer-Feedback/saf_communication_networks_english", split="test_unseen_questions") |
|
if dataset_name == "Communication Networks: unseen answers": |
|
split = load_dataset("Short-Answer-Feedback/saf_communication_networks_english", split="test_unseen_answers") |
|
if dataset_name == "Micro Job: unseen questions": |
|
split = load_dataset("Short-Answer-Feedback/saf_micro_job_german", split="test_unseen_questions") |
|
if dataset_name == "Micro Job: unseen answers": |
|
split = load_dataset("Short-Answer-Feedback/saf_micro_job_german", split="test_unseen_answers") |
|
if dataset_name == "Legal Domain: unseen questions": |
|
split = load_dataset("Short-Answer-Feedback/saf_legal_domain_german", split="test_unseen_questions") |
|
if dataset_name == "Legal Domain: unseen answers": |
|
split = load_dataset("Short-Answer-Feedback/saf_legal_domain_german", split="test_unseen_answers") |
|
return split |
|
|
|
def get_model(datasetname): |
|
if datasetname == "Communication Networks: unseen questions" or datasetname == "Communication Networks: unseen answers": |
|
model = "Short-Answer-Feedback/bart-finetuned-saf-communication-networks" |
|
if datasetname == "Micro Job: unseen questions" or datasetname == "Micro Job: unseen answers": |
|
model = "Short-Answer-Feedback/mbart-finetuned-saf-micro-job" |
|
if datasetname == "Legal Domain: unseen questions" or datasetname == "Legal Domain: unseen answers": |
|
model = "Short-Answer-Feedback/mbart-finetuned-saf-legal-domain" |
|
return model |
|
|
|
def get_tokenizer(datasetname): |
|
if datasetname == "Communication Networks: unseen questions" or datasetname == "Communication Networks: unseen answers": |
|
tokenizer = "Short-Answer-Feedback/bart-finetuned-saf-communication-networks" |
|
if datasetname == "Micro Job: unseen questions" or datasetname == "Micro Job: unseen answers": |
|
tokenizer = "Short-Answer-Feedback/mbart-finetuned-saf-micro-job" |
|
if datasetname == "Legal Domain: unseen questions" or datasetname == "Legal Domain: unseen answers": |
|
tokenizer = "Short-Answer-Feedback/mbart-finetuned-saf-legal-domain" |
|
return tokenizer |
|
|
|
sacrebleu = load_metric('sacrebleu') |
|
rouge = load_metric('rouge') |
|
meteor = load_metric('meteor') |
|
bertscore = load_metric('bertscore') |
|
|
|
MAX_INPUT_LENGTH = 256 |
|
MAX_TARGET_LENGTH = 128 |
|
|
|
def preprocess_function(examples): |
|
""" |
|
Preprocess entries of the given dataset |
|
|
|
Params: |
|
examples (Dataset): dataset to be preprocessed |
|
Returns: |
|
model_inputs (BatchEncoding): tokenized dataset entries |
|
""" |
|
inputs, targets = [], [] |
|
for i in range(len(examples['question'])): |
|
inputs.append(f"Antwort: {examples['provided_answer'][i]} Lösung: {examples['reference_answer'][i]} Frage: {examples['question'][i]}") |
|
targets.append(f"{examples['verification_feedback'][i]} Feedback: {examples['answer_feedback'][i]}") |
|
|
|
|
|
model_inputs = tokenizer(inputs, max_length=MAX_INPUT_LENGTH, padding='max_length', truncation=True) |
|
labels = tokenizer(text_target=targets, max_length=MAX_TARGET_LENGTH, padding='max_length', truncation=True) |
|
|
|
model_inputs['labels'] = labels['input_ids'] |
|
|
|
return model_inputs |
|
|
|
|
|
def flatten_list(l): |
|
""" |
|
Utility function to convert a list of lists into a flattened list |
|
Params: |
|
l (list of lists): list to be flattened |
|
Returns: |
|
A flattened list with the elements of the original list |
|
""" |
|
return [item for sublist in l for item in sublist] |
|
|
|
|
|
def extract_feedback(predictions): |
|
""" |
|
Utility function to extract the feedback from the predictions of the model |
|
Params: |
|
predictions (list): complete model predictions |
|
Returns: |
|
feedback (list): extracted feedback from the model's predictions |
|
""" |
|
feedback = [] |
|
|
|
for pred in predictions: |
|
try: |
|
fb = pred.split(':', 1)[1] |
|
except IndexError: |
|
try: |
|
if pred.lower().startswith('partially correct'): |
|
fb = pred.split(' ', 1)[2] |
|
else: |
|
fb = pred.split(' ', 1)[1] |
|
except IndexError: |
|
fb = pred |
|
feedback.append(fb.strip()) |
|
|
|
return feedback |
|
|
|
|
|
def extract_labels(predictions): |
|
""" |
|
Utility function to extract the labels from the predictions of the model |
|
Params: |
|
predictions (list): complete model predictions |
|
Returns: |
|
feedback (list): extracted labels from the model's predictions |
|
""" |
|
labels = [] |
|
for pred in predictions: |
|
if pred.lower().startswith('correct'): |
|
label = 'Correct' |
|
elif pred.lower().startswith('partially correct'): |
|
label = 'Partially correct' |
|
elif pred.lower().startswith('incorrect'): |
|
label = 'Incorrect' |
|
else: |
|
label = 'Unknown label' |
|
labels.append(label) |
|
|
|
return labels |
|
|
|
|
|
def get_predictions_labels(model, dataloader): |
|
""" |
|
Evaluate model on the given dataset |
|
|
|
Params: |
|
model (PreTrainedModel): seq2seq model |
|
dataloader (torch Dataloader): dataloader of the dataset to be used for evaluation |
|
Returns: |
|
results (dict): dictionary with the computed evaluation metrics |
|
predictions (list): list of the decoded predictions of the model |
|
""" |
|
decoded_preds, decoded_labels = [], [] |
|
|
|
model.eval() |
|
|
|
for batch in tqdm(dataloader): |
|
with torch.no_grad(): |
|
batch = {k: v.to(device) for k, v in batch.items()} |
|
|
|
generated_tokens = model.generate( |
|
batch['input_ids'], |
|
attention_mask=batch['attention_mask'], |
|
max_length=MAX_TARGET_LENGTH |
|
) |
|
|
|
labels_batch = batch['labels'] |
|
|
|
|
|
decoded_preds_batch = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) |
|
decoded_labels_batch = tokenizer.batch_decode(labels_batch, skip_special_tokens=True) |
|
|
|
decoded_preds.append(decoded_preds_batch) |
|
decoded_labels.append(decoded_labels_batch) |
|
|
|
|
|
predictions = flatten_list(decoded_preds) |
|
labels = flatten_list(decoded_labels) |
|
|
|
return predictions, labels |
|
|
|
|
|
|
|
|
|
|
|
def load_data(): |
|
df = pd.DataFrame(columns=['Model', 'Dataset', 'SacreBLEU', 'ROUGE-2', 'METEOR', 'BERTScore', 'Accuracy', 'Weighted F1', 'Macro F1']) |
|
for ds in all_datasets: |
|
split = get_split(ds) |
|
model = AutoModelForSeq2SeqLM.from_pretrained(get_model(ds)) |
|
tokenizer = AutoTokenizer.from_pretrained(get_tokenizer(ds)) |
|
|
|
processed_dataset = split.map( |
|
preprocess_function, |
|
batched=True, |
|
remove_columns=split.column_names |
|
) |
|
processed_dataset.set_format('torch') |
|
|
|
dataloader = DataLoader(processed_dataset, batch_size=4) |
|
|
|
predictions, labels = get_predictions_labels(model, dataloader) |
|
|
|
predicted_feedback = extract_feedback(predictions) |
|
predicted_labels = extract_labels(predictions) |
|
|
|
reference_feedback = [x.split('Feedback:', 1)[1].strip() for x in labels] |
|
reference_labels = [x.split('Feedback:', 1)[0].strip() for x in labels] |
|
|
|
rouge_score = rouge.compute(predictions=predicted_feedback, references=reference_feedback)['rouge2'] |
|
bleu_score = sacrebleu.compute(predictions=predicted_feedback, references=[[x] for x in reference_feedback])['score'] |
|
meteor_score = meteor.compute(predictions=predicted_feedback, references=reference_feedback)['meteor'] |
|
bert_score = bertscore.compute(predictions=predicted_feedback, references=reference_feedback, lang='de', model_type='bert-base-multilingual-cased', rescale_with_baseline=True) |
|
|
|
reference_labels_np = np.array(reference_labels) |
|
|
|
accuracy_value = accuracy_score(reference_labels_np, predicted_labels) |
|
f1_weighted_value = f1_score(reference_labels_np, predicted_labels, average='weighted') |
|
f1_macro_value = f1_score(reference_labels_np, predicted_labels, average='macro', labels=['Incorrect', 'Partially correct', 'Correct']) |
|
|
|
new_row = pd.Dataframe("Model" : get_model(ds), "Dataset" : ds, "SacreBLEU" : bleu_score, "ROUGE-2" : rouge_score, "METEOR" : meteor_score, "BERTScore" : bert_score, "Accuracy" : accuracy_value, "Weighted F1" : f1_weighted_value, "Macro F1": f1_macro_value) |
|
|
|
df = pd.concat([df, new_row]) |
|
return df |
|
|
|
dataframe = load_data() |
|
|
|
st.dataframe(dataframe) |