versae commited on
Commit
8290317
1 Parent(s): 1bc73ac

Update eval.py

Browse files
Files changed (1) hide show
  1. eval.py +23 -7
eval.py CHANGED
@@ -6,8 +6,8 @@ from typing import Dict
6
  import torch
7
  from datasets import Audio, Dataset, load_dataset, load_metric
8
 
9
- from transformers import AutoFeatureExtractor, pipeline, Wav2Vec2Processor, Wav2Vec2ProcessorWithLM, Wav2Vec2FeatureExtractor
10
- from pyctcdecode import BeamSearchDecoderCTC
11
 
12
 
13
  def log_results(result: Dataset, args: Dict[str, str]):
@@ -16,7 +16,7 @@ def log_results(result: Dataset, args: Dict[str, str]):
16
  log_outputs = args.log_outputs
17
  lm = "withLM" if args.use_lm else "noLM"
18
  model_id = args.model_id.replace("/", "_").replace(".", "")
19
- dataset_id = "_".join(args.dataset.split("/") + [model_id, args.config, args.split, lm])
20
 
21
  # load metric
22
  wer = load_metric("wer")
@@ -112,11 +112,27 @@ def main(args):
112
  args.device = 0 if torch.cuda.is_available() else -1
113
  # asr = pipeline("automatic-speech-recognition", model=args.model_id, device=args.device)
114
 
115
- feature_extractor_dict, _ = Wav2Vec2FeatureExtractor.get_feature_extractor_dict(args.model_id)
116
- feature_extractor_dict["processor_class"] = "Wav2Vec2Processor" if not args.use_lm else "Wav2Vec2ProcessorWithLM"
117
- feature_extractor = Wav2Vec2FeatureExtractor.from_dict(feature_extractor_dict)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
- asr = pipeline("automatic-speech-recognition", model=args.model_id, feature_extractor=feature_extractor, device=args.device, decoder=BeamSearchDecoderCTC.load_from_dir("./"))
120
 
121
  # map function to decode audio
122
  def map_to_pred(batch):
 
6
  import torch
7
  from datasets import Audio, Dataset, load_dataset, load_metric
8
 
9
+ from transformers import AutoFeatureExtractor, AutoModelForCTC, pipeline, Wav2Vec2Processor, Wav2Vec2ProcessorWithLM, Wav2Vec2FeatureExtractor
10
+ # from pyctcdecode import BeamSearchDecoderCTC
11
 
12
 
13
  def log_results(result: Dataset, args: Dict[str, str]):
 
16
  log_outputs = args.log_outputs
17
  lm = "withLM" if args.use_lm else "noLM"
18
  model_id = args.model_id.replace("/", "_").replace(".", "")
19
+ dataset_id = "_".join([model_id] + args.dataset.split("/") + [args.config, args.split, lm])
20
 
21
  # load metric
22
  wer = load_metric("wer")
 
112
  args.device = 0 if torch.cuda.is_available() else -1
113
  # asr = pipeline("automatic-speech-recognition", model=args.model_id, device=args.device)
114
 
115
+ model_instance = AutoModelForCTC.from_pretrained(args.model_id)
116
+ if args.use_lm:
117
+ processor = Wav2Vec2ProcessorWithLM.from_pretrained(args.model_id)
118
+ decoder = processor.decoder
119
+ else:
120
+ processor = Wav2Vec2Processor.from_pretrained(args.model_id)
121
+ decoder = None
122
+ asr = pipeline(
123
+ "automatic-speech-recognition",
124
+ model=model_instance,
125
+ tokenizer=processor.tokenizer,
126
+ feature_extractor=processor.feature_extractor,
127
+ decoder=decoder,
128
+ device=args.device
129
+ )
130
+
131
+ # feature_extractor_dict, _ = Wav2Vec2FeatureExtractor.get_feature_extractor_dict(args.model_id)
132
+ # feature_extractor_dict["processor_class"] = "Wav2Vec2Processor" if not args.use_lm else "Wav2Vec2ProcessorWithLM"
133
+ # feature_extractor = Wav2Vec2FeatureExtractor.from_dict(feature_extractor_dict)
134
 
135
+ # asr = pipeline("automatic-speech-recognition", model=args.model_id, feature_extractor=feature_extractor, device=args.device, decoder=BeamSearchDecoderCTC.load_from_dir("./"))
136
 
137
  # map function to decode audio
138
  def map_to_pred(batch):