Fix some bugs
Browse files
src/run_wav2vec2_pretrain_flax.py
CHANGED
@@ -49,7 +49,7 @@ from transformers.models.wav2vec2.modeling_flax_wav2vec2 import _compute_mask_in
|
|
49 |
|
50 |
logger = logging.getLogger(__name__)
|
51 |
|
52 |
-
print(f"TPU: {jax.devices()
|
53 |
|
54 |
@flax.struct.dataclass
|
55 |
class ModelArguments:
|
|
|
49 |
|
50 |
logger = logging.getLogger(__name__)
|
51 |
|
52 |
+
print(f"TPU: {jax.devices()}")
|
53 |
|
54 |
@flax.struct.dataclass
|
55 |
class ModelArguments:
|