Jack000 commited on
Commit
4a3ecd0
·
1 Parent(s): 742d897

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +202 -65
app.py CHANGED
@@ -2,45 +2,195 @@ import gradio as gr
2
 
3
  import torch
4
  from torch import autocast
5
- from diffusers import StableDiffusionPipeline
6
- from datasets import load_dataset
7
- from PIL import Image
8
- import re
9
 
10
- model_id = "CompVis/stable-diffusion-v1-4"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  device = "cuda"
12
 
13
- #If you are running this code locally, you need to either do a 'huggingface-cli login` or paste your User Access Token from here https://huggingface.co/settings/tokens into the use_auth_token field below.
14
- pipe = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=False, revision="fp16", torch_dtype=torch.float16)
15
- pipe = pipe.to(device)
16
- #When running locally, you won`t have access to this, so you can remove this part
17
- word_list_dataset = load_dataset("stabilityai/word-list", data_files="list.txt", use_auth_token=True)
18
- word_list = word_list_dataset["train"]['text']
19
-
20
- def infer(prompt, samples, steps, scale, seed):
21
- #When running locally you can also remove this filter
22
- for filter in word_list:
23
- if re.search(rf"\b{filter}\b", prompt):
24
- raise gr.Error("Unsafe content found. Please try again with different prompts.")
25
-
26
- generator = torch.Generator(device=device).manual_seed(seed)
27
-
28
- #If you are running locally with CPU, you can remove the `with autocast("cuda")`
29
- with autocast("cuda"):
30
- images_list = pipe(
31
- [prompt] * samples,
32
- num_inference_steps=steps,
33
- guidance_scale=scale,
34
- generator=generator,
35
- )
36
- images = []
37
- safe_image = Image.open(r"unsafe.png")
38
- for i, image in enumerate(images_list["sample"]):
39
- if(images_list["nsfw_content_detected"][i]):
40
- images.append(safe_image)
41
- else:
42
- images.append(image)
43
- return images
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
  css = """
46
  .gradio-container {
@@ -97,8 +247,7 @@ css = """
97
  padding: 2px 8px;
98
  border-radius: 14px !important;
99
  }
100
- #advanced-options {
101
- display: none;
102
  margin-bottom: 20px;
103
  }
