Update eval.py
Browse files
eval.py
CHANGED
@@ -6,7 +6,7 @@ from typing import Dict
|
|
6 |
import torch
|
7 |
from datasets import Audio, Dataset, load_dataset, load_metric
|
8 |
|
9 |
-
from transformers import AutoFeatureExtractor,
|
10 |
|
11 |
|
12 |
def log_results(result: Dataset, args: Dict[str, str]):
|
@@ -81,7 +81,6 @@ def normalize_text(text: str) -> str:
|
|
81 |
def main(args):
|
82 |
# load dataset
|
83 |
dataset = load_dataset(args.dataset, args.config, split=args.split, use_auth_token=True)
|
84 |
-
#dataset = load_dataset(args.dataset, args.config, split=args.split, use_auth_token=True).filter(lambda entry: re.search("nb-nn", entry["sentence_language_code"], flags=re.IGNORECASE))
|
85 |
|
86 |
# for testing: only process the first two examples as a test
|
87 |
# dataset = dataset.select(range(10))
|
@@ -96,12 +95,7 @@ def main(args):
|
|
96 |
# load eval pipeline
|
97 |
if args.device is None:
|
98 |
args.device = 0 if torch.cuda.is_available() else -1
|
99 |
-
asr = pipeline("automatic-speech-recognition",
|
100 |
-
model=AutoModel.from_pretrained(args.model_id),
|
101 |
-
tokenizer=AutoTokenizer.from_pretrained(args.model_id),
|
102 |
-
feature_extractor=feature_extractor,
|
103 |
-
device=args.device
|
104 |
-
)
|
105 |
|
106 |
# map function to decode audio
|
107 |
def map_to_pred(batch):
|
|
|
6 |
import torch
|
7 |
from datasets import Audio, Dataset, load_dataset, load_metric
|
8 |
|
9 |
+
from transformers import AutoFeatureExtractor, pipeline
|
10 |
|
11 |
|
12 |
def log_results(result: Dataset, args: Dict[str, str]):
|
|
|
81 |
def main(args):
|
82 |
# load dataset
|
83 |
dataset = load_dataset(args.dataset, args.config, split=args.split, use_auth_token=True)
|
|
|
84 |
|
85 |
# for testing: only process the first two examples as a test
|
86 |
# dataset = dataset.select(range(10))
|
|
|
95 |
# load eval pipeline
|
96 |
if args.device is None:
|
97 |
args.device = 0 if torch.cuda.is_available() else -1
|
98 |
+
asr = pipeline("automatic-speech-recognition", model=args.model_id, device=args.device)
|
|
|
|
|
|
|
|
|
|
|
99 |
|
100 |
# map function to decode audio
|
101 |
def map_to_pred(batch):
|