File size: 1,308 Bytes
35c1cfd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from aac_metrics import evaluate
import sys

def compute_wer(ref_file,

                hyp_file):
    pred_captions = []
    gt_captions = []

    with open(hyp_file, 'r') as hyp_reader:
        for line in hyp_reader:
            key = line.strip().split()[0]
            value = line.strip().split()[1:]
            pred_captions.append(value)
    with open(ref_file, 'r') as ref_reader:
        for line in ref_reader:
            key = line.strip().split()[0]
            value = line.strip().split()[1:]
            gt_captions.append(value)

    print('Used lines:', len(pred_captions))
    candidates: list[str] = pred_captions
    mult_references: list[list[str]] = [[gt] for gt in gt_captions]

    corpus_scores, _ = evaluate(candidates, mult_references)
    print(corpus_scores)
    # dict containing the score of each metric: "bleu_1", "bleu_2", "bleu_3", "bleu_4", "rouge_l", "meteor", "cider_d", "spice", "spider"
    # {"bleu_1": tensor(0.4278), "bleu_2": ..., ...}


if __name__ == '__main__':
    if len(sys.argv) != 3:
        print("usage : python compute_aac_metrics.py test.ref test.hyp")
        sys.exit(0)

    ref_file = sys.argv[1]
    hyp_file = sys.argv[2]
    cer_detail_file = sys.argv[3]
    compute_wer(ref_file, hyp_file, cer_detail_file)