Fix some bugs
Browse files- src/run_persian.sh +1 -1
- src/run_wav2vec2_pretrain_flax.py +1 -0
src/run_persian.sh
CHANGED
@@ -4,7 +4,7 @@ export LC_ALL=C.UTF-8
|
|
4 |
export LANG=C.UTF-8
|
5 |
|
6 |
export OUTPUT_DIR=/home/m3hrdadfi/code/wav2vec2-base-persian
|
7 |
-
export
|
8 |
export MODEL_NAME_OR_PATH=/home/m3hrdadfi/code/wav2vec2-base-persian
|
9 |
|
10 |
|
|
|
4 |
export LANG=C.UTF-8
|
5 |
|
6 |
export OUTPUT_DIR=/home/m3hrdadfi/code/wav2vec2-base-persian
|
7 |
+
export CACHE_DIR=/home/m3hrdadfi/data_cache/
|
8 |
export MODEL_NAME_OR_PATH=/home/m3hrdadfi/code/wav2vec2-base-persian
|
9 |
|
10 |
|
src/run_wav2vec2_pretrain_flax.py
CHANGED
@@ -49,6 +49,7 @@ from transformers.models.wav2vec2.modeling_flax_wav2vec2 import _compute_mask_in
|
|
49 |
|
50 |
logger = logging.getLogger(__name__)
|
51 |
|
|
|
52 |
|
53 |
@flax.struct.dataclass
|
54 |
class ModelArguments:
|
|
|
49 |
|
50 |
logger = logging.getLogger(__name__)
|
51 |
|
52 |
+
print(f"TPU: {jax.devices())")
|
53 |
|
54 |
@flax.struct.dataclass
|
55 |
class ModelArguments:
|