File size: 2,019 Bytes
d5175d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass, field

from fairseq.dataclass import FairseqDataclass
from fairseq.scoring import BaseScorer, register_scorer
from fairseq.scoring.tokenizer import EvaluationTokenizer


@dataclass
class WerScorerConfig(FairseqDataclass):
    wer_tokenizer: EvaluationTokenizer.ALL_TOKENIZER_TYPES = field(
        default="none", metadata={"help": "sacreBLEU tokenizer to use for evaluation"}
    )
    wer_remove_punct: bool = field(
        default=False, metadata={"help": "remove punctuation"}
    )
    wer_char_level: bool = field(
        default=False, metadata={"help": "evaluate at character level"}
    )
    wer_lowercase: bool = field(default=False, metadata={"help": "lowercasing"})


@register_scorer("wer", dataclass=WerScorerConfig)
class WerScorer(BaseScorer):
    def __init__(self, cfg):
        super().__init__(cfg)
        self.reset()
        try:
            import editdistance as ed
        except ImportError:
            raise ImportError("Please install editdistance to use WER scorer")
        self.ed = ed
        self.tokenizer = EvaluationTokenizer(
            tokenizer_type=self.cfg.wer_tokenizer,
            lowercase=self.cfg.wer_lowercase,
            punctuation_removal=self.cfg.wer_remove_punct,
            character_tokenization=self.cfg.wer_char_level,
        )

    def reset(self):
        self.distance = 0
        self.ref_length = 0

    def add_string(self, ref, pred):
        ref_items = self.tokenizer.tokenize(ref).split()
        pred_items = self.tokenizer.tokenize(pred).split()
        self.distance += self.ed.eval(ref_items, pred_items)
        self.ref_length += len(ref_items)

    def result_string(self):
        return f"WER: {self.score():.2f}"

    def score(self):
        return 100.0 * self.distance / self.ref_length if self.ref_length > 0 else 0