patrickvonplaten commited on
Commit
4f20072
1 Parent(s): 0b5509c
Files changed (1) hide show
  1. run_ngram_wav2vec2.py +65 -0
run_ngram_wav2vec2.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import sys
3
+ import torch
4
+ import re
5
+ from datasets import load_dataset, load_metric
6
+ from transformers import Wav2Vec2Processor, AutoModelForCTC
7
+ from transformers.models.wav2vec2.processing_wav2vec2_with_lm import Wav2Vec2ProcessorWithLM
8
+ import torchaudio.functional as F
9
+
10
+ # decide if lm should be used for decoding or not via command line
11
+ do_lm = bool(int(sys.argv[1]))
12
+ eval_size = int(sys.argv[2])
13
+
14
+ model_path = "./"
15
+
16
+ wer = load_metric("wer", download_mode="force_redownload")
17
+ cer = load_metric("cer", download_mode="force_redownload")
18
+
19
+ # load model and processor
20
+ processor = Wav2Vec2ProcessorWithLM.from_pretrained(model_path) if do_lm else Wav2Vec2Processor.from_pretrained(model_path)
21
+ model = AutoModelForCTC.from_pretrained(model_path)
22
+
23
+ ds = load_dataset("common_voice", "es", split="test", streaming=True)
24
+ ds_iter = iter(ds)
25
+
26
+
27
+ references = []
28
+ predictions = []
29
+
30
+
31
+ CHARS_TO_IGNORE = [",", "?", "¿", ".", "!", "¡", ";", ";", ":", '""', "%", '"', "�", "ʿ", "·", "჻", "~", "՞",
32
+ "؟", "،", "।", "॥", "«", "»", "„", "“", "”", "「", "」", "‘", "’", "《", "》", "(", ")", "[", "]",
33
+ "{", "}", "=", "`", "_", "+", "<", ">", "…", "–", "°", "´", "ʾ", "‹", "›", "©", "®", "—", "→", "。",
34
+ "、", "﹂", "﹁", "‧", "~", "﹏", ",", "{", "}", "(", ")", "[", "]", "【", "】", "‥", "〽",
35
+ "『", "』", "〝", "〟", "⟨", "⟩", "〜", ":", "!", "?", "♪", "؛", "/", "\\", "º", "−", "^", "ʻ", "ˆ"]
36
+ chars_to_ignore_regex = f"[{re.escape(''.join(CHARS_TO_IGNORE))}]"
37
+
38
+
39
+ for _ in range(eval_size):
40
+ sample = next(ds_iter)
41
+ resampled_audio = F.resample(torch.tensor(sample["audio"]["array"]), 48_000, 16_000).numpy()
42
+
43
+ input_values = processor(resampled_audio, return_tensors="pt", sampling_rate=16_000).input_values
44
+ with torch.no_grad():
45
+ logits = model(input_values).logits.cpu()
46
+
47
+ if do_lm:
48
+ output_str = processor.batch_decode(logits)[0].lower()
49
+ else:
50
+ pred_ids = torch.argmax(logits, dim=-1)
51
+ output_str = processor.batch_decode(pred_ids)[0].lower()
52
+
53
+ ref_str = re.sub(chars_to_ignore_regex, "", sample["sentence"]).lower()
54
+
55
+ # replace long empty strings by a single string
56
+ ref_str = " ".join(ref_str.split())
57
+
58
+ print(f"Pred: {output_str} | Target: {ref_str}")
59
+ print(50 * "=")
60
+
61
+ references.append(ref_str)
62
+ predictions.append(output_str)
63
+
64
+ print(f"WER: {wer.compute(predictions=predictions, references=references) * 100}")
65
+ print(f"CER: {cer.compute(predictions=predictions, references=references) * 100}")