|
|
|
|
|
|
|
|
|
|
|
|
|
import importlib |
|
import os |
|
from abc import ABC, abstractmethod |
|
|
|
from fairseq import registry |
|
from omegaconf import DictConfig |
|
|
|
|
|
class BaseScorer(ABC): |
|
def __init__(self, cfg): |
|
self.cfg = cfg |
|
self.ref = [] |
|
self.pred = [] |
|
|
|
def add_string(self, ref, pred): |
|
self.ref.append(ref) |
|
self.pred.append(pred) |
|
|
|
@abstractmethod |
|
def score(self) -> float: |
|
pass |
|
|
|
@abstractmethod |
|
def result_string(self) -> str: |
|
pass |
|
|
|
|
|
_build_scorer, register_scorer, SCORER_REGISTRY, _ = registry.setup_registry( |
|
"--scoring", default="bleu" |
|
) |
|
|
|
|
|
def build_scorer(choice, tgt_dict): |
|
_choice = choice._name if isinstance(choice, DictConfig) else choice |
|
|
|
if _choice == "bleu": |
|
from fairseq.scoring import bleu |
|
|
|
return bleu.Scorer( |
|
bleu.BleuConfig(pad=tgt_dict.pad(), eos=tgt_dict.eos(), unk=tgt_dict.unk()) |
|
) |
|
return _build_scorer(choice) |
|
|
|
|
|
|
|
for file in sorted(os.listdir(os.path.dirname(__file__))): |
|
if file.endswith(".py") and not file.startswith("_"): |
|
module = file[: file.find(".py")] |
|
importlib.import_module("fairseq.scoring." + module) |
|
|