patrickvonplaten
commited on
Commit
•
4f20072
1
Parent(s):
0b5509c
up
Browse files- run_ngram_wav2vec2.py +65 -0
run_ngram_wav2vec2.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
import sys
|
3 |
+
import torch
|
4 |
+
import re
|
5 |
+
from datasets import load_dataset, load_metric
|
6 |
+
from transformers import Wav2Vec2Processor, AutoModelForCTC
|
7 |
+
from transformers.models.wav2vec2.processing_wav2vec2_with_lm import Wav2Vec2ProcessorWithLM
|
8 |
+
import torchaudio.functional as F
|
9 |
+
|
10 |
+
# decide if lm should be used for decoding or not via command line
|
11 |
+
do_lm = bool(int(sys.argv[1]))
|
12 |
+
eval_size = int(sys.argv[2])
|
13 |
+
|
14 |
+
model_path = "./"
|
15 |
+
|
16 |
+
wer = load_metric("wer", download_mode="force_redownload")
|
17 |
+
cer = load_metric("cer", download_mode="force_redownload")
|
18 |
+
|
19 |
+
# load model and processor
|
20 |
+
processor = Wav2Vec2ProcessorWithLM.from_pretrained(model_path) if do_lm else Wav2Vec2Processor.from_pretrained(model_path)
|
21 |
+
model = AutoModelForCTC.from_pretrained(model_path)
|
22 |
+
|
23 |
+
ds = load_dataset("common_voice", "es", split="test", streaming=True)
|
24 |
+
ds_iter = iter(ds)
|
25 |
+
|
26 |
+
|
27 |
+
references = []
|
28 |
+
predictions = []
|
29 |
+
|
30 |
+
|
31 |
+
CHARS_TO_IGNORE = [",", "?", "¿", ".", "!", "¡", ";", ";", ":", '""', "%", '"', "�", "ʿ", "·", "჻", "~", "՞",
|
32 |
+
"؟", "،", "।", "॥", "«", "»", "„", "“", "”", "「", "」", "‘", "’", "《", "》", "(", ")", "[", "]",
|
33 |
+
"{", "}", "=", "`", "_", "+", "<", ">", "…", "–", "°", "´", "ʾ", "‹", "›", "©", "®", "—", "→", "。",
|
34 |
+
"、", "﹂", "﹁", "‧", "~", "﹏", ",", "{", "}", "(", ")", "[", "]", "【", "】", "‥", "〽",
|
35 |
+
"『", "』", "〝", "〟", "⟨", "⟩", "〜", ":", "!", "?", "♪", "؛", "/", "\\", "º", "−", "^", "ʻ", "ˆ"]
|
36 |
+
chars_to_ignore_regex = f"[{re.escape(''.join(CHARS_TO_IGNORE))}]"
|
37 |
+
|
38 |
+
|
39 |
+
for _ in range(eval_size):
|
40 |
+
sample = next(ds_iter)
|
41 |
+
resampled_audio = F.resample(torch.tensor(sample["audio"]["array"]), 48_000, 16_000).numpy()
|
42 |
+
|
43 |
+
input_values = processor(resampled_audio, return_tensors="pt", sampling_rate=16_000).input_values
|
44 |
+
with torch.no_grad():
|
45 |
+
logits = model(input_values).logits.cpu()
|
46 |
+
|
47 |
+
if do_lm:
|
48 |
+
output_str = processor.batch_decode(logits)[0].lower()
|
49 |
+
else:
|
50 |
+
pred_ids = torch.argmax(logits, dim=-1)
|
51 |
+
output_str = processor.batch_decode(pred_ids)[0].lower()
|
52 |
+
|
53 |
+
ref_str = re.sub(chars_to_ignore_regex, "", sample["sentence"]).lower()
|
54 |
+
|
55 |
+
# replace long empty strings by a single string
|
56 |
+
ref_str = " ".join(ref_str.split())
|
57 |
+
|
58 |
+
print(f"Pred: {output_str} | Target: {ref_str}")
|
59 |
+
print(50 * "=")
|
60 |
+
|
61 |
+
references.append(ref_str)
|
62 |
+
predictions.append(output_str)
|
63 |
+
|
64 |
+
print(f"WER: {wer.compute(predictions=predictions, references=references) * 100}")
|
65 |
+
print(f"CER: {cer.compute(predictions=predictions, references=references) * 100}")
|