valhalla commited on
Commit
f15bc76
1 Parent(s): 52b46db

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +71 -0
README.md CHANGED
@@ -53,6 +53,8 @@ This weights here are intended to be used with the 🧨 Diffusers library. If yo
53
 
54
  We recommend using [🤗's Diffusers library](https://github.com/huggingface/diffusers) to run Stable Diffusion.
55
 
 
 
56
  ```bash
57
  pip install --upgrade diffusers transformers scipy
58
  ```
@@ -119,6 +121,75 @@ with autocast("cuda"):
119
  image.save("astronaut_rides_horse.png")
120
  ```
121
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  # Uses
123
 
124
  ## Direct Use
 
53
 
54
  We recommend using [🤗's Diffusers library](https://github.com/huggingface/diffusers) to run Stable Diffusion.
55
 
56
+ ### PyTorch
57
+
58
  ```bash
59
  pip install --upgrade diffusers transformers scipy
60
  ```
 
121
  image.save("astronaut_rides_horse.png")
122
  ```
123
 
124
+ ### JAX/Flax
125
+
126
+ To use StableDiffusion on TPUs and GPUs for faster inference you can leverage JAX/Flax.
127
+
128
+ Running the pipeline with default PNDMScheduler
129
+
130
+ ```python
131
+ import jax
132
+ import numpy as np
133
+ from flax.jax_utils import replicate
134
+ from flax.training.common_utils import shard
135
+
136
+ from diffusers import FlaxStableDiffusionPipeline
137
+
138
+ pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
139
+ "CompVis/stable-diffusion-v1-4", revision="flax", dtype=jax.numpy.bfloat16
140
+ )
141
+
142
+ prompt = "a photo of an astronaut riding a horse on mars"
143
+
144
+ prng_seed = jax.random.PRNGKey(0)
145
+ num_inference_steps = 50
146
+
147
+ num_samples = jax.device_count()
148
+ prompt = num_samples * [prompt]
149
+ prompt_ids = pipeline.prepare_inputs(prompt)
150
+
151
+ # shard inputs and rng
152
+ params = replicate(params)
153
+ prng_seed = jax.random.split(prng_seed, 8)
154
+ prompt_ids = shard(prompt_ids)
155
+
156
+ images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
157
+ images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
158
+ ```
159
+
160
+ **Note**:
161
+ If you are limited by TPU memory, please make sure to load the `FlaxStableDiffusionPipeline` in `bfloat16` precision instead of the default `float32` precision as done above. You can do so by telling diffusers to load the weights from "bf16" branch.
162
+
163
+ ```python
164
+ import jax
165
+ import numpy as np
166
+ from flax.jax_utils import replicate
167
+ from flax.training.common_utils import shard
168
+
169
+ from diffusers import FlaxStableDiffusionPipeline
170
+
171
+ pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
172
+ "CompVis/stable-diffusion-v1-4", revision="bf16", dtype=jax.numpy.bfloat16
173
+ )
174
+
175
+ prompt = "a photo of an astronaut riding a horse on mars"
176
+
177
+ prng_seed = jax.random.PRNGKey(0)
178
+ num_inference_steps = 50
179
+
180
+ num_samples = jax.device_count()
181
+ prompt = num_samples * [prompt]
182
+ prompt_ids = pipeline.prepare_inputs(prompt)
183
+
184
+ # shard inputs and rng
185
+ params = replicate(params)
186
+ prng_seed = jax.random.split(prng_seed, 8)
187
+ prompt_ids = shard(prompt_ids)
188
+
189
+ images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
190
+ images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
191
+ ```
192
+
193
  # Uses
194
 
195
  ## Direct Use