Update runner
Browse files
src/run_wav2vec2_pretrain_flax.py
CHANGED
@@ -200,11 +200,23 @@ class FlaxDataCollatorForWav2Vec2Pretraining:
|
|
200 |
)
|
201 |
mask_indices_seq_length = self.model._get_feat_extract_output_lengths(batch["input_values"].shape[-1])
|
202 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
203 |
# sample randomly masked indices
|
204 |
batch["mask_time_indices"] = _compute_mask_indices(
|
205 |
-
(
|
206 |
self.model.config.mask_time_prob,
|
207 |
self.model.config.mask_time_length,
|
|
|
208 |
min_masks=2,
|
209 |
)
|
210 |
|
@@ -216,7 +228,6 @@ class FlaxDataCollatorForWav2Vec2Pretraining:
|
|
216 |
|
217 |
return batch
|
218 |
|
219 |
-
|
220 |
def configure_logger(model_args: ModelArguments, training_args: TrainingArguments):
|
221 |
logging.basicConfig(
|
222 |
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
@@ -348,7 +359,7 @@ def main():
|
|
348 |
do_normalize=True
|
349 |
)
|
350 |
|
351 |
-
target_sampling_rate =
|
352 |
def prepare_dataset(batch):
|
353 |
# check that all files have the correct sampling rate
|
354 |
# batch["speech"], _ = librosa.load(batch[data_args.speech_file_column], sr=feature_extractor.sampling_rate)
|
|
|
200 |
)
|
201 |
mask_indices_seq_length = self.model._get_feat_extract_output_lengths(batch["input_values"].shape[-1])
|
202 |
|
203 |
+
batch_size = batch["input_values"].shape[0]
|
204 |
+
|
205 |
+
if batch["attention_mask"] is not None:
|
206 |
+
output_lengths = self.model._get_feat_extract_output_lengths(batch["attention_mask"].sum(-1))
|
207 |
+
attention_mask = np.zeros((batch_size, mask_indices_seq_length), dtype=np.int8)
|
208 |
+
|
209 |
+
# these two operations makes sure that all values
|
210 |
+
# before the output lengths indices are attended to
|
211 |
+
attention_mask[(np.arange(attention_mask.shape[0]), output_lengths - 1)] = 1
|
212 |
+
attention_mask = jnp.flip(jnp.flip(attention_mask, -1).cumsum(-1), -1).astype("bool")
|
213 |
+
|
214 |
# sample randomly masked indices
|
215 |
batch["mask_time_indices"] = _compute_mask_indices(
|
216 |
+
(batch_size, mask_indices_seq_length),
|
217 |
self.model.config.mask_time_prob,
|
218 |
self.model.config.mask_time_length,
|
219 |
+
attention_mask=attention_mask,
|
220 |
min_masks=2,
|
221 |
)
|
222 |
|
|
|
228 |
|
229 |
return batch
|
230 |
|
|
|
231 |
def configure_logger(model_args: ModelArguments, training_args: TrainingArguments):
|
232 |
logging.basicConfig(
|
233 |
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
|
|
359 |
do_normalize=True
|
360 |
)
|
361 |
|
362 |
+
target_sampling_rate = feature_extractor.sampling_rate
|
363 |
def prepare_dataset(batch):
|
364 |
# check that all files have the correct sampling rate
|
365 |
# batch["speech"], _ = librosa.load(batch[data_args.speech_file_column], sr=feature_extractor.sampling_rate)
|