|
|
|
import sys |
|
import torch |
|
import re |
|
from datasets import load_dataset, load_metric |
|
from transformers import Wav2Vec2Processor, AutoModelForCTC |
|
from transformers.models.wav2vec2.processing_wav2vec2_with_lm import Wav2Vec2ProcessorWithLM |
|
import torchaudio.functional as F |
|
|
|
|
|
do_lm = bool(int(sys.argv[1])) |
|
eval_size = int(sys.argv[2]) |
|
|
|
model_path = "./" |
|
|
|
wer = load_metric("wer", download_mode="force_redownload") |
|
cer = load_metric("cer", download_mode="force_redownload") |
|
|
|
|
|
processor = Wav2Vec2ProcessorWithLM.from_pretrained(model_path) if do_lm else Wav2Vec2Processor.from_pretrained(model_path) |
|
model = AutoModelForCTC.from_pretrained(model_path) |
|
|
|
ds = load_dataset("common_voice", "es", split="test", streaming=True) |
|
ds_iter = iter(ds) |
|
|
|
|
|
references = [] |
|
predictions = [] |
|
|
|
|
|
CHARS_TO_IGNORE = [",", "?", "¿", ".", "!", "¡", ";", ";", ":", '""', "%", '"', "�", "ʿ", "·", "჻", "~", "՞", |
|
"؟", "،", "।", "॥", "«", "»", "„", "“", "”", "「", "」", "‘", "’", "《", "》", "(", ")", "[", "]", |
|
"{", "}", "=", "`", "_", "+", "<", ">", "…", "–", "°", "´", "ʾ", "‹", "›", "©", "®", "—", "→", "。", |
|
"、", "﹂", "﹁", "‧", "~", "﹏", ",", "{", "}", "(", ")", "[", "]", "【", "】", "‥", "〽", |
|
"『", "』", "〝", "〟", "⟨", "⟩", "〜", ":", "!", "?", "♪", "؛", "/", "\\", "º", "−", "^", "ʻ", "ˆ"] |
|
chars_to_ignore_regex = f"[{re.escape(''.join(CHARS_TO_IGNORE))}]" |
|
|
|
|
|
for _ in range(eval_size): |
|
sample = next(ds_iter) |
|
resampled_audio = F.resample(torch.tensor(sample["audio"]["array"]), 48_000, 16_000).numpy() |
|
|
|
input_values = processor(resampled_audio, return_tensors="pt", sampling_rate=16_000).input_values |
|
with torch.no_grad(): |
|
logits = model(input_values).logits.cpu() |
|
|
|
if do_lm: |
|
output_str = processor.batch_decode(logits)[0].lower() |
|
else: |
|
pred_ids = torch.argmax(logits, dim=-1) |
|
output_str = processor.batch_decode(pred_ids)[0].lower() |
|
|
|
ref_str = re.sub(chars_to_ignore_regex, "", sample["sentence"]).lower() |
|
|
|
|
|
ref_str = " ".join(ref_str.split()) |
|
|
|
print(f"Pred: {output_str} | Target: {ref_str}") |
|
print(50 * "=") |
|
|
|
references.append(ref_str) |
|
predictions.append(output_str) |
|
|
|
print(f"WER: {wer.compute(predictions=predictions, references=references) * 100}") |
|
print(f"CER: {cer.compute(predictions=predictions, references=references) * 100}") |
|
|