patrickvonplaten's picture
up
eaa092f
raw
history blame
2.67 kB
#!/usr/bin/env python3
import sys
import torch
import re
from datasets import load_dataset, load_metric
from transformers import Wav2Vec2Processor, AutoModelForCTC
from transformers.models.wav2vec2.processing_wav2vec2_with_lm import Wav2Vec2ProcessorWithLM
import torchaudio.functional as F
import torch
# decide if lm should be used for decoding or not via command line
do_lm = bool(int(sys.argv[1]))
eval_size = int(sys.argv[2])
device = "cuda" if torch.cuda.is_available() else "cpu"
model_path = "./"
wer = load_metric("wer")
cer = load_metric("cer")
# load model and processor
processor = Wav2Vec2ProcessorWithLM.from_pretrained(model_path) if do_lm else Wav2Vec2Processor.from_pretrained(model_path)
model = AutoModelForCTC.from_pretrained(model_path).to(device)
ds = load_dataset("common_voice", "es", split="test", streaming=True)
ds_iter = iter(ds)
references = []
predictions = []
CHARS_TO_IGNORE = [",", "?", "¿", ".", "!", "¡", ";", ";", ":", '""', "%", '"', "�", "ʿ", "·", "჻", "~", "՞",
"؟", "،", "।", "॥", "«", "»", "„", "“", "”", "「", "」", "‘", "’", "《", "》", "(", ")", "[", "]",
"{", "}", "=", "`", "_", "+", "<", ">", "…", "–", "°", "´", "ʾ", "‹", "›", "©", "®", "—", "→", "。",
"、", "﹂", "﹁", "‧", "~", "﹏", ",", "{", "}", "(", ")", "[", "]", "【", "】", "‥", "〽",
"『", "』", "〝", "〟", "⟨", "⟩", "〜", ":", "!", "?", "♪", "؛", "/", "\\", "º", "−", "^", "ʻ", "ˆ"]
chars_to_ignore_regex = f"[{re.escape(''.join(CHARS_TO_IGNORE))}]"
for _ in range(eval_size):
sample = next(ds_iter)
resampled_audio = F.resample(torch.tensor(sample["audio"]["array"]), 48_000, 16_000).numpy()
input_values = processor(resampled_audio, return_tensors="pt", sampling_rate=16_000).input_values
with torch.no_grad():
logits = model(input_values.to(device)).logits.cpu()
if do_lm:
output_str = processor.batch_decode(logits)[0].lower()
else:
pred_ids = torch.argmax(logits, dim=-1)
output_str = processor.batch_decode(pred_ids)[0].lower()
ref_str = re.sub(chars_to_ignore_regex, "", sample["sentence"]).lower()
# replace long empty strings by a single string
ref_str = " ".join(ref_str.split())
print(f"Pred: {output_str} | Target: {ref_str}")
print(50 * "=")
references.append(ref_str)
predictions.append(output_str)
print(f"WER: {wer.compute(predictions=predictions, references=references) * 100}")
print(f"CER: {cer.compute(predictions=predictions, references=references) * 100}")