fix syntax errors
Browse files
app.py
CHANGED
@@ -14,44 +14,44 @@ select = st.selectbox('Which model would you like to evaluate?',
|
|
14 |
('Bart', 'mBart'))
|
15 |
|
16 |
def get_datasets():
|
17 |
-
if select == 'Bart'
|
18 |
all_datasets = ["Communication Networks: unseen questions", "Communication Networks: unseen answers"]
|
19 |
-
if select == 'mBart'
|
20 |
all_datasets = ["Micro Job: unseen questions", "Micro Job: unseen answers", "Legal Domain: unseen questions", "Legal Domain: unseen answers"]
|
21 |
return all_datasets
|
22 |
|
23 |
all_datasets = get_datasets()
|
24 |
|
25 |
def get_split(dataset_name):
|
26 |
-
if dataset_name == "Communication Networks: unseen questions"
|
27 |
split = load_dataset("Short-Answer-Feedback/saf_communication_networks_english", split="test_unseen_questions")
|
28 |
-
if dataset_name == "Communication Networks: unseen answers"
|
29 |
split = load_dataset("Short-Answer-Feedback/saf_communication_networks_english", split="test_unseen_answers")
|
30 |
-
if dataset_name == "Micro Job: unseen questions"
|
31 |
split = load_dataset("Short-Answer-Feedback/saf_micro_job_german", split="test_unseen_questions")
|
32 |
-
if dataset_name == "Micro Job: unseen answers"
|
33 |
split = load_dataset("Short-Answer-Feedback/saf_micro_job_german", split="test_unseen_answers")
|
34 |
-
if dataset_name == "Legal Domain: unseen questions"
|
35 |
split = load_dataset("Short-Answer-Feedback/saf_legal_domain_german", split="test_unseen_questions")
|
36 |
-
if dataset_name == "Legal Domain: unseen answers"
|
37 |
split = load_dataset("Short-Answer-Feedback/saf_legal_domain_german", split="test_unseen_answers")
|
38 |
return split
|
39 |
|
40 |
def get_model(datasetname):
|
41 |
-
if datasetname == "Communication Networks: unseen questions" or datasetname == "Communication Networks: unseen answers"
|
42 |
model = "Short-Answer-Feedback/bart-finetuned-saf-communication-networks"
|
43 |
-
if datasetname == "Micro Job: unseen questions" or datasetname == "Micro Job: unseen answers"
|
44 |
model = "Short-Answer-Feedback/mbart-finetuned-saf-micro-job"
|
45 |
-
if datasetname == "Legal Domain: unseen questions" or datasetname == "Legal Domain: unseen answers"
|
46 |
model = "Short-Answer-Feedback/mbart-finetuned-saf-legal-domain"
|
47 |
return model
|
48 |
|
49 |
def get_tokenizer(datasetname):
|
50 |
-
if datasetname == "Communication Networks: unseen questions" or datasetname == "Communication Networks: unseen answers"
|
51 |
tokenizer = "Short-Answer-Feedback/bart-finetuned-saf-communication-networks"
|
52 |
-
if datasetname == "Micro Job: unseen questions" or datasetname == "Micro Job: unseen answers"
|
53 |
tokenizer = "Short-Answer-Feedback/mbart-finetuned-saf-micro-job"
|
54 |
-
if datasetname == "Legal Domain: unseen questions" or datasetname == "Legal Domain: unseen answers"
|
55 |
tokenizer = "Short-Answer-Feedback/mbart-finetuned-saf-legal-domain"
|
56 |
return tokenizer
|
57 |
|
@@ -212,7 +212,7 @@ def load_data():
|
|
212 |
predicted_labels = extract_labels(predictions)
|
213 |
|
214 |
reference_feedback = [x.split('Feedback:', 1)[1].strip() for x in labels]
|
215 |
-
|
216 |
|
217 |
rouge_score = rouge.compute(predictions=predicted_feedback, references=reference_feedback)['rouge2']
|
218 |
bleu_score = sacrebleu.compute(predictions=predicted_feedback, references=[[x] for x in reference_feedback])['score']
|
|
|
14 |
('Bart', 'mBart'))
|
15 |
|
16 |
def get_datasets():
|
17 |
+
if select == 'Bart':
|
18 |
all_datasets = ["Communication Networks: unseen questions", "Communication Networks: unseen answers"]
|
19 |
+
if select == 'mBart':
|
20 |
all_datasets = ["Micro Job: unseen questions", "Micro Job: unseen answers", "Legal Domain: unseen questions", "Legal Domain: unseen answers"]
|
21 |
return all_datasets
|
22 |
|
23 |
all_datasets = get_datasets()
|
24 |
|
25 |
def get_split(dataset_name):
|
26 |
+
if dataset_name == "Communication Networks: unseen questions":
|
27 |
split = load_dataset("Short-Answer-Feedback/saf_communication_networks_english", split="test_unseen_questions")
|
28 |
+
if dataset_name == "Communication Networks: unseen answers":
|
29 |
split = load_dataset("Short-Answer-Feedback/saf_communication_networks_english", split="test_unseen_answers")
|
30 |
+
if dataset_name == "Micro Job: unseen questions":
|
31 |
split = load_dataset("Short-Answer-Feedback/saf_micro_job_german", split="test_unseen_questions")
|
32 |
+
if dataset_name == "Micro Job: unseen answers":
|
33 |
split = load_dataset("Short-Answer-Feedback/saf_micro_job_german", split="test_unseen_answers")
|
34 |
+
if dataset_name == "Legal Domain: unseen questions":
|
35 |
split = load_dataset("Short-Answer-Feedback/saf_legal_domain_german", split="test_unseen_questions")
|
36 |
+
if dataset_name == "Legal Domain: unseen answers":
|
37 |
split = load_dataset("Short-Answer-Feedback/saf_legal_domain_german", split="test_unseen_answers")
|
38 |
return split
|
39 |
|
40 |
def get_model(datasetname):
|
41 |
+
if datasetname == "Communication Networks: unseen questions" or datasetname == "Communication Networks: unseen answers":
|
42 |
model = "Short-Answer-Feedback/bart-finetuned-saf-communication-networks"
|
43 |
+
if datasetname == "Micro Job: unseen questions" or datasetname == "Micro Job: unseen answers":
|
44 |
model = "Short-Answer-Feedback/mbart-finetuned-saf-micro-job"
|
45 |
+
if datasetname == "Legal Domain: unseen questions" or datasetname == "Legal Domain: unseen answers":
|
46 |
model = "Short-Answer-Feedback/mbart-finetuned-saf-legal-domain"
|
47 |
return model
|
48 |
|
49 |
def get_tokenizer(datasetname):
|
50 |
+
if datasetname == "Communication Networks: unseen questions" or datasetname == "Communication Networks: unseen answers":
|
51 |
tokenizer = "Short-Answer-Feedback/bart-finetuned-saf-communication-networks"
|
52 |
+
if datasetname == "Micro Job: unseen questions" or datasetname == "Micro Job: unseen answers":
|
53 |
tokenizer = "Short-Answer-Feedback/mbart-finetuned-saf-micro-job"
|
54 |
+
if datasetname == "Legal Domain: unseen questions" or datasetname == "Legal Domain: unseen answers":
|
55 |
tokenizer = "Short-Answer-Feedback/mbart-finetuned-saf-legal-domain"
|
56 |
return tokenizer
|
57 |
|
|
|
212 |
predicted_labels = extract_labels(predictions)
|
213 |
|
214 |
reference_feedback = [x.split('Feedback:', 1)[1].strip() for x in labels]
|
215 |
+
reference_labels = [x.split('Feedback:', 1)[0].strip() for x in labels]
|
216 |
|
217 |
rouge_score = rouge.compute(predictions=predicted_feedback, references=reference_feedback)['rouge2']
|
218 |
bleu_score = sacrebleu.compute(predictions=predicted_feedback, references=[[x] for x in reference_feedback])['score']
|