m3hrdadfi commited on
Commit
5696c96
1 Parent(s): da9f194

Fix some bugs

Browse files
Files changed (1) hide show
  1. src/run_wav2vec2_pretrain_flax.py +1 -1
src/run_wav2vec2_pretrain_flax.py CHANGED
@@ -49,7 +49,7 @@ from transformers.models.wav2vec2.modeling_flax_wav2vec2 import _compute_mask_in
49
 
50
  logger = logging.getLogger(__name__)
51
 
52
- print(f"TPU: {jax.devices())")
53
 
54
  @flax.struct.dataclass
55
  class ModelArguments:
 
49
 
50
  logger = logging.getLogger(__name__)
51
 
52
+ print(f"TPU: {jax.devices()}")
53
 
54
  @flax.struct.dataclass
55
  class ModelArguments: