File size: 2,912 Bytes
69882ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import dataclasses
import os
import os.path
import re

from datasets import load_dataset
from datasets import Audio
import jiwer
import torch
from transformers import AutoProcessor, Wav2Vec2ForCTC
from transformers.models.wav2vec2.processing_wav2vec2 import Wav2Vec2Processor

MODEL = "xekri/wav2vec2-common_voice_13_0-eo-10"
DATA = "validation[:10]"

chars_to_ignore_regex = "[-!\"'(),.:;=?_`¨«¸»ʼ‑–—‘’“”„…‹›♫?]"
chars_to_substitute = {
    "przy": "pŝe",
    "byn": "bin",
    "cx": "ĉ",
    "sx": "ŝ",
    "fi": "fi",
    "fl": "fl",
    "ǔ": "ŭ",
    "ñ": "nj",
    "á": "a",
    "é": "e",
    "ü": "ŭ",
    "y": "j",
    "qu": "ku",
}


def remove_special_characters(text: str) -> str:
    text = re.sub(chars_to_ignore_regex, "", text)
    text = text.lower()
    return text


def substitute_characters(text: str) -> str:
    for k, v in chars_to_substitute.items():
        text.replace(k, v)
    text = text.lower()
    return text


@dataclasses.dataclass
class EvalResult:
    filename: str
    cer: float
    loss: float
    actual: str
    predicted: str

    def print(self) -> None:
        print(f"FILE {self.filename}")
        print(f"CERR {self.cer}")
        print(f"LOSS {self.loss}")
        print(f"ACTU {self.actual}")
        print(f"PRED {self.predicted}")


def evaluate(processor: Wav2Vec2Processor, model, example) -> EvalResult:
    """Evaluates a single example."""
    audio_file = example["path"]
    d, n = os.path.split(audio_file)
    f = os.listdir(d)[0]
    audio_file = os.path.join(d, f, n)

    inputs = processor(
        audio=example["audio"]["array"], sampling_rate=16000, return_tensors="pt"
    )

    with torch.no_grad():
        logits = model(**inputs).logits
    predicted_ids = logits.argmax(dim=-1)
    predict = processor.batch_decode(predicted_ids)[0]

    actual = example["sentence"]
    actual = substitute_characters(remove_special_characters(actual))
    inputs["labels"] = processor(text=actual, return_tensors="pt").input_ids
    loss = model(**inputs).loss
    cer = jiwer.cer(actual, predict)

    return EvalResult(os.path.basename(audio_file), cer, loss, actual, predict)


def run() -> None:
    cv13 = load_dataset(
        "mozilla-foundation/common_voice_13_0",
        "eo",
        split=DATA,
    )
    cv13 = cv13.cast_column("audio", Audio(sampling_rate=16000))

    processor: Wav2Vec2Processor = AutoProcessor.from_pretrained(MODEL)
    model = Wav2Vec2ForCTC.from_pretrained(MODEL)

    print("| Actual<br>Predicted | CER |")
    print("|:--------------------|:----|")

    for i, example in enumerate(cv13):
        results = evaluate(processor, model, example)
        print(f"| `{results.actual}`<br>`{results.predicted}` | {results.cer} |")


if __name__ == "__main__":
    run()