Rolv-Arild commited on
Commit
9a7a0bd
1 Parent(s): 7fcdd24

Add NST+NPSC dataset script

Browse files
Files changed (2) hide show
  1. run.sh +1 -3
  2. run_speech_recognition_ctc.py +100 -67
run.sh CHANGED
@@ -1,8 +1,6 @@
1
  WANDB_ENTITY=NbAiLab WANDB_PROJECT=wav2vec2 python run_speech_recognition_ctc.py \
2
- --dataset_name="NbAiLab/NST" \
3
  --model_name_or_path="KBLab/wav2vec2-large-voxrex" \
4
- --hub_model_id="NbAiLab/wav2vec2-large-voxrex-nst" \
5
- --dataset_config_name="no-close" \
6
  --output_dir="./" \
7
  --overwrite_output_dir \
8
  --num_train_epochs="15" \
 
1
  WANDB_ENTITY=NbAiLab WANDB_PROJECT=wav2vec2 python run_speech_recognition_ctc.py \
 
2
  --model_name_or_path="KBLab/wav2vec2-large-voxrex" \
3
+ --hub_model_id="NbAiLab/wav2vec2-large-voxrex-npsc-nst" \
 
4
  --output_dir="./" \
5
  --overwrite_output_dir \
6
  --num_train_epochs="15" \
run_speech_recognition_ctc.py CHANGED
@@ -47,13 +47,11 @@ from transformers.trainer_utils import get_last_checkpoint, is_main_process
47
  from transformers.utils import check_min_version
48
  from transformers.utils.versions import require_version
49
 
50
-
51
  # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
52
  check_min_version("4.16.0.dev0")
53
 
54
  require_version("datasets>=1.13.3", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
55
 
56
-
57
  logger = logging.getLogger(__name__)
58
 
59
 
@@ -102,8 +100,8 @@ class ModelArguments:
102
  default=0.05,
103
  metadata={
104
  "help": "Probability of each feature vector along the time axis to be chosen as the start of the vector"
105
- "span to be masked. Approximately ``mask_time_prob * sequence_length // mask_time_length`` feature"
106
- "vectors will be masked along the time axis."
107
  },
108
  )
109
  mask_time_length: int = field(
@@ -114,7 +112,7 @@ class ModelArguments:
114
  default=0.0,
115
  metadata={
116
  "help": "Probability of each feature vector along the feature axis to be chosen as the start of the vector"
117
- "span to be masked. Approximately ``mask_feature_prob * sequence_length // mask_feature_length`` feature bins will be masked along the time axis."
118
  },
119
  )
120
  mask_feature_length: int = field(
@@ -129,6 +127,7 @@ class ModelArguments:
129
  default=False, metadata={"help": "If True, will try yo aboud the CTC loss goinf to infinity."}
130
  )
131
 
 
132
  @dataclass
133
  class DataTrainingArguments:
