alvinwatner commited on
Commit
9525173
·
1 Parent(s): 7e9570d

Updating training metrics

Browse files
Files changed (1) hide show
  1. run_summarization_flax.py +26 -9
run_summarization_flax.py CHANGED
@@ -589,8 +589,10 @@ def main():
589
  desc="Running tokenizer on prediction dataset",
590
  )
591
 
592
- # Metric
593
- metric = load_metric("rouge")
 
 
594
 
595
  def postprocess_text(preds, labels):
596
  preds = [pred.strip() for pred in preds]
@@ -609,14 +611,29 @@ def main():
609
  # Some simple post-processing
610
  decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
611
 
612
- result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
 
 
613
  # Extract a few results from ROUGE
614
- result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
615
-
616
- prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
617
- result["gen_len"] = np.mean(prediction_lens)
618
- result = {k: round(v, 4) for k, v in result.items()}
619
- return result
 
 
 
 
 
 
 
 
 
 
 
 
 
620
 
621
  # Enable tensorboard only on the master node
622
  has_tensorboard = is_tensorboard_available()
 
589
  desc="Running tokenizer on prediction dataset",
590
  )
591
 
592
+ # Metric
593
+ rouge_metric = load_metric("rouge")
594
+ bleu_metric = load_metric("bleu")
595
+ meteor_metric = load_metric("meteor")
596
 
597
  def postprocess_text(preds, labels):
598
  preds = [pred.strip() for pred in preds]
 
611
  # Some simple post-processing
612
  decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
613
 
614
+ results = {}
615
+ rouge_scores = rouge_metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer = True, \
616
+ rouge_types=['rougeL'])
617
  # Extract a few results from ROUGE
618
+ rouge_scores = {key: value.mid.fmeasure * 100 for key, value in rouge_scores.items()}
619
+ rouge_scores = {k: round(v, 4) for k, v in rouge_scores.items()}
620
+ meteor_scores = meteor_metric.compute(predictions=decoded_preds, references=decoded_labels)
621
+ meteor_scores = {k: round(v, 4) for k, v in meteor_scores.items()}
622
+
623
+ # Compute bleu-1,2,3,4 scores
624
+ # Postprocess the predictions and references to compute bleu scores
625
+ tokenized_predictions = [decoded_preds[i].split() for i in range(len(decoded_preds))]
626
+ tokenized_labels = [[decoded_labels[i].split()] for i in range(len(decoded_labels))]
627
+ bleu_scores = {f'bleu-{i}' : \
628
+ bleu_metric.compute(predictions=tokenized_predictions, references=tokenized_labels, max_order=i)['bleu']\
629
+ for i in range(1,5)}
630
+ bleu_scores = {k: round(v, 4) for k, v in bleu_scores.items()}
631
+
632
+ results.update(bleu_scores)
633
+ results.update(rouge_scores)
634
+ results.update(meteor_scores)
635
+
636
+ return results
637
 
638
  # Enable tensorboard only on the master node
639
  has_tensorboard = is_tensorboard_available()