Update for the revision
Browse files
events.out.tfevents.1626448850.t1v-n-278acf21-w-0.590260.3.v2
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0f8f2848f118433d3ae3412ed5ed0df7242cdf899879357f922313aeaf0b7b5d
|
3 |
+
size 809333
|
flax_model.msgpack
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:dd33994b480ef0a93c7821a12df82c34656dc30539b623c1fb2050b1ba03be19
|
3 |
-
size 190539834
|
|
|
|
|
|
|
|
src/run_wav2vec2_pretrain_flax.py
CHANGED
@@ -160,7 +160,6 @@ class FlaxDataCollatorForWav2Vec2Pretraining:
|
|
160 |
"""
|
161 |
Data collator that will dynamically pad the inputs received and prepare masked indices
|
162 |
for self-supervised pretraining.
|
163 |
-
|
164 |
Args:
|
165 |
model (:class:`~transformers.FlaxWav2Vec2ForPreTraining`):
|
166 |
The Wav2Vec2 model used for pretraining. The data collator needs to have access
|
@@ -203,6 +202,7 @@ class FlaxDataCollatorForWav2Vec2Pretraining:
|
|
203 |
|
204 |
batch_size = batch["input_values"].shape[0]
|
205 |
|
|
|
206 |
if batch["attention_mask"] is not None:
|
207 |
output_lengths = self.model._get_feat_extract_output_lengths(batch["attention_mask"].sum(-1))
|
208 |
attention_mask = np.zeros((batch_size, mask_indices_seq_length), dtype=np.int8)
|
@@ -225,9 +225,11 @@ class FlaxDataCollatorForWav2Vec2Pretraining:
|
|
225 |
batch["sampled_negative_indices"] = _sample_negative_indices(
|
226 |
(batch["mask_time_indices"].shape + (self.model.config.proj_codevector_dim,)),
|
227 |
self.model.config.num_negatives,
|
|
|
228 |
)
|
229 |
|
230 |
return batch
|
|
|
231 |
|
232 |
def configure_logger(model_args: ModelArguments, training_args: TrainingArguments):
|
233 |
logging.basicConfig(
|
|
|
160 |
"""
|
161 |
Data collator that will dynamically pad the inputs received and prepare masked indices
|
162 |
for self-supervised pretraining.
|
|
|
163 |
Args:
|
164 |
model (:class:`~transformers.FlaxWav2Vec2ForPreTraining`):
|
165 |
The Wav2Vec2 model used for pretraining. The data collator needs to have access
|
|
|
202 |
|
203 |
batch_size = batch["input_values"].shape[0]
|
204 |
|
205 |
+
attention_mask = None
|
206 |
if batch["attention_mask"] is not None:
|
207 |
output_lengths = self.model._get_feat_extract_output_lengths(batch["attention_mask"].sum(-1))
|
208 |
attention_mask = np.zeros((batch_size, mask_indices_seq_length), dtype=np.int8)
|
|
|
225 |
batch["sampled_negative_indices"] = _sample_negative_indices(
|
226 |
(batch["mask_time_indices"].shape + (self.model.config.proj_codevector_dim,)),
|
227 |
self.model.config.num_negatives,
|
228 |
+
attention_mask=attention_mask,
|
229 |
)
|
230 |
|
231 |
return batch
|
232 |
+
|
233 |
|
234 |
def configure_logger(model_args: ModelArguments, training_args: TrainingArguments):
|
235 |
logging.basicConfig(
|