134
  """
@@ -176,14 +175,14 @@ class DataTrainingArguments:
176
  default=None,
177
  metadata={
178
  "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
179
- "value if set."
180
  },
181
  )
182
  max_eval_samples: Optional[int] = field(
183
  default=None,
184
  metadata={
185
  "help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
186
- "value if set."
187
  },
188
  )
189
  chars_to_ignore: Optional[List[str]] = list_field(
@@ -207,16 +206,16 @@ class DataTrainingArguments:
207
  default=False,
208
  metadata={
209
  "help": "Whether to only do data preprocessing and skip training. "
210
- "This is especially useful when data preprocessing errors out in distributed training due to timeout. "
211
- "In this case, one should run the preprocessing in a non-distributed setup with `preprocessing_only=True` "
212
- "so that the cached datasets can consequently be loaded in distributed training"
213
  },
214
  )
215
  use_auth_token: bool = field(
216
  default=False,
217
  metadata={
218
  "help": "If :obj:`True`, will use the token generated when running"
219
- ":obj:`transformers-cli login` as HTTP bearer authorization for remote files."
220
  },
221
  )
222
  unk_token: str = field(
@@ -235,9 +234,9 @@ class DataTrainingArguments:
235
  default=None,
236
  metadata={
237
  "help": "The target language that should be used be"
238
- " passed to the tokenizer for tokenization. Note that"
239
- " this is only relevant if the model classifies the"
240
- " input audio to a sequence of phoneme sequences."
241
  },
242
  )
243
 
@@ -303,10 +302,10 @@ class DataCollatorCTCWithPadding:
303
 
304
 
305
  def create_vocabulary_from_data(
306
- datasets: DatasetDict,
307
- word_delimiter_token: Optional[str] = None,
308
- unk_token: Optional[str] = None,
309
- pad_token: Optional[str] = None,
310
  ):
311
  # Given training and test labels create vocabulary
312
  def extract_all_chars(batch):
@@ -344,6 +343,85 @@ def create_vocabulary_from_data(
344
  return vocab_dict
345
 
346
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
347
  def main():
348
  # See all possible arguments in src/transformers/training_args.py
349
  # or by passing the --help flag to this script.
@@ -393,45 +471,10 @@ def main():
393
  # Set seed before initializing model.
394
  set_seed(training_args.seed)
395
 
396
- # Pre-processing dataset
397
- import re
398
-
399
- def map_dataset(entry):
400
- text = entry["text"].lower()
401
- text = text.replace("(...Vær stille under dette opptaket...)", "")
402
- text = re.sub('[áàâ]', 'a', text)
403
- text = re.sub('[ä]', 'æ', text)
404
- text = re.sub('[éèëê]', 'e', text)
405
- text = re.sub('[íìïî]', 'i', text)
406
- text = re.sub('[óòöô]', 'o', text)
407
- text = re.sub('[ö]', 'ø', text)
408
- text = re.sub('[ç]', 'c', text)
409
- text = re.sub('[úùüû]', 'u', text)
410
- # text = re.sub('\\(?=(Punktum|Komma|Utropstegn|Spørsmålstegn))', ' ', text)
411
- text = re.sub('\s+', ' ', text)
412
- return {"text": text}
413
-
414
-
415
- def filter_dataset(entry):
416
- if not (len(entry["text"]) <= len(entry["audio"]["array"]) // 320) and (len(entry["text"].strip()) >= 3):
417
- return False # Too short
418
- if re.match(entry["type"], "pIW|CA"):
419
- return False # Spelling out words
420
- return True
421
-
422
  # 1. First, let's load the dataset
423
- raw_datasets = DatasetDict()
424
 
425
  if training_args.do_train:
426
- raw_datasets["train"] = load_dataset(
427
- data_args.dataset_name,
428
- data_args.dataset_config_name,
429
- split=data_args.train_split_name,
430
- use_auth_token=data_args.use_auth_token,
431
- ).shuffle()
432
- raw_datasets["train"] = raw_datasets["train"].filter(filter_dataset)
433
- raw_datasets["train"] = raw_datasets["train"].map(map_dataset)
434
-
435
  if data_args.audio_column_name not in raw_datasets["train"].column_names:
436
  raise ValueError(
437
  f"--audio_column_name '{data_args.audio_column_name}' not found in dataset '{data_args.dataset_name}'. "
@@ -450,28 +493,18 @@ def main():
450
  raw_datasets["train"] = raw_datasets["train"].select(range(data_args.max_train_samples))
451
 
452
  if training_args.do_eval:
453
- raw_datasets["eval"] = load_dataset(
454
- data_args.dataset_name,
455
- data_args.dataset_config_name,
456
- split=data_args.eval_split_name,
457
- use_auth_token=data_args.use_auth_token,
458
- ).shuffle()
459
- raw_datasets["eval"] = raw_datasets["eval"].filter(filter_dataset)
460
- raw_datasets["eval"] = raw_datasets["eval"].map(map_dataset)
461
-
462
  if data_args.max_eval_samples is not None:
463
  raw_datasets["eval"] = raw_datasets["eval"].select(range(data_args.max_eval_samples))
464
 
465
-
466
  # 2. We remove some special characters from the datasets
467
  # that make training complicated and do not help in transcribing the speech
468
  # E.g. characters, such as `,` and `.` do not really have an acoustic characteristic
469
  # that could be easily picked up by the model
470
- #chars_to_ignore_regex = (
471
  # f'[{"".join(data_args.chars_to_ignore)}]' if data_args.chars_to_ignore is not None else None
472
- #)
473
  chars_to_ignore_regex = '[\,\?\.\!\-\;\:\"\“\%\‘\”\�\'\–\_\\\+\#\/]'
474
-
475
  text_column_name = data_args.text_column_name
476
 
477
  def remove_special_characters(batch):
 
47
  from transformers.utils import check_min_version
48
  from transformers.utils.versions import require_version
49
 
 
50
  # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
51
  check_min_version("4.16.0.dev0")
52
 
53
  require_version("datasets>=1.13.3", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
54
 
 
55
  logger = logging.getLogger(__name__)
56
 
57
 
 
100
  default=0.05,
101
  metadata={
102
  "help": "Probability of each feature vector along the time axis to be chosen as the start of the vector"
103
+ "span to be masked. Approximately ``mask_time_prob * sequence_length // mask_time_length`` feature"
104
+ "vectors will be masked along the time axis."
105
  },
106
  )
107
  mask_time_length: int = field(
 
112
  default=0.0,
113
  metadata={
114
  "help": "Probability of each feature vector along the feature axis to be chosen as the start of the vector"
115
+ "span to be masked. Approximately ``mask_feature_prob * sequence_length // mask_feature_length`` feature bins will be masked along the time axis."
116
  },
117
  )
118
  mask_feature_length: int = field(
 
127
  default=False, metadata={"help": "If True, will try yo aboud the CTC loss goinf to infinity."}
128
  )
129
 
130
+
131
  @dataclass
132
  class DataTrainingArguments:
133
  """
 
