File size: 2,019 Bytes
d5175d3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass, field
from fairseq.dataclass import FairseqDataclass
from fairseq.scoring import BaseScorer, register_scorer
from fairseq.scoring.tokenizer import EvaluationTokenizer
@dataclass
class WerScorerConfig(FairseqDataclass):
wer_tokenizer: EvaluationTokenizer.ALL_TOKENIZER_TYPES = field(
default="none", metadata={"help": "sacreBLEU tokenizer to use for evaluation"}
)
wer_remove_punct: bool = field(
default=False, metadata={"help": "remove punctuation"}
)
wer_char_level: bool = field(
default=False, metadata={"help": "evaluate at character level"}
)
wer_lowercase: bool = field(default=False, metadata={"help": "lowercasing"})
@register_scorer("wer", dataclass=WerScorerConfig)
class WerScorer(BaseScorer):
def __init__(self, cfg):
super().__init__(cfg)
self.reset()
try:
import editdistance as ed
except ImportError:
raise ImportError("Please install editdistance to use WER scorer")
self.ed = ed
self.tokenizer = EvaluationTokenizer(
tokenizer_type=self.cfg.wer_tokenizer,
lowercase=self.cfg.wer_lowercase,
punctuation_removal=self.cfg.wer_remove_punct,
character_tokenization=self.cfg.wer_char_level,
)
def reset(self):
self.distance = 0
self.ref_length = 0
def add_string(self, ref, pred):
ref_items = self.tokenizer.tokenize(ref).split()
pred_items = self.tokenizer.tokenize(pred).split()
self.distance += self.ed.eval(ref_items, pred_items)
self.ref_length += len(ref_items)
def result_string(self):
return f"WER: {self.score():.2f}"
def score(self):
return 100.0 * self.distance / self.ref_length if self.ref_length > 0 else 0
|