File size: 2,646 Bytes
4f20072
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
#!/usr/bin/env python3
import sys
import torch
import re
from datasets import load_dataset, load_metric
from transformers import Wav2Vec2Processor, AutoModelForCTC
from transformers.models.wav2vec2.processing_wav2vec2_with_lm import Wav2Vec2ProcessorWithLM
import torchaudio.functional as F

# decide if lm should be used for decoding or not via command line
do_lm = bool(int(sys.argv[1]))
eval_size = int(sys.argv[2])

model_path = "./"

wer = load_metric("wer", download_mode="force_redownload")
cer = load_metric("cer", download_mode="force_redownload")

# load model and processor
processor = Wav2Vec2ProcessorWithLM.from_pretrained(model_path) if do_lm else Wav2Vec2Processor.from_pretrained(model_path)
model = AutoModelForCTC.from_pretrained(model_path)

ds = load_dataset("common_voice", "es", split="test", streaming=True)
ds_iter = iter(ds)


references = []
predictions = []


CHARS_TO_IGNORE = [",", "?", "¿", ".", "!", "¡", ";", ";", ":", '""', "%", '"', "�", "ʿ", "·", "჻", "~", "՞",
                   "؟", "،", "।", "॥", "«", "»", "„", "“", "”", "「", "」", "‘", "’", "《", "》", "(", ")", "[", "]",
                   "{", "}", "=", "`", "_", "+", "<", ">", "…", "–", "°", "´", "ʾ", "‹", "›", "©", "®", "—", "→", "。",
                   "、", "﹂", "﹁", "‧", "~", "﹏", ",", "{", "}", "(", ")", "[", "]", "【", "】", "‥", "〽",
                   "『", "』", "〝", "〟", "⟨", "⟩", "〜", ":", "!", "?", "♪", "؛", "/", "\\", "º", "−", "^", "ʻ", "ˆ"]
chars_to_ignore_regex = f"[{re.escape(''.join(CHARS_TO_IGNORE))}]"


for _ in range(eval_size):
    sample = next(ds_iter)
    resampled_audio = F.resample(torch.tensor(sample["audio"]["array"]), 48_000, 16_000).numpy()

    input_values = processor(resampled_audio, return_tensors="pt", sampling_rate=16_000).input_values
    with torch.no_grad():
        logits = model(input_values).logits.cpu()

    if do_lm:
        output_str = processor.batch_decode(logits)[0].lower()
    else:
        pred_ids = torch.argmax(logits, dim=-1)
        output_str = processor.batch_decode(pred_ids)[0].lower()

    ref_str = re.sub(chars_to_ignore_regex, "", sample["sentence"]).lower()

    # replace long empty strings by a single string
    ref_str = " ".join(ref_str.split())

    print(f"Pred: {output_str} | Target: {ref_str}")
    print(50 * "=")

    references.append(ref_str)
    predictions.append(output_str)

print(f"WER: {wer.compute(predictions=predictions, references=references) * 100}")
print(f"CER: {cer.compute(predictions=predictions, references=references) * 100}")