Update README.md
Browse files
README.md
CHANGED
@@ -53,6 +53,8 @@ This weights here are intended to be used with the 🧨 Diffusers library. If yo
|
|
53 |
|
54 |
We recommend using [🤗's Diffusers library](https://github.com/huggingface/diffusers) to run Stable Diffusion.
|
55 |
|
|
|
|
|
56 |
```bash
|
57 |
pip install --upgrade diffusers transformers scipy
|
58 |
```
|
@@ -119,6 +121,75 @@ with autocast("cuda"):
|
|
119 |
image.save("astronaut_rides_horse.png")
|
120 |
```
|
121 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
# Uses
|
123 |
|
124 |
## Direct Use
|
|
|
53 |
|
54 |
We recommend using [🤗's Diffusers library](https://github.com/huggingface/diffusers) to run Stable Diffusion.
|
55 |
|
56 |
+
### PyTorch
|
57 |
+
|
58 |
```bash
|
59 |
pip install --upgrade diffusers transformers scipy
|
60 |
```
|
|
|
121 |
image.save("astronaut_rides_horse.png")
|
122 |
```
|
123 |
|
124 |
+
### JAX/Flax
|
125 |
+
|
126 |
+
To use StableDiffusion on TPUs and GPUs for faster inference you can leverage JAX/Flax.
|
127 |
+
|
128 |
+
Running the pipeline with default PNDMScheduler
|
129 |
+
|
130 |
+
```python
|
131 |
+
import jax
|
132 |
+
import numpy as np
|
133 |
+
from flax.jax_utils import replicate
|
134 |
+
from flax.training.common_utils import shard
|
135 |
+
|
136 |
+
from diffusers import FlaxStableDiffusionPipeline
|
137 |
+
|
138 |
+
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
|
139 |
+
"CompVis/stable-diffusion-v1-4", revision="flax", dtype=jax.numpy.bfloat16
|
140 |
+
)
|
141 |
+
|
142 |
+
prompt = "a photo of an astronaut riding a horse on mars"
|
143 |
+
|
144 |
+
prng_seed = jax.random.PRNGKey(0)
|
145 |
+
num_inference_steps = 50
|
146 |
+
|
147 |
+
num_samples = jax.device_count()
|
148 |
+
prompt = num_samples * [prompt]
|
149 |
+
prompt_ids = pipeline.prepare_inputs(prompt)
|
150 |
+
|
151 |
+
# shard inputs and rng
|
152 |
+
params = replicate(params)
|
153 |
+
prng_seed = jax.random.split(prng_seed, 8)
|
154 |
+
prompt_ids = shard(prompt_ids)
|
155 |
+
|
156 |
+
images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
|
157 |
+
images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
|
158 |
+
```
|
159 |
+
|
160 |
+
**Note**:
|
161 |
+
If you are limited by TPU memory, please make sure to load the `FlaxStableDiffusionPipeline` in `bfloat16` precision instead of the default `float32` precision as done above. You can do so by telling diffusers to load the weights from "bf16" branch.
|
162 |
+
|
163 |
+
```python
|
164 |
+
import jax
|
165 |
+
import numpy as np
|
166 |
+
from flax.jax_utils import replicate
|
167 |
+
from flax.training.common_utils import shard
|
168 |
+
|
169 |
+
from diffusers import FlaxStableDiffusionPipeline
|
170 |
+
|
171 |
+
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
|
172 |
+
"CompVis/stable-diffusion-v1-4", revision="bf16", dtype=jax.numpy.bfloat16
|
173 |
+
)
|
174 |
+
|
175 |
+
prompt = "a photo of an astronaut riding a horse on mars"
|
176 |
+
|
177 |
+
prng_seed = jax.random.PRNGKey(0)
|
178 |
+
num_inference_steps = 50
|
179 |
+
|
180 |
+
num_samples = jax.device_count()
|
181 |
+
prompt = num_samples * [prompt]
|
182 |
+
prompt_ids = pipeline.prepare_inputs(prompt)
|
183 |
+
|
184 |
+
# shard inputs and rng
|
185 |
+
params = replicate(params)
|
186 |
+
prng_seed = jax.random.split(prng_seed, 8)
|
187 |
+
prompt_ids = shard(prompt_ids)
|
188 |
+
|
189 |
+
images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
|
190 |
+
images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
|
191 |
+
```
|
192 |
+
|
193 |
# Uses
|
194 |
|
195 |
## Direct Use
|