MCK-02 commited on
Commit
32f760f
1 Parent(s): 8075e4b
Files changed (1) hide show
  1. app.py +235 -0
app.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import numpy as np
4
+ import torch
5
+
6
+ from datasets import load_dataset
7
+ from evaluate import load as load_metric
8
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
9
+ from sklearn.metrics import accuracy_score, f1_score
10
+ from tqdm.auto import tqdm
11
+ from torch.utils.data import DataLoader
12
+
13
+ select = st.selectbox('Which model would you like to evaluate?',
14
+ ('Bart', 'mBart'))
15
+
16
+ def get_datasets():
17
+ if select == "Bart"
18
+ all_datasets = ["Communication Networks: unseen questions", "Communication Networks: unseen answers"]
19
+ if select == "mBart"
20
+ all_datasets = ["Micro Job: unseen questions", "Micro Job: unseen answers", "Legal Domain: unseen questions", "Legal Domain: unseen answers"]
21
+ return all_datasets
22
+
23
+ all_datasets = get_datasets()
24
+
25
+ def get_split(dataset_name):
26
+ if dataset_name == "Communication Networks: unseen questions"
27
+ split = load_dataset("Short-Answer-Feedback/saf_communication_networks_english", split="test_unseen_questions")
28
+ if dataset_name == "Communication Networks: unseen answers"
29
+ split = load_dataset("Short-Answer-Feedback/saf_communication_networks_english", split="test_unseen_answers")
30
+ if dataset_name == "Micro Job: unseen questions"
31
+ split = load_dataset("Short-Answer-Feedback/saf_micro_job_german", split="test_unseen_questions")
32
+ if dataset_name == "Micro Job: unseen answers"
33
+ split = load_dataset("Short-Answer-Feedback/saf_micro_job_german", split="test_unseen_answers")
34
+ if dataset_name == "Legal Domain: unseen questions"
35
+ split = load_dataset("Short-Answer-Feedback/saf_legal_domain_german", split="test_unseen_questions")
36
+ if dataset_name == "Legal Domain: unseen answers"
37
+ split = load_dataset("Short-Answer-Feedback/saf_legal_domain_german", split="test_unseen_answers")
38
+ return split
39
+
40
+ def get_model(datasetname):
41
+ if datasetname == "Communication Networks: unseen questions" or datasetname == "Communication Networks: unseen answers"
42
+ model = "Short-Answer-Feedback/bart-finetuned-saf-communication-networks"
43
+ if datasetname == "Micro Job: unseen questions" or datasetname == "Micro Job: unseen answers"
44
+ model = "Short-Answer-Feedback/mbart-finetuned-saf-micro-job"
45
+ if datasetname == "Legal Domain: unseen questions" or datasetname == "Legal Domain: unseen answers"
46
+ model = "Short-Answer-Feedback/mbart-finetuned-saf-legal-domain"
47
+ return model
48
+
49
+ def get_tokenizer(datasetname):
50
+ if datasetname == "Communication Networks: unseen questions" or datasetname == "Communication Networks: unseen answers"
51
+ tokenizer = "Short-Answer-Feedback/bart-finetuned-saf-communication-networks"
52
+ if datasetname == "Micro Job: unseen questions" or datasetname == "Micro Job: unseen answers"
53
+ tokenizer = "Short-Answer-Feedback/mbart-finetuned-saf-micro-job"
54
+ if datasetname == "Legal Domain: unseen questions" or datasetname == "Legal Domain: unseen answers"
55
+ tokenizer = "Short-Answer-Feedback/mbart-finetuned-saf-legal-domain"
56
+ return tokenizer
57
+
58
+ sacrebleu = load_metric('sacrebleu')
59
+ rouge = load_metric('rouge')
60
+ meteor = load_metric('meteor')
61
+ bertscore = load_metric('bertscore')
62
+
63
+ MAX_INPUT_LENGTH = 256
64
+ MAX_TARGET_LENGTH = 128
65
+
66
+ def preprocess_function(examples):
67
+ """
68
+ Preprocess entries of the given dataset
69
+
70
+ Params:
71
+ examples (Dataset): dataset to be preprocessed
72
+ Returns:
73
+ model_inputs (BatchEncoding): tokenized dataset entries
74
+ """
75
+ inputs, targets = [], []
76
+ for i in range(len(examples['question'])):
77
+ inputs.append(f"Antwort: {examples['provided_answer'][i]} Lösung: {examples['reference_answer'][i]} Frage: {examples['question'][i]}")
78
+ targets.append(f"{examples['verification_feedback'][i]} Feedback: {examples['answer_feedback'][i]}")
79
+
80
+ # apply tokenization to inputs and labels
81
+ model_inputs = tokenizer(inputs, max_length=MAX_INPUT_LENGTH, padding='max_length', truncation=True)
82
+ labels = tokenizer(text_target=targets, max_length=MAX_TARGET_LENGTH, padding='max_length', truncation=True)
83
+
84
+ model_inputs['labels'] = labels['input_ids']
85
+
86
+ return model_inputs
87
+
88
+
89
+ def flatten_list(l):
90
+ """
91
+ Utility function to convert a list of lists into a flattened list
92
+ Params:
93
+ l (list of lists): list to be flattened
94
+ Returns:
95
+ A flattened list with the elements of the original list
96
+ """
97
+ return [item for sublist in l for item in sublist]
98
+
99
+
100
+ def extract_feedback(predictions):
101
+ """
102
+ Utility function to extract the feedback from the predictions of the model
103
+ Params:
104
+ predictions (list): complete model predictions
105
+ Returns:
106
+ feedback (list): extracted feedback from the model's predictions
107
+ """
108
+ feedback = []
109
+ # iterate through predictions and try to extract predicted feedback
110
+ for pred in predictions:
111
+ try:
112
+ fb = pred.split(':', 1)[1]
113
+ except IndexError:
114
+ try:
115
+ if pred.lower().startswith('partially correct'):
116
+ fb = pred.split(' ', 1)[2]
117
+ else:
118
+ fb = pred.split(' ', 1)[1]
119
+ except IndexError:
120
+ fb = pred
121
+ feedback.append(fb.strip())
122
+
123
+ return feedback
124
+
125
+
126
+ def extract_labels(predictions):
127
+ """
128
+ Utility function to extract the labels from the predictions of the model
129
+ Params:
130
+ predictions (list): complete model predictions
131
+ Returns:
132
+ feedback (list): extracted labels from the model's predictions
133
+ """
134
+ labels = []
135
+ for pred in predictions:
136
+ if pred.lower().startswith('correct'):
137
+ label = 'Correct'
138
+ elif pred.lower().startswith('partially correct'):
139
+ label = 'Partially correct'
140
+ elif pred.lower().startswith('incorrect'):
141
+ label = 'Incorrect'
142
+ else:
143
+ label = 'Unknown label'
144
+ labels.append(label)
145
+
146
+ return labels
147
+
148
+
149
+ def get_predictions_labels(model, dataloader):
150
+ """
151
+ Evaluate model on the given dataset
152
+
153
+ Params:
154
+ model (PreTrainedModel): seq2seq model
155
+ dataloader (torch Dataloader): dataloader of the dataset to be used for evaluation
156
+ Returns:
157
+ results (dict): dictionary with the computed evaluation metrics
158
+ predictions (list): list of the decoded predictions of the model
159
+ """
160
+ decoded_preds, decoded_labels = [], []
161
+
162
+ model.eval()
163
+ # iterate through batchs in the dataloader
164
+ for batch in tqdm(dataloader):
165
+ with torch.no_grad():
166
+ batch = {k: v.to(device) for k, v in batch.items()}
167
+ # generate tokens from batch
168
+ generated_tokens = model.generate(
169
+ batch['input_ids'],
170
+ attention_mask=batch['attention_mask'],
171
+ max_length=MAX_TARGET_LENGTH
172
+ )
173
+ # get golden labels from batch
174
+ labels_batch = batch['labels']
175
+
176
+ # decode model predictions and golden labels
177
+ decoded_preds_batch = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
178
+ decoded_labels_batch = tokenizer.batch_decode(labels_batch, skip_special_tokens=True)
179
+
180
+ decoded_preds.append(decoded_preds_batch)
181
+ decoded_labels.append(decoded_labels_batch)
182
+
183
+ # convert predictions and golden labels into flattened lists
184
+ predictions = flatten_list(decoded_preds)
185
+ labels = flatten_list(decoded_labels)
186
+
187
+ return predictions, labels
188
+
189
+
190
+
191
+
192
+
193
+ def load_data():
194
+ df = pd.DataFrame(columns=['Model', 'Dataset', 'SacreBLEU', 'ROUGE-2', 'METEOR', 'BERTScore', 'Accuracy', 'Weighted F1', 'Macro F1'])
195
+ for ds in all_datasets:
196
+ split = get_split(ds)
197
+ model = AutoModelForSeq2SeqLM.from_pretrained(get_model(ds))
198
+ tokenizer = AutoTokenizer.from_pretrained(get_tokenizer(ds))
199
+
200
+ processed_dataset = split.map(
201
+ preprocess_function,
202
+ batched=True,
203
+ remove_columns=split.column_names
204
+ )
205
+ processed_dataset.set_format('torch')
206
+
207
+ dataloader = DataLoader(processed_dataset, batch_size=4)
208
+
209
+ predictions, labels = get_predictions_labels(model, dataloader)
210
+
211
+ predicted_feedback = extract_feedback(predictions)
212
+ predicted_labels = extract_labels(predictions)
213
+
214
+ reference_feedback = [x.split('Feedback:', 1)[1].strip() for x in labels]
215
+ reference_labels = [x.split('Feedback:', 1)[0].strip() for x in labels]
216
+
217
+ rouge_score = rouge.compute(predictions=predicted_feedback, references=reference_feedback)['rouge2']
218
+ bleu_score = sacrebleu.compute(predictions=predicted_feedback, references=[[x] for x in reference_feedback])['score']
219
+ meteor_score = meteor.compute(predictions=predicted_feedback, references=reference_feedback)['meteor']
220
+ bert_score = bertscore.compute(predictions=predicted_feedback, references=reference_feedback, lang='de', model_type='bert-base-multilingual-cased', rescale_with_baseline=True)
221
+
222
+ reference_labels_np = np.array(reference_labels)
223
+
224
+ accuracy_value = accuracy_score(reference_labels_np, predicted_labels)
225
+ f1_weighted_value = f1_score(reference_labels_np, predicted_labels, average='weighted')
226
+ f1_macro_value = f1_score(reference_labels_np, predicted_labels, average='macro', labels=['Incorrect', 'Partially correct', 'Correct'])
227
+
228
+ new_row = pd.Dataframe("Model" : get_model(ds), "Dataset" : ds, "SacreBLEU" : bleu_score, "ROUGE-2" : rouge_score, "METEOR" : meteor_score, "BERTScore" : bert_score, "Accuracy" : accuracy_value, "Weighted F1" : f1_weighted_value, "Macro F1": f1_macro_value)
229
+
230
+ df = pd.concat([df, new_row])
231
+ return df
232
+
233
+ dataframe = load_data()
234
+
235
+ st.dataframe(dataframe)