Spaces:
Runtime error
Runtime error
File size: 1,308 Bytes
35c1cfd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
from aac_metrics import evaluate
import sys
def compute_wer(ref_file,
hyp_file):
pred_captions = []
gt_captions = []
with open(hyp_file, 'r') as hyp_reader:
for line in hyp_reader:
key = line.strip().split()[0]
value = line.strip().split()[1:]
pred_captions.append(value)
with open(ref_file, 'r') as ref_reader:
for line in ref_reader:
key = line.strip().split()[0]
value = line.strip().split()[1:]
gt_captions.append(value)
print('Used lines:', len(pred_captions))
candidates: list[str] = pred_captions
mult_references: list[list[str]] = [[gt] for gt in gt_captions]
corpus_scores, _ = evaluate(candidates, mult_references)
print(corpus_scores)
# dict containing the score of each metric: "bleu_1", "bleu_2", "bleu_3", "bleu_4", "rouge_l", "meteor", "cider_d", "spice", "spider"
# {"bleu_1": tensor(0.4278), "bleu_2": ..., ...}
if __name__ == '__main__':
if len(sys.argv) != 3:
print("usage : python compute_aac_metrics.py test.ref test.hyp")
sys.exit(0)
ref_file = sys.argv[1]
hyp_file = sys.argv[2]
cer_detail_file = sys.argv[3]
compute_wer(ref_file, hyp_file, cer_detail_file) |