alvinwatner
commited on
Commit
·
9525173
1
Parent(s):
7e9570d
Updating training metrics
Browse files- run_summarization_flax.py +26 -9
run_summarization_flax.py
CHANGED
@@ -589,8 +589,10 @@ def main():
|
|
589 |
desc="Running tokenizer on prediction dataset",
|
590 |
)
|
591 |
|
592 |
-
# Metric
|
593 |
-
|
|
|
|
|
594 |
|
595 |
def postprocess_text(preds, labels):
|
596 |
preds = [pred.strip() for pred in preds]
|
@@ -609,14 +611,29 @@ def main():
|
|
609 |
# Some simple post-processing
|
610 |
decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
|
611 |
|
612 |
-
|
|
|
|
|
613 |
# Extract a few results from ROUGE
|
614 |
-
|
615 |
-
|
616 |
-
|
617 |
-
|
618 |
-
|
619 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
620 |
|
621 |
# Enable tensorboard only on the master node
|
622 |
has_tensorboard = is_tensorboard_available()
|
|
|
589 |
desc="Running tokenizer on prediction dataset",
|
590 |
)
|
591 |
|
592 |
+
# Metric
|
593 |
+
rouge_metric = load_metric("rouge")
|
594 |
+
bleu_metric = load_metric("bleu")
|
595 |
+
meteor_metric = load_metric("meteor")
|
596 |
|
597 |
def postprocess_text(preds, labels):
|
598 |
preds = [pred.strip() for pred in preds]
|
|
|
611 |
# Some simple post-processing
|
612 |
decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
|
613 |
|
614 |
+
results = {}
|
615 |
+
rouge_scores = rouge_metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer = True, \
|
616 |
+
rouge_types=['rougeL'])
|
617 |
# Extract a few results from ROUGE
|
618 |
+
rouge_scores = {key: value.mid.fmeasure * 100 for key, value in rouge_scores.items()}
|
619 |
+
rouge_scores = {k: round(v, 4) for k, v in rouge_scores.items()}
|
620 |
+
meteor_scores = meteor_metric.compute(predictions=decoded_preds, references=decoded_labels)
|
621 |
+
meteor_scores = {k: round(v, 4) for k, v in meteor_scores.items()}
|
622 |
+
|
623 |
+
# Compute bleu-1,2,3,4 scores
|
624 |
+
# Postprocess the predictions and references to compute bleu scores
|
625 |
+
tokenized_predictions = [decoded_preds[i].split() for i in range(len(decoded_preds))]
|
626 |
+
tokenized_labels = [[decoded_labels[i].split()] for i in range(len(decoded_labels))]
|
627 |
+
bleu_scores = {f'bleu-{i}' : \
|
628 |
+
bleu_metric.compute(predictions=tokenized_predictions, references=tokenized_labels, max_order=i)['bleu']\
|
629 |
+
for i in range(1,5)}
|
630 |
+
bleu_scores = {k: round(v, 4) for k, v in bleu_scores.items()}
|
631 |
+
|
632 |
+
results.update(bleu_scores)
|
633 |
+
results.update(rouge_scores)
|
634 |
+
results.update(meteor_scores)
|
635 |
+
|
636 |
+
return results
|
637 |
|
638 |
# Enable tensorboard only on the master node
|
639 |
has_tensorboard = is_tensorboard_available()
|