lnyan commited on
Commit
cbf7ffd
·
1 Parent(s): da9606c
Files changed (1) hide show
  1. flux/math.py +1 -1
flux/math.py CHANGED
@@ -11,7 +11,7 @@ def check_tpu():
11
  return any('TPU' in d.device_kind for d in jax.devices())
12
 
13
  # from torch import Tensor
14
- if check_tpu():
15
  from jax.experimental.pallas.ops.tpu.flash_attention import flash_attention
16
  # q, # [batch_size, num_heads, q_seq_len, d_model]
17
  # k, # [batch_size, num_heads, kv_seq_len, d_model]
 
11
  return any('TPU' in d.device_kind for d in jax.devices())
12
 
13
  # from torch import Tensor
14
+ if False:
15
  from jax.experimental.pallas.ops.tpu.flash_attention import flash_attention
16
  # q, # [batch_size, num_heads, q_seq_len, d_model]
17
  # k, # [batch_size, num_heads, kv_seq_len, d_model]