|
import dataclasses
|
|
import os
|
|
import os.path
|
|
import re
|
|
|
|
from datasets import load_dataset
|
|
from datasets import Audio
|
|
import jiwer
|
|
import torch
|
|
from transformers import AutoProcessor, Wav2Vec2ForCTC
|
|
from transformers.models.wav2vec2.processing_wav2vec2 import Wav2Vec2Processor
|
|
|
|
MODEL = "xekri/wav2vec2-common_voice_13_0-eo-10"
|
|
DATA = "validation[:10]"
|
|
|
|
chars_to_ignore_regex = "[-!\"'(),.:;=?_`¨«¸»ʼ‑–—‘’“”„…‹›♫?]"
|
|
chars_to_substitute = {
|
|
"przy": "pŝe",
|
|
"byn": "bin",
|
|
"cx": "ĉ",
|
|
"sx": "ŝ",
|
|
"fi": "fi",
|
|
"fl": "fl",
|
|
"ǔ": "ŭ",
|
|
"ñ": "nj",
|
|
"á": "a",
|
|
"é": "e",
|
|
"ü": "ŭ",
|
|
"y": "j",
|
|
"qu": "ku",
|
|
}
|
|
|
|
|
|
def remove_special_characters(text: str) -> str:
|
|
text = re.sub(chars_to_ignore_regex, "", text)
|
|
text = text.lower()
|
|
return text
|
|
|
|
|
|
def substitute_characters(text: str) -> str:
|
|
for k, v in chars_to_substitute.items():
|
|
text.replace(k, v)
|
|
text = text.lower()
|
|
return text
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class EvalResult:
|
|
filename: str
|
|
cer: float
|
|
loss: float
|
|
actual: str
|
|
predicted: str
|
|
|
|
def print(self) -> None:
|
|
print(f"FILE {self.filename}")
|
|
print(f"CERR {self.cer}")
|
|
print(f"LOSS {self.loss}")
|
|
print(f"ACTU {self.actual}")
|
|
print(f"PRED {self.predicted}")
|
|
|
|
|
|
def evaluate(processor: Wav2Vec2Processor, model, example) -> EvalResult:
|
|
"""Evaluates a single example."""
|
|
audio_file = example["path"]
|
|
d, n = os.path.split(audio_file)
|
|
f = os.listdir(d)[0]
|
|
audio_file = os.path.join(d, f, n)
|
|
|
|
inputs = processor(
|
|
audio=example["audio"]["array"], sampling_rate=16000, return_tensors="pt"
|
|
)
|
|
|
|
with torch.no_grad():
|
|
logits = model(**inputs).logits
|
|
predicted_ids = logits.argmax(dim=-1)
|
|
predict = processor.batch_decode(predicted_ids)[0]
|
|
|
|
actual = example["sentence"]
|
|
actual = substitute_characters(remove_special_characters(actual))
|
|
inputs["labels"] = processor(text=actual, return_tensors="pt").input_ids
|
|
loss = model(**inputs).loss
|
|
cer = jiwer.cer(actual, predict)
|
|
|
|
return EvalResult(os.path.basename(audio_file), cer, loss, actual, predict)
|
|
|
|
|
|
def run() -> None:
|
|
cv13 = load_dataset(
|
|
"mozilla-foundation/common_voice_13_0",
|
|
"eo",
|
|
split=DATA,
|
|
)
|
|
cv13 = cv13.cast_column("audio", Audio(sampling_rate=16000))
|
|
|
|
processor: Wav2Vec2Processor = AutoProcessor.from_pretrained(MODEL)
|
|
model = Wav2Vec2ForCTC.from_pretrained(MODEL)
|
|
|
|
print("| Actual<br>Predicted | CER |")
|
|
print("|:--------------------|:----|")
|
|
|
|
for i, example in enumerate(cv13):
|
|
results = evaluate(processor, model, example)
|
|
print(f"| `{results.actual}`<br>`{results.predicted}` | {results.cer} |")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run() |