Commit
•
bd73f2a
1
Parent(s):
114c79c
Do not assume 8 devices in JAX (#154)
Browse files- Do not assume 8 devices in JAX (e124bbdca2dab1af0cdce19d575f8043eab9341e)
Co-authored-by: Pedro Cuenca <[email protected]>
README.md
CHANGED
@@ -154,7 +154,7 @@ prompt_ids = pipeline.prepare_inputs(prompt)
|
|
154 |
|
155 |
# shard inputs and rng
|
156 |
params = replicate(params)
|
157 |
-
prng_seed = jax.random.split(prng_seed,
|
158 |
prompt_ids = shard(prompt_ids)
|
159 |
|
160 |
images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
|
@@ -187,7 +187,7 @@ prompt_ids = pipeline.prepare_inputs(prompt)
|
|
187 |
|
188 |
# shard inputs and rng
|
189 |
params = replicate(params)
|
190 |
-
prng_seed = jax.random.split(prng_seed,
|
191 |
prompt_ids = shard(prompt_ids)
|
192 |
|
193 |
images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
|
|
|
154 |
|
155 |
# shard inputs and rng
|
156 |
params = replicate(params)
|
157 |
+
prng_seed = jax.random.split(prng_seed, num_samples)
|
158 |
prompt_ids = shard(prompt_ids)
|
159 |
|
160 |
images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
|
|
|
187 |
|
188 |
# shard inputs and rng
|
189 |
params = replicate(params)
|
190 |
+
prng_seed = jax.random.split(prng_seed, num_samples)
|
191 |
prompt_ids = shard(prompt_ids)
|
192 |
|
193 |
images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
|