Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -2,45 +2,195 @@ import gradio as gr
|
|
2 |
|
3 |
import torch
|
4 |
from torch import autocast
|
5 |
-
from diffusers import StableDiffusionPipeline
|
6 |
-
from datasets import load_dataset
|
7 |
-
from PIL import Image
|
8 |
-
import re
|
9 |
|
10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
device = "cuda"
|
12 |
|
13 |
-
#
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
css = """
|
46 |
.gradio-container {
|
@@ -97,8 +247,7 @@ css = """
|
|
97 |
padding: 2px 8px;
|
98 |
border-radius: 14px !important;
|
99 |
}
|
100 |
-
#advanced-options {
|
101 |
-
display: none;
|
102 |
margin-bottom: 20px;
|
103 |
}
|
104 |
.footer {
|
@@ -213,19 +362,11 @@ with block:
|
|
213 |
<rect x="23" y="69" width="23" height="23" fill="black"></rect>
|
214 |
</svg>
|
215 |
<h1 style="font-weight: 900; margin-bottom: 7px;">
|
216 |
-
Stable Diffusion
|
217 |
</h1>
|
218 |
</div>
|
219 |
<p style="margin-bottom: 10px; font-size: 94%">
|
220 |
-
|
221 |
-
images from text.<br>For faster generation and forthcoming API
|
222 |
-
access you can try
|
223 |
-
<a
|
224 |
-
href="http://beta.dreamstudio.ai/"
|
225 |
-
style="text-decoration: underline;"
|
226 |
-
target="_blank"
|
227 |
-
>DreamStudio Beta</a
|
228 |
-
>
|
229 |
</p>
|
230 |
</div>
|
231 |
"""
|
@@ -252,13 +393,18 @@ with block:
|
|
252 |
label="Generated images", show_label=False, elem_id="gallery"
|
253 |
).style(grid=[2], height="auto")
|
254 |
|
255 |
-
advanced_button = gr.Button("Advanced options", elem_id="advanced-btn")
|
256 |
|
|
|
|
|
257 |
with gr.Row(elem_id="advanced-options"):
|
258 |
-
samples = gr.Slider(label="Images", minimum=1, maximum=4, value=4, step=1)
|
259 |
-
steps = gr.Slider(label="Steps", minimum=1, maximum=50, value=45, step=1)
|
260 |
scale = gr.Slider(
|
261 |
-
label="
|
|
|
|
|
|
|
262 |
)
|
263 |
seed = gr.Slider(
|
264 |
label="Seed",
|
@@ -268,22 +414,13 @@ with block:
|
|
268 |
randomize=True,
|
269 |
)
|
270 |
|
271 |
-
ex = gr.Examples(examples=examples, fn=infer, inputs=[text,
|
272 |
ex.dataset.headers = [""]
|
273 |
|
274 |
|
275 |
-
text.submit(infer, inputs=[text,
|
276 |
-
btn.click(infer, inputs=[text,
|
277 |
-
|
278 |
-
None,
|
279 |
-
[],
|
280 |
-
text,
|
281 |
-
_js="""
|
282 |
-
() => {
|
283 |
-
const options = document.querySelector("body > gradio-app").querySelector("#advanced-options");
|
284 |
-
options.style.display = ["none", ""].includes(options.style.display) ? "flex" : "none";
|
285 |
-
}""",
|
286 |
-
)
|
287 |
gr.HTML(
|
288 |
"""
|
289 |
<div class="footer">
|
|
|
2 |
|
3 |
import torch
|
4 |
from torch import autocast
|
|
|
|
|
|
|
|
|
5 |
|
6 |
+
import gc
|
7 |
+
import io
|
8 |
+
import math
|
9 |
+
import sys
|
10 |
+
|
11 |
+
from PIL import Image, ImageOps
|
12 |
+
import requests
|
13 |
+
from torch import nn
|
14 |
+
from torch.nn import functional as F
|
15 |
+
from torchvision import transforms
|
16 |
+
from torchvision.transforms import functional as TF
|
17 |
+
from tqdm.notebook import tqdm
|
18 |
+
|
19 |
+
import numpy as np
|
20 |
+
|
21 |
+
from guided_diffusion.script_util import create_model_and_diffusion, model_and_diffusion_defaults, classifier_defaults, create_classifier
|
22 |
+
|
23 |
+
from omegaconf import OmegaConf
|
24 |
+
from ldm.util import instantiate_from_config
|
25 |
+
|
26 |
+
from einops import rearrange
|
27 |
+
from math import log2, sqrt
|
28 |
+
|
29 |
+
import argparse
|
30 |
+
import pickle
|
31 |
+
|
32 |
+
import os
|
33 |
+
|
34 |
+
from transformers import CLIPTokenizer, CLIPTextModel
|
35 |
+
|
36 |
+
def fetch(url_or_path):
|
37 |
+
if str(url_or_path).startswith('http://') or str(url_or_path).startswith('https://'):
|
38 |
+
r = requests.get(url_or_path)
|
39 |
+
r.raise_for_status()
|
40 |
+
fd = io.BytesIO()
|
41 |
+
fd.write(r.content)
|
42 |
+
fd.seek(0)
|
43 |
+
return fd
|
44 |
+
return open(url_or_path, 'rb')
|
45 |
+
|
46 |
device = "cuda"
|
47 |
|
48 |
+
#model_state_dict = torch.load('diffusion.pt', map_location='cpu')
|
49 |
+
model_state_dict = torch.load(fetch('https://huggingface.co/Jack000/glid-3-xl-stable/resolve/main/default/diffusion-1.4.pt'), map_location='cpu')
|
50 |
+
|
51 |
+
model_params = {
|
52 |
+
'attention_resolutions': '32,16,8',
|
53 |
+
'class_cond': False,
|
54 |
+
'diffusion_steps': 1000,
|
55 |
+
'rescale_timesteps': True,
|
56 |
+
'timestep_respacing': 'ddim100',
|
57 |
+
'image_size': 32,
|
58 |
+
'learn_sigma': False,
|
59 |
+
'noise_schedule': 'linear',
|
60 |
+
'num_channels': 320,
|
61 |
+
'num_heads': 8,
|
62 |
+
'num_res_blocks': 2,
|
63 |
+
'resblock_updown': False,
|
64 |
+
'use_fp16': True,
|
65 |
+
'use_scale_shift_norm': False,
|
66 |
+
'clip_embed_dim': None,
|
67 |
+
'image_condition': False,
|
68 |
+
'super_res_condition': False,
|
69 |
+
}
|
70 |
+
|
71 |
+
model_config = model_and_diffusion_defaults()
|
72 |
+
model_config.update(model_params)
|
73 |
+
|
74 |
+
# Load models
|
75 |
+
model, diffusion = create_model_and_diffusion(**model_config)
|
76 |
+
model.load_state_dict(model_state_dict, strict=True)
|
77 |
+
model.requires_grad_(False).eval().to(device)
|
78 |
+
|
79 |
+
if model_config['use_fp16']:
|
80 |
+
model.convert_to_fp16()
|
81 |
+
else:
|
82 |
+
model.convert_to_fp32()
|
83 |
+
|
84 |
+
def set_requires_grad(model, value):
|
85 |
+
for param in model.parameters():
|
86 |
+
param.requires_grad = value
|
87 |
+
|
88 |
+
# vae
|
89 |
+
kl_config = OmegaConf.load('kl.yaml')
|
90 |
+
kl_sd = torch.load(fetch('https://huggingface.co/Jack000/glid-3-xl-stable/resolve/main/default/kl-1.4.pt'), map_location="cpu")
|
91 |
+
|
92 |
+
ldm = instantiate_from_config(kl_config.model)
|
93 |
+
ldm.load_state_dict(kl_sd, strict=True)
|
94 |
+
|
95 |
+
ldm.to(device)
|
96 |
+
ldm.eval()
|
97 |
+
ldm.requires_grad_(False)
|
98 |
+
set_requires_grad(ldm, False)
|
99 |
+
|
100 |
+
# clip
|
101 |
+
clip_version = 'openai/clip-vit-large-patch14'
|
102 |
+
clip_tokenizer = CLIPTokenizer.from_pretrained(clip_version)
|
103 |
+
clip_transformer = CLIPTextModel.from_pretrained(clip_version)
|
104 |
+
clip_transformer.eval().requires_grad_(False).to(device)
|
105 |
+
|
106 |
+
# classifier
|
107 |
+
# load classifier
|
108 |
+
classifier_config = classifier_defaults()
|
109 |
+
classifier_config['classifier_width'] = 128
|
110 |
+
classifier_config['classifier_depth'] = 4
|
111 |
+
classifier_config['classifier_attention_resolutions'] = '64,32,16,8'
|
112 |
+
|
113 |
+
classifier_photo = create_classifier(**classifier_config)
|
114 |
+
classifier_photo.load_state_dict(
|
115 |
+
torch.load(fetch('https://huggingface.co/Jack000/glid-3-xl-stable/resolve/main/classifier_photo/model060000.pt'), map_location="cpu")
|
116 |
+
)
|
117 |
+
classifier_photo.to(device)
|
118 |
+
classifier_photo.convert_to_fp16()
|
119 |
+
classifier_photo.eval()
|
120 |
+
|
121 |
+
classifier_art = create_classifier(**classifier_config)
|
122 |
+
classifier_art.load_state_dict(
|
123 |
+
torch.load('https://huggingface.co/Jack000/glid-3-xl-stable/resolve/main/classifier_art/model110000.pt', map_location="cpu")
|
124 |
+
)
|
125 |
+
classifier_art.to(device)
|
126 |
+
classifier_art.convert_to_fp16()
|
127 |
+
classifier_art.eval()
|
128 |
+
|
129 |
+
def infer(prompt, style, scale, classifier_scale, seed):
|
130 |
+
torch.manual_seed(seed)
|
131 |
+
|
132 |
+
# clip context
|
133 |
+
text = clip_tokenizer([prompt], truncation=True, max_length=77, return_length=True, return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
|
134 |
+
text_blank = clip_tokenizer([''], truncation=True, max_length=77, return_length=True, return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
|
135 |
+
text_tokens = text["input_ids"].to(device)
|
136 |
+
text_blank_tokens = text_blank["input_ids"].to(device)
|
137 |
+
|
138 |
+
text_emb = clip_transformer(input_ids=text_tokens).last_hidden_state
|
139 |
+
text_emb_blank = clip_transformer(input_ids=text_blank_tokens).last_hidden_state
|
140 |
+
|
141 |
+
kwargs = {
|
142 |
+
"context": torch.cat([text_emb, text_emb_blank], dim=0).half(),
|
143 |
+
"clip_embed": None,
|
144 |
+
"image_embed": None,
|
145 |
+
}
|
146 |
+
|
147 |
+
def model_fn(x_t, ts, **kwargs):
|
148 |
+
half = x_t[: len(x_t) // 2]
|
149 |
+
combined = torch.cat([half, half], dim=0)
|
150 |
+
model_out = model(combined, ts, **kwargs)
|
151 |
+
eps, rest = model_out[:, :3], model_out[:, 3:]
|
152 |
+
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
|
153 |
+
half_eps = uncond_eps + scale * (cond_eps - uncond_eps)
|
154 |
+
eps = torch.cat([half_eps, half_eps], dim=0)
|
155 |
+
return torch.cat([eps, rest], dim=1)
|
156 |
+
|
157 |
+
def cond_fn(x, t, context=None, clip_embed=None, image_embed=None):
|
158 |
+
with torch.enable_grad():
|
159 |
+
x_in = x[:x.shape[0]//2].detach().requires_grad_(True)
|
160 |
+
if style == 'photo':
|
161 |
+
logits = classifier_photo(x_in, t)
|
162 |
+
elif style == 'digital art':
|
163 |
+
logits = classifier_art(x_in, t)
|
164 |
+
else:
|
165 |
+
return 0
|
166 |
+
|
167 |
+
log_probs = F.log_softmax(logits, dim=-1)
|
168 |
+
selected = log_probs[range(len(logits)), torch.ones(x_in.shape[0], dtype=torch.long)]
|
169 |
+
return torch.autograd.grad(selected.sum(), x_in)[0] * classifier_scale
|
170 |
+
|
171 |
+
samples = diffusion.ddim_sample_loop_progressive(
|
172 |
+
model_fn,
|
173 |
+
(2, 4, 64, 64),
|
174 |
+
clip_denoised=False,
|
175 |
+
model_kwargs=kwargs,
|
176 |
+
cond_fn=cond_fn,
|
177 |
+
device=device,
|
178 |
+
progress=True,
|
179 |
+
init_image=None,
|
180 |
+
skip_timesteps=0,
|
181 |
+
)
|
182 |
+
|
183 |
+
for j, sample in enumerate(samples):
|
184 |
+
pass
|
185 |
+
|
186 |
+
emb = sample['pred_xstart'][0]
|
187 |
+
emb /= 0.18215
|
188 |
+
im = emb.unsqueeze(0)
|
189 |
+
im = ldm.decode(im)
|
190 |
+
|
191 |
+
im = TF.to_pil_image(im.squeeze(0).add(1).div(2).clamp(0, 1))
|
192 |
+
|
193 |
+
return [im]
|
194 |
|
195 |
css = """
|
196 |
.gradio-container {
|
|
|
247 |
padding: 2px 8px;
|
248 |
border-radius: 14px !important;
|
249 |
}
|
250 |
+
#advanced-options, #style-options {
|
|
|
251 |
margin-bottom: 20px;
|
252 |
}
|
253 |
.footer {
|
|
|
362 |
<rect x="23" y="69" width="23" height="23" fill="black"></rect>
|
363 |
</svg>
|
364 |
<h1 style="font-weight: 900; margin-bottom: 7px;">
|
365 |
+
Classifier Guided Stable Diffusion
|
366 |
</h1>
|
367 |
</div>
|
368 |
<p style="margin-bottom: 10px; font-size: 94%">
|
369 |
+
a custom version of stable diffusion with classifier guidance
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
370 |
</p>
|
371 |
</div>
|
372 |
"""
|
|
|
393 |
label="Generated images", show_label=False, elem_id="gallery"
|
394 |
).style(grid=[2], height="auto")
|
395 |
|
396 |
+
#advanced_button = gr.Button("Advanced options", elem_id="advanced-btn")
|
397 |
|
398 |
+
with gr.Row(elem_id="style-options"):
|
399 |
+
style = gr.Radio(["none","photo","digital art","anime"], label="Image style")
|
400 |
with gr.Row(elem_id="advanced-options"):
|
401 |
+
#samples = gr.Slider(label="Images", minimum=1, maximum=4, value=4, step=1)
|
402 |
+
#steps = gr.Slider(label="Steps", minimum=1, maximum=50, value=45, step=1)
|
403 |
scale = gr.Slider(
|
404 |
+
label="CFG Scale", minimum=0, maximum=50, value=7.5, step=0.1
|
405 |
+
)
|
406 |
+
classifier_scale = gr.Slider(
|
407 |
+
label="Classifier Scale", minimum=0, maximum=1000, value=100, step=1
|
408 |
)
|
409 |
seed = gr.Slider(
|
410 |
label="Seed",
|
|
|
414 |
randomize=True,
|
415 |
)
|
416 |
|
417 |
+
ex = gr.Examples(examples=examples, fn=infer, inputs=[text, style, scale, classifier_scale, seed], outputs=gallery, cache_examples=True)
|
418 |
ex.dataset.headers = [""]
|
419 |
|
420 |
|
421 |
+
text.submit(infer, inputs=[text, style, scale, classifier_scale, seed], outputs=gallery)
|
422 |
+
btn.click(infer, inputs=[text, style, scale, classifier_scale, seed], outputs=gallery)
|
423 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
424 |
gr.HTML(
|
425 |
"""
|
426 |
<div class="footer">
|