175
  default=None,
176
  metadata={
177
  "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
178
+ "value if set."
179
  },
180
  )
181
  max_eval_samples: Optional[int] = field(
182
  default=None,
183
  metadata={
184
  "help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
185
+ "value if set."
186
  },
187
  )
188
  chars_to_ignore: Optional[List[str]] = list_field(
 
206
  default=False,
207
  metadata={
208
  "help": "Whether to only do data preprocessing and skip training. "
209
+ "This is especially useful when data preprocessing errors out in distributed training due to timeout. "
210
+ "In this case, one should run the preprocessing in a non-distributed setup with `preprocessing_only=True` "
211
+ "so that the cached datasets can consequently be loaded in distributed training"
212
  },
213
  )
214
  use_auth_token: bool = field(
215
  default=False,
216
  metadata={
217
  "help": "If :obj:`True`, will use the token generated when running"
218
+ ":obj:`transformers-cli login` as HTTP bearer authorization for remote files."
219
  },
220
  )
221
  unk_token: str = field(
 
234
  default=None,
235
  metadata={
236
  "help": "The target language that should be used be"
237
+ " passed to the tokenizer for tokenization. Note that"
238
+ " this is only relevant if the model classifies the"
239
+ " input audio to a sequence of phoneme sequences."
240
  },
241
  )
242
 
 
302
 
303
 
304
  def create_vocabulary_from_data(
305
+ datasets: DatasetDict,
306
+ word_delimiter_token: Optional[str] = None,
307
+ unk_token: Optional[str] = None,
308
+ pad_token: Optional[str] = None,
309
  ):
310
  # Given training and test labels create vocabulary
311
  def extract_all_chars(batch):
 
343
  return vocab_dict
344
 
345
 
