import dataclasses
import os
import os.path
import re
from datasets import load_dataset
from datasets import Audio
import jiwer
import torch
from transformers import AutoProcessor, Wav2Vec2ForCTC
from transformers.models.wav2vec2.processing_wav2vec2 import Wav2Vec2Processor
MODEL = "xekri/wav2vec2-common_voice_13_0-eo-10"
DATA = "validation[:10]"
chars_to_ignore_regex = "[-!\"'(),.:;=?_`¨«¸»ʼ‑–—‘’“”„…‹›♫?]"
chars_to_substitute = {
"przy": "pŝe",
"byn": "bin",
"cx": "ĉ",
"sx": "ŝ",
"fi": "fi",
"fl": "fl",
"ǔ": "ŭ",
"ñ": "nj",
"á": "a",
"é": "e",
"ü": "ŭ",
"y": "j",
"qu": "ku",
}
def remove_special_characters(text: str) -> str:
text = re.sub(chars_to_ignore_regex, "", text)
text = text.lower()
return text
def substitute_characters(text: str) -> str:
for k, v in chars_to_substitute.items():
text.replace(k, v)
text = text.lower()
return text
@dataclasses.dataclass
class EvalResult:
filename: str
cer: float
loss: float
actual: str
predicted: str
def print(self) -> None:
print(f"FILE {self.filename}")
print(f"CERR {self.cer}")
print(f"LOSS {self.loss}")
print(f"ACTU {self.actual}")
print(f"PRED {self.predicted}")
def evaluate(processor: Wav2Vec2Processor, model, example) -> EvalResult:
"""Evaluates a single example."""
audio_file = example["path"]
d, n = os.path.split(audio_file)
f = os.listdir(d)[0]
audio_file = os.path.join(d, f, n)
inputs = processor(
audio=example["audio"]["array"], sampling_rate=16000, return_tensors="pt"
)
with torch.no_grad():
logits = model(**inputs).logits
predicted_ids = logits.argmax(dim=-1)
predict = processor.batch_decode(predicted_ids)[0]
actual = example["sentence"]
actual = substitute_characters(remove_special_characters(actual))
inputs["labels"] = processor(text=actual, return_tensors="pt").input_ids
loss = model(**inputs).loss
cer = jiwer.cer(actual, predict)
return EvalResult(os.path.basename(audio_file), cer, loss, actual, predict)
def run() -> None:
cv13 = load_dataset(
"mozilla-foundation/common_voice_13_0",
"eo",
split=DATA,
)
cv13 = cv13.cast_column("audio", Audio(sampling_rate=16000))
processor: Wav2Vec2Processor = AutoProcessor.from_pretrained(MODEL)
model = Wav2Vec2ForCTC.from_pretrained(MODEL)
print("| Actual
Predicted | CER |")
print("|:--------------------|:----|")
for i, example in enumerate(cv13):
results = evaluate(processor, model, example)
print(f"| `{results.actual}`
`{results.predicted}` | {results.cer} |")
if __name__ == "__main__":
run()