104
  .footer {
@@ -213,19 +362,11 @@ with block:
213
  <rect x="23" y="69" width="23" height="23" fill="black"></rect>
214
  </svg>
215
  <h1 style="font-weight: 900; margin-bottom: 7px;">
216
- Stable Diffusion Demo
217
  </h1>
218
  </div>
219
  <p style="margin-bottom: 10px; font-size: 94%">
220
- Stable Diffusion is a state of the art text-to-image model that generates
221
- images from text.<br>For faster generation and forthcoming API
222
- access you can try
223
- <a
224
- href="http://beta.dreamstudio.ai/"
225
- style="text-decoration: underline;"
226
- target="_blank"
227
- >DreamStudio Beta</a
228
- >
229
  </p>
230
  </div>
231
  """
@@ -252,13 +393,18 @@ with block:
252
  label="Generated images", show_label=False, elem_id="gallery"
253
  ).style(grid=[2], height="auto")
254
 
255
- advanced_button = gr.Button("Advanced options", elem_id="advanced-btn")
256
 
 
 
257
  with gr.Row(elem_id="advanced-options"):
258
- samples = gr.Slider(label="Images", minimum=1, maximum=4, value=4, step=1)
259
- steps = gr.Slider(label="Steps", minimum=1, maximum=50, value=45, step=1)
260
  scale = gr.Slider(
261
- label="Guidance Scale", minimum=0, maximum=50, value=7.5, step=0.1
 
 
 
262
  )
263
  seed = gr.Slider(
264
  label="Seed",
@@ -268,22 +414,13 @@ with block:
268
  randomize=True,
269
  )
270
 
271
- ex = gr.Examples(examples=examples, fn=infer, inputs=[text, samples, steps, scale, seed], outputs=gallery, cache_examples=True)
272
  ex.dataset.headers = [""]
273
 
274
 
275
- text.submit(infer, inputs=[text, samples, steps, scale, seed], outputs=gallery)
276
- btn.click(infer, inputs=[text, samples, steps, scale, seed], outputs=gallery)
277
- advanced_button.click(
278
- None,
279
- [],
280
- text,
281
- _js="""
282
- () => {
283
- const options = document.querySelector("body > gradio-app").querySelector("#advanced-options");
284
- options.style.display = ["none", ""].includes(options.style.display) ? "flex" : "none";
285
- }""",
286
- )
287
  gr.HTML(
288
  """
289
  <div class="footer">
 
2
 
3
  import torch
4
  from torch import autocast
 
 
 
 
5
 
6
+ import gc
7
+ import io
8
+ import math
9
+ import sys
10
+
11
+ from PIL import Image, ImageOps
12
+ import requests
13
+ from torch import nn
14
+ from torch.nn import functional as F
15
+ from torchvision import transforms
16
+ from torchvision.transforms import functional as TF
17
+ from tqdm.notebook import tqdm
18
+
19
+ import numpy as np
20
+
21
+ from guided_diffusion.script_util import create_model_and_diffusion, model_and_diffusion_defaults, classifier_defaults, create_classifier
22
+
23
+ from omegaconf import OmegaConf
24
+ from ldm.util import instantiate_from_config
25
+
26
+ from einops import rearrange
27
+ from math import log2, sqrt
28
+
29
+ import argparse
30
+ import pickle
31
+
32
+ import os
33
+
34
+ from transformers import CLIPTokenizer, CLIPTextModel
35
+
36
+ def fetch(url_or_path):
37
+ if str(url_or_path).startswith('http://') or str(url_or_path).startswith('https://'):
38
+ r = requests.get(url_or_path)
39
+ r.raise_for_status()
40
+ fd = io.BytesIO()
41
+ fd.write(r.content)
42
+ fd.seek(0)
43
+ return fd
44
+ return open(url_or_path, 'rb')
45
+
46
  device = "cuda"
47
 
48
+ #model_state_dict = torch.load('diffusion.pt', map_location='cpu')
49
+ model_state_dict = torch.load(fetch('https://huggingface.co/Jack000/glid-3-xl-stable/resolve/main/default/diffusion-1.4.pt'), map_location='cpu')
50
+
51
+ model_params = {
52
+ 'attention_resolutions': '32,16,8',
53
+ 'class_cond': False,
54
+ 'diffusion_steps': 1000,
55
+ 'rescale_timesteps': True,
56
+ 'timestep_respacing': 'ddim100',
57
+ 'image_size': 32,
58
+ 'learn_sigma': False,
59
+ 'noise_schedule': 'linear',
60
+ 'num_channels': 320,
61
+ 'num_heads': 8,
62
+ 'num_res_blocks': 2,
63
+ 'resblock_updown': False,
64
+ 'use_fp16': True,
65
+ 'use_scale_shift_norm': False,
66
+ 'clip_embed_dim': None,
67
+ 'image_condition': False,
68
+ 'super_res_condition': False,
69
+ }
70
+
71
+ model_config = model_and_diffusion_defaults()
72
+ model_config.update(model_params)
73
+
74
+ # Load models
75
+ model, diffusion = create_model_and_diffusion(**model_config)
76
+ model.load_state_dict(model_state_dict, strict=True)
77
+ model.requires_grad_(False).eval().to(device)
78
+
79
+ if model_config['use_fp16']:
80
+ model.convert_to_fp16()
81
+ else:
82
+ model.convert_to_fp32()
83
+
84
+ def set_requires_grad(model, value):
85
+ for param in model.parameters():
86
+ param.requires_grad = value
87
+
88
+ # vae
89
+ kl_config = OmegaConf.load('kl.yaml')
90
+ kl_sd = torch.load(fetch('https://huggingface.co/Jack000/glid-3-xl-stable/resolve/main/default/kl-1.4.pt'), map_location="cpu")
91
+
92
+ ldm = instantiate_from_config(kl_config.model)
93
+ ldm.load_state_dict(kl_sd, strict=True)
94
+
95
+ ldm.to(device)
96
+ ldm.eval()
97
+ ldm.requires_grad_(False)
98
+ set_requires_grad(ldm, False)
99
+
100
+ # clip
101
+ clip_version = 'openai/clip-vit-large-patch14'
102
+ clip_tokenizer = CLIPTokenizer.from_pretrained(clip_version)
103
+ clip_transformer = CLIPTextModel.from_pretrained(clip_version)
104
+ clip_transformer.eval().requires_grad_(False).to(device)
105
+
106
+ # classifier
107
+ # load classifier
108
+ classifier_config = classifier_defaults()
109
+ classifier_config['classifier_width'] = 128
110
+ classifier_config['classifier_depth'] = 4
111
+ classifier_config['classifier_attention_resolutions'] = '64,32,16,8'
112
+
113
+ classifier_photo = create_classifier(**classifier_config)
114
+ classifier_photo.load_state_dict(
115
+ torch.load(fetch('https://huggingface.co/Jack000/glid-3-xl-stable/resolve/main/classifier_photo/model060000.pt'), map_location="cpu")
116
+ )
117
+ classifier_photo.to(device)
118
+ classifier_photo.convert_to_fp16()
119
+ classifier_photo.eval()
120
+
121
+ classifier_art = create_classifier(**classifier_config)
122
+ classifier_art.load_state_dict(
123
+ torch.load('https://huggingface.co/Jack000/glid-3-xl-stable/resolve/main/classifier_art/model110000.pt', map_location="cpu")
124
+ )
125
+ classifier_art.to(device)
126
+ classifier_art.convert_to_fp16()
127
+ classifier_art.eval()
128
+
129
+ def infer(prompt, style, scale, classifier_scale, seed):
130
+ torch.manual_seed(seed)
131
+
132
+ # clip context
133
+ text = clip_tokenizer([prompt], truncation=True, max_length=77, return_length=True, return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
134
+ text_blank = clip_tokenizer([''], truncation=True, max_length=77, return_length=True, return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
135
+ text_tokens = text["input_ids"].to(device)
136
+ text_blank_tokens = text_blank["input_ids"].to(device)
137
+
138
+ text_emb = clip_transformer(input_ids=text_tokens).last_hidden_state
139
+ text_emb_blank = clip_transformer(input_ids=text_blank_tokens).last_hidden_state
140
+
141
+ kwargs = {
142
+ "context": torch.cat([text_emb, text_emb_blank], dim=0).half(),
143
+ "clip_embed": None,
144
+ "image_embed": None,
145
+ }
146
+
147
+ def model_fn(x_t, ts, **kwargs):
148
+ half = x_t[: len(x_t) // 2]
149
+ combined = torch.cat([half, half], dim=0)
150
+ model_out = model(combined, ts, **kwargs)
151
+ eps, rest = model_out[:, :3], model_out[:, 3:]
152
+ cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
153
+ half_eps = uncond_eps + scale * (cond_eps - uncond_eps)
154
+ eps = torch.cat([half_eps, half_eps], dim=0)
155
+ return torch.cat([eps, rest], dim=1)
156
+
157
+ def cond_fn(x, t, context=None, clip_embed=None, image_embed=None):
158
+ with torch.enable_grad():
159
+ x_in = x[:x.shape[0]//2].detach().requires_grad_(True)
160
+ if style == 'photo':
161
+ logits = classifier_photo(x_in, t)
162
+ elif style == 'digital art':
163
+ logits = classifier_art(x_in, t)
164
+ else:
165
+ return 0
166
+
167
+ log_probs = F.log_softmax(logits, dim=-1)
168
+ selected = log_probs[range(len(logits)), torch.ones(x_in.shape[0], dtype=torch.long)]
169
+ return torch.autograd.grad(selected.sum(), x_in)[0] * classifier_scale
170
+
171
+ samples = diffusion.ddim_sample_loop_progressive(
172
+ model_fn,
173
+ (2, 4, 64, 64),
174
+ clip_denoised=False,
175
+ model_kwargs=kwargs,
176
+ cond_fn=cond_fn,
177
+ device=device,
178
+ progress=True,
179
+ init_image=None,
180
+ skip_timesteps=0,
181
+ )
182
+
183
+ for j, sample in enumerate(samples):
184
+ pass
185
+
186
+ emb = sample['pred_xstart'][0]
187
+ emb /= 0.18215
188
+ im = emb.unsqueeze(0)
189
+ im = ldm.decode(im)
190
+
191
+ im = TF.to_pil_image(im.squeeze(0).add(1).div(2).clamp(0, 1))
192
+
193
+ return [im]
194
 
195
  css = """
196
  .gradio-container {
 
247
  padding: 2px 8px;
248
  border-radius: 14px !important;
249
  }
250
+ #advanced-options, #style-options {
 
251
  margin-bottom: 20px;
252
  }
253
  .footer {
 
362
  <rect x="23" y="69" width="23" height="23" fill="black"></rect>
363
  </svg>
364
  <h1 style="font-weight: 900; margin-bottom: 7px;">
365
+ Classifier Guided Stable Diffusion
366
  </h1>
367
  </div>
368
  <p style="margin-bottom: 10px; font-size: 94%">
369
+ a custom version of stable diffusion with classifier guidance
 
 
 
 
 
 
 
 
370
  </p>
371
  </div>
372
  """
 
393
  label="Generated images", show_label=False, elem_id="gallery"
394
  ).style(grid=[2], height="auto")
395
 
396
+ #advanced_button = gr.Button("Advanced options", elem_id="advanced-btn")
397
 
398
+ with gr.Row(elem_id="style-options"):
399
+ style = gr.Radio(["none","photo","digital art","anime"], label="Image style")
400
  with gr.Row(elem_id="advanced-options"):
401
+ #samples = gr.Slider(label="Images", minimum=1, maximum=4, value=4, step=1)
402
+ #steps = gr.Slider(label="Steps", minimum=1, maximum=50, value=45, step=1)
403
  scale = gr.Slider(
404
+ label="CFG Scale", minimum=0, maximum=50, value=7.5, step=0.1
405
+ )
406
+ classifier_scale = gr.Slider(
407
+ label="Classifier Scale", minimum=0, maximum=1000, value=100, step=1
408
  )
409
  seed = gr.Slider(
410
  label="Seed",
 
414
  randomize=True,
415
  )
416
 
417
+ ex = gr.Examples(examples=examples, fn=infer, inputs=[text, style, scale, classifier_scale, seed], outputs=gallery, cache_examples=True)
418
  ex.dataset.headers = [""]
419
 
420
 
421
+ text.submit(infer, inputs=[text, style, scale, classifier_scale, seed], outputs=gallery)
422
+ btn.click(infer, inputs=[text, style, scale, classifier_scale, seed], outputs=gallery)
423
+
 
 
 
 
 
 
 
 
 
424
  gr.HTML(
425
  """
426
  <div class="footer">