Update eval.py
Browse files
eval.py
CHANGED
@@ -6,8 +6,8 @@ from typing import Dict
|
|
6 |
import torch
|
7 |
from datasets import Audio, Dataset, load_dataset, load_metric
|
8 |
|
9 |
-
from transformers import AutoFeatureExtractor, pipeline, Wav2Vec2Processor, Wav2Vec2ProcessorWithLM, Wav2Vec2FeatureExtractor
|
10 |
-
from pyctcdecode import BeamSearchDecoderCTC
|
11 |
|
12 |
|
13 |
def log_results(result: Dataset, args: Dict[str, str]):
|
@@ -16,7 +16,7 @@ def log_results(result: Dataset, args: Dict[str, str]):
|
|
16 |
log_outputs = args.log_outputs
|
17 |
lm = "withLM" if args.use_lm else "noLM"
|
18 |
model_id = args.model_id.replace("/", "_").replace(".", "")
|
19 |
-
dataset_id = "_".join(args.dataset.split("/") + [
|
20 |
|
21 |
# load metric
|
22 |
wer = load_metric("wer")
|
@@ -112,11 +112,27 @@ def main(args):
|
|
112 |
args.device = 0 if torch.cuda.is_available() else -1
|
113 |
# asr = pipeline("automatic-speech-recognition", model=args.model_id, device=args.device)
|
114 |
|
115 |
-
|
116 |
-
|
117 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
|
119 |
-
asr = pipeline("automatic-speech-recognition", model=args.model_id, feature_extractor=feature_extractor, device=args.device, decoder=BeamSearchDecoderCTC.load_from_dir("./"))
|
120 |
|
121 |
# map function to decode audio
|
122 |
def map_to_pred(batch):
|
|
|
6 |
import torch
|
7 |
from datasets import Audio, Dataset, load_dataset, load_metric
|
8 |
|
9 |
+
from transformers import AutoFeatureExtractor, AutoModelForCTC, pipeline, Wav2Vec2Processor, Wav2Vec2ProcessorWithLM, Wav2Vec2FeatureExtractor
|
10 |
+
# from pyctcdecode import BeamSearchDecoderCTC
|
11 |
|
12 |
|
13 |
def log_results(result: Dataset, args: Dict[str, str]):
|
|
|
16 |
log_outputs = args.log_outputs
|
17 |
lm = "withLM" if args.use_lm else "noLM"
|
18 |
model_id = args.model_id.replace("/", "_").replace(".", "")
|
19 |
+
dataset_id = "_".join([model_id] + args.dataset.split("/") + [args.config, args.split, lm])
|
20 |
|
21 |
# load metric
|
22 |
wer = load_metric("wer")
|
|
|
112 |
args.device = 0 if torch.cuda.is_available() else -1
|
113 |
# asr = pipeline("automatic-speech-recognition", model=args.model_id, device=args.device)
|
114 |
|
115 |
+
model_instance = AutoModelForCTC.from_pretrained(args.model_id)
|
116 |
+
if args.use_lm:
|
117 |
+
processor = Wav2Vec2ProcessorWithLM.from_pretrained(args.model_id)
|
118 |
+
decoder = processor.decoder
|
119 |
+
else:
|
120 |
+
processor = Wav2Vec2Processor.from_pretrained(args.model_id)
|
121 |
+
decoder = None
|
122 |
+
asr = pipeline(
|
123 |
+
"automatic-speech-recognition",
|
124 |
+
model=model_instance,
|
125 |
+
tokenizer=processor.tokenizer,
|
126 |
+
feature_extractor=processor.feature_extractor,
|
127 |
+
decoder=decoder,
|
128 |
+
device=args.device
|
129 |
+
)
|
130 |
+
|
131 |
+
# feature_extractor_dict, _ = Wav2Vec2FeatureExtractor.get_feature_extractor_dict(args.model_id)
|
132 |
+
# feature_extractor_dict["processor_class"] = "Wav2Vec2Processor" if not args.use_lm else "Wav2Vec2ProcessorWithLM"
|
133 |
+
# feature_extractor = Wav2Vec2FeatureExtractor.from_dict(feature_extractor_dict)
|
134 |
|
135 |
+
# asr = pipeline("automatic-speech-recognition", model=args.model_id, feature_extractor=feature_extractor, device=args.device, decoder=BeamSearchDecoderCTC.load_from_dir("./"))
|
136 |
|
137 |
# map function to decode audio
|
138 |
def map_to_pred(batch):
|