346
+ def make_dataset(seed=42):
347
+ # Pre-processing dataset
348
+ import re
349
+
350
+ def map_nst(entry):
351
+ text = entry["text"].lower()
352
+ text = text.replace("(...Vær stille under dette opptaket...)", "")
353
+ text = re.sub('[áàâ]', 'a', text)
354
+ text = re.sub('[ä]', 'æ', text)
355
+ text = re.sub('[éèëê]', 'e', text)
356
+ text = re.sub('[íìïî]', 'i', text)
357
+ text = re.sub('[óòöô]', 'o', text)
358
+ text = re.sub('[ö]', 'ø', text)
359
+ text = re.sub('[ç]', 'c', text)
360
+ text = re.sub('[úùüû]', 'u', text)
361
+ # text = re.sub('\\(?=(Punktum|Komma|Utropstegn|Spørsmålstegn))', ' ', text)
362
+ text = re.sub('\s+', ' ', text)
363
+ return {"text": text}
364
+
365
+ def filter_nst(entry):
366
+ if not ((len(entry["text"]) <= len(entry["audio"]["array"]) // 320) and (len(entry["text"].strip()) >= 3)):
367
+ return False # Too short
368
+ if re.match(entry["type"], "pIW|CA"):
369
+ return False # Spelling out words
370
+ return True
371
+
372
+ def filter_npsc(entry):
373
+ # False if there are digits in the text
374
+ if not ((len(entry["text"]) <= len(entry["audio"]["array"]) // 320) and (len(entry["text"].strip()) >= 3)):
375
+ return False # Too short
376
+ if re.search("\d", entry["text"]):
377
+ return False
378
+ return True
379
+
380
+ def map_npsc(entry):
381
+ batch = {"text": entry["text"].lower()}
382
+ batch["text"] = re.sub('[áàâ]', 'a', batch["text"])
383
+ batch["text"] = re.sub('[ä]', 'æ', batch["text"])
384
+ batch["text"] = re.sub('[éèëê]', 'e', batch["text"])
385
+ batch["text"] = re.sub('[íìïî]', 'i', batch["text"])
386
+ batch["text"] = re.sub('[óòöô]', 'o', batch["text"])
387
+ batch["text"] = re.sub('[ö]', 'ø', batch["text"])
388
+ batch["text"] = re.sub('[ç]', 'c', batch["text"])
389
+ batch["text"] = re.sub('[úùüû]', 'u', batch["text"])
390
+ batch["text"] = re.sub('\s', ' ', batch["text"])
391
+ batch["text"] = re.sub('<ee>', 'eee', batch["text"])
392
+ batch["text"] = re.sub('<qq>', 'qqq', batch["text"])
393
+ batch["text"] = re.sub('<mm>', 'mmm', batch["text"])
394
+ batch["text"] = re.sub('<inaudible>', 'xxx', batch["text"])
395
+ # batch["text"] = re.sub('<inaudible>', '?', batch["text"])
396
+ if "<" in batch["text"]:
397
+ raise ValueError(batch["text"])
398
+ return batch
399
+
400
+ nst = datasets.load_dataset("NbAiLab/NST", "no-close")
401
+ npsc = datasets.load_dataset("NbAiLab/NPSC", "16K_mp3")
402
+ # TODO NST_hesitate
403
+
404
+ split = len(npsc["train"]) / (len(npsc["train"]) + len(npsc["validation"])) # Use same train/val ratio as NPSC
405
+ nst_train = nst["train"].train_test_split(train_size=split, seed=seed)
406
+ nst["train"] = nst_train["train"]
407
+ nst["validation"] = nst_train["test"]
408
+
409
+ nst = nst.filter(filter_nst).map(map_nst).shuffle(seed=seed)
410
+ npsc = npsc.filter(filter_npsc).map(map_npsc).shuffle(seed=seed)
411
+
412
+ npsc_base = npsc.remove_columns([col for col in npsc["train"].column_names if col not in ["text", "audio"]])
413
+ nst_base = nst.remove_columns([col for col in nst["train"].column_names if col not in ["text", "audio"]])
414
+
415
+ combined = {}
416
+ for split in "train", "validation", "test":
417
+ probs = np.array([len(nst_base[split]), len(npsc_base[split])]) # Weight by number of examples
418
+ probs = (probs / probs.sum()).tolist()
419
+ comb = datasets.interleave_datasets([nst_base[split], npsc_base[split]], probabilities=probs, seed=seed)
420
+ combined[split] = comb
421
+
422
+ return datasets.DatasetDict(**combined)
423
+
424
+
425
  def main():
426
  # See all possible arguments in src/transformers/training_args.py
427
  # or by passing the --help flag to this script.
 
471
  # Set seed before initializing model.
472
  set_seed(training_args.seed)
473
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
474
  # 1. First, let's load the dataset
475
+ raw_datasets = make_dataset(seed=training_args.seed)
476
 
477
  if training_args.do_train:
 
 
 
 
 
 
 
 
 
478
  if data_args.audio_column_name not in raw_datasets["train"].column_names:
479
  raise ValueError(
480
  f"--audio_column_name '{data_args.audio_column_name}' not found in dataset '{data_args.dataset_name}'. "
 
493
  raw_datasets["train"] = raw_datasets["train"].select(range(data_args.max_train_samples))
494
 
495
  if training_args.do_eval:
 
 
 
 
 
 
 
 
 
496
  if data_args.max_eval_samples is not None:
497
  raw_datasets["eval"] = raw_datasets["eval"].select(range(data_args.max_eval_samples))
498
 
 
499
  # 2. We remove some special characters from the datasets
500
  # that make training complicated and do not help in transcribing the speech
501
  # E.g. characters, such as `,` and `.` do not really have an acoustic characteristic
502
  # that could be easily picked up by the model
503
+ # chars_to_ignore_regex = (
504
  # f'[{"".join(data_args.chars_to_ignore)}]' if data_args.chars_to_ignore is not None else None
505
+ # )
506
  chars_to_ignore_regex = '[\,\?\.\!\-\;\:\"\“\%\‘\”\�\'\–\_\\\+\#\/]'
507
+
508
  text_column_name = data_args.text_column_name
509
 
510
  def remove_special_characters(batch):