ironjr commited on
Commit
82bf0c3
Β·
1 Parent(s): 2a39a67

first commit

Browse files
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .ipynb_checkpoints/*
README.md CHANGED
@@ -1,13 +1,16 @@
1
  ---
2
- title: SemanticPaletteXL
3
- emoji: πŸ“‰
4
  colorFrom: red
5
- colorTo: indigo
6
  sdk: gradio
7
- sdk_version: 4.22.0
8
  app_file: app.py
9
- pinned: false
10
  license: mit
 
 
 
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: SemanticPalette X Animagine XL 3.1
3
+ emoji: πŸ”₯🧠🎨πŸ”₯
4
  colorFrom: red
5
+ colorTo: yellow
6
  sdk: gradio
7
+ sdk_version: 4.21.0
8
  app_file: app.py
9
+ pinned: true
10
  license: mit
11
+ models:
12
+ - cagliostrolab/animagine-xl-3.1
13
+ - ByteDance/SDXL-Lightning
14
  ---
15
 
16
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,873 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Jaerin Lee
2
+
3
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
4
+ # of this software and associated documentation files (the "Software"), to deal
5
+ # in the Software without restriction, including without limitation the rights
6
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7
+ # copies of the Software, and to permit persons to whom the Software is
8
+ # furnished to do so, subject to the following conditions:
9
+
10
+ # The above copyright notice and this permission notice shall be included in all
11
+ # copies or substantial portions of the Software.
12
+
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
19
+ # SOFTWARE.
20
+
21
+ import sys
22
+
23
+ sys.path.append('../../src')
24
+
25
+ import argparse
26
+ import random
27
+ import time
28
+ import json
29
+ import os
30
+ import glob
31
+ import pathlib
32
+ from functools import partial
33
+ from pprint import pprint
34
+
35
+ import numpy as np
36
+ from PIL import Image
37
+ import torch
38
+
39
+ import gradio as gr
40
+ from huggingface_hub import snapshot_download
41
+
42
+ from model import StableMultiDiffusionSDXLPipeline
43
+ from util import seed_everything
44
+ from prompt_util import preprocess_prompts, _quality_dict, _style_dict
45
+ from share_btn import community_icon_html, loading_icon_html, share_js
46
+
47
+
48
+ ### Utils
49
+
50
+
51
+
52
+
53
+ def log_state(state):
54
+ pprint(vars(opt))
55
+ if isinstance(state, gr.State):
56
+ state = state.value
57
+ pprint(vars(state))
58
+
59
+
60
+ def is_empty_image(im: Image.Image) -> bool:
61
+ if im is None:
62
+ return True
63
+ im = np.array(im)
64
+ has_alpha = (im.shape[2] == 4)
65
+ if not has_alpha:
66
+ return False
67
+ elif im.sum() == 0:
68
+ return True
69
+ else:
70
+ return False
71
+
72
+
73
+ ### Argument passing
74
+
75
+ parser = argparse.ArgumentParser(description='Semantic Palette demo powered by StreamMultiDiffusion with SDXL support.')
76
+ parser.add_argument('-H', '--height', type=int, default=1024)
77
+ parser.add_argument('-W', '--width', type=int, default=2560)
78
+ parser.add_argument('--model', type=str, default=None, help='Hugging face model repository or local path for a SD1.5 model checkpoint to run.')
79
+ parser.add_argument('--bootstrap_steps', type=int, default=1)
80
+ parser.add_argument('--seed', type=int, default=-1)
81
+ parser.add_argument('--device', type=int, default=0)
82
+ parser.add_argument('--port', type=int, default=8000)
83
+ opt = parser.parse_args()
84
+
85
+
86
+ ### Global variables and data structures
87
+
88
+ device = f'cuda:{opt.device}' if opt.device >= 0 else 'cpu'
89
+
90
+
91
+ if opt.model is None:
92
+ model_dict = {
93
+ 'Animagine XL 3.1': 'cagliostrolab/animagine-xl-3.1',
94
+ }
95
+ else:
96
+ if opt.model.endswith('.safetensors'):
97
+ opt.model = os.path.abspath(os.path.join('checkpoints', opt.model))
98
+ model_dict = {os.path.splitext(os.path.basename(opt.model))[0]: opt.model}
99
+
100
+ models = {
101
+ k: StableMultiDiffusionSDXLPipeline(device, hf_key=v, has_i2t=False)
102
+ for k, v in model_dict.items()
103
+ }
104
+
105
+
106
+ prompt_suggestions = [
107
+ '1girl, souryuu asuka langley, neon genesis evangelion, solo, upper body, v, smile, looking at viewer',
108
+ '1boy, solo, portrait, looking at viewer, white t-shirt, brown hair',
109
+ '1girl, arima kana, oshi no ko, solo, upper body, from behind',
110
+ ]
111
+
112
+ opt.max_palettes = 5
113
+ opt.default_prompt_strength = 1.0
114
+ opt.default_mask_strength = 1.0
115
+ opt.default_mask_std = 0.0
116
+ opt.default_negative_prompt = (
117
+ 'nsfw, worst quality, bad quality, normal quality, cropped, framed'
118
+ )
119
+ opt.verbose = True
120
+ opt.colors = [
121
+ '#000000',
122
+ '#2692F3',
123
+ '#F89E12',
124
+ '#16C232',
125
+ '#F92F6C',
126
+ '#AC6AEB',
127
+ # '#92C62C',
128
+ # '#92C6EC',
129
+ # '#FECAC0',
130
+ ]
131
+
132
+
133
+ ### Event handlers
134
+
135
+ def add_palette(state):
136
+ old_actives = state.active_palettes
137
+ state.active_palettes = min(state.active_palettes + 1, opt.max_palettes)
138
+
139
+ if opt.verbose:
140
+ log_state(state)
141
+
142
+ if state.active_palettes != old_actives:
143
+ return [state] + [
144
+ gr.update() if state.active_palettes != opt.max_palettes else gr.update(visible=False)
145
+ ] + [
146
+ gr.update() if i != state.active_palettes - 1 else gr.update(value=state.prompt_names[i + 1], visible=True)
147
+ for i in range(opt.max_palettes)
148
+ ]
149
+ else:
150
+ return [state] + [gr.update() for i in range(opt.max_palettes + 1)]
151
+
152
+
153
+ def select_palette(state, button, idx):
154
+ if idx < 0 or idx > opt.max_palettes:
155
+ idx = 0
156
+ old_idx = state.current_palette
157
+ if old_idx == idx:
158
+ return [state] + [gr.update() for _ in range(opt.max_palettes + 7)]
159
+
160
+ state.current_palette = idx
161
+
162
+ if opt.verbose:
163
+ log_state(state)
164
+
165
+ updates = [state] + [
166
+ gr.update() if i not in (idx, old_idx) else
167
+ gr.update(variant='secondary') if i == old_idx else gr.update(variant='primary')
168
+ for i in range(opt.max_palettes + 1)
169
+ ]
170
+ label = 'Background' if idx == 0 else f'Palette {idx}'
171
+ updates.extend([
172
+ gr.update(value=button, interactive=(idx > 0)),
173
+ gr.update(value=state.prompts[idx], label=f'Edit Prompt for {label}'),
174
+ gr.update(value=state.neg_prompts[idx], label=f'Edit Negative Prompt for {label}'),
175
+ (
176
+ gr.update(value=state.mask_strengths[idx - 1], interactive=True) if idx > 0 else
177
+ gr.update(value=opt.default_mask_strength, interactive=False)
178
+ ),
179
+ (
180
+ gr.update(value=state.prompt_strengths[idx - 1], interactive=True) if idx > 0 else
181
+ gr.update(value=opt.default_prompt_strength, interactive=False)
182
+ ),
183
+ (
184
+ gr.update(value=state.mask_stds[idx - 1], interactive=True) if idx > 0 else
185
+ gr.update(value=opt.default_mask_std, interactive=False)
186
+ ),
187
+ ])
188
+ return updates
189
+
190
+
191
+ def change_prompt_strength(state, strength):
192
+ if state.current_palette == 0:
193
+ return state
194
+
195
+ state.prompt_strengths[state.current_palette - 1] = strength
196
+ if opt.verbose:
197
+ log_state(state)
198
+
199
+ return state
200
+
201
+
202
+ def change_std(state, std):
203
+ if state.current_palette == 0:
204
+ return state
205
+
206
+ state.mask_stds[state.current_palette - 1] = std
207
+ if opt.verbose:
208
+ log_state(state)
209
+
210
+ return state
211
+
212
+
213
+ def change_mask_strength(state, strength):
214
+ if state.current_palette == 0:
215
+ return state
216
+
217
+ state.mask_strengths[state.current_palette - 1] = strength
218
+ if opt.verbose:
219
+ log_state(state)
220
+
221
+ return state
222
+
223
+
224
+ def reset_seed(state, seed):
225
+ state.seed = seed
226
+ if opt.verbose:
227
+ log_state(state)
228
+
229
+ return state
230
+
231
+ def rename_prompt(state, name):
232
+ state.prompt_names[state.current_palette] = name
233
+ if opt.verbose:
234
+ log_state(state)
235
+
236
+ return [state] + [
237
+ gr.update() if i != state.current_palette else gr.update(value=name)
238
+ for i in range(opt.max_palettes + 1)
239
+ ]
240
+
241
+
242
+ def change_prompt(state, prompt):
243
+ state.prompts[state.current_palette] = prompt
244
+ if opt.verbose:
245
+ log_state(state)
246
+
247
+ return state
248
+
249
+
250
+ def change_neg_prompt(state, neg_prompt):
251
+ state.neg_prompts[state.current_palette] = neg_prompt
252
+ if opt.verbose:
253
+ log_state(state)
254
+
255
+ return state
256
+
257
+
258
+ def select_model(state, model_id):
259
+ state.model_id = model_id
260
+ if opt.verbose:
261
+ log_state(state)
262
+
263
+ return state
264
+
265
+
266
+ def select_style(state, style_name):
267
+ state.style_name = style_name
268
+ if opt.verbose:
269
+ log_state(state)
270
+
271
+ return state
272
+
273
+
274
+ def select_quality(state, quality_name):
275
+ state.quality_name = quality_name
276
+ if opt.verbose:
277
+ log_state(state)
278
+
279
+ return state
280
+
281
+
282
+ def import_state(state, json_text):
283
+ current_palette = state.current_palette
284
+ # active_palettes = state.active_palettes
285
+ state = argparse.Namespace(**json.loads(json_text))
286
+ state.active_palettes = opt.max_palettes
287
+ return [state] + [
288
+ gr.update(value=v, visible=True) for v in state.prompt_names
289
+ ] + [
290
+ state.model_id,
291
+ state.style_name,
292
+ state.quality_name,
293
+ state.prompts[current_palette],
294
+ state.prompt_names[current_palette],
295
+ state.neg_prompts[current_palette],
296
+ state.prompt_strengths[current_palette - 1],
297
+ state.mask_strengths[current_palette - 1],
298
+ state.mask_stds[current_palette - 1],
299
+ state.seed,
300
+ ]
301
+
302
+
303
+ ### Main worker
304
+
305
+ def generate(state, *args, **kwargs):
306
+ return models[state.model_id](*args, **kwargs)
307
+
308
+
309
+
310
+ def run(state, drawpad):
311
+ seed_everything(state.seed if state.seed >=0 else np.random.randint(2147483647))
312
+ print('Generate!')
313
+
314
+ background = drawpad['background'].convert('RGBA')
315
+ inpainting_mode = np.asarray(background).sum() != 0
316
+ print('Inpainting mode: ', inpainting_mode)
317
+
318
+ user_input = np.asarray(drawpad['layers'][0]) # (H, W, 4)
319
+ foreground_mask = torch.tensor(user_input[..., -1])[None, None] # (1, 1, H, W)
320
+ user_input = torch.tensor(user_input[..., :-1]) # (H, W, 3)
321
+
322
+ palette = torch.tensor([
323
+ tuple(int(s[i+1:i+3], 16) for i in (0, 2, 4))
324
+ for s in opt.colors[1:]
325
+ ]) # (N, 3)
326
+ masks = (palette[:, None, None, :] == user_input[None]).all(dim=-1)[:, None, ...] # (N, 1, H, W)
327
+ has_masks = [i for i, m in enumerate(masks.sum(dim=(1, 2, 3)) == 0) if not m]
328
+ print('Has mask: ', has_masks)
329
+ masks = masks * foreground_mask
330
+ masks = masks[has_masks]
331
+
332
+ if inpainting_mode:
333
+ prompts = [state.prompts[v + 1] for v in has_masks]
334
+ negative_prompts = [state.neg_prompts[v + 1] for v in has_masks]
335
+ mask_strengths = [state.mask_strengths[v] for v in has_masks]
336
+ mask_stds = [state.mask_stds[v] for v in has_masks]
337
+ prompt_strengths = [state.prompt_strengths[v] for v in has_masks]
338
+ else:
339
+ masks = torch.cat([torch.ones_like(foreground_mask), masks], dim=0)
340
+ prompts = [state.prompts[0]] + [state.prompts[v + 1] for v in has_masks]
341
+ negative_prompts = [state.neg_prompts[0]] + [state.neg_prompts[v + 1] for v in has_masks]
342
+ mask_strengths = [1] + [state.mask_strengths[v] for v in has_masks]
343
+ mask_stds = [0] + [state.mask_stds[v] for v in has_masks]
344
+ prompt_strengths = [1] + [state.prompt_strengths[v] for v in has_masks]
345
+
346
+ prompts, negative_prompts = preprocess_prompts(
347
+ prompts, negative_prompts, style_name=state.style_name, quality_name=state.quality_name)
348
+
349
+ return generate(
350
+ state,
351
+ prompts,
352
+ negative_prompts,
353
+ masks=masks,
354
+ mask_strengths=mask_strengths,
355
+ mask_stds=mask_stds,
356
+ prompt_strengths=prompt_strengths,
357
+ background=background.convert('RGB'),
358
+ background_prompt=state.prompts[0],
359
+ background_negative_prompt=state.neg_prompts[0],
360
+ height=opt.height,
361
+ width=opt.width,
362
+ bootstrap_steps=2,
363
+ guidance_scale=0,
364
+ )
365
+
366
+
367
+
368
+ ### Load examples
369
+
370
+
371
+ root = pathlib.Path(__file__).parent
372
+ print(root)
373
+ example_root = os.path.join(root, 'examples')
374
+ example_images = glob.glob(os.path.join(example_root, '*.png'))
375
+ example_images = [Image.open(i) for i in example_images]
376
+
377
+ with open(os.path.join(example_root, 'prompt_background_advanced.txt')) as f:
378
+ prompts_background = [l.strip() for l in f.readlines() if l.strip() != '']
379
+
380
+ with open(os.path.join(example_root, 'prompt_girl.txt')) as f:
381
+ prompts_girl = [l.strip() for l in f.readlines() if l.strip() != '']
382
+
383
+ with open(os.path.join(example_root, 'prompt_boy.txt')) as f:
384
+ prompts_boy = [l.strip() for l in f.readlines() if l.strip() != '']
385
+
386
+ with open(os.path.join(example_root, 'prompt_props.txt')) as f:
387
+ prompts_props = [l.strip() for l in f.readlines() if l.strip() != '']
388
+ prompts_props = {l.split(',')[0].strip(): ','.join(l.split(',')[1:]).strip() for l in prompts_props}
389
+
390
+ prompt_background = lambda: random.choice(prompts_background)
391
+ prompt_girl = lambda: random.choice(prompts_girl)
392
+ prompt_boy = lambda: random.choice(prompts_boy)
393
+ prompt_props = lambda: np.random.choice(list(prompts_props.keys()), size=(opt.max_palettes - 2), replace=False).tolist()
394
+
395
+
396
+ ### Main application
397
+
398
+ css = f"""
399
+ #run-button {{
400
+ font-size: 30pt;
401
+ background-image: linear-gradient(to right, #4338ca 0%, #26a0da 51%, #4338ca 100%);
402
+ margin: 0;
403
+ padding: 15px 45px;
404
+ text-align: center;
405
+ text-transform: uppercase;
406
+ transition: 0.5s;
407
+ background-size: 200% auto;
408
+ color: white;
409
+ box-shadow: 0 0 20px #eee;
410
+ border-radius: 10px;
411
+ display: block;
412
+ background-position: right center;
413
+ }}
414
+
415
+ #run-button:hover {{
416
+ background-position: left center;
417
+ color: #fff;
418
+ text-decoration: none;
419
+ }}
420
+
421
+ #semantic-palette {{
422
+ border-style: solid;
423
+ border-width: 0.2em;
424
+ border-color: #eee;
425
+ }}
426
+
427
+ #semantic-palette:hover {{
428
+ box-shadow: 0 0 20px #eee;
429
+ }}
430
+
431
+ #output-screen {{
432
+ width: 100%;
433
+ aspect-ratio: {opt.width} / {opt.height};
434
+ }}
435
+
436
+ .layer-wrap {{
437
+ display: none;
438
+ }}
439
+ """
440
+
441
+ for i in range(opt.max_palettes + 1):
442
+ css = css + f"""
443
+ .secondary#semantic-palette-{i} {{
444
+ background-image: linear-gradient(to right, #374151 0%, #374151 71%, {opt.colors[i]} 100%);
445
+ color: white;
446
+ }}
447
+
448
+ .primary#semantic-palette-{i} {{
449
+ background-image: linear-gradient(to right, #4338ca 0%, #4338ca 71%, {opt.colors[i]} 100%);
450
+ color: white;
451
+ }}
452
+ """
453
+
454
+
455
+ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
456
+
457
+ iface = argparse.Namespace()
458
+
459
+ def _define_state():
460
+ state = argparse.Namespace()
461
+
462
+ # Cursor.
463
+ state.current_palette = 0 # 0: Background; 1,2,3,...: Layers
464
+ state.model_id = list(model_dict.keys())[0]
465
+ state.style_name = '(None)'
466
+ state.quality_name = 'Standard v3.1'
467
+
468
+ # State variables (one-hot).
469
+ state.active_palettes = 1
470
+
471
+ # Front-end initialized to the default values.
472
+ prompt_props_ = prompt_props()
473
+ state.prompt_names = [
474
+ 'πŸŒ„ Background',
475
+ 'πŸ‘§ Girl',
476
+ 'πŸ‘¦ Boy',
477
+ ] + prompt_props_ + ['🎨 New Palette' for _ in range(opt.max_palettes - 5)]
478
+ state.prompts = [
479
+ prompt_background(),
480
+ prompt_girl(),
481
+ prompt_boy(),
482
+ ] + [prompts_props[k] for k in prompt_props_] + ['' for _ in range(opt.max_palettes - 5)]
483
+ state.neg_prompts = [
484
+ opt.default_negative_prompt
485
+ + (', humans, humans, humans' if i == 0 else '')
486
+ for i in range(opt.max_palettes + 1)
487
+ ]
488
+ state.prompt_strengths = [opt.default_prompt_strength for _ in range(opt.max_palettes)]
489
+ state.mask_strengths = [opt.default_mask_strength for _ in range(opt.max_palettes)]
490
+ state.mask_stds = [opt.default_mask_std for _ in range(opt.max_palettes)]
491
+ state.seed = opt.seed
492
+ return state
493
+
494
+ state = gr.State(value=_define_state)
495
+
496
+
497
+ ### Demo user interface
498
+
499
+ gr.HTML(
500
+ """
501
+ <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
502
+ <div>
503
+ <h1>🧠 Semantic Paint X Animagine XL 3.1 🎨</h1>
504
+ <h5 style="margin: 0;">powered by</h5>
505
+ <h3>StreamMultiDiffusion: Real-Time Interactive Generation with Region-Based Semantic Control</h3>
506
+ <h5 style="margin: 0;">and</h5>
507
+ <h3>Animagine XL 3.1 by Cagliostro Research Lab</h3>
508
+ <h5 style="margin: 0;">If you ❀️ our project, please visit our Github and give us a 🌟!</h5>
509
+ </br>
510
+ <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
511
+ <a href='https://arxiv.org/abs/2403.09055'>
512
+ <img src="https://img.shields.io/badge/arXiv-2403.09055-red">
513
+ </a>
514
+ &nbsp;
515
+ <a href='https://jaerinlee.com/research/StreamMultiDiffusion'>
516
+ <img src='https://img.shields.io/badge/Project-Page-green' alt='Project Page'>
517
+ </a>
518
+ &nbsp;
519
+ <a href='https://github.com/ironjr/StreamMultiDiffusion'>
520
+ <img src='https://img.shields.io/github/stars/ironjr/StreamMultiDiffusion?label=Github&color=blue'>
521
+ </a>
522
+ &nbsp;
523
+ <a href='https://twitter.com/_ironjr_'>
524
+ <img src='https://img.shields.io/twitter/url?label=_ironjr_&url=https%3A%2F%2Ftwitter.com%2F_ironjr_'>
525
+ </a>
526
+ &nbsp;
527
+ <a href='https://github.com/ironjr/StreamMultiDiffusion/blob/main/LICENSE'>
528
+ <img src='https://img.shields.io/badge/license-MIT-lightgrey'>
529
+ </a>
530
+ &nbsp;
531
+ <a href='https://huggingface.co/papers/2403.09055'>
532
+ <img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Paper-yellow'>
533
+ </a>
534
+ &nbsp;
535
+ <a href='https://huggingface.co/spaces/ironjr/SemanticPalette'>
536
+ <img src='https://img.shields.io/badge/%F0%9F%A4%97%20Demo-v1.5-yellow'>
537
+ </a>
538
+ &nbsp;
539
+ <a href='https://huggingface.co/cagliostrolab/animagine-xl-3.1'>
540
+ <img src='https://img.shields.io/badge/%F0%9F%A4%97%20Model-AnimagineXL3.1-yellow'>
541
+ </a>
542
+ </div>
543
+ </div>
544
+ </div>
545
+ <div>
546
+ </br>
547
+ </div>
548
+ """
549
+ )
550
+
551
+ with gr.Row():
552
+
553
+ iface.image_slot = gr.Image(
554
+ interactive=False,
555
+ show_label=False,
556
+ show_download_button=True,
557
+ type='pil',
558
+ label='Generated Result',
559
+ elem_id='output-screen',
560
+ value=lambda: random.choice(example_images),
561
+ )
562
+
563
+ with gr.Row():
564
+
565
+ with gr.Column(scale=1):
566
+
567
+ with gr.Group(elem_id='semantic-palette'):
568
+
569
+ gr.HTML(
570
+ """
571
+ <div style="justify-content: center; align-items: center;">
572
+ <br/>
573
+ <h3 style="margin: 0; text-align: center;"><b>🧠 Semantic Palette 🎨</b></h3>
574
+ <br/>
575
+ </div>
576
+ """
577
+ )
578
+
579
+ iface.btn_semantics = [gr.Button(
580
+ value=state.value.prompt_names[0],
581
+ variant='primary',
582
+ elem_id='semantic-palette-0',
583
+ )]
584
+ for i in range(opt.max_palettes):
585
+ iface.btn_semantics.append(gr.Button(
586
+ value=state.value.prompt_names[i + 1],
587
+ variant='secondary',
588
+ visible=(i < state.value.active_palettes),
589
+ elem_id=f'semantic-palette-{i + 1}'
590
+ ))
591
+
592
+ iface.btn_add_palette = gr.Button(
593
+ value='Create New Semantic Brush',
594
+ variant='primary',
595
+ )
596
+
597
+ with gr.Accordion(label='Import/Export Semantic Palette', open=False):
598
+ iface.tbox_state_import = gr.Textbox(label='Put Palette JSON Here To Import')
599
+ iface.json_state_export = gr.JSON(label='Exported Palette')
600
+ iface.btn_export_state = gr.Button("Export Palette ➑️ JSON", variant='primary')
601
+ iface.btn_import_state = gr.Button("Import JSON ➑️ Palette", variant='secondary')
602
+
603
+ gr.HTML(
604
+ """
605
+ <div>
606
+ </br>
607
+ </div>
608
+ <div style="justify-content: center; align-items: center;">
609
+ <h3 style="margin: 0; text-align: center;"><b>❓Usage❓</b></h3>
610
+ </br>
611
+ <div style="justify-content: center; align-items: left; text-align: left;">
612
+ <p>1-1. Type in the background prompt. Background is not required if you paint the whole drawpad.</p>
613
+ <p>1-2. (Optional: <em><b>Inpainting mode</b></em>) Uploading a background image will make the app into inpainting mode. Removing the image returns to the creation mode. In the inpainting mode, increasing the <em>Mask Blur STD</em> > 8 for every colored palette is recommended for smooth boundaries.</p>
614
+ <p>2. Select a semantic brush by clicking onto one in the <b>Semantic Palette</b> above. Edit prompt for the semantic brush.</p>
615
+ <p>2-1. If you are willing to draw more diverse images, try <b>Create New Semantic Brush</b>.</p>
616
+ <p>3. Start drawing in the <b>Semantic Drawpad</b> tab. The brush color is directly linked to the semantic brushes.</p>
617
+ <p>4. Click [<b>GENERATE!</b>] button to create your (large-scale) artwork!</p>
618
+ </div>
619
+ </div>
620
+ """
621
+ )
622
+
623
+ gr.HTML(
624
+ """
625
+ <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
626
+ <h5 style="margin: 0;"><b>... or run in your own πŸ€— space!</b></h5>
627
+ </div>
628
+ """
629
+ )
630
+
631
+ gr.DuplicateButton()
632
+
633
+ with gr.Column(scale=4):
634
+
635
+ with gr.Row():
636
+
637
+ with gr.Column(scale=3):
638
+
639
+ iface.ctrl_semantic = gr.ImageEditor(
640
+ image_mode='RGBA',
641
+ sources=['upload', 'clipboard', 'webcam'],
642
+ transforms=['crop'],
643
+ crop_size=(opt.width, opt.height),
644
+ brush=gr.Brush(
645
+ colors=opt.colors[1:],
646
+ color_mode="fixed",
647
+ ),
648
+ type='pil',
649
+ label='Semantic Drawpad',
650
+ elem_id='drawpad',
651
+ )
652
+
653
+ with gr.Column(scale=1):
654
+
655
+ iface.btn_generate = gr.Button(
656
+ value='Generate!',
657
+ variant='primary',
658
+ # scale=1,
659
+ elem_id='run-button'
660
+ )
661
+ with gr.Group(elem_id="share-btn-container"):
662
+ gr.HTML(community_icon_html)
663
+ gr.HTML(loading_icon_html)
664
+ iface.btn_share = gr.Button("Share with Community", elem_id="share-btn")
665
+
666
+ iface.model_select = gr.Radio(
667
+ list(model_dict.keys()),
668
+ label='Stable Diffusion Checkpoint',
669
+ info='Choose your favorite style.',
670
+ value=state.value.model_id,
671
+ )
672
+
673
+ with gr.Accordion(label='Prompt Engineering', open=True):
674
+ iface.quality_select = gr.Dropdown(
675
+ label='Quality Presets',
676
+ interactive=True,
677
+ choices=list(_quality_dict.keys()),
678
+ value='Standard v3.1',
679
+ )
680
+ iface.style_select = gr.Radio(
681
+ label='Style Preset',
682
+ container=True,
683
+ interactive=True,
684
+ choices=list(_style_dict.keys()),
685
+ value='(None)',
686
+ )
687
+
688
+ with gr.Group(elem_id='control-panel'):
689
+
690
+ with gr.Row():
691
+ iface.tbox_prompt = gr.Textbox(
692
+ label='Edit Prompt for Background',
693
+ info='What do you want to draw?',
694
+ value=state.value.prompts[0],
695
+ placeholder=lambda: random.choice(prompt_suggestions),
696
+ scale=2,
697
+ )
698
+
699
+ iface.tbox_name = gr.Textbox(
700
+ label='Edit Brush Name',
701
+ info='Just for your convenience.',
702
+ value=state.value.prompt_names[0],
703
+ placeholder='πŸŒ„ Background',
704
+ scale=1,
705
+ )
706
+
707
+ with gr.Row():
708
+ iface.tbox_neg_prompt = gr.Textbox(
709
+ label='Edit Negative Prompt for Background',
710
+ info='Add unwanted objects for this semantic brush.',
711
+ value=opt.default_negative_prompt,
712
+ scale=2,
713
+ )
714
+
715
+ iface.slider_strength = gr.Slider(
716
+ label='Prompt Strength',
717
+ info='Blends fg & bg in the prompt level, >0.8 Preferred.',
718
+ minimum=0.5,
719
+ maximum=1.0,
720
+ value=opt.default_prompt_strength,
721
+ scale=1,
722
+ )
723
+
724
+ with gr.Row():
725
+ iface.slider_alpha = gr.Slider(
726
+ label='Mask Alpha',
727
+ info='Factor multiplied to the mask before quantization. Extremely sensitive, >0.98 Preferred.',
728
+ minimum=0.5,
729
+ maximum=1.0,
730
+ value=opt.default_mask_strength,
731
+ )
732
+
733
+ iface.slider_std = gr.Slider(
734
+ label='Mask Blur STD',
735
+ info='Blends fg & bg in the latent level, 0 for generation, 8-32 for inpainting.',
736
+ minimum=0.0001,
737
+ maximum=100.0,
738
+ value=opt.default_mask_std,
739
+ )
740
+
741
+ iface.slider_seed = gr.Slider(
742
+ label='Seed',
743
+ info='The global seed.',
744
+ minimum=-1,
745
+ maximum=2147483647,
746
+ step=1,
747
+ value=opt.seed,
748
+ )
749
+
750
+ ### Attach event handlers
751
+
752
+ for idx, btn in enumerate(iface.btn_semantics):
753
+ btn.click(
754
+ fn=partial(select_palette, idx=idx),
755
+ inputs=[state, btn],
756
+ outputs=[state] + iface.btn_semantics + [
757
+ iface.tbox_name,
758
+ iface.tbox_prompt,
759
+ iface.tbox_neg_prompt,
760
+ iface.slider_alpha,
761
+ iface.slider_strength,
762
+ iface.slider_std,
763
+ ],
764
+ api_name=f'select_palette_{idx}',
765
+ )
766
+
767
+ iface.btn_add_palette.click(
768
+ fn=add_palette,
769
+ inputs=state,
770
+ outputs=[state, iface.btn_add_palette] + iface.btn_semantics[1:],
771
+ api_name='create_new',
772
+ )
773
+
774
+ iface.btn_generate.click(
775
+ fn=run,
776
+ inputs=[state, iface.ctrl_semantic],
777
+ outputs=iface.image_slot,
778
+ api_name='run',
779
+ )
780
+
781
+ iface.slider_alpha.input(
782
+ fn=change_mask_strength,
783
+ inputs=[state, iface.slider_alpha],
784
+ outputs=state,
785
+ api_name='change_alpha',
786
+ )
787
+ iface.slider_std.input(
788
+ fn=change_std,
789
+ inputs=[state, iface.slider_std],
790
+ outputs=state,
791
+ api_name='change_std',
792
+ )
793
+ iface.slider_strength.input(
794
+ fn=change_prompt_strength,
795
+ inputs=[state, iface.slider_strength],
796
+ outputs=state,
797
+ api_name='change_strength',
798
+ )
799
+ iface.slider_seed.input(
800
+ fn=reset_seed,
801
+ inputs=[state, iface.slider_seed],
802
+ outputs=state,
803
+ api_name='reset_seed',
804
+ )
805
+
806
+ iface.tbox_name.input(
807
+ fn=rename_prompt,
808
+ inputs=[state, iface.tbox_name],
809
+ outputs=[state] + iface.btn_semantics,
810
+ api_name='prompt_rename',
811
+ )
812
+ iface.tbox_prompt.input(
813
+ fn=change_prompt,
814
+ inputs=[state, iface.tbox_prompt],
815
+ outputs=state,
816
+ api_name='prompt_edit',
817
+ )
818
+ iface.tbox_neg_prompt.input(
819
+ fn=change_neg_prompt,
820
+ inputs=[state, iface.tbox_neg_prompt],
821
+ outputs=state,
822
+ api_name='neg_prompt_edit',
823
+ )
824
+
825
+ iface.model_select.change(
826
+ fn=select_model,
827
+ inputs=[state, iface.model_select],
828
+ outputs=state,
829
+ api_name='model_select',
830
+ )
831
+ iface.style_select.change(
832
+ fn=select_style,
833
+ inputs=[state, iface.style_select],
834
+ outputs=state,
835
+ api_name='style_select',
836
+ )
837
+ iface.quality_select.change(
838
+ fn=select_quality,
839
+ inputs=[state, iface.quality_select],
840
+ outputs=state,
841
+ api_name='quality_select',
842
+ )
843
+
844
+ iface.btn_share.click(None, [], [], _js=share_js)
845
+
846
+ iface.btn_export_state.click(lambda x: vars(x), state, iface.json_state_export)
847
+ iface.btn_import_state.click(import_state, [state, iface.tbox_state_import], [
848
+ state,
849
+ *iface.btn_semantics,
850
+ iface.model_select,
851
+ iface.style_select,
852
+ iface.quality_select,
853
+ iface.tbox_prompt,
854
+ iface.tbox_name,
855
+ iface.tbox_neg_prompt,
856
+ iface.slider_strength,
857
+ iface.slider_alpha,
858
+ iface.slider_std,
859
+ iface.slider_seed,
860
+ ])
861
+
862
+ gr.HTML(
863
+ """
864
+ <div class="footer">
865
+ <p>We thank <a href="https://cagliostrolab.net/">Cagliostro Research Lab</a> for their permission to use <a href="https://huggingface.co/cagliostrolab/animagine-xl-3.1">Animagine XL 3.1</a> model under academic purpose.
866
+ Note that the MIT license only applies to StreamMultiDiffusion and Semantic Palette demo app, but not Animagine XL 3.1 model, which is distributed under <a href="https://freedevproject.org/faipl-1.0-sd/">Fair AI Public License 1.0-SD</a>.
867
+ </p>
868
+ </div>
869
+ """
870
+ )
871
+
872
+ if __name__ == '__main__':
873
+ demo.launch(server_port=opt.port)
examples/prompt_background.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ Maximalism, best quality, high quality, no humans, background, clear sky, γ… black sky, starry universe, planets
2
+ Maximalism, best quality, high quality, no humans, background, clear sky, blue sky
3
+ Maximalism, best quality, high quality, no humans, background, universe, void, black, galaxy, galaxy, stars, stars, stars
4
+ Maximalism, best quality, high quality, no humans, background, galaxy
5
+ Maximalism, best quality, high quality, no humans, background, sky, daylight
6
+ Maximalism, best quality, high quality, no humans, background, skyscrappers, rooftop, city of light, helicopters, bright night, sky
7
+ Maximalism, best quality, high quality, flowers, flowers, flowers, flower garden, no humans, background
8
+ Maximalism, best quality, high quality, flowers, flowers, flowers, flower garden
examples/prompt_background_advanced.txt ADDED
The diff for this file is too large to render. See raw diff
 
examples/prompt_boy.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 1boy, looking at viewer, brown hair, blue shirt
2
+ 1boy, looking at viewer, brown hair, red shirt
3
+ 1boy, looking at viewer, brown hair, purple shirt
4
+ 1boy, looking at viewer, brown hair, orange shirt
5
+ 1boy, looking at viewer, brown hair, yellow shirt
6
+ 1boy, looking at viewer, brown hair, green shirt
7
+ 1boy, looking back, side shaved hair, cyberpunk cloths, robotic suit, large body
8
+ 1boy, looking back, short hair, renaissance cloths, noble boy
9
+ 1boy, looking back, long hair, ponytail, leather jacket, heavy metal boy
10
+ 1boy, looking at viewer, a king, kingly grace, majestic cloths, crown
11
+ 1boy, looking at viewer, an astronaut, brown hair, faint smile, engineer
12
+ 1boy, looking at viewer, a medieval knight, helmet, swordman, plate armour
13
+ 1boy, looking at viewer, black haired, old eastern cloth
14
+ 1boy, looking back, messy hair, suit, short beard, noir
15
+ 1boy, looking at viewer, cute face, light smile, starry eyes, jeans
examples/prompt_girl.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 1girl, looking at viewer, pretty face, light smile, haughty smile, proud, long wavy hair, charcoal dark eyes, chinese cloths
2
+ 1girl, looking at viewer, princess, pretty face, light smile, haughty smile, proud, long wavy hair, charcoal dark eyes, majestic gown
3
+ 1girl, looking at viewer, astronaut girl, long red hair, space suit, black starry eyes, happy face, pretty face
4
+ 1girl, looking at viewer, fantasy adventurer, backpack
5
+ 1girl, looking at viewer, astronaut girl, spacesuit, eva, happy face
6
+ 1girl, looking at viewer, soldier, rusty cloths, backpack, pretty face, sad smile, tears
7
+ 1girl, looking at viewer, majestic cloths, long hair, glittering eye, pretty face
8
+ 1girl, looking at viewer, from behind, majestic cloths, long hair, glittering eye
9
+ 1girl, looking at viewer, evil smile, very short hair, suit, evil genius
10
+ 1girl, looking at viewer, elven queen, green hair, haughty face, eyes wide open, crazy smile, brown jacket, leaves
11
+ 1girl, looking at viewer, purple hair, happy face, black leather jacket
12
+ 1girl, looking at viewer, pink hair, happy face, blue jeans, black leather jacket
13
+ 1girl, looking at viewer, knight, medium length hair, red hair, plate armour, blue eyes, sad, pretty face, determined face
14
+ 1girl, looking at viewer, pretty face, light smile, orange hair, casual cloths
15
+ 1girl, looking at viewer, pretty face, large smile, open mouth, uniform, mcdonald employee, short wavy hair
16
+ 1girl, looking at viewer, brown hair, ponytail, happy face, bright smile, blue jeans and white shirt
examples/prompt_props.txt ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 🏯 Palace, Gyeongbokgung palace
2
+ 🌳 Garden, Chinese garden
3
+ πŸ›οΈ Rome, Ancient city of Rome
4
+ 🧱 Wall, Castle wall
5
+ πŸ”΄ Mars, Martian desert, Red rocky desert
6
+ 🌻 Grassland, Grasslands
7
+ 🏑 Village, A fantasy village
8
+ πŸ‰ Dragon, a flying chinese dragon
9
+ 🌏 Earth, Earth seen from ISS
10
+ πŸš€ Space Station, the international space station
11
+ πŸͺ» Grassland, Rusty grassland with flowers
12
+ πŸ–ΌοΈ Tapestry, majestic tapestry, glittering effect, glowing in light, mural painting with mountain
13
+ πŸ™οΈ City Ruin, city, ruins, ruins, ruins, deserted
14
+ πŸ™οΈ Renaissance City, renaissance city, renaissance city, renaissance city
15
+ 🌷 Flowers, Flower garden
16
+ 🌼 Flowers, Flower garden, spring garden
17
+ 🌹 Flowers, Flowers flowers, flowers
18
+ ⛰️ Dolomites Mountains, Dolomites
19
+ ⛰️ Himalayas Mountains, Himalayas
20
+ ⛰️ Alps Mountains, Alps
21
+ ⛰️ Mountains, Mountains
22
+ ❄️⛰️ Mountains, Winter mountains
23
+ πŸŒ·β›°οΈ Mountains, Spring mountains
24
+ πŸŒžβ›°οΈ Mountains, Summer mountains
25
+ 🌡 Desert, A sandy desert, dunes
26
+ πŸͺ¨πŸŒ΅ Desert, A rocky desert
27
+ πŸ’¦ Waterfall, A giant waterfall
28
+ 🌊 Ocean, Ocean
29
+ ⛱️ Seashore, Seashore
30
+ πŸŒ… Sea Horizon, Sea horizon
31
+ 🌊 Lake, Clear blue lake
32
+ πŸ’» Computer, A giant supecomputer
33
+ 🌳 Tree, A giant tree
34
+ 🌳 Forest, A forest
35
+ 🌳🌳 Forest, A dense forest
36
+ 🌲 Forest, Winter forest
37
+ 🌴 Forest, Summer forest, tropical forest
38
+ πŸ‘’ Hat, A hat
39
+ 🐢 Dog, Doggy body parts
40
+ 😻 Cat, A cat
41
+ πŸ¦‰ Owl, A small sitting owl
42
+ πŸ¦… Eagle, A small sitting eagle
43
+ πŸš€ Rocket, A flying rocket
model.py ADDED
@@ -0,0 +1,1410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Jaerin Lee
2
+
3
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
4
+ # of this software and associated documentation files (the "Software"), to deal
5
+ # in the Software without restriction, including without limitation the rights
6
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7
+ # copies of the Software, and to permit persons to whom the Software is
8
+ # furnished to do so, subject to the following conditions:
9
+
10
+ # The above copyright notice and this permission notice shall be included in all
11
+ # copies or substantial portions of the Software.
12
+
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
19
+ # SOFTWARE.
20
+
21
+ from transformers import Blip2Processor, Blip2ForConditionalGeneration
22
+ from diffusers import (
23
+ AutoencoderTiny,
24
+ StableDiffusionXLPipeline,
25
+ UNet2DConditionModel,
26
+ EulerDiscreteScheduler,
27
+ )
28
+ from diffusers.models.attention_processor import (
29
+ AttnProcessor2_0,
30
+ FusedAttnProcessor2_0,
31
+ LoRAAttnProcessor2_0,
32
+ LoRAXFormersAttnProcessor,
33
+ XFormersAttnProcessor,
34
+ )
35
+ from diffusers.loaders import (
36
+ StableDiffusionXLLoraLoaderMixin,
37
+ TextualInversionLoaderMixin,
38
+ )
39
+ from diffusers.utils import (
40
+ USE_PEFT_BACKEND,
41
+ logging,
42
+ )
43
+ from huggingface_hub import hf_hub_download
44
+ from safetensors.torch import load_file
45
+
46
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
47
+
48
+ import torch
49
+ import torch.nn as nn
50
+ import torch.nn.functional as F
51
+ import torchvision.transforms as T
52
+ from einops import rearrange
53
+
54
+ from typing import Tuple, List, Literal, Optional, Union
55
+ from tqdm import tqdm
56
+ from PIL import Image
57
+
58
+ from util import gaussian_lowpass, blend, get_panorama_views, shift_to_mask_bbox_center
59
+
60
+
61
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
62
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
63
+ """
64
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
65
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
66
+ """
67
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
68
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
69
+ # rescale the results from guidance (fixes overexposure)
70
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
71
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
72
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
73
+ return noise_cfg
74
+
75
+
76
+ class StableMultiDiffusionSDXLPipeline(nn.Module):
77
+ def __init__(
78
+ self,
79
+ device: torch.device,
80
+ dtype: torch.dtype = torch.float16,
81
+ hf_key: Optional[str] = None,
82
+ lora_key: Optional[str] = None,
83
+ load_from_local: bool = False, # Turn on if you have already downloaed LoRA & Hugging Face hub is down.
84
+ default_mask_std: float = 1.0, # 8.0
85
+ default_mask_strength: float = 1.0,
86
+ default_prompt_strength: float = 1.0, # 8.0
87
+ default_bootstrap_steps: int = 1,
88
+ default_boostrap_mix_steps: float = 1.0,
89
+ default_bootstrap_leak_sensitivity: float = 0.2,
90
+ default_preprocess_mask_cover_alpha: float = 0.3,
91
+ t_index_list: List[int] = [0, 4, 12, 25, 37], # [0, 5, 16, 18, 20, 37], # # [0, 12, 25, 37], # Magic number.
92
+ mask_type: Literal['discrete', 'semi-continuous', 'continuous'] = 'discrete',
93
+ has_i2t: bool = True,
94
+ lora_weight: float = 1.0,
95
+ ) -> None:
96
+ r"""Stabilized MultiDiffusion for fast sampling.
97
+
98
+ Accelrated region-based text-to-image synthesis with Latent Consistency
99
+ Model while preserving mask fidelity and quality.
100
+
101
+ Args:
102
+ device (torch.device): Specify CUDA device.
103
+ hf_key (Optional[str]): Custom StableDiffusion checkpoint for
104
+ stylized generation.
105
+ lora_key (Optional[str]): Custom Lightning LoRA for acceleration.
106
+ load_from_local (bool): Turn on if you have already downloaed LoRA
107
+ & Hugging Face hub is down.
108
+ default_mask_std (float): Preprocess mask with Gaussian blur with
109
+ specified standard deviation.
110
+ default_mask_strength (float): Preprocess mask by multiplying it
111
+ globally with the specified variable. Caution: extremely
112
+ sensitive. Recommended range: 0.98-1.
113
+ default_prompt_strength (float): Preprocess foreground prompts
114
+ globally by linearly interpolating its embedding with the
115
+ background prompt embeddint with specified mix ratio. Useful
116
+ control handle for foreground blending. Recommended range:
117
+ 0.5-1.
118
+ default_bootstrap_steps (int): Bootstrapping stage steps to
119
+ encourage region separation. Recommended range: 1-3.
120
+ default_boostrap_mix_steps (float): Bootstrapping background is a
121
+ linear interpolation between background latent and the white
122
+ image latent. This handle controls the mix ratio. Available
123
+ range: 0-(number of bootstrapping inference steps). For
124
+ example, 2.3 means that for the first two steps, white image
125
+ is used as a bootstrapping background and in the third step,
126
+ mixture of white (0.3) and registered background (0.7) is used
127
+ as a bootstrapping background.
128
+ default_bootstrap_leak_sensitivity (float): Postprocessing at each
129
+ inference step by masking away the remaining bootstrap
130
+ backgrounds t Recommended range: 0-1.
131
+ default_preprocess_mask_cover_alpha (float): Optional preprocessing
132
+ where each mask covered by other masks is reduced in its alpha
133
+ value by this specified factor.
134
+ t_index_list (List[int]): The default scheduling for LCM scheduler.
135
+ mask_type (Literal['discrete', 'semi-continuous', 'continuous']):
136
+ defines the mask quantization modes. Details in the codes of
137
+ `self.process_mask`. Basically, this (subtly) controls the
138
+ smoothness of foreground-background blending. More continuous
139
+ means more blending, but smaller generated patch depending on
140
+ the mask standard deviation.
141
+ has_i2t (bool): Automatic background image to text prompt con-
142
+ version with BLIP-2 model. May not be necessary for the non-
143
+ streaming application.
144
+ lora_weight (float): Adjusts weight of the LCM/Lightning LoRA.
145
+ Heavily affects the overall quality!
146
+ """
147
+ super().__init__()
148
+
149
+ self.device = device
150
+ self.dtype = dtype
151
+
152
+ self.default_mask_std = default_mask_std
153
+ self.default_mask_strength = default_mask_strength
154
+ self.default_prompt_strength = default_prompt_strength
155
+ self.default_t_list = t_index_list
156
+ self.default_bootstrap_steps = default_bootstrap_steps
157
+ self.default_boostrap_mix_steps = default_boostrap_mix_steps
158
+ self.default_bootstrap_leak_sensitivity = default_bootstrap_leak_sensitivity
159
+ self.default_preprocess_mask_cover_alpha = default_preprocess_mask_cover_alpha
160
+ self.mask_type = mask_type
161
+
162
+ # Create model.
163
+ print(f'[INFO] Loading Stable Diffusion...')
164
+ variant = None
165
+ model_ckpt = None
166
+ lora_ckpt = None
167
+ lightning_repo = 'ByteDance/SDXL-Lightning'
168
+ if hf_key is not None:
169
+ print(f'[INFO] Using Hugging Face custom model key: {hf_key}')
170
+ model_key = hf_key
171
+ lora_ckpt = 'sdxl_lightning_4step_lora.safetensors'
172
+
173
+ self.pipe = StableDiffusionXLPipeline.from_pretrained(model_key, variant=variant, torch_dtype=self.dtype).to(self.device)
174
+ self.pipe.load_lora_weights(hf_hub_download(lightning_repo, lora_ckpt), adapter_name='lightning')
175
+ self.pipe.set_adapters(["lightning"], adapter_weights=[lora_weight])
176
+ self.pipe.fuse_lora()
177
+ else:
178
+ model_key = 'stabilityai/stable-diffusion-xl-base-1.0'
179
+ variant = 'fp16'
180
+ model_ckpt = "sdxl_lightning_4step_unet.safetensors" # Use the correct ckpt for your step setting!
181
+
182
+ unet = UNet2DConditionModel.from_config(model_key, subfolder='unet').to(self.device, self.dtype)
183
+ unet.load_state_dict(load_file(hf_hub_download(lightning_repo, model_ckpt), device=self.device))
184
+ self.pipe = StableDiffusionXLPipeline.from_pretrained(model_key, unet=unet, torch_dtype=self.dtype, variant=variant).to(self.device)
185
+
186
+ # Create model
187
+ if has_i2t:
188
+ self.i2t_processor = Blip2Processor.from_pretrained('Salesforce/blip2-opt-2.7b')
189
+ self.i2t_model = Blip2ForConditionalGeneration.from_pretrained('Salesforce/blip2-opt-2.7b')
190
+
191
+ # Use SDXL-Lightning LoRA by default.
192
+ self.pipe.scheduler = EulerDiscreteScheduler.from_config(
193
+ self.pipe.scheduler.config, timestep_spacing="trailing")
194
+ self.scheduler = self.pipe.scheduler
195
+ self.default_num_inference_steps = 4
196
+ self.default_guidance_scale = 0.0
197
+
198
+ if t_index_list is None:
199
+ self.prepare_lightning_schedule(
200
+ list(range(self.default_num_inference_steps)),
201
+ self.default_num_inference_steps,
202
+ )
203
+ else:
204
+ self.prepare_lightning_schedule(t_index_list, 50)
205
+
206
+ self.vae = self.pipe.vae
207
+ self.tokenizer = self.pipe.tokenizer
208
+ self.tokenizer_2 = self.pipe.tokenizer_2
209
+ self.text_encoder = self.pipe.text_encoder
210
+ self.text_encoder_2 = self.pipe.text_encoder_2
211
+ self.unet = self.pipe.unet
212
+ self.vae_scale_factor = self.pipe.vae_scale_factor
213
+
214
+ # Prepare white background for bootstrapping.
215
+ self.get_white_background(1024, 1024)
216
+
217
+ print(f'[INFO] Model is loaded!')
218
+
219
+ def prepare_lightning_schedule(
220
+ self,
221
+ t_index_list: Optional[List[int]] = None,
222
+ num_inference_steps: Optional[int] = None,
223
+ s_churn: float = 0.0,
224
+ s_tmin: float = 0.0,
225
+ s_tmax: float = float("inf"),
226
+ ) -> None:
227
+ r"""Set up different inference schedule for the diffusion model.
228
+
229
+ You do not have to run this explicitly if you want to use the default
230
+ setting, but if you want other time schedules, run this function
231
+ between the module initialization and the main call.
232
+
233
+ Note:
234
+ - Recommended t_index_lists for LCMs:
235
+ - [0, 12, 25, 37]: Default schedule for 4 steps. Best for
236
+ panorama. Not recommended if you want to use bootstrapping.
237
+ Because bootstrapping stage affects the initial structuring
238
+ of the generated image & in this four step LCM, this is done
239
+ with only at the first step, the structure may be distorted.
240
+ - [0, 4, 12, 25, 37]: Recommended if you would use 1-step boot-
241
+ strapping. Default initialization in this implementation.
242
+ - [0, 5, 16, 18, 20, 37]: Recommended if you would use 2-step
243
+ bootstrapping.
244
+ - Due to the characteristic of SD1.5 LCM LoRA, setting
245
+ `num_inference_steps` larger than 20 may results in overly blurry
246
+ and unrealistic images. Beware!
247
+
248
+ Args:
249
+ t_index_list (Optional[List[int]]): The specified scheduling step
250
+ regarding the maximum timestep as `num_inference_steps`, which
251
+ is by default, 50. That means that
252
+ `t_index_list=[0, 12, 25, 37]` is a relative time indices basd
253
+ on the full scale of 50. If None, reinitialize the module with
254
+ the default value.
255
+ num_inference_steps (Optional[int]): The maximum timestep of the
256
+ sampler. Defines relative scale of the `t_index_list`. Rarely
257
+ used in practice. If None, reinitialize the module with the
258
+ default value.
259
+ """
260
+ if t_index_list is None:
261
+ t_index_list = self.default_t_list
262
+ if num_inference_steps is None:
263
+ num_inference_steps = self.default_num_inference_steps
264
+
265
+ self.scheduler.set_timesteps(num_inference_steps)
266
+ self.timesteps = self.scheduler.timesteps[torch.tensor(t_index_list)]
267
+
268
+ # EulerDiscreteScheduler
269
+
270
+ self.sigmas = self.scheduler.sigmas[torch.tensor(t_index_list)]
271
+ self.sigmas_next = torch.cat([self.sigmas, self.sigmas.new_zeros(1)])[1:]
272
+ sigma_mask = torch.logical_and(s_tmin <= self.sigmas, self.sigmas <= s_tmax)
273
+ # self.gammas = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) * sigma_mask
274
+ self.gammas = min(s_churn / (num_inference_steps - 1), 2**0.5 - 1) * sigma_mask
275
+ self.sigma_hats = self.sigmas * (self.gammas + 1)
276
+ self.dt = self.sigmas_next - self.sigma_hats
277
+
278
+ noise_lvs = self.sigmas * (self.sigmas**2 + 1)**(-0.5)
279
+ self.noise_lvs = noise_lvs[None, :, None, None, None]
280
+ self.next_noise_lvs = torch.cat([noise_lvs[1:], noise_lvs.new_zeros(1)])[None, :, None, None, None]
281
+
282
+ def upcast_vae(self):
283
+ dtype = self.vae.dtype
284
+ self.vae.to(dtype=torch.float32)
285
+ use_torch_2_0_or_xformers = isinstance(
286
+ self.vae.decoder.mid_block.attentions[0].processor,
287
+ (
288
+ AttnProcessor2_0,
289
+ XFormersAttnProcessor,
290
+ LoRAXFormersAttnProcessor,
291
+ LoRAAttnProcessor2_0,
292
+ FusedAttnProcessor2_0,
293
+ ),
294
+ )
295
+ # if xformers or torch_2_0 is used attention block does not need
296
+ # to be in float32 which can save lots of memory
297
+ if use_torch_2_0_or_xformers:
298
+ self.vae.post_quant_conv.to(dtype)
299
+ self.vae.decoder.conv_in.to(dtype)
300
+ self.vae.decoder.mid_block.to(dtype)
301
+
302
+ def _get_add_time_ids(
303
+ self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None
304
+ ):
305
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
306
+
307
+ passed_add_embed_dim = (
308
+ self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
309
+ )
310
+ expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
311
+
312
+ if expected_add_embed_dim != passed_add_embed_dim:
313
+ raise ValueError(
314
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
315
+ )
316
+
317
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
318
+ return add_time_ids
319
+
320
+ def encode_prompt(
321
+ self,
322
+ prompt: str,
323
+ prompt_2: Optional[str] = None,
324
+ device: Optional[torch.device] = None,
325
+ num_images_per_prompt: int = 1,
326
+ do_classifier_free_guidance: bool = True,
327
+ negative_prompt: Optional[str] = None,
328
+ negative_prompt_2: Optional[str] = None,
329
+ prompt_embeds: Optional[torch.FloatTensor] = None,
330
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
331
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
332
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
333
+ lora_scale: Optional[float] = None,
334
+ clip_skip: Optional[int] = None,
335
+ ):
336
+ r"""
337
+ Encodes the prompt into text encoder hidden states.
338
+
339
+ Args:
340
+ prompt (`str` or `List[str]`, *optional*):
341
+ prompt to be encoded
342
+ prompt_2 (`str` or `List[str]`, *optional*):
343
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
344
+ used in both text-encoders
345
+ device: (`torch.device`):
346
+ torch device
347
+ num_images_per_prompt (`int`):
348
+ number of images that should be generated per prompt
349
+ do_classifier_free_guidance (`bool`):
350
+ whether to use classifier free guidance or not
351
+ negative_prompt (`str` or `List[str]`, *optional*):
352
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
353
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
354
+ less than `1`).
355
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
356
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
357
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
358
+ prompt_embeds (`torch.FloatTensor`, *optional*):
359
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
360
+ provided, text embeddings will be generated from `prompt` input argument.
361
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
362
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
363
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
364
+ argument.
365
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
366
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
367
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
368
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
369
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
370
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
371
+ input argument.
372
+ lora_scale (`float`, *optional*):
373
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
374
+ clip_skip (`int`, *optional*):
375
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
376
+ the output of the pre-final layer will be used for computing the prompt embeddings.
377
+ """
378
+ device = device or self._execution_device
379
+
380
+ # set lora scale so that monkey patched LoRA
381
+ # function of text encoder can correctly access it
382
+ if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):
383
+ self._lora_scale = lora_scale
384
+
385
+ # dynamically adjust the LoRA scale
386
+ if self.text_encoder is not None:
387
+ if not USE_PEFT_BACKEND:
388
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
389
+ else:
390
+ scale_lora_layers(self.text_encoder, lora_scale)
391
+
392
+ if self.text_encoder_2 is not None:
393
+ if not USE_PEFT_BACKEND:
394
+ adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
395
+ else:
396
+ scale_lora_layers(self.text_encoder_2, lora_scale)
397
+
398
+ prompt = [prompt] if isinstance(prompt, str) else prompt
399
+
400
+ if prompt is not None:
401
+ batch_size = len(prompt)
402
+ else:
403
+ batch_size = prompt_embeds.shape[0]
404
+
405
+ # Define tokenizers and text encoders
406
+ tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
407
+ text_encoders = (
408
+ [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
409
+ )
410
+
411
+ if prompt_embeds is None:
412
+ prompt_2 = prompt_2 or prompt
413
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
414
+
415
+ # textual inversion: process multi-vector tokens if necessary
416
+ prompt_embeds_list = []
417
+ prompts = [prompt, prompt_2]
418
+ for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
419
+ if isinstance(self, TextualInversionLoaderMixin):
420
+ prompt = self.maybe_convert_prompt(prompt, tokenizer)
421
+
422
+ text_inputs = tokenizer(
423
+ prompt,
424
+ padding="max_length",
425
+ max_length=tokenizer.model_max_length,
426
+ truncation=True,
427
+ return_tensors="pt",
428
+ )
429
+
430
+ text_input_ids = text_inputs.input_ids
431
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
432
+
433
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
434
+ text_input_ids, untruncated_ids
435
+ ):
436
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
437
+ logger.warning(
438
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
439
+ f" {tokenizer.model_max_length} tokens: {removed_text}"
440
+ )
441
+
442
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
443
+
444
+ # We are only ALWAYS interested in the pooled output of the final text encoder
445
+ pooled_prompt_embeds = prompt_embeds[0]
446
+ if clip_skip is None:
447
+ prompt_embeds = prompt_embeds.hidden_states[-2]
448
+ else:
449
+ # "2" because SDXL always indexes from the penultimate layer.
450
+ prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
451
+
452
+ prompt_embeds_list.append(prompt_embeds)
453
+
454
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
455
+
456
+ # get unconditional embeddings for classifier free guidance
457
+ zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
458
+ if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
459
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
460
+ negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
461
+ elif do_classifier_free_guidance and negative_prompt_embeds is None:
462
+ negative_prompt = negative_prompt or ""
463
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
464
+
465
+ # normalize str to list
466
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
467
+ negative_prompt_2 = (
468
+ batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
469
+ )
470
+
471
+ uncond_tokens: List[str]
472
+ if prompt is not None and type(prompt) is not type(negative_prompt):
473
+ raise TypeError(
474
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
475
+ f" {type(prompt)}."
476
+ )
477
+ elif batch_size != len(negative_prompt):
478
+ raise ValueError(
479
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
480
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
481
+ " the batch size of `prompt`."
482
+ )
483
+ else:
484
+ uncond_tokens = [negative_prompt, negative_prompt_2]
485
+
486
+ negative_prompt_embeds_list = []
487
+ for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
488
+ if isinstance(self, TextualInversionLoaderMixin):
489
+ negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
490
+
491
+ max_length = prompt_embeds.shape[1]
492
+ uncond_input = tokenizer(
493
+ negative_prompt,
494
+ padding="max_length",
495
+ max_length=max_length,
496
+ truncation=True,
497
+ return_tensors="pt",
498
+ )
499
+
500
+ negative_prompt_embeds = text_encoder(
501
+ uncond_input.input_ids.to(device),
502
+ output_hidden_states=True,
503
+ )
504
+ # We are only ALWAYS interested in the pooled output of the final text encoder
505
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
506
+ negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
507
+
508
+ negative_prompt_embeds_list.append(negative_prompt_embeds)
509
+
510
+ negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
511
+
512
+ if self.text_encoder_2 is not None:
513
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
514
+ else:
515
+ prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)
516
+
517
+ bs_embed, seq_len, _ = prompt_embeds.shape
518
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
519
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
520
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
521
+
522
+ if do_classifier_free_guidance:
523
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
524
+ seq_len = negative_prompt_embeds.shape[1]
525
+
526
+ if self.text_encoder_2 is not None:
527
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
528
+ else:
529
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)
530
+
531
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
532
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
533
+
534
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
535
+ bs_embed * num_images_per_prompt, -1
536
+ )
537
+ if do_classifier_free_guidance:
538
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
539
+ bs_embed * num_images_per_prompt, -1
540
+ )
541
+
542
+ if self.text_encoder is not None:
543
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
544
+ # Retrieve the original scale by scaling back the LoRA layers
545
+ unscale_lora_layers(self.text_encoder, lora_scale)
546
+
547
+ if self.text_encoder_2 is not None:
548
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
549
+ # Retrieve the original scale by scaling back the LoRA layers
550
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
551
+
552
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
553
+
554
+ @torch.no_grad()
555
+ def get_text_prompts(self, image: Image.Image) -> str:
556
+ r"""A convenient method to extract text prompt from an image.
557
+
558
+ This is called if the user does not provide background prompt but only
559
+ the background image. We use BLIP-2 to automatically generate prompts.
560
+
561
+ Args:
562
+ image (Image.Image): A PIL image.
563
+
564
+ Returns:
565
+ A single string of text prompt.
566
+ """
567
+ if hasattr(self, 'i2t_model'):
568
+ question = 'Question: What are in the image? Answer:'
569
+ inputs = self.i2t_processor(image, question, return_tensors='pt')
570
+ out = self.i2t_model.generate(**inputs, max_new_tokens=77)
571
+ prompt = self.i2t_processor.decode(out[0], skip_special_tokens=True).strip()
572
+ return prompt
573
+ else:
574
+ return ''
575
+
576
+ @torch.no_grad()
577
+ def encode_imgs(
578
+ self,
579
+ imgs: torch.Tensor,
580
+ generator: Optional[torch.Generator] = None,
581
+ vae: Optional[nn.Module] = None,
582
+ ) -> torch.Tensor:
583
+ r"""A wrapper function for VAE encoder of the latent diffusion model.
584
+
585
+ Args:
586
+ imgs (torch.Tensor): An image to get StableDiffusion latents.
587
+ Expected shape: (B, 3, H, W). Expected pixel scale: [0, 1].
588
+ generator (Optional[torch.Generator]): Seed for KL-Autoencoder.
589
+ vae (Optional[nn.Module]): Explicitly specify VAE (used for
590
+ the demo application with TinyVAE).
591
+
592
+ Returns:
593
+ An image latent embedding with 1/8 size (depending on the auto-
594
+ encoder. Shape: (B, 4, H//8, W//8).
595
+ """
596
+ def _retrieve_latents(
597
+ encoder_output: torch.Tensor,
598
+ generator: Optional[torch.Generator] = None,
599
+ sample_mode: str = 'sample',
600
+ ):
601
+ if hasattr(encoder_output, 'latent_dist') and sample_mode == 'sample':
602
+ return encoder_output.latent_dist.sample(generator)
603
+ elif hasattr(encoder_output, 'latent_dist') and sample_mode == 'argmax':
604
+ return encoder_output.latent_dist.mode()
605
+ elif hasattr(encoder_output, 'latents'):
606
+ return encoder_output.latents
607
+ else:
608
+ raise AttributeError('Could not access latents of provided encoder_output')
609
+
610
+ vae = self.vae if vae is None else vae
611
+ imgs = 2 * imgs - 1
612
+ latents = vae.config.scaling_factor * _retrieve_latents(vae.encode(imgs), generator=generator)
613
+ return latents
614
+
615
+ @torch.no_grad()
616
+ def decode_latents(self, latents: torch.Tensor, vae: Optional[nn.Module] = None) -> torch.Tensor:
617
+ r"""A wrapper function for VAE decoder of the latent diffusion model.
618
+
619
+ Args:
620
+ latents (torch.Tensor): An image latent to get associated images.
621
+ Expected shape: (B, 4, H//8, W//8).
622
+ vae (Optional[nn.Module]): Explicitly specify VAE (used for
623
+ the demo application with TinyVAE).
624
+
625
+ Returns:
626
+ An image latent embedding with 1/8 size (depending on the auto-
627
+ encoder. Shape: (B, 3, H, W).
628
+ """
629
+ vae = self.vae if vae is None else vae
630
+ latents = 1 / vae.config.scaling_factor * latents
631
+ imgs = vae.decode(latents).sample
632
+ imgs = (imgs / 2 + 0.5).clip_(0, 1)
633
+ return imgs
634
+
635
+ @torch.no_grad()
636
+ def get_white_background(self, height: int, width: int) -> torch.Tensor:
637
+ r"""White background image latent for bootstrapping or in case of
638
+ absent background.
639
+
640
+ Additionally stores the maximally-sized white latent for fast retrieval
641
+ in the future. By default, we initially call this with 1024x1024 sized
642
+ white image, so the function is rarely visited twice.
643
+
644
+ Args:
645
+ height (int): The height of the white *image*, not its latent.
646
+ width (int): The width of the white *image*, not its latent.
647
+
648
+ Returns:
649
+ A white image latent of size (1, 4, height//8, width//8). A cropped
650
+ version of the stored white latent is returned if the requested
651
+ size is smaller than what we already have created.
652
+ """
653
+ if not hasattr(self, 'white') or self.white.shape[-2] < height or self.white.shape[-1] < width:
654
+ white = torch.ones(1, 3, height, width, dtype=self.dtype, device=self.device)
655
+ self.white = self.encode_imgs(white)
656
+ return self.white
657
+ return self.white[..., :(height // self.vae_scale_factor), :(width // self.vae_scale_factor)]
658
+
659
+ @torch.no_grad()
660
+ def process_mask(
661
+ self,
662
+ masks: Union[torch.Tensor, Image.Image, List[Image.Image]],
663
+ strength: Optional[Union[torch.Tensor, float]] = None,
664
+ std: Optional[Union[torch.Tensor, float]] = None,
665
+ height: int = 1024,
666
+ width: int = 1024,
667
+ use_boolean_mask: bool = True,
668
+ timesteps: Optional[torch.Tensor] = None,
669
+ preprocess_mask_cover_alpha: Optional[float] = None,
670
+ ) -> Tuple[torch.Tensor]:
671
+ r"""Fast preprocess of masks for region-based generation with fine-
672
+ grained controls.
673
+
674
+ Mask preprocessing is done in four steps:
675
+ 1. Resizing: Resize the masks into the specified width and height by
676
+ nearest neighbor interpolation.
677
+ 2. (Optional) Ordering: Masks with higher indices are considered to
678
+ cover the masks with smaller indices. Covered masks are decayed
679
+ in its alpha value by the specified factor of
680
+ `preprocess_mask_cover_alpha`.
681
+ 3. Blurring: Gaussian blur is applied to the mask with the specified
682
+ standard deviation (isotropic). This results in gradual increase of
683
+ masked region as the timesteps evolve, naturally blending fore-
684
+ ground and the predesignated background. Not strictly required if
685
+ you want to produce images from scratch withoout background.
686
+ 4. Quantization: Split the real-numbered masks of value between [0, 1]
687
+ into predefined noise levels for each quantized scheduling step of
688
+ the diffusion sampler. For example, if the diffusion model sampler
689
+ has noise level of [0.9977, 0.9912, 0.9735, 0.8499, 0.5840], which
690
+ is the default noise level of this module with schedule [0, 4, 12,
691
+ 25, 37], the masks are split into binary masks whose values are
692
+ greater than these levels. This results in tradual increase of mask
693
+ region as the timesteps increase. Details are described in our
694
+ paper at https://arxiv.org/pdf/2403.09055.pdf.
695
+
696
+ On the Three Modes of `mask_type`:
697
+ `self.mask_type` is predefined at the initialization stage of this
698
+ pipeline. Three possible modes are available: 'discrete', 'semi-
699
+ continuous', and 'continuous'. These define the mask quantization
700
+ modes we use. Basically, this (subtly) controls the smoothness of
701
+ foreground-background blending. Continuous modes produces nonbinary
702
+ masks to further blend foreground and background latents by linear-
703
+ ly interpolating between them. Semi-continuous masks only applies
704
+ continuous mask at the last step of the LCM sampler. Due to the
705
+ large step size of the LCM scheduler, we find that our continuous
706
+ blending helps generating seamless inpainting and editing results.
707
+
708
+ Args:
709
+ masks (Union[torch.Tensor, Image.Image, List[Image.Image]]): Masks.
710
+ strength (Optional[Union[torch.Tensor, float]]): Mask strength that
711
+ overrides the default value. A globally multiplied factor to
712
+ the mask at the initial stage of processing. Can be applied
713
+ seperately for each mask.
714
+ std (Optional[Union[torch.Tensor, float]]): Mask blurring Gaussian
715
+ kernel's standard deviation. Overrides the default value. Can
716
+ be applied seperately for each mask.
717
+ height (int): The height of the expected generation. Mask is
718
+ resized to (height//8, width//8) with nearest neighbor inter-
719
+ polation.
720
+ width (int): The width of the expected generation. Mask is resized
721
+ to (height//8, width//8) with nearest neighbor interpolation.
722
+ use_boolean_mask (bool): Specify this to treat the mask image as
723
+ a boolean tensor. The retion with dark part darker than 0.5 of
724
+ the maximal pixel value (that is, 127.5) is considered as the
725
+ designated mask.
726
+ timesteps (Optional[torch.Tensor]): Defines the scheduler noise
727
+ levels that acts as bins of mask quantization.
728
+ preprocess_mask_cover_alpha (Optional[float]): Optional pre-
729
+ processing where each mask covered by other masks is reduced in
730
+ its alpha value by this specified factor. Overrides the default
731
+ value.
732
+
733
+ Returns: A tuple of tensors.
734
+ - masks: Preprocessed (ordered, blurred, and quantized) binary/non-
735
+ binary masks (see the explanation on `mask_type` above) for
736
+ region-based image synthesis.
737
+ - masks_blurred: Gaussian blurred masks. Used for optionally
738
+ specified foreground-background blending after image
739
+ generation.
740
+ - std: Mask blur standard deviation. Used for optionally specified
741
+ foreground-background blending after image generation.
742
+ """
743
+ if isinstance(masks, Image.Image):
744
+ masks = [masks]
745
+ if isinstance(masks, (tuple, list)):
746
+ # Assumes white background for Image.Image;
747
+ # inverted boolean masks with shape (1, 1, H, W) for torch.Tensor.
748
+ if use_boolean_mask:
749
+ proc = lambda m: T.ToTensor()(m)[None, -1:] < 0.5
750
+ else:
751
+ proc = lambda m: 1.0 - T.ToTensor()(m)[None, -1:]
752
+ masks = torch.cat([proc(mask) for mask in masks], dim=0).float().clip_(0, 1)
753
+ masks = F.interpolate(masks.float(), size=(height, width), mode='bilinear', align_corners=False)
754
+ masks = masks.to(self.device)
755
+
756
+ # Background mask alpha is decayed by the specified factor where foreground masks covers it.
757
+ if preprocess_mask_cover_alpha is None:
758
+ preprocess_mask_cover_alpha = self.default_preprocess_mask_cover_alpha
759
+ if preprocess_mask_cover_alpha > 0:
760
+ masks = torch.stack([
761
+ torch.where(
762
+ masks[i + 1:].sum(dim=0) > 0,
763
+ mask * preprocess_mask_cover_alpha,
764
+ mask,
765
+ ) if i < len(masks) - 1 else mask
766
+ for i, mask in enumerate(masks)
767
+ ], dim=0)
768
+
769
+ # Scheduler noise levels for mask quantization.
770
+ if timesteps is None:
771
+ noise_lvs = self.noise_lvs
772
+ next_noise_lvs = self.next_noise_lvs
773
+ else:
774
+ noise_lvs_ = self.sigmas * (self.sigmas**2 + 1)**(-0.5)
775
+ # noise_lvs_ = (1 - self.scheduler.alphas_cumprod[timesteps].to(self.device)) ** 0.5
776
+ noise_lvs = noise_lvs_[None, :, None, None, None].to(masks.device)
777
+ next_noise_lvs = torch.cat([noise_lvs_[1:], noise_lvs_.new_zeros(1)])[None, :, None, None, None]
778
+
779
+ # Mask preprocessing parameters are fetched from the default settings.
780
+ if std is None:
781
+ std = self.default_mask_std
782
+ if isinstance(std, (int, float)):
783
+ std = [std] * len(masks)
784
+ if isinstance(std, (list, tuple)):
785
+ std = torch.as_tensor(std, dtype=torch.float, device=self.device)
786
+
787
+ if strength is None:
788
+ strength = self.default_mask_strength
789
+ if isinstance(strength, (int, float)):
790
+ strength = [strength] * len(masks)
791
+ if isinstance(strength, (list, tuple)):
792
+ strength = torch.as_tensor(strength, dtype=torch.float, device=self.device)
793
+
794
+ if (std > 0).any():
795
+ std = torch.where(std > 0, std, 1e-5)
796
+ masks = gaussian_lowpass(masks, std)
797
+ masks_blurred = masks
798
+
799
+ # NOTE: This `strength` aligns with `denoising strength`. However, with LCM, using strength < 0.96
800
+ # gives unpleasant results.
801
+ masks = masks * strength[:, None, None, None]
802
+ masks = masks.unsqueeze(1).repeat(1, noise_lvs.shape[1], 1, 1, 1)
803
+
804
+ # Mask is quantized according to the current noise levels specified by the scheduler.
805
+ if self.mask_type == 'discrete':
806
+ # Discrete mode.
807
+ masks = masks > noise_lvs
808
+ elif self.mask_type == 'semi-continuous':
809
+ # Semi-continuous mode (continuous at the last step only).
810
+ masks = torch.cat((
811
+ masks[:, :-1] > noise_lvs[:, :-1],
812
+ (
813
+ (masks[:, -1:] - next_noise_lvs[:, -1:]) / (noise_lvs[:, -1:] - next_noise_lvs[:, -1:])
814
+ ).clip_(0, 1),
815
+ ), dim=1)
816
+ elif self.mask_type == 'continuous':
817
+ # Continuous mode: Have the exact same `1` coverage with discrete mode, but the mask gradually
818
+ # decreases continuously after the discrete mode boundary to become `0` at the
819
+ # next lower threshold.
820
+ masks = ((masks - next_noise_lvs) / (noise_lvs - next_noise_lvs)).clip_(0, 1)
821
+
822
+ # NOTE: Post processing mask strength does not align with conventional 'denoising_strength'. However,
823
+ # fine-grained mask alpha channel tuning is available with this form.
824
+ # masks = masks * strength[None, :, None, None, None]
825
+
826
+ h = height // self.vae_scale_factor
827
+ w = width // self.vae_scale_factor
828
+ masks = rearrange(masks.float(), 'p t () h w -> (p t) () h w')
829
+ masks = F.interpolate(masks, size=(h, w), mode='nearest')
830
+ masks = rearrange(masks.to(self.dtype), '(p t) () h w -> p t () h w', p=len(std))
831
+ return masks, masks_blurred, std
832
+
833
+ def scheduler_scale_model_input(
834
+ self,
835
+ latent: torch.FloatTensor,
836
+ idx: int,
837
+ ) -> torch.FloatTensor:
838
+ """
839
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
840
+ current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm.
841
+
842
+ Args:
843
+ sample (`torch.FloatTensor`):
844
+ The input sample.
845
+ timestep (`int`, *optional*):
846
+ The current timestep in the diffusion chain.
847
+
848
+ Returns:
849
+ `torch.FloatTensor`:
850
+ A scaled input sample.
851
+ """
852
+ latent = latent / ((self.sigmas[idx]**2 + 1) ** 0.5)
853
+ return latent
854
+
855
+ def scheduler_step(
856
+ self,
857
+ noise_pred: torch.Tensor,
858
+ idx: int,
859
+ latent: torch.Tensor,
860
+ ) -> torch.Tensor:
861
+ r"""Denoise-only step for reverse diffusion scheduler.
862
+
863
+ Designed to match the interface of the original `pipe.scheduler.step`,
864
+ which is a combination of this method and the following
865
+ `scheduler_add_noise`.
866
+
867
+ Args:
868
+ noise_pred (torch.Tensor): Noise prediction results from the U-Net.
869
+ idx (int): Instead of timesteps (in [0, 1000]-scale) use indices
870
+ for the timesteps tensor (ranged in [0, len(timesteps)-1]).
871
+ latent (torch.Tensor): Noisy latent.
872
+
873
+ Returns:
874
+ A denoised tensor with the same size as latent.
875
+ """
876
+ # Upcast to avoid precision issues when computing prev_sample.
877
+ latent = latent.to(torch.float32)
878
+
879
+ # 1. Compute predicted original sample (x_0) from sigma-scaled predicted noise.
880
+ assert self.scheduler.config.prediction_type == 'epsilon', 'Only supports `prediction_type` of `epsilon` for now.'
881
+ # pred_original_sample = latent - self.sigma_hats[idx] * noise_pred
882
+ # prev_sample = pred_original_sample + noise_pred * (self.dt[i] + self.sigma_hats[i])
883
+ # return pred_original_sample.to(self.dtype)
884
+
885
+ # 2. Convert to an ODE derivative.
886
+ prev_sample = latent + noise_pred * self.dt[idx]
887
+ return prev_sample.to(self.dtype)
888
+
889
+ def scheduler_add_noise(
890
+ self,
891
+ latent: torch.Tensor,
892
+ noise: Optional[torch.Tensor],
893
+ idx: int,
894
+ s_noise: float = 1.0,
895
+ initial: bool = False,
896
+ ) -> torch.Tensor:
897
+ r"""Separated noise-add step for the reverse diffusion scheduler.
898
+
899
+ Designed to match the interface of the original
900
+ `pipe.scheduler.add_noise`.
901
+
902
+ Args:
903
+ latent (torch.Tensor): Denoised latent.
904
+ noise (torch.Tensor): Added noise. Can be None. If None, a random
905
+ noise is newly sampled for addition.
906
+ idx (int): Instead of timesteps (in [0, 1000]-scale) use indices
907
+ for the timesteps tensor (ranged in [0, len(timesteps)-1]).
908
+
909
+ Returns:
910
+ A noisy tensor with the same size as latent.
911
+ """
912
+ if initial:
913
+ if idx < len(self.sigmas) and idx >= 0:
914
+ noise = torch.randn_like(latent) if noise is None else noise
915
+ return latent + self.sigmas[idx] * noise
916
+ else:
917
+ return latent
918
+ else:
919
+ # 3. Post-add noise.
920
+ noise_lv = (self.sigma_hats[idx]**2 - self.sigmas[idx]**2) ** 0.5
921
+ if self.gammas[idx] > 0 and noise_lv > 0 and s_noise > 0 and idx < len(self.sigmas) and idx >= 0:
922
+ noise = torch.randn_like(latent) if noise is None else noise
923
+ eps = noise * s_noise * noise_lv
924
+ latent = latent + eps
925
+ # pred_original_sample = pred_original_sample + eps
926
+ return latent
927
+
928
+ @torch.no_grad()
929
+ def __call__(
930
+ self,
931
+ prompts: Optional[Union[str, List[str]]] = None,
932
+ negative_prompts: Union[str, List[str]] = '',
933
+ suffix: Optional[str] = None, #', background is ',
934
+ background: Optional[Union[torch.Tensor, Image.Image]] = None,
935
+ background_prompt: Optional[str] = None,
936
+ background_negative_prompt: str = '',
937
+ height: int = 1024,
938
+ width: int = 1024,
939
+ num_inference_steps: Optional[int] = None,
940
+ guidance_scale: Optional[float] = None,
941
+ prompt_strengths: Optional[Union[torch.Tensor, float, List[float]]] = None,
942
+ masks: Optional[Union[Image.Image, List[Image.Image]]] = None,
943
+ mask_strengths: Optional[Union[torch.Tensor, float, List[float]]] = None,
944
+ mask_stds: Optional[Union[torch.Tensor, float, List[float]]] = None,
945
+ use_boolean_mask: bool = True,
946
+ do_blend: bool = True,
947
+ tile_size: int = 1024,
948
+ bootstrap_steps: Optional[int] = None,
949
+ boostrap_mix_steps: Optional[float] = None,
950
+ bootstrap_leak_sensitivity: Optional[float] = None,
951
+ preprocess_mask_cover_alpha: Optional[float] = None,
952
+ ) -> Image.Image:
953
+ r"""Arbitrary-size image generation from multiple pairs of (regional)
954
+ text prompt-mask pairs.
955
+
956
+ This is a main routine for this pipeline.
957
+
958
+ Example:
959
+ >>> device = torch.device('cuda:0')
960
+ >>> smd = StableMultiDiffusionPipeline(device)
961
+ >>> prompts = {... specify prompts}
962
+ >>> masks = {... specify mask tensors}
963
+ >>> height, width = masks.shape[-2:]
964
+ >>> image = smd(
965
+ >>> prompts, masks=masks.float(), height=height, width=width)
966
+ >>> image.save('my_beautiful_creation.png')
967
+
968
+ Args:
969
+ prompts (Union[str, List[str]]): A text prompt.
970
+ negative_prompts (Union[str, List[str]]): A negative text prompt.
971
+ suffix (Optional[str]): One option for blending foreground prompts
972
+ with background prompts by simply appending background prompt
973
+ to the end of each foreground prompt with this `middle word` in
974
+ between. For example, if you set this as `, background is`,
975
+ then the foreground prompt will be changed into
976
+ `(fg), background is (bg)` before conditional generation.
977
+ background (Optional[Union[torch.Tensor, Image.Image]]): a
978
+ background image, if the user wants to draw in front of the
979
+ specified image. Background prompt will automatically generated
980
+ with a BLIP-2 model.
981
+ background_prompt (Optional[str]): The background prompt is used
982
+ for preprocessing foreground prompt embeddings to blend
983
+ foreground and background.
984
+ background_negative_prompt (Optional[str]): The negative background
985
+ prompt.
986
+ height (int): Height of a generated image. It is tiled if larger
987
+ than `tile_size`.
988
+ width (int): Width of a generated image. It is tiled if larger
989
+ than `tile_size`.
990
+ num_inference_steps (Optional[int]): Number of inference steps.
991
+ Default inference scheduling is used if none is specified.
992
+ guidance_scale (Optional[float]): Classifier guidance scale.
993
+ Default value is used if none is specified.
994
+ prompt_strength (float): Overrides default value. Preprocess
995
+ foreground prompts globally by linearly interpolating its
996
+ embedding with the background prompt embeddint with specified
997
+ mix ratio. Useful control handle for foreground blending.
998
+ Recommended range: 0.5-1.
999
+ masks (Optional[Union[Image.Image, List[Image.Image]]]): a list of
1000
+ mask images. Each mask associates with each of the text prompts
1001
+ and each of the negative prompts. If specified as an image, it
1002
+ regards the image as a boolean mask. Also accepts torch.Tensor
1003
+ masks, which can have nonbinary values for fine-grained
1004
+ controls in mixing regional generations.
1005
+ mask_strengths (Optional[Union[torch.Tensor, float, List[float]]]):
1006
+ Overrides the default value. an be assigned for each mask
1007
+ separately. Preprocess mask by multiplying it globally with the
1008
+ specified variable. Caution: extremely sensitive. Recommended
1009
+ range: 0.98-1.
1010
+ mask_stds (Optional[Union[torch.Tensor, float, List[float]]]):
1011
+ Overrides the default value. Can be assigned for each mask
1012
+ separately. Preprocess mask with Gaussian blur with specified
1013
+ standard deviation. Recommended range: 0-64.
1014
+ use_boolean_mask (bool): Turn this off if you want to treat the
1015
+ mask image as nonbinary one. The module will use the last
1016
+ channel of the given image in `masks` as the mask value.
1017
+ do_blend (bool): Blend the generated foreground and the optionally
1018
+ predefined background by smooth boundary obtained from Gaussian
1019
+ blurs of the foreground `masks` with the given `mask_stds`.
1020
+ tile_size (Optional[int]): Tile size of the panorama generation.
1021
+ Works best with the default training size of the Stable-
1022
+ Diffusion model, i.e., 1024 or 1024 for SD1.5 and 1024 for SDXL.
1023
+ bootstrap_steps (int): Overrides the default value. Bootstrapping
1024
+ stage steps to encourage region separation. Recommended range:
1025
+ 1-3.
1026
+ boostrap_mix_steps (float): Overrides the default value.
1027
+ Bootstrapping background is a linear interpolation between
1028
+ background latent and the white image latent. This handle
1029
+ controls the mix ratio. Available range: 0-(number of
1030
+ bootstrapping inference steps). For example, 2.3 means that for
1031
+ the first two steps, white image is used as a bootstrapping
1032
+ background and in the third step, mixture of white (0.3) and
1033
+ registered background (0.7) is used as a bootstrapping
1034
+ background.
1035
+ bootstrap_leak_sensitivity (float): Overrides the default value.
1036
+ Postprocessing at each inference step by masking away the
1037
+ remaining bootstrap backgrounds t Recommended range: 0-1.
1038
+ preprocess_mask_cover_alpha (float): Overrides the default value.
1039
+ Optional preprocessing where each mask covered by other masks
1040
+ is reduced in its alpha value by this specified factor.
1041
+
1042
+ Returns: A PIL.Image image of a panorama (large-size) image.
1043
+ """
1044
+
1045
+ ### Simplest cases
1046
+
1047
+ # prompts is None: return background.
1048
+ # masks is None but prompts is not None: return prompts
1049
+ # masks is not None and prompts is not None: Do StableMultiDiffusion.
1050
+
1051
+ if prompts is None or (isinstance(prompts, (list, tuple, str)) and len(prompts) == 0):
1052
+ if background is None and background_prompt is not None:
1053
+ return sample(background_prompt, background_negative_prompt, height, width, num_inference_steps, guidance_scale)
1054
+ return background
1055
+ elif masks is None or (isinstance(masks, (list, tuple)) and len(masks) == 0):
1056
+ return sample(prompts, negative_prompts, height, width, num_inference_steps, guidance_scale)
1057
+
1058
+
1059
+ ### Prepare generation
1060
+
1061
+ if num_inference_steps is not None:
1062
+ # self.prepare_lcm_schedule(list(range(num_inference_steps)), num_inference_steps)
1063
+ self.prepare_lightning_schedule(list(range(num_inference_steps)), num_inference_steps)
1064
+
1065
+ if guidance_scale is None:
1066
+ guidance_scale = self.default_guidance_scale
1067
+ do_classifier_free_guidance = guidance_scale > 1.0
1068
+
1069
+
1070
+ ### Prompts & Masks
1071
+
1072
+ # asserts #m > 0 and #p > 0.
1073
+ # #m == #p == #n > 0: We happily generate according to the prompts & masks.
1074
+ # #m != #p: #p should be 1 and we will broadcast text embeds of p through m masks.
1075
+ # #p != #n: #n should be 1 and we will broadcast negative embeds n through p prompts.
1076
+
1077
+ if isinstance(masks, Image.Image):
1078
+ masks = [masks]
1079
+ if isinstance(prompts, str):
1080
+ prompts = [prompts]
1081
+ if isinstance(negative_prompts, str):
1082
+ negative_prompts = [negative_prompts]
1083
+ num_masks = len(masks)
1084
+ num_prompts = len(prompts)
1085
+ num_nprompts = len(negative_prompts)
1086
+ assert num_prompts in (num_masks, 1), \
1087
+ f'The number of prompts {num_prompts} should match the number of masks {num_masks}!'
1088
+ assert num_nprompts in (num_prompts, 1), \
1089
+ f'The number of negative prompts {num_nprompts} should match the number of prompts {num_prompts}!'
1090
+
1091
+ fg_masks, masks_g, std = self.process_mask(
1092
+ masks,
1093
+ mask_strengths,
1094
+ mask_stds,
1095
+ height=height,
1096
+ width=width,
1097
+ use_boolean_mask=use_boolean_mask,
1098
+ timesteps=self.timesteps,
1099
+ preprocess_mask_cover_alpha=preprocess_mask_cover_alpha,
1100
+ ) # (p, t, 1, H, W)
1101
+ bg_masks = (1 - fg_masks.sum(dim=0)).clip_(0, 1) # (T, 1, h, w)
1102
+ has_background = bg_masks.sum() > 0
1103
+
1104
+ h = (height + self.vae_scale_factor - 1) // self.vae_scale_factor
1105
+ w = (width + self.vae_scale_factor - 1) // self.vae_scale_factor
1106
+
1107
+
1108
+ ### Background
1109
+
1110
+ # background == None && background_prompt == None: Initialize with white background.
1111
+ # background == None && background_prompt != None: Generate background *along with other prompts*.
1112
+ # background != None && background_prompt == None: Retrieve text prompt using BLIP.
1113
+ # background != None && background_prompt != None: Use the given arguments.
1114
+
1115
+ # not has_background: no effect of prompt_strength (the mix ratio between fg prompt & bg prompt)
1116
+ # has_background && prompt_strength != 1: mix only for this case.
1117
+
1118
+ bg_latent = None
1119
+ if has_background:
1120
+ if background is None and background_prompt is not None:
1121
+ fg_masks = torch.cat((bg_masks[None], fg_masks), dim=0)
1122
+ if suffix is not None:
1123
+ prompts = [p + suffix + background_prompt for p in prompts]
1124
+ prompts = [background_prompt] + prompts
1125
+ negative_prompts = [background_negative_prompt] + negative_prompts
1126
+ has_background = False # Regard that background does not exist.
1127
+ else:
1128
+ if background is None and background_prompt is None:
1129
+ background = torch.ones(1, 3, height, width, dtype=self.dtype, device=self.device)
1130
+ background_prompt = 'simple white background image'
1131
+ elif background is not None and background_prompt is None:
1132
+ background_prompt = self.get_text_prompts(background)
1133
+ if suffix is not None:
1134
+ prompts = [p + suffix + background_prompt for p in prompts]
1135
+ prompts = [background_prompt] + prompts
1136
+ negative_prompts = [background_negative_prompt] + negative_prompts
1137
+ if isinstance(background, Image.Image):
1138
+ background = T.ToTensor()(background).to(dtype=self.dtype, device=self.device)[None]
1139
+ background = F.interpolate(background, size=(height, width), mode='bicubic', align_corners=False)
1140
+ bg_latent = self.encode_imgs(background)
1141
+
1142
+ # Bootstrapping stage preparation.
1143
+
1144
+ if bootstrap_steps is None:
1145
+ bootstrap_steps = self.default_bootstrap_steps
1146
+ if boostrap_mix_steps is None:
1147
+ boostrap_mix_steps = self.default_boostrap_mix_steps
1148
+ if bootstrap_leak_sensitivity is None:
1149
+ bootstrap_leak_sensitivity = self.default_bootstrap_leak_sensitivity
1150
+ if bootstrap_steps > 0:
1151
+ height_ = min(height, tile_size)
1152
+ width_ = min(width, tile_size)
1153
+ white = self.get_white_background(height, width) # (1, 4, h, w)
1154
+
1155
+
1156
+ ### Prepare text embeddings (optimized for the minimal encoder batch size)
1157
+
1158
+ # SDXL pipeline settings.
1159
+ batch_size = 1
1160
+ output_type = 'pil'
1161
+
1162
+ guidance_rescale = 0.7
1163
+
1164
+ prompt_2 = None
1165
+ device = self.device
1166
+ num_images_per_prompt = 1
1167
+ negative_prompt_2 = None
1168
+
1169
+ original_size = (height, width)
1170
+ target_size = (height, width)
1171
+ crops_coords_top_left = (0, 0)
1172
+ negative_crops_coords_top_left = (0, 0)
1173
+ negative_original_size = None
1174
+ negative_target_size = None
1175
+ pooled_prompt_embeds = None
1176
+ negative_pooled_prompt_embeds = None
1177
+ text_encoder_lora_scale = None
1178
+
1179
+ prompt_embeds = None
1180
+ negative_prompt_embeds = None
1181
+
1182
+ (
1183
+ prompt_embeds,
1184
+ negative_prompt_embeds,
1185
+ pooled_prompt_embeds,
1186
+ negative_pooled_prompt_embeds,
1187
+ ) = self.encode_prompt(
1188
+ prompt=prompts,
1189
+ prompt_2=prompt_2,
1190
+ device=device,
1191
+ num_images_per_prompt=num_images_per_prompt,
1192
+ do_classifier_free_guidance=do_classifier_free_guidance,
1193
+ negative_prompt=negative_prompts,
1194
+ negative_prompt_2=negative_prompt_2,
1195
+ prompt_embeds=prompt_embeds,
1196
+ negative_prompt_embeds=negative_prompt_embeds,
1197
+ pooled_prompt_embeds=pooled_prompt_embeds,
1198
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
1199
+ lora_scale=text_encoder_lora_scale,
1200
+ )
1201
+
1202
+ add_text_embeds = pooled_prompt_embeds
1203
+ if self.text_encoder_2 is None:
1204
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
1205
+ else:
1206
+ text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
1207
+
1208
+ add_time_ids = self._get_add_time_ids(
1209
+ original_size,
1210
+ crops_coords_top_left,
1211
+ target_size,
1212
+ dtype=prompt_embeds.dtype,
1213
+ text_encoder_projection_dim=text_encoder_projection_dim,
1214
+ )
1215
+ if negative_original_size is not None and negative_target_size is not None:
1216
+ negative_add_time_ids = self._get_add_time_ids(
1217
+ negative_original_size,
1218
+ negative_crops_coords_top_left,
1219
+ negative_target_size,
1220
+ dtype=prompt_embeds.dtype,
1221
+ text_encoder_projection_dim=text_encoder_projection_dim,
1222
+ )
1223
+ else:
1224
+ negative_add_time_ids = add_time_ids
1225
+
1226
+ if has_background:
1227
+ # First channel is background prompt text embeds. Background prompt itself is not used for generation.
1228
+ s = prompt_strengths
1229
+ if prompt_strengths is None:
1230
+ s = self.default_prompt_strength
1231
+ if isinstance(s, (int, float)):
1232
+ s = [s] * num_prompts
1233
+ if isinstance(s, (list, tuple)):
1234
+ assert len(s) == num_prompts, \
1235
+ f'The number of prompt strengths {len(s)} should match the number of prompts {num_prompts}!'
1236
+ s = torch.as_tensor(s, dtype=self.dtype, device=self.device)
1237
+ s = s[:, None, None]
1238
+
1239
+ be = prompt_embeds[:1]
1240
+ fe = prompt_embeds[1:]
1241
+ prompt_embeds = torch.lerp(be, fe, s) # (p, 77, 1024)
1242
+
1243
+ if negative_prompt_embeds is not None:
1244
+ bu = negative_prompt_embeds[:1]
1245
+ fu = negative_prompt_embeds[1:]
1246
+ if num_prompts > num_nprompts:
1247
+ # # negative prompts = 1; # prompts > 1.
1248
+ assert fu.shape[0] == 1 and fe.shape == num_prompts
1249
+ fu = fu.repeat(num_prompts, 1, 1)
1250
+ negative_prompt_embeds = torch.lerp(bu, fu, s) # (n, 77, 1024)
1251
+ elif negative_prompt_embeds is not None and num_prompts > num_nprompts:
1252
+ # # negative prompts = 1; # prompts > 1.
1253
+ assert negative_prompt_embeds.shape[0] == 1 and prompt_embeds.shape[0] == num_prompts
1254
+ negative_prompt_embeds = negative_prompt_embeds.repeat(num_prompts, 1, 1)
1255
+ # assert negative_prompt_embeds.shape[0] == prompt_embeds.shape[0] == num_prompts
1256
+ if num_masks > num_prompts:
1257
+ assert masks.shape[0] == num_masks and num_prompts == 1
1258
+ prompt_embeds = prompt_embeds.repeat(num_masks, 1, 1)
1259
+ if negative_prompt_embeds is not None:
1260
+ negative_prompt_embeds = negative_prompt_embeds.repeat(num_masks, 1, 1)
1261
+
1262
+ # SDXL pipeline settings.
1263
+ if do_classifier_free_guidance:
1264
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
1265
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
1266
+ add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
1267
+ del negative_prompt_embeds, negative_pooled_prompt_embeds, negative_add_time_ids
1268
+
1269
+ prompt_embeds = prompt_embeds.to(device)
1270
+ add_text_embeds = add_text_embeds.to(device)
1271
+ add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
1272
+
1273
+
1274
+ ### Run
1275
+
1276
+ # Latent initialization.
1277
+ if self.timesteps[0] < 999 and has_background:
1278
+ latents = self.scheduler_add_noise(bg_latents, None, 0, initial=True)
1279
+ else:
1280
+ latents = torch.randn((1, self.unet.config.in_channels, h, w), dtype=self.dtype, device=self.device)
1281
+ latents = latents * self.scheduler.init_noise_sigma
1282
+
1283
+ # Tiling (if needed).
1284
+ if height > tile_size or width > tile_size:
1285
+ t = (tile_size + self.vae_scale_factor - 1) // self.vae_scale_factor
1286
+ views, tile_masks = get_panorama_views(h, w, t)
1287
+ tile_masks = tile_masks.to(self.device)
1288
+ else:
1289
+ views = [(0, h, 0, w)]
1290
+ tile_masks = latents.new_ones((1, 1, h, w))
1291
+ value = torch.zeros_like(latents)
1292
+ count_all = torch.zeros_like(latents)
1293
+
1294
+ with torch.autocast('cuda'):
1295
+ for i, t in enumerate(tqdm(self.timesteps)):
1296
+ fg_mask = fg_masks[:, i]
1297
+ bg_mask = bg_masks[i:i + 1]
1298
+
1299
+ value.zero_()
1300
+ count_all.zero_()
1301
+ for j, (h_start, h_end, w_start, w_end) in enumerate(views):
1302
+ fg_mask_ = fg_mask[..., h_start:h_end, w_start:w_end]
1303
+ latents_ = latents[..., h_start:h_end, w_start:w_end].repeat(num_masks, 1, 1, 1)
1304
+
1305
+ # Additional arguments for the SDXL pipeline.
1306
+ add_time_ids_input = add_time_ids.clone()
1307
+ add_time_ids_input[:, 2] = h_start * self.vae_scale_factor
1308
+ add_time_ids_input[:, 3] = w_start * self.vae_scale_factor
1309
+ add_time_ids_input = add_time_ids_input.repeat_interleave(num_prompts, dim=0)
1310
+
1311
+ # Bootstrap for tight background.
1312
+ if i < bootstrap_steps:
1313
+ mix_ratio = min(1, max(0, boostrap_mix_steps - i))
1314
+ # Treat the first foreground latent as the background latent if one does not exist.
1315
+ bg_latents_ = bg_latents[..., h_start:h_end, w_start:w_end] if has_background else latents_[:1]
1316
+ white_ = white[..., h_start:h_end, w_start:w_end]
1317
+ white_ = self.scheduler_add_noise(white_, None, i, initial=True)
1318
+ bg_latents_ = mix_ratio * white_ + (1.0 - mix_ratio) * bg_latents_
1319
+ latents_ = (1.0 - fg_mask_) * bg_latents_ + fg_mask_ * latents_
1320
+
1321
+ # Centering.
1322
+ latents_ = shift_to_mask_bbox_center(latents_, fg_mask_, reverse=True)
1323
+
1324
+ latent_model_input = torch.cat([latents_] * 2) if do_classifier_free_guidance else latents_
1325
+ latent_model_input = self.scheduler_scale_model_input(latent_model_input, i)
1326
+
1327
+ # Perform one step of the reverse diffusion.
1328
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids_input}
1329
+ noise_pred = self.unet(
1330
+ latent_model_input,
1331
+ t,
1332
+ encoder_hidden_states=prompt_embeds,
1333
+ timestep_cond=None,
1334
+ cross_attention_kwargs=None,
1335
+ added_cond_kwargs=added_cond_kwargs,
1336
+ return_dict=False,
1337
+ )[0]
1338
+
1339
+ if do_classifier_free_guidance:
1340
+ noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
1341
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
1342
+
1343
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
1344
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
1345
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_cond, guidance_rescale=guidance_rescale)
1346
+
1347
+ latents_ = self.scheduler_step(noise_pred, i, latents_)
1348
+
1349
+ if i < bootstrap_steps:
1350
+ # Uncentering.
1351
+ latents_ = shift_to_mask_bbox_center(latents_, fg_mask_)
1352
+
1353
+ # Remove leakage (optional).
1354
+ leak = (latents_ - bg_latents_).pow(2).mean(dim=1, keepdim=True)
1355
+ leak_sigmoid = torch.sigmoid(leak / bootstrap_leak_sensitivity) * 2 - 1
1356
+ fg_mask_ = fg_mask_ * leak_sigmoid
1357
+
1358
+ # Mix the latents.
1359
+ fg_mask_ = fg_mask_ * tile_masks[:, j:j + 1, h_start:h_end, w_start:w_end]
1360
+ value[..., h_start:h_end, w_start:w_end] += (fg_mask_ * latents_).sum(dim=0, keepdim=True)
1361
+ count_all[..., h_start:h_end, w_start:w_end] += fg_mask_.sum(dim=0, keepdim=True)
1362
+
1363
+ latents = torch.where(count_all > 0, value / count_all, value)
1364
+ bg_mask = (1 - count_all).clip_(0, 1) # (T, 1, h, w)
1365
+ if has_background:
1366
+ latents = (1 - bg_mask) * latents + bg_mask * bg_latents
1367
+
1368
+ # Noise is added after mixing.
1369
+ if i < len(self.timesteps) - 1:
1370
+ latents = self.scheduler_add_noise(latents, None, i + 1)
1371
+
1372
+ if not output_type == "latent":
1373
+ # make sure the VAE is in float32 mode, as it overflows in float16
1374
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
1375
+
1376
+ if needs_upcasting:
1377
+ self.upcast_vae()
1378
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
1379
+
1380
+ # unscale/denormalize the latents
1381
+ # denormalize with the mean and std if available and not None
1382
+ has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
1383
+ has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
1384
+ if has_latents_mean and has_latents_std:
1385
+ latents_mean = (
1386
+ torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
1387
+ )
1388
+ latents_std = (
1389
+ torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
1390
+ )
1391
+ latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
1392
+ else:
1393
+ latents = latents / self.vae.config.scaling_factor
1394
+
1395
+ image = self.vae.decode(latents, return_dict=False)[0]
1396
+
1397
+ # cast back to fp16 if needed
1398
+ if needs_upcasting:
1399
+ self.vae.to(dtype=torch.float16)
1400
+ else:
1401
+ image = latents
1402
+
1403
+ # Return PIL Image.
1404
+ image = image[0].clip_(-1, 1) * 0.5 + 0.5
1405
+ if has_background and do_blend:
1406
+ fg_mask = torch.sum(masks_g, dim=0).clip_(0, 1)
1407
+ image = blend(image, background[0], fg_mask)
1408
+ else:
1409
+ image = T.ToPILImage()(image)
1410
+ return image
prompt_util.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Tuple, Union
2
+
3
+
4
+ quality_prompt_list = [
5
+ {
6
+ "name": "(None)",
7
+ "prompt": "{prompt}",
8
+ "negative_prompt": "nsfw, lowres",
9
+ },
10
+ {
11
+ "name": "Standard v3.0",
12
+ "prompt": "{prompt}, masterpiece, best quality",
13
+ "negative_prompt": "nsfw, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, artist name",
14
+ },
15
+ {
16
+ "name": "Standard v3.1",
17
+ "prompt": "{prompt}, masterpiece, best quality, very aesthetic, absurdres",
18
+ "negative_prompt": "nsfw, lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]",
19
+ },
20
+ {
21
+ "name": "Light v3.1",
22
+ "prompt": "{prompt}, (masterpiece), best quality, very aesthetic, perfect face",
23
+ "negative_prompt": "nsfw, (low quality, worst quality:1.2), very displeasing, 3d, watermark, signature, ugly, poorly drawn",
24
+ },
25
+ {
26
+ "name": "Heavy v3.1",
27
+ "prompt": "{prompt}, (masterpiece), (best quality), (ultra-detailed), very aesthetic, illustration, disheveled hair, perfect composition, moist skin, intricate details",
28
+ "negative_prompt": "nsfw, longbody, lowres, bad anatomy, bad hands, missing fingers, pubic hair, extra digit, fewer digits, cropped, worst quality, low quality, very displeasing",
29
+ },
30
+ ]
31
+
32
+ style_list = [
33
+ {
34
+ "name": "(None)",
35
+ "prompt": "{prompt}",
36
+ "negative_prompt": "",
37
+ },
38
+ {
39
+ "name": "Cinematic",
40
+ "prompt": "{prompt}, cinematic still, emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy",
41
+ "negative_prompt": "nsfw, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured",
42
+ },
43
+ {
44
+ "name": "Photographic",
45
+ "prompt": "{prompt}, cinematic photo, 35mm photograph, film, bokeh, professional, 4k, highly detailed",
46
+ "negative_prompt": "nsfw, drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly",
47
+ },
48
+ {
49
+ "name": "Anime",
50
+ "prompt": "{prompt}, anime artwork, anime style, key visual, vibrant, studio anime, highly detailed",
51
+ "negative_prompt": "nsfw, photo, deformed, black and white, realism, disfigured, low contrast",
52
+ },
53
+ {
54
+ "name": "Manga",
55
+ "prompt": "{prompt}, manga style, vibrant, high-energy, detailed, iconic, Japanese comic style",
56
+ "negative_prompt": "nsfw, ugly, deformed, noisy, blurry, low contrast, realism, photorealistic, Western comic style",
57
+ },
58
+ {
59
+ "name": "Digital Art",
60
+ "prompt": "{prompt}, concept art, digital artwork, illustrative, painterly, matte painting, highly detailed",
61
+ "negative_prompt": "nsfw, photo, photorealistic, realism, ugly",
62
+ },
63
+ {
64
+ "name": "Pixel art",
65
+ "prompt": "{prompt}, pixel-art, low-res, blocky, pixel art style, 8-bit graphics",
66
+ "negative_prompt": "nsfw, sloppy, messy, blurry, noisy, highly detailed, ultra textured, photo, realistic",
67
+ },
68
+ {
69
+ "name": "Fantasy art",
70
+ "prompt": "{prompt}, ethereal fantasy concept art, magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy",
71
+ "negative_prompt": "nsfw, photographic, realistic, realism, 35mm film, dslr, cropped, frame, text, deformed, glitch, noise, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, disfigured, sloppy, duplicate, mutated, black and white",
72
+ },
73
+ {
74
+ "name": "Neonpunk",
75
+ "prompt": "{prompt}, neonpunk style, cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional",
76
+ "negative_prompt": "nsfw, painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured",
77
+ },
78
+ {
79
+ "name": "3D Model",
80
+ "prompt": "{prompt}, professional 3d model, octane render, highly detailed, volumetric, dramatic lighting",
81
+ "negative_prompt": "nsfw, ugly, deformed, noisy, low poly, blurry, painting",
82
+ },
83
+ ]
84
+
85
+
86
+ _style_dict = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}
87
+ _quality_dict = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in quality_prompt_list}
88
+
89
+
90
+ def preprocess_prompt(
91
+ positive: str,
92
+ negative: str = "",
93
+ style_dict: Dict[str, dict] = _quality_dict,
94
+ style_name: str = "Standard v3.1", # "Heavy v3.1"
95
+ add_style: bool = True,
96
+ ) -> Tuple[str, str]:
97
+ p, n = style_dict.get(style_name, style_dict["(None)"])
98
+
99
+ if add_style and positive.strip():
100
+ formatted_positive = p.format(prompt=positive)
101
+ else:
102
+ formatted_positive = positive
103
+
104
+ combined_negative = n
105
+ if negative.strip():
106
+ if combined_negative:
107
+ combined_negative += ", " + negative
108
+ else:
109
+ combined_negative = negative
110
+
111
+ return formatted_positive, combined_negative
112
+
113
+
114
+ def preprocess_prompts(
115
+ positives: List[str],
116
+ negatives: List[str] = None,
117
+ style_dict = _style_dict,
118
+ style_name: str = "Manga", # "(None)"
119
+ quality_dict = _quality_dict,
120
+ quality_name: str = "Standard v3.1", # "Heavy v3.1"
121
+ add_style: bool = True,
122
+ add_quality_tags = True,
123
+ ) -> Tuple[List[str], List[str]]:
124
+ if negatives is None:
125
+ negatives = ['' for _ in positives]
126
+
127
+ positives_ = []
128
+ negatives_ = []
129
+ for pos, neg in zip(positives, negatives):
130
+ pos, neg = preprocess_prompt(pos, neg, quality_dict, quality_name, add_quality_tags)
131
+ pos, neg = preprocess_prompt(pos, neg, style_dict, style_name, add_style)
132
+ positives_.append(pos)
133
+ negatives_.append(neg)
134
+ return positives_, negatives_
135
+
136
+
137
+ def print_prompts(
138
+ positives: Union[str, List[str]],
139
+ negatives: Union[str, List[str]],
140
+ has_background: bool = False,
141
+ ) -> None:
142
+ if isinstance(positives, str):
143
+ positives = [positives]
144
+ if isinstance(negatives, str):
145
+ negatives = [negatives]
146
+
147
+ for i, prompt in enumerate(positives):
148
+ prefix = ((f'Prompt{i}' if i > 0 else 'Background Prompt')
149
+ if has_background else f'Prompt{i + 1}')
150
+ print(prefix + ': ' + prompt)
151
+ for i, prompt in enumerate(negatives):
152
+ prefix = ((f'Negative Prompt{i}' if i > 0 else 'Background Negative Prompt')
153
+ if has_background else f'Negative Prompt{i + 1}')
154
+ print(prefix + ': ' + prompt)
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.0.1
2
+ torchvision
3
+ xformers==0.0.22
4
+ einops
5
+ diffusers
6
+ transformers
7
+ huggingface_hub[torch]
8
+ gradio
9
+ Pillow
10
+ emoji
11
+ numpy
12
+ tqdm
13
+ jupyterlab
14
+ spaces
share_btn.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ community_icon_html = """<svg id="share-btn-share-icon" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32">
2
+ <path d="M20.6081 3C21.7684 3 22.8053 3.49196 23.5284 4.38415C23.9756 4.93678 24.4428 5.82749 24.4808 7.16133C24.9674 7.01707 25.4353 6.93643 25.8725 6.93643C26.9833 6.93643 27.9865 7.37587 28.696 8.17411C29.6075 9.19872 30.0124 10.4579 29.8361 11.7177C29.7523 12.3177 29.5581 12.8555 29.2678 13.3534C29.8798 13.8646 30.3306 14.5763 30.5485 15.4322C30.719 16.1032 30.8939 17.5006 29.9808 18.9403C30.0389 19.0342 30.0934 19.1319 30.1442 19.2318C30.6932 20.3074 30.7283 21.5229 30.2439 22.6548C29.5093 24.3704 27.6841 25.7219 24.1397 27.1727C21.9347 28.0753 19.9174 28.6523 19.8994 28.6575C16.9842 29.4379 14.3477 29.8345 12.0653 29.8345C7.87017 29.8345 4.8668 28.508 3.13831 25.8921C0.356375 21.6797 0.754104 17.8269 4.35369 14.1131C6.34591 12.058 7.67023 9.02782 7.94613 8.36275C8.50224 6.39343 9.97271 4.20438 12.4172 4.20438H12.4179C12.6236 4.20438 12.8314 4.2214 13.0364 4.25468C14.107 4.42854 15.0428 5.06476 15.7115 6.02205C16.4331 5.09583 17.134 4.359 17.7682 3.94323C18.7242 3.31737 19.6794 3 20.6081 3ZM20.6081 5.95917C20.2427 5.95917 19.7963 6.1197 19.3039 6.44225C17.7754 7.44319 14.8258 12.6772 13.7458 14.7131C13.3839 15.3952 12.7655 15.6837 12.2086 15.6837C11.1036 15.6837 10.2408 14.5497 12.1076 13.1085C14.9146 10.9402 13.9299 7.39584 12.5898 7.1776C12.5311 7.16799 12.4731 7.16355 12.4172 7.16355C11.1989 7.16355 10.6615 9.33114 10.6615 9.33114C10.6615 9.33114 9.0863 13.4148 6.38031 16.206C3.67434 18.998 3.5346 21.2388 5.50675 24.2246C6.85185 26.2606 9.42666 26.8753 12.0653 26.8753C14.8021 26.8753 17.6077 26.2139 19.1799 25.793C19.2574 25.7723 28.8193 22.984 27.6081 20.6107C27.4046 20.212 27.0693 20.0522 26.6471 20.0522C24.9416 20.0522 21.8393 22.6726 20.5057 22.6726C20.2076 22.6726 19.9976 22.5416 19.9116 22.222C19.3433 20.1173 28.552 19.2325 27.7758 16.1839C27.639 15.6445 27.2677 15.4256 26.746 15.4263C24.4923 15.4263 19.4358 19.5181 18.3759 19.5181C18.2949 19.5181 18.2368 19.4937 18.2053 19.4419C17.6743 18.557 17.9653 17.9394 21.7082 15.6009C25.4511 13.2617 28.0783 11.8545 26.5841 10.1752C26.4121 9.98141 26.1684 9.8956 25.8725 9.8956C23.6001 9.89634 18.2311 14.9403 18.2311 14.9403C18.2311 14.9403 16.7821 16.496 15.9057 16.496C15.7043 16.496 15.533 16.4139 15.4169 16.2112C14.7956 15.1296 21.1879 10.1286 21.5484 8.06535C21.7928 6.66715 21.3771 5.95917 20.6081 5.95917Z" fill="#FF9D00"></path>
3
+ <path d="M5.50686 24.2246C3.53472 21.2387 3.67446 18.9979 6.38043 16.206C9.08641 13.4147 10.6615 9.33111 10.6615 9.33111C10.6615 9.33111 11.2499 6.95933 12.59 7.17757C13.93 7.39581 14.9139 10.9401 12.1069 13.1084C9.29997 15.276 12.6659 16.7489 13.7459 14.713C14.8258 12.6772 17.7747 7.44316 19.304 6.44221C20.8326 5.44128 21.9089 6.00204 21.5484 8.06532C21.188 10.1286 14.795 15.1295 15.4171 16.2118C16.0391 17.2934 18.2312 14.9402 18.2312 14.9402C18.2312 14.9402 25.0907 8.49588 26.5842 10.1752C28.0776 11.8545 25.4512 13.2616 21.7082 15.6008C17.9646 17.9393 17.6744 18.557 18.2054 19.4418C18.7372 20.3266 26.9998 13.1351 27.7759 16.1838C28.5513 19.2324 19.3434 20.1173 19.9117 22.2219C20.48 24.3274 26.3979 18.2382 27.6082 20.6107C28.8193 22.9839 19.2574 25.7722 19.18 25.7929C16.0914 26.62 8.24723 28.3726 5.50686 24.2246Z" fill="#FFD21E"></path>
4
+ </svg>"""
5
+
6
+ loading_icon_html = """<svg id="share-btn-loading-icon" style="display:none;" class="animate-spin"
7
+ style="color: #ffffff;
8
+ "
9
+ xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" fill="none" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 24 24"><circle style="opacity: 0.25;" cx="12" cy="12" r="10" stroke="white" stroke-width="4"></circle><path style="opacity: 0.75;" fill="white" d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"></path></svg>"""
10
+
11
+ share_js = """async () => {
12
+ async function uploadFile(file){
13
+ const UPLOAD_URL = 'https://huggingface.co/uploads';
14
+ const response = await fetch(UPLOAD_URL, {
15
+ method: 'POST',
16
+ headers: {
17
+ 'Content-Type': file.type,
18
+ 'X-Requested-With': 'XMLHttpRequest',
19
+ },
20
+ body: file, /// <- File inherits from Blob
21
+ });
22
+ const url = await response.text();
23
+ return url;
24
+ }
25
+ const gradioEl = document.querySelector('body > gradio-app');
26
+ const imgEls = gradioEl.querySelectorAll('#output-screen img');
27
+ const shareBtnEl = gradioEl.querySelector('#share-btn');
28
+ const shareIconEl = gradioEl.querySelector('#share-btn-share-icon');
29
+ const loadingIconEl = gradioEl.querySelector('#share-btn-loading-icon');
30
+ if(!imgEls.length){
31
+ return;
32
+ };
33
+ shareBtnEl.style.pointerEvents = 'none';
34
+ shareIconEl.style.display = 'none';
35
+ loadingIconEl.style.removeProperty('display');
36
+ const files = await Promise.all(
37
+ [...imgEls].map(async (imgEl) => {
38
+ const res = await fetch(imgEl.src);
39
+ const blob = await res.blob();
40
+ const imgId = Date.now() % 200;
41
+ const fileName = `diffuse-the-rest-${{imgId}}.jpg`;
42
+ return new File([blob], fileName, { type: 'image/jpeg' });
43
+ })
44
+ );
45
+ const urls = await Promise.all(files.map((f) => uploadFile(f)));
46
+ const htmlImgs = urls.map(url => `<img src='${url}' width='2560' height='1024'>`);
47
+ const descriptionMd = `<div style='display: flex; flex-wrap: wrap; column-gap: 0.75rem;'>
48
+ ${htmlImgs.join(`\n`)}
49
+ </div>`;
50
+ const params = new URLSearchParams({
51
+ title: <p>My creation</p>,
52
+ description: descriptionMd,
53
+ });
54
+ const paramsStr = params.toString();
55
+ window.open(`https://huggingface.co/spaces/ironjr/SemanticPaletteXL/discussions/new?${paramsStr}`, '_blank');
56
+ shareBtnEl.style.removeProperty('pointer-events');
57
+ shareIconEl.style.removeProperty('display');
58
+ loadingIconEl.style.display = 'none';
59
+ }"""
util.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Jaerin Lee
2
+
3
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
4
+ # of this software and associated documentation files (the "Software"), to deal
5
+ # in the Software without restriction, including without limitation the rights
6
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7
+ # copies of the Software, and to permit persons to whom the Software is
8
+ # furnished to do so, subject to the following conditions:
9
+
10
+ # The above copyright notice and this permission notice shall be included in all
11
+ # copies or substantial portions of the Software.
12
+
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
19
+ # SOFTWARE.
20
+
21
+ import concurrent.futures
22
+ import time
23
+ from typing import Any, Callable, List, Literal, Tuple, Union
24
+
25
+ from PIL import Image
26
+ import numpy as np
27
+
28
+ import torch
29
+ import torch.nn.functional as F
30
+ import torch.cuda.amp as amp
31
+ import torchvision.transforms as T
32
+ import torchvision.transforms.functional as TF
33
+
34
+ from diffusers import (
35
+ DiffusionPipeline,
36
+ StableDiffusionPipeline,
37
+ StableDiffusionXLPipeline,
38
+ )
39
+
40
+
41
+ def seed_everything(seed: int) -> None:
42
+ torch.manual_seed(seed)
43
+ torch.cuda.manual_seed(seed)
44
+ torch.backends.cudnn.deterministic = True
45
+ torch.backends.cudnn.benchmark = True
46
+
47
+
48
+ def load_model(
49
+ model_key: str,
50
+ sd_version: Literal['1.5', 'xl'],
51
+ device: torch.device,
52
+ dtype: torch.dtype,
53
+ ) -> torch.nn.Module:
54
+ if model_key.endswith('.safetensors'):
55
+ if sd_version == '1.5':
56
+ pipeline = StableDiffusionPipeline
57
+ elif sd_version == 'xl':
58
+ pipeline = StableDiffusionXLPipeline
59
+ else:
60
+ raise ValueError(f'Stable Diffusion version {sd_version} not supported.')
61
+ return pipeline.from_single_file(model_key, torch_dtype=dtype).to(device)
62
+ try:
63
+ return DiffusionPipeline.from_pretrained(model_key, variant='fp16', torch_dtype=dtype).to(device)
64
+ except:
65
+ return DiffusionPipeline.from_pretrained(model_key, variant=None, torch_dtype=dtype).to(device)
66
+
67
+
68
+ def get_cutoff(cutoff: float = None, scale: float = None) -> float:
69
+ if cutoff is not None:
70
+ return cutoff
71
+
72
+ if scale is not None and cutoff is None:
73
+ return 0.5 / scale
74
+
75
+ raise ValueError('Either one of `cutoff`, or `scale` should be specified.')
76
+
77
+
78
+ def get_scale(cutoff: float = None, scale: float = None) -> float:
79
+ if scale is not None:
80
+ return scale
81
+
82
+ if cutoff is not None and scale is None:
83
+ return 0.5 / cutoff
84
+
85
+ raise ValueError('Either one of `cutoff`, or `scale` should be specified.')
86
+
87
+
88
+ def filter_2d_by_kernel_1d(x: torch.Tensor, k: torch.Tensor) -> torch.Tensor:
89
+ assert len(k.shape) in (1,), 'Kernel size should be one of (1,).'
90
+ # assert len(k.shape) in (1, 2), 'Kernel size should be one of (1, 2).'
91
+
92
+ b, c, h, w = x.shape
93
+ ks = k.shape[-1]
94
+ k = k.view(1, 1, -1).repeat(c, 1, 1)
95
+
96
+ x = x.permute(0, 2, 1, 3)
97
+ x = x.reshape(b * h, c, w)
98
+ x = F.pad(x, (ks // 2, (ks - 1) // 2), mode='replicate')
99
+ x = F.conv1d(x, k, groups=c)
100
+ x = x.reshape(b, h, c, w).permute(0, 3, 2, 1).reshape(b * w, c, h)
101
+ x = F.pad(x, (ks // 2, (ks - 1) // 2), mode='replicate')
102
+ x = F.conv1d(x, k, groups=c)
103
+ x = x.reshape(b, w, c, h).permute(0, 2, 3, 1)
104
+ return x
105
+
106
+
107
+ def filter_2d_by_kernel_2d(x: torch.Tensor, k: torch.Tensor) -> torch.Tensor:
108
+ assert len(k.shape) in (2, 3), 'Kernel size should be one of (2, 3).'
109
+
110
+ x = F.pad(x, (
111
+ k.shape[-2] // 2, (k.shape[-2] - 1) // 2,
112
+ k.shape[-1] // 2, (k.shape[-1] - 1) // 2,
113
+ ), mode='replicate')
114
+
115
+ b, c, _, _ = x.shape
116
+ if len(k.shape) == 2 or (len(k.shape) == 3 and k.shape[0] == 1):
117
+ k = k.view(1, 1, *k.shape[-2:]).repeat(c, 1, 1, 1)
118
+ x = F.conv2d(x, k, groups=c)
119
+ elif len(k.shape) == 3:
120
+ assert k.shape[0] == b, \
121
+ 'The number of kernels should match the batch size.'
122
+
123
+ k = k.unsqueeze(1)
124
+ x = F.conv2d(x.permute(1, 0, 2, 3), k, groups=b).permute(1, 0, 2, 3)
125
+ return x
126
+
127
+
128
+ @amp.autocast(False)
129
+ def filter_by_kernel(
130
+ x: torch.Tensor,
131
+ k: torch.Tensor,
132
+ is_batch: bool = False,
133
+ ) -> torch.Tensor:
134
+ k_dim = len(k.shape)
135
+ if k_dim == 1 or k_dim == 2 and is_batch:
136
+ return filter_2d_by_kernel_1d(x, k)
137
+ elif k_dim == 2 or k_dim == 3 and is_batch:
138
+ return filter_2d_by_kernel_2d(x, k)
139
+ else:
140
+ raise ValueError('Kernel size should be one of (1, 2, 3).')
141
+
142
+
143
+ def gen_gauss_lowpass_filter_2d(
144
+ std: torch.Tensor,
145
+ window_size: int = None,
146
+ ) -> torch.Tensor:
147
+ # Gaussian kernel size is odd in order to preserve the center.
148
+ if window_size is None:
149
+ window_size = (
150
+ 2 * int(np.ceil(3 * std.max().detach().cpu().numpy())) + 1)
151
+
152
+ y = torch.arange(
153
+ window_size, dtype=std.dtype, device=std.device
154
+ ).view(-1, 1).repeat(1, window_size)
155
+ grid = torch.stack((y.t(), y), dim=-1)
156
+ grid -= 0.5 * (window_size - 1) # (W, W)
157
+ var = (std * std).unsqueeze(-1).unsqueeze(-1)
158
+ distsq = (grid * grid).sum(dim=-1).unsqueeze(0).repeat(*std.shape, 1, 1)
159
+ k = torch.exp(-0.5 * distsq / var)
160
+ k /= k.sum(dim=(-2, -1), keepdim=True)
161
+ return k
162
+
163
+
164
+ def gaussian_lowpass(
165
+ x: torch.Tensor,
166
+ std: Union[float, Tuple[float], torch.Tensor] = None,
167
+ cutoff: Union[float, torch.Tensor] = None,
168
+ scale: Union[float, torch.Tensor] = None,
169
+ ) -> torch.Tensor:
170
+ if std is None:
171
+ cutoff = get_cutoff(cutoff, scale)
172
+ std = 0.5 / (np.pi * cutoff)
173
+ if isinstance(std, (float, int)):
174
+ std = (std, std)
175
+ if isinstance(std, torch.Tensor):
176
+ """Using nn.functional.conv2d with Gaussian kernels built in runtime is
177
+ 80% faster than transforms.functional.gaussian_blur for individual
178
+ items.
179
+
180
+ (in GPU); However, in CPU, the result is exactly opposite. But you
181
+ won't gonna run this on CPU, right?
182
+ """
183
+ if len(list(s for s in std.shape if s != 1)) >= 2:
184
+ raise NotImplementedError(
185
+ 'Anisotropic Gaussian filter is not currently available.')
186
+
187
+ # k.shape == (B, W, W).
188
+ k = gen_gauss_lowpass_filter_2d(std=std.view(-1))
189
+ if k.shape[0] == 1:
190
+ return filter_by_kernel(x, k[0], False)
191
+ else:
192
+ return filter_by_kernel(x, k, True)
193
+ else:
194
+ # Gaussian kernel size is odd in order to preserve the center.
195
+ window_size = tuple(2 * int(np.ceil(3 * s)) + 1 for s in std)
196
+ return TF.gaussian_blur(x, window_size, std)
197
+
198
+
199
+ def blend(
200
+ fg: Union[torch.Tensor, Image.Image],
201
+ bg: Union[torch.Tensor, Image.Image],
202
+ mask: Union[torch.Tensor, Image.Image],
203
+ std: float = 0.0,
204
+ ) -> Image.Image:
205
+ if not isinstance(fg, torch.Tensor):
206
+ fg = T.ToTensor()(fg)
207
+ if not isinstance(bg, torch.Tensor):
208
+ bg = T.ToTensor()(bg)
209
+ if not isinstance(mask, torch.Tensor):
210
+ mask = (T.ToTensor()(mask) < 0.5).float()[:1]
211
+ if std > 0:
212
+ mask = gaussian_lowpass(mask[None], std)[0].clip_(0, 1)
213
+ return T.ToPILImage()(fg * mask + bg * (1 - mask))
214
+
215
+
216
+ def get_panorama_views(
217
+ panorama_height: int,
218
+ panorama_width: int,
219
+ window_size: int = 64,
220
+ ) -> tuple[List[Tuple[int]], torch.Tensor]:
221
+ stride = window_size // 2
222
+ is_horizontal = panorama_width > panorama_height
223
+ num_blocks_height = (panorama_height - window_size + stride - 1) // stride + 1
224
+ num_blocks_width = (panorama_width - window_size + stride - 1) // stride + 1
225
+ total_num_blocks = num_blocks_height * num_blocks_width
226
+
227
+ half_fwd = torch.linspace(0, 1, (window_size + 1) // 2)
228
+ half_rev = half_fwd.flip(0)
229
+ if window_size % 2 == 1:
230
+ half_rev = half_rev[1:]
231
+ c = torch.cat((half_fwd, half_rev))
232
+ one = torch.ones_like(c)
233
+ f = c.clone()
234
+ f[:window_size // 2] = 1
235
+ b = c.clone()
236
+ b[-(window_size // 2):] = 1
237
+
238
+ h = [one] if num_blocks_height == 1 else [f] + [c] * (num_blocks_height - 2) + [b]
239
+ w = [one] if num_blocks_width == 1 else [f] + [c] * (num_blocks_width - 2) + [b]
240
+
241
+ views = []
242
+ masks = torch.zeros(total_num_blocks, panorama_height, panorama_width) # (n, h, w)
243
+ for i in range(total_num_blocks):
244
+ hi, wi = i // num_blocks_width, i % num_blocks_width
245
+ h_start = hi * stride
246
+ h_end = min(h_start + window_size, panorama_height)
247
+ w_start = wi * stride
248
+ w_end = min(w_start + window_size, panorama_width)
249
+ views.append((h_start, h_end, w_start, w_end))
250
+
251
+ h_width = h_end - h_start
252
+ w_width = w_end - w_start
253
+ masks[i, h_start:h_end, w_start:w_end] = h[hi][:h_width, None] * w[wi][None, :w_width]
254
+
255
+ # Sum of the mask weights at each pixel `masks.sum(dim=1)` must be unity.
256
+ return views, masks[None] # (1, n, h, w)
257
+
258
+
259
+ def shift_to_mask_bbox_center(im: torch.Tensor, mask: torch.Tensor, reverse: bool = False) -> List[int]:
260
+ h, w = mask.shape[-2:]
261
+ device = mask.device
262
+ mask = mask.reshape(-1, h, w)
263
+ # assert mask.shape[0] == im.shape[0]
264
+ h_occupied = mask.sum(dim=-2) > 0
265
+ w_occupied = mask.sum(dim=-1) > 0
266
+ l = torch.argmax(h_occupied * torch.arange(w, 0, -1).to(device), 1, keepdim=True).cpu()
267
+ r = torch.argmax(h_occupied * torch.arange(w).to(device), 1, keepdim=True).cpu()
268
+ t = torch.argmax(w_occupied * torch.arange(h, 0, -1).to(device), 1, keepdim=True).cpu()
269
+ b = torch.argmax(w_occupied * torch.arange(h).to(device), 1, keepdim=True).cpu()
270
+ tb = (t + b + 1) // 2
271
+ lr = (l + r + 1) // 2
272
+ shifts = (tb - (h // 2), lr - (w // 2))
273
+ shifts = torch.cat(shifts, dim=1) # (p, 2)
274
+ if reverse:
275
+ shifts = shifts * -1
276
+ return torch.stack([i.roll(shifts=s.tolist(), dims=(-2, -1)) for i, s in zip(im, shifts)], dim=0)
277
+
278
+
279
+ class Streamer:
280
+ def __init__(self, fn: Callable, ema_alpha: float = 0.9) -> None:
281
+ self.fn = fn
282
+ self.ema_alpha = ema_alpha
283
+
284
+ self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
285
+ self.future = self.executor.submit(fn)
286
+ self.image = None
287
+
288
+ self.prev_exec_time = 0
289
+ self.ema_exec_time = 0
290
+
291
+ @property
292
+ def throughput(self) -> float:
293
+ return 1.0 / self.ema_exec_time if self.ema_exec_time else float('inf')
294
+
295
+ def timed_fn(self) -> Any:
296
+ start = time.time()
297
+ res = self.fn()
298
+ end = time.time()
299
+ self.prev_exec_time = end - start
300
+ self.ema_exec_time = self.ema_exec_time * self.ema_alpha + self.prev_exec_time * (1 - self.ema_alpha)
301
+ return res
302
+
303
+ def __call__(self) -> Any:
304
+ if self.future.done() or self.image is None:
305
+ # get the result (the new image) and start a new task
306
+ image = self.future.result()
307
+ self.future = self.executor.submit(self.timed_fn)
308
+ self.image = image
309
+ return image
310
+ else:
311
+ # if self.fn() is not ready yet, use the previous image
312
+ # NOTE: This assumes that we have access to a previously generated image here.
313
+ # If there's no previous image (i.e., this is the first invocation), you could fall
314
+ # back to some default image or handle it differently based on your requirements.
315
+ return self.image