Spaces:
Runtime error
Runtime error
Update
Browse files- flux/math.py +1 -1
flux/math.py
CHANGED
@@ -11,7 +11,7 @@ def check_tpu():
|
|
11 |
return any('TPU' in d.device_kind for d in jax.devices())
|
12 |
|
13 |
# from torch import Tensor
|
14 |
-
if
|
15 |
from jax.experimental.pallas.ops.tpu.flash_attention import flash_attention
|
16 |
# q, # [batch_size, num_heads, q_seq_len, d_model]
|
17 |
# k, # [batch_size, num_heads, kv_seq_len, d_model]
|
|
|
11 |
return any('TPU' in d.device_kind for d in jax.devices())
|
12 |
|
13 |
# from torch import Tensor
|
14 |
+
if False:
|
15 |
from jax.experimental.pallas.ops.tpu.flash_attention import flash_attention
|
16 |
# q, # [batch_size, num_heads, q_seq_len, d_model]
|
17 |
# k, # [batch_size, num_heads, kv_seq_len, d_model]
|