add model
Browse files
events.out.tfevents.1626217020.t1v-n-278acf21-w-0.60949.3.v2 → events.out.tfevents.1626420112.t1v-n-278acf21-w-0.561381.3.v2
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d8ee56f09af3471c76f8991f10060e5a16d7121ab1cf0cba9d7959393bb5c223
|
3 |
+
size 220634
|
events.out.tfevents.1626448850.t1v-n-278acf21-w-0.590260.3.v2
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1fbe385b41508eae766e3ae9763a6bf8a20b0dad2a36c5058b526b6884a8433a
|
3 |
+
size 662195
|
flax_model.msgpack
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:dd33994b480ef0a93c7821a12df82c34656dc30539b623c1fb2050b1ba03be19
|
3 |
+
size 190539834
|
src/run_persian.sh
CHANGED
@@ -19,7 +19,7 @@ export PER_DEVICE_EVAL_BATCH_SIZE=8
|
|
19 |
#export GRADIENT_ACCUMULATION_STEPS=2
|
20 |
export NUM_TRAIN_EPOCHS=5.0
|
21 |
export LEARNING_RATE=5e-4
|
22 |
-
export WARMUP_STEPS=
|
23 |
export LOGGING_STEPS=500
|
24 |
#export EVAL_STEPS=2500
|
25 |
#export SAVE_STEPS=2500
|
|
|
19 |
#export GRADIENT_ACCUMULATION_STEPS=2
|
20 |
export NUM_TRAIN_EPOCHS=5.0
|
21 |
export LEARNING_RATE=5e-4
|
22 |
+
export WARMUP_STEPS=2000
|
23 |
export LOGGING_STEPS=500
|
24 |
#export EVAL_STEPS=2500
|
25 |
#export SAVE_STEPS=2500
|
src/run_wav2vec2_pretrain_flax.py
CHANGED
@@ -26,6 +26,7 @@ from typing import Dict, List, Optional, Union
|
|
26 |
|
27 |
import numpy as np
|
28 |
from datasets import DatasetDict, load_dataset
|
|
|
29 |
from tqdm import tqdm
|
30 |
|
31 |
import flax
|
@@ -370,29 +371,33 @@ def main():
|
|
370 |
return batch
|
371 |
|
372 |
# load audio files into numpy arrays
|
373 |
-
vectorized_datasets = datasets.map(
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
)
|
378 |
|
379 |
# filter audio files that are too long
|
380 |
-
vectorized_datasets = vectorized_datasets.filter(
|
381 |
-
|
382 |
-
)
|
383 |
|
384 |
-
def normalize(batch):
|
385 |
-
|
386 |
|
387 |
# normalize and transform to `BatchFeatures`
|
388 |
-
vectorized_datasets = vectorized_datasets.map(
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
)
|
395 |
-
vectorized_datasets.save_to_disk(model_args.cache_dir)
|
|
|
|
|
|
|
|
|
396 |
|
397 |
# pretraining is only supported for "newer" stable layer norm architecture
|
398 |
# apply_spec_augment has to be True, mask_feature_prob has to be 0.0
|
|
|
26 |
|
27 |
import numpy as np
|
28 |
from datasets import DatasetDict, load_dataset
|
29 |
+
from datasets import load_from_disk
|
30 |
from tqdm import tqdm
|
31 |
|
32 |
import flax
|
|
|
371 |
return batch
|
372 |
|
373 |
# load audio files into numpy arrays
|
374 |
+
# vectorized_datasets = datasets.map(
|
375 |
+
# prepare_dataset,
|
376 |
+
# num_proc=data_args.preprocessing_num_workers,
|
377 |
+
# remove_columns=datasets["train"].column_names
|
378 |
+
# )
|
379 |
|
380 |
# filter audio files that are too long
|
381 |
+
# vectorized_datasets = vectorized_datasets.filter(
|
382 |
+
# lambda data: len(data["speech"]) < int(data_args.max_duration_in_seconds * feature_extractor.sampling_rate)
|
383 |
+
# )
|
384 |
|
385 |
+
# def normalize(batch):
|
386 |
+
# return feature_extractor(batch["speech"], sampling_rate=feature_extractor.sampling_rate)
|
387 |
|
388 |
# normalize and transform to `BatchFeatures`
|
389 |
+
# vectorized_datasets = vectorized_datasets.map(
|
390 |
+
# normalize,
|
391 |
+
# batched=True,
|
392 |
+
# num_proc=data_args.preprocessing_num_workers,
|
393 |
+
# load_from_cache_file=not data_args.overwrite_cache,
|
394 |
+
# remove_columns=vectorized_datasets["train"].column_names,
|
395 |
+
# )
|
396 |
+
# vectorized_datasets.save_to_disk(model_args.cache_dir)
|
397 |
+
|
398 |
+
logger.info(f"Loading from {model_args.cache_dir}")
|
399 |
+
vectorized_datasets = load_from_disk(model_args.cache_dir)
|
400 |
+
logger.info(f"vectorized_datasets: {vectorized_datasets}")
|
401 |
|
402 |
# pretraining is only supported for "newer" stable layer norm architecture
|
403 |
# apply_spec_augment has to be True, mask_feature_prob has to be 0.0
|