thejagstudio commited on
Commit
510ee71
1 Parent(s): d9087f2

Upload 61 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. __init__.py +0 -0
  2. app.py +513 -0
  3. app_settings.py +94 -0
  4. backend/__init__.py +0 -0
  5. backend/annotators/canny_control.py +15 -0
  6. backend/annotators/control_interface.py +12 -0
  7. backend/annotators/depth_control.py +15 -0
  8. backend/annotators/image_control_factory.py +31 -0
  9. backend/annotators/lineart_control.py +11 -0
  10. backend/annotators/mlsd_control.py +10 -0
  11. backend/annotators/normal_control.py +10 -0
  12. backend/annotators/pose_control.py +10 -0
  13. backend/annotators/shuffle_control.py +10 -0
  14. backend/annotators/softedge_control.py +10 -0
  15. backend/api/models/response.py +16 -0
  16. backend/api/web.py +103 -0
  17. backend/base64_image.py +21 -0
  18. backend/controlnet.py +90 -0
  19. backend/device.py +23 -0
  20. backend/image_saver.py +60 -0
  21. backend/lcm_text_to_image.py +386 -0
  22. backend/lora.py +136 -0
  23. backend/models/device.py +9 -0
  24. backend/models/gen_images.py +16 -0
  25. backend/models/lcmdiffusion_setting.py +64 -0
  26. backend/models/upscale.py +9 -0
  27. backend/openvino/custom_ov_model_vae_decoder.py +21 -0
  28. backend/openvino/pipelines.py +75 -0
  29. backend/pipelines/lcm.py +100 -0
  30. backend/pipelines/lcm_lora.py +82 -0
  31. backend/tiny_decoder.py +32 -0
  32. backend/upscale/aura_sr.py +834 -0
  33. backend/upscale/aura_sr_upscale.py +9 -0
  34. backend/upscale/edsr_upscale_onnx.py +37 -0
  35. backend/upscale/tiled_upscale.py +238 -0
  36. backend/upscale/upscaler.py +52 -0
  37. constants.py +20 -0
  38. context.py +77 -0
  39. frontend/cli_interactive.py +655 -0
  40. frontend/gui/app_window.py +612 -0
  41. frontend/gui/image_generator_worker.py +37 -0
  42. frontend/gui/ui.py +15 -0
  43. frontend/utils.py +83 -0
  44. frontend/webui/controlnet_ui.py +194 -0
  45. frontend/webui/css/style.css +22 -0
  46. frontend/webui/generation_settings_ui.py +157 -0
  47. frontend/webui/image_to_image_ui.py +120 -0
  48. frontend/webui/image_variations_ui.py +106 -0
  49. frontend/webui/lora_models_ui.py +185 -0
  50. frontend/webui/models_ui.py +85 -0
__init__.py ADDED
File without changes
app.py ADDED
@@ -0,0 +1,513 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from argparse import ArgumentParser
3
+
4
+ import constants
5
+ from backend.controlnet import controlnet_settings_from_dict
6
+ from backend.models.gen_images import ImageFormat
7
+ from backend.models.lcmdiffusion_setting import DiffusionTask
8
+ from backend.upscale.tiled_upscale import generate_upscaled_image
9
+ from constants import APP_VERSION, DEVICE
10
+ from frontend.webui.image_variations_ui import generate_image_variations
11
+ from models.interface_types import InterfaceType
12
+ from paths import FastStableDiffusionPaths
13
+ from PIL import Image
14
+ from state import get_context, get_settings
15
+ from utils import show_system_info
16
+ from backend.device import get_device_name
17
+
18
+ parser = ArgumentParser(description=f"FAST SD CPU {constants.APP_VERSION}")
19
+ parser.add_argument(
20
+ "-s",
21
+ "--share",
22
+ action="store_true",
23
+ help="Create sharable link(Web UI)",
24
+ required=False,
25
+ )
26
+ group = parser.add_mutually_exclusive_group(required=False)
27
+ group.add_argument(
28
+ "-g",
29
+ "--gui",
30
+ action="store_true",
31
+ help="Start desktop GUI",
32
+ )
33
+ group.add_argument(
34
+ "-w",
35
+ "--webui",
36
+ action="store_true",
37
+ help="Start Web UI",
38
+ )
39
+ group.add_argument(
40
+ "-a",
41
+ "--api",
42
+ action="store_true",
43
+ help="Start Web API server",
44
+ )
45
+ group.add_argument(
46
+ "-r",
47
+ "--realtime",
48
+ action="store_true",
49
+ help="Start realtime inference UI(experimental)",
50
+ )
51
+ group.add_argument(
52
+ "-v",
53
+ "--version",
54
+ action="store_true",
55
+ help="Version",
56
+ )
57
+
58
+ parser.add_argument(
59
+ "-b",
60
+ "--benchmark",
61
+ action="store_true",
62
+ help="Run inference benchmark on the selected device",
63
+ )
64
+ parser.add_argument(
65
+ "--lcm_model_id",
66
+ type=str,
67
+ help="Model ID or path,Default stabilityai/sd-turbo",
68
+ default="stabilityai/sd-turbo",
69
+ )
70
+ parser.add_argument(
71
+ "--openvino_lcm_model_id",
72
+ type=str,
73
+ help="OpenVINO Model ID or path,Default rupeshs/sd-turbo-openvino",
74
+ default="rupeshs/sd-turbo-openvino",
75
+ )
76
+ parser.add_argument(
77
+ "--prompt",
78
+ type=str,
79
+ help="Describe the image you want to generate",
80
+ default="",
81
+ )
82
+ parser.add_argument(
83
+ "--negative_prompt",
84
+ type=str,
85
+ help="Describe what you want to exclude from the generation",
86
+ default="",
87
+ )
88
+ parser.add_argument(
89
+ "--image_height",
90
+ type=int,
91
+ help="Height of the image",
92
+ default=512,
93
+ )
94
+ parser.add_argument(
95
+ "--image_width",
96
+ type=int,
97
+ help="Width of the image",
98
+ default=512,
99
+ )
100
+ parser.add_argument(
101
+ "--inference_steps",
102
+ type=int,
103
+ help="Number of steps,default : 1",
104
+ default=1,
105
+ )
106
+ parser.add_argument(
107
+ "--guidance_scale",
108
+ type=float,
109
+ help="Guidance scale,default : 1.0",
110
+ default=1.0,
111
+ )
112
+
113
+ parser.add_argument(
114
+ "--number_of_images",
115
+ type=int,
116
+ help="Number of images to generate ,default : 1",
117
+ default=1,
118
+ )
119
+ parser.add_argument(
120
+ "--seed",
121
+ type=int,
122
+ help="Seed,default : -1 (disabled) ",
123
+ default=-1,
124
+ )
125
+ parser.add_argument(
126
+ "--use_openvino",
127
+ action="store_true",
128
+ help="Use OpenVINO model",
129
+ )
130
+
131
+ parser.add_argument(
132
+ "--use_offline_model",
133
+ action="store_true",
134
+ help="Use offline model",
135
+ )
136
+ parser.add_argument(
137
+ "--use_safety_checker",
138
+ action="store_true",
139
+ help="Use safety checker",
140
+ )
141
+ parser.add_argument(
142
+ "--use_lcm_lora",
143
+ action="store_true",
144
+ help="Use LCM-LoRA",
145
+ )
146
+ parser.add_argument(
147
+ "--base_model_id",
148
+ type=str,
149
+ help="LCM LoRA base model ID,Default Lykon/dreamshaper-8",
150
+ default="Lykon/dreamshaper-8",
151
+ )
152
+ parser.add_argument(
153
+ "--lcm_lora_id",
154
+ type=str,
155
+ help="LCM LoRA model ID,Default latent-consistency/lcm-lora-sdv1-5",
156
+ default="latent-consistency/lcm-lora-sdv1-5",
157
+ )
158
+ parser.add_argument(
159
+ "-i",
160
+ "--interactive",
161
+ action="store_true",
162
+ help="Interactive CLI mode",
163
+ )
164
+ parser.add_argument(
165
+ "-t",
166
+ "--use_tiny_auto_encoder",
167
+ action="store_true",
168
+ help="Use tiny auto encoder for SD (TAESD)",
169
+ )
170
+ parser.add_argument(
171
+ "-f",
172
+ "--file",
173
+ type=str,
174
+ help="Input image for img2img mode",
175
+ default="",
176
+ )
177
+ parser.add_argument(
178
+ "--img2img",
179
+ action="store_true",
180
+ help="img2img mode; requires input file via -f argument",
181
+ )
182
+ parser.add_argument(
183
+ "--batch_count",
184
+ type=int,
185
+ help="Number of sequential generations",
186
+ default=1,
187
+ )
188
+ parser.add_argument(
189
+ "--strength",
190
+ type=float,
191
+ help="Denoising strength for img2img and Image variations",
192
+ default=0.3,
193
+ )
194
+ parser.add_argument(
195
+ "--sdupscale",
196
+ action="store_true",
197
+ help="Tiled SD upscale,works only for the resolution 512x512,(2x upscale)",
198
+ )
199
+ parser.add_argument(
200
+ "--upscale",
201
+ action="store_true",
202
+ help="EDSR SD upscale ",
203
+ )
204
+ parser.add_argument(
205
+ "--custom_settings",
206
+ type=str,
207
+ help="JSON file containing custom generation settings",
208
+ default=None,
209
+ )
210
+ parser.add_argument(
211
+ "--usejpeg",
212
+ action="store_true",
213
+ help="Images will be saved as JPEG format",
214
+ )
215
+ parser.add_argument(
216
+ "--noimagesave",
217
+ action="store_true",
218
+ help="Disable image saving",
219
+ )
220
+ parser.add_argument(
221
+ "--lora",
222
+ type=str,
223
+ help="LoRA model full path e.g D:\lora_models\CuteCartoon15V-LiberteRedmodModel-Cartoon-CuteCartoonAF.safetensors",
224
+ default=None,
225
+ )
226
+ parser.add_argument(
227
+ "--lora_weight",
228
+ type=float,
229
+ help="LoRA adapter weight [0 to 1.0]",
230
+ default=0.5,
231
+ )
232
+
233
+ args = parser.parse_args()
234
+
235
+ if args.version:
236
+ print(APP_VERSION)
237
+ exit()
238
+
239
+ # parser.print_help()
240
+ show_system_info()
241
+ print(f"Using device : {constants.DEVICE}")
242
+
243
+ if args.webui:
244
+ app_settings = get_settings()
245
+ else:
246
+ app_settings = get_settings()
247
+
248
+ print(f"Found {len(app_settings.lcm_models)} LCM models in config/lcm-models.txt")
249
+ print(
250
+ f"Found {len(app_settings.stable_diffsuion_models)} stable diffusion models in config/stable-diffusion-models.txt"
251
+ )
252
+ print(
253
+ f"Found {len(app_settings.lcm_lora_models)} LCM-LoRA models in config/lcm-lora-models.txt"
254
+ )
255
+ print(
256
+ f"Found {len(app_settings.openvino_lcm_models)} OpenVINO LCM models in config/openvino-lcm-models.txt"
257
+ )
258
+
259
+ if args.noimagesave:
260
+ app_settings.settings.generated_images.save_image = False
261
+ else:
262
+ app_settings.settings.generated_images.save_image = True
263
+
264
+ if not args.realtime:
265
+ # To minimize realtime mode dependencies
266
+ from backend.upscale.upscaler import upscale_image
267
+ from frontend.cli_interactive import interactive_mode
268
+
269
+ if args.gui:
270
+ from frontend.gui.ui import start_gui
271
+
272
+ print("Starting desktop GUI mode(Qt)")
273
+ start_gui(
274
+ [],
275
+ app_settings,
276
+ )
277
+ elif args.webui:
278
+ from frontend.webui.ui import start_webui
279
+
280
+ print("Starting web UI mode")
281
+ start_webui(
282
+ args.share,
283
+ )
284
+ elif args.realtime:
285
+ from frontend.webui.realtime_ui import start_realtime_text_to_image
286
+
287
+ print("Starting realtime text to image(EXPERIMENTAL)")
288
+ start_realtime_text_to_image(args.share)
289
+ elif args.api:
290
+ from backend.api.web import start_web_server
291
+
292
+ start_web_server()
293
+
294
+ else:
295
+ context = get_context(InterfaceType.CLI)
296
+ config = app_settings.settings
297
+
298
+ if args.use_openvino:
299
+ config.lcm_diffusion_setting.openvino_lcm_model_id = args.openvino_lcm_model_id
300
+ else:
301
+ config.lcm_diffusion_setting.lcm_model_id = args.lcm_model_id
302
+
303
+ config.lcm_diffusion_setting.prompt = args.prompt
304
+ config.lcm_diffusion_setting.negative_prompt = args.negative_prompt
305
+ config.lcm_diffusion_setting.image_height = args.image_height
306
+ config.lcm_diffusion_setting.image_width = args.image_width
307
+ config.lcm_diffusion_setting.guidance_scale = args.guidance_scale
308
+ config.lcm_diffusion_setting.number_of_images = args.number_of_images
309
+ config.lcm_diffusion_setting.inference_steps = args.inference_steps
310
+ config.lcm_diffusion_setting.strength = args.strength
311
+ config.lcm_diffusion_setting.seed = args.seed
312
+ config.lcm_diffusion_setting.use_openvino = args.use_openvino
313
+ config.lcm_diffusion_setting.use_tiny_auto_encoder = args.use_tiny_auto_encoder
314
+ config.lcm_diffusion_setting.use_lcm_lora = args.use_lcm_lora
315
+ config.lcm_diffusion_setting.lcm_lora.base_model_id = args.base_model_id
316
+ config.lcm_diffusion_setting.lcm_lora.lcm_lora_id = args.lcm_lora_id
317
+ config.lcm_diffusion_setting.diffusion_task = DiffusionTask.text_to_image.value
318
+ config.lcm_diffusion_setting.lora.enabled = False
319
+ config.lcm_diffusion_setting.lora.path = args.lora
320
+ config.lcm_diffusion_setting.lora.weight = args.lora_weight
321
+ config.lcm_diffusion_setting.lora.fuse = True
322
+ if config.lcm_diffusion_setting.lora.path:
323
+ config.lcm_diffusion_setting.lora.enabled = True
324
+ if args.usejpeg:
325
+ config.generated_images.format = ImageFormat.JPEG.value.upper()
326
+ if args.seed > -1:
327
+ config.lcm_diffusion_setting.use_seed = True
328
+ else:
329
+ config.lcm_diffusion_setting.use_seed = False
330
+ config.lcm_diffusion_setting.use_offline_model = args.use_offline_model
331
+ config.lcm_diffusion_setting.use_safety_checker = args.use_safety_checker
332
+
333
+ # Read custom settings from JSON file
334
+ custom_settings = {}
335
+ if args.custom_settings:
336
+ with open(args.custom_settings) as f:
337
+ custom_settings = json.load(f)
338
+
339
+ # Basic ControlNet settings; if ControlNet is enabled, an image is
340
+ # required even in txt2img mode
341
+ config.lcm_diffusion_setting.controlnet = None
342
+ controlnet_settings_from_dict(
343
+ config.lcm_diffusion_setting,
344
+ custom_settings,
345
+ )
346
+
347
+ # Interactive mode
348
+ if args.interactive:
349
+ # wrapper(interactive_mode, config, context)
350
+ config.lcm_diffusion_setting.lora.fuse = False
351
+ interactive_mode(config, context)
352
+
353
+ # Start of non-interactive CLI image generation
354
+ if args.img2img and args.file != "":
355
+ config.lcm_diffusion_setting.init_image = Image.open(args.file)
356
+ config.lcm_diffusion_setting.diffusion_task = DiffusionTask.image_to_image.value
357
+ elif args.img2img and args.file == "":
358
+ print("Error : You need to specify a file in img2img mode")
359
+ exit()
360
+ elif args.upscale and args.file == "" and args.custom_settings == None:
361
+ print("Error : You need to specify a file in SD upscale mode")
362
+ exit()
363
+ elif (
364
+ args.prompt == ""
365
+ and args.file == ""
366
+ and args.custom_settings == None
367
+ and not args.benchmark
368
+ ):
369
+ print("Error : You need to provide a prompt")
370
+ exit()
371
+
372
+ if args.upscale:
373
+ # image = Image.open(args.file)
374
+ output_path = FastStableDiffusionPaths.get_upscale_filepath(
375
+ args.file,
376
+ 2,
377
+ config.generated_images.format,
378
+ )
379
+ result = upscale_image(
380
+ context,
381
+ args.file,
382
+ output_path,
383
+ 2,
384
+ )
385
+ # Perform Tiled SD upscale (EXPERIMENTAL)
386
+ elif args.sdupscale:
387
+ if args.use_openvino:
388
+ config.lcm_diffusion_setting.strength = 0.3
389
+ upscale_settings = None
390
+ if custom_settings != {}:
391
+ upscale_settings = custom_settings
392
+ filepath = args.file
393
+ output_format = config.generated_images.format
394
+ if upscale_settings:
395
+ filepath = upscale_settings["source_file"]
396
+ output_format = upscale_settings["output_format"].upper()
397
+ output_path = FastStableDiffusionPaths.get_upscale_filepath(
398
+ filepath,
399
+ 2,
400
+ output_format,
401
+ )
402
+
403
+ generate_upscaled_image(
404
+ config,
405
+ filepath,
406
+ config.lcm_diffusion_setting.strength,
407
+ upscale_settings=upscale_settings,
408
+ context=context,
409
+ tile_overlap=32 if config.lcm_diffusion_setting.use_openvino else 16,
410
+ output_path=output_path,
411
+ image_format=output_format,
412
+ )
413
+ exit()
414
+ # If img2img argument is set and prompt is empty, use image variations mode
415
+ elif args.img2img and args.prompt == "":
416
+ for i in range(0, args.batch_count):
417
+ generate_image_variations(
418
+ config.lcm_diffusion_setting.init_image, args.strength
419
+ )
420
+ else:
421
+
422
+ if args.benchmark:
423
+ print("Initializing benchmark...")
424
+ bench_lcm_setting = config.lcm_diffusion_setting
425
+ bench_lcm_setting.prompt = "a cat"
426
+ bench_lcm_setting.use_tiny_auto_encoder = False
427
+ context.generate_text_to_image(
428
+ settings=config,
429
+ device=DEVICE,
430
+ )
431
+ latencies = []
432
+
433
+ print("Starting benchmark please wait...")
434
+ for _ in range(3):
435
+ context.generate_text_to_image(
436
+ settings=config,
437
+ device=DEVICE,
438
+ )
439
+ latencies.append(context.latency)
440
+
441
+ avg_latency = sum(latencies) / 3
442
+
443
+ bench_lcm_setting.use_tiny_auto_encoder = True
444
+
445
+ context.generate_text_to_image(
446
+ settings=config,
447
+ device=DEVICE,
448
+ )
449
+ latencies = []
450
+ for _ in range(3):
451
+ context.generate_text_to_image(
452
+ settings=config,
453
+ device=DEVICE,
454
+ )
455
+ latencies.append(context.latency)
456
+
457
+ avg_latency_taesd = sum(latencies) / 3
458
+
459
+ benchmark_name = ""
460
+
461
+ if config.lcm_diffusion_setting.use_openvino:
462
+ benchmark_name = "OpenVINO"
463
+ else:
464
+ benchmark_name = "PyTorch"
465
+
466
+ bench_model_id = ""
467
+ if bench_lcm_setting.use_openvino:
468
+ bench_model_id = bench_lcm_setting.openvino_lcm_model_id
469
+ elif bench_lcm_setting.use_lcm_lora:
470
+ bench_model_id = bench_lcm_setting.lcm_lora.base_model_id
471
+ else:
472
+ bench_model_id = bench_lcm_setting.lcm_model_id
473
+
474
+ benchmark_result = [
475
+ ["Device", f"{DEVICE.upper()},{get_device_name()}"],
476
+ ["Stable Diffusion Model", bench_model_id],
477
+ [
478
+ "Image Size ",
479
+ f"{bench_lcm_setting.image_width}x{bench_lcm_setting.image_height}",
480
+ ],
481
+ [
482
+ "Inference Steps",
483
+ f"{bench_lcm_setting.inference_steps}",
484
+ ],
485
+ [
486
+ "Benchmark Passes",
487
+ 3,
488
+ ],
489
+ [
490
+ "Average Latency",
491
+ f"{round(avg_latency,3)} sec",
492
+ ],
493
+ [
494
+ "Average Latency(TAESD* enabled)",
495
+ f"{round(avg_latency_taesd,3)} sec",
496
+ ],
497
+ ]
498
+ print()
499
+ print(
500
+ f" FastSD Benchmark - {benchmark_name:8} "
501
+ )
502
+ print(f"-" * 80)
503
+ for benchmark in benchmark_result:
504
+ print(f"{benchmark[0]:35} - {benchmark[1]}")
505
+ print(f"-" * 80)
506
+ print("*TAESD - Tiny AutoEncoder for Stable Diffusion")
507
+
508
+ else:
509
+ for i in range(0, args.batch_count):
510
+ context.generate_text_to_image(
511
+ settings=config,
512
+ device=DEVICE,
513
+ )
app_settings.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import yaml
2
+ from os import path, makedirs
3
+ from models.settings import Settings
4
+ from paths import FastStableDiffusionPaths
5
+ from utils import get_models_from_text_file
6
+ from constants import (
7
+ OPENVINO_LCM_MODELS_FILE,
8
+ LCM_LORA_MODELS_FILE,
9
+ SD_MODELS_FILE,
10
+ LCM_MODELS_FILE,
11
+ )
12
+ from copy import deepcopy
13
+
14
+
15
+ class AppSettings:
16
+ def __init__(self):
17
+ self.config_path = FastStableDiffusionPaths().get_app_settings_path()
18
+ self._stable_diffsuion_models = get_models_from_text_file(
19
+ FastStableDiffusionPaths().get_models_config_path(SD_MODELS_FILE)
20
+ )
21
+ self._lcm_lora_models = get_models_from_text_file(
22
+ FastStableDiffusionPaths().get_models_config_path(LCM_LORA_MODELS_FILE)
23
+ )
24
+ self._openvino_lcm_models = get_models_from_text_file(
25
+ FastStableDiffusionPaths().get_models_config_path(OPENVINO_LCM_MODELS_FILE)
26
+ )
27
+ self._lcm_models = get_models_from_text_file(
28
+ FastStableDiffusionPaths().get_models_config_path(LCM_MODELS_FILE)
29
+ )
30
+ self._config = None
31
+
32
+ @property
33
+ def settings(self):
34
+ return self._config
35
+
36
+ @property
37
+ def stable_diffsuion_models(self):
38
+ return self._stable_diffsuion_models
39
+
40
+ @property
41
+ def openvino_lcm_models(self):
42
+ return self._openvino_lcm_models
43
+
44
+ @property
45
+ def lcm_models(self):
46
+ return self._lcm_models
47
+
48
+ @property
49
+ def lcm_lora_models(self):
50
+ return self._lcm_lora_models
51
+
52
+ def load(self, skip_file=False):
53
+ if skip_file:
54
+ print("Skipping config file")
55
+ settings_dict = self._load_default()
56
+ self._config = Settings.model_validate(settings_dict)
57
+ else:
58
+ if not path.exists(self.config_path):
59
+ base_dir = path.dirname(self.config_path)
60
+ if not path.exists(base_dir):
61
+ makedirs(base_dir)
62
+ try:
63
+ print("Settings not found creating default settings")
64
+ with open(self.config_path, "w") as file:
65
+ yaml.dump(
66
+ self._load_default(),
67
+ file,
68
+ )
69
+ except Exception as ex:
70
+ print(f"Error in creating settings : {ex}")
71
+ exit()
72
+ try:
73
+ with open(self.config_path) as file:
74
+ settings_dict = yaml.safe_load(file)
75
+ self._config = Settings.model_validate(settings_dict)
76
+ except Exception as ex:
77
+ print(f"Error in loading settings : {ex}")
78
+
79
+ def save(self):
80
+ try:
81
+ with open(self.config_path, "w") as file:
82
+ tmp_cfg = deepcopy(self._config)
83
+ tmp_cfg.lcm_diffusion_setting.init_image = None
84
+ configurations = tmp_cfg.model_dump(
85
+ exclude=["init_image"],
86
+ )
87
+ if configurations:
88
+ yaml.dump(configurations, file)
89
+ except Exception as ex:
90
+ print(f"Error in saving settings : {ex}")
91
+
92
+ def _load_default(self) -> dict:
93
+ default_config = Settings()
94
+ return default_config.model_dump()
backend/__init__.py ADDED
File without changes
backend/annotators/canny_control.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from backend.annotators.control_interface import ControlInterface
3
+ from cv2 import Canny
4
+ from PIL import Image
5
+
6
+
7
+ class CannyControl(ControlInterface):
8
+ def get_control_image(self, image: Image) -> Image:
9
+ low_threshold = 100
10
+ high_threshold = 200
11
+ image = np.array(image)
12
+ image = Canny(image, low_threshold, high_threshold)
13
+ image = image[:, :, None]
14
+ image = np.concatenate([image, image, image], axis=2)
15
+ return Image.fromarray(image)
backend/annotators/control_interface.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+
3
+ from PIL import Image
4
+
5
+
6
+ class ControlInterface(ABC):
7
+ @abstractmethod
8
+ def get_control_image(
9
+ self,
10
+ image: Image,
11
+ ) -> Image:
12
+ pass
backend/annotators/depth_control.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from backend.annotators.control_interface import ControlInterface
3
+ from PIL import Image
4
+ from transformers import pipeline
5
+
6
+
7
+ class DepthControl(ControlInterface):
8
+ def get_control_image(self, image: Image) -> Image:
9
+ depth_estimator = pipeline("depth-estimation")
10
+ image = depth_estimator(image)["depth"]
11
+ image = np.array(image)
12
+ image = image[:, :, None]
13
+ image = np.concatenate([image, image, image], axis=2)
14
+ image = Image.fromarray(image)
15
+ return image
backend/annotators/image_control_factory.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from backend.annotators.canny_control import CannyControl
2
+ from backend.annotators.depth_control import DepthControl
3
+ from backend.annotators.lineart_control import LineArtControl
4
+ from backend.annotators.mlsd_control import MlsdControl
5
+ from backend.annotators.normal_control import NormalControl
6
+ from backend.annotators.pose_control import PoseControl
7
+ from backend.annotators.shuffle_control import ShuffleControl
8
+ from backend.annotators.softedge_control import SoftEdgeControl
9
+
10
+
11
+ class ImageControlFactory:
12
+ def create_control(self, controlnet_type: str):
13
+ if controlnet_type == "Canny":
14
+ return CannyControl()
15
+ elif controlnet_type == "Pose":
16
+ return PoseControl()
17
+ elif controlnet_type == "MLSD":
18
+ return MlsdControl()
19
+ elif controlnet_type == "Depth":
20
+ return DepthControl()
21
+ elif controlnet_type == "LineArt":
22
+ return LineArtControl()
23
+ elif controlnet_type == "Shuffle":
24
+ return ShuffleControl()
25
+ elif controlnet_type == "NormalBAE":
26
+ return NormalControl()
27
+ elif controlnet_type == "SoftEdge":
28
+ return SoftEdgeControl()
29
+ else:
30
+ print("Error: Control type not implemented!")
31
+ raise Exception("Error: Control type not implemented!")
backend/annotators/lineart_control.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from backend.annotators.control_interface import ControlInterface
3
+ from controlnet_aux import LineartDetector
4
+ from PIL import Image
5
+
6
+
7
+ class LineArtControl(ControlInterface):
8
+ def get_control_image(self, image: Image) -> Image:
9
+ processor = LineartDetector.from_pretrained("lllyasviel/Annotators")
10
+ control_image = processor(image)
11
+ return control_image
backend/annotators/mlsd_control.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from backend.annotators.control_interface import ControlInterface
2
+ from controlnet_aux import MLSDdetector
3
+ from PIL import Image
4
+
5
+
6
+ class MlsdControl(ControlInterface):
7
+ def get_control_image(self, image: Image) -> Image:
8
+ mlsd = MLSDdetector.from_pretrained("lllyasviel/ControlNet")
9
+ image = mlsd(image)
10
+ return image
backend/annotators/normal_control.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from backend.annotators.control_interface import ControlInterface
2
+ from controlnet_aux import NormalBaeDetector
3
+ from PIL import Image
4
+
5
+
6
+ class NormalControl(ControlInterface):
7
+ def get_control_image(self, image: Image) -> Image:
8
+ processor = NormalBaeDetector.from_pretrained("lllyasviel/Annotators")
9
+ control_image = processor(image)
10
+ return control_image
backend/annotators/pose_control.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from backend.annotators.control_interface import ControlInterface
2
+ from controlnet_aux import OpenposeDetector
3
+ from PIL import Image
4
+
5
+
6
+ class PoseControl(ControlInterface):
7
+ def get_control_image(self, image: Image) -> Image:
8
+ openpose = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")
9
+ image = openpose(image)
10
+ return image
backend/annotators/shuffle_control.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from backend.annotators.control_interface import ControlInterface
2
+ from controlnet_aux import ContentShuffleDetector
3
+ from PIL import Image
4
+
5
+
6
+ class ShuffleControl(ControlInterface):
7
+ def get_control_image(self, image: Image) -> Image:
8
+ shuffle_processor = ContentShuffleDetector()
9
+ image = shuffle_processor(image)
10
+ return image
backend/annotators/softedge_control.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from backend.annotators.control_interface import ControlInterface
2
+ from controlnet_aux import PidiNetDetector
3
+ from PIL import Image
4
+
5
+
6
+ class SoftEdgeControl(ControlInterface):
7
+ def get_control_image(self, image: Image) -> Image:
8
+ processor = PidiNetDetector.from_pretrained("lllyasviel/Annotators")
9
+ control_image = processor(image)
10
+ return control_image
backend/api/models/response.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ from pydantic import BaseModel
4
+
5
+
6
+ class StableDiffusionResponse(BaseModel):
7
+ """
8
+ Stable diffusion response model
9
+
10
+ Attributes:
11
+ images (List[str]): List of JPEG image as base64 encoded
12
+ latency (float): Latency in seconds
13
+ """
14
+
15
+ images: List[str]
16
+ latency: float
backend/api/web.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import platform
2
+
3
+ import uvicorn
4
+ from backend.api.models.response import StableDiffusionResponse
5
+ from backend.models.device import DeviceInfo
6
+ from backend.base64_image import base64_image_to_pil, pil_image_to_base64_str
7
+ from backend.device import get_device_name
8
+ from backend.models.lcmdiffusion_setting import DiffusionTask, LCMDiffusionSetting
9
+ from constants import APP_VERSION, DEVICE
10
+ from context import Context
11
+ from fastapi import FastAPI
12
+ from models.interface_types import InterfaceType
13
+ from state import get_settings
14
+
15
+ app_settings = get_settings()
16
+ app = FastAPI(
17
+ title="FastSD CPU",
18
+ description="Fast stable diffusion on CPU",
19
+ version=APP_VERSION,
20
+ license_info={
21
+ "name": "MIT",
22
+ "identifier": "MIT",
23
+ },
24
+ docs_url="/api/docs",
25
+ redoc_url="/api/redoc",
26
+ openapi_url="/api/openapi.json",
27
+ )
28
+ print(app_settings.settings.lcm_diffusion_setting)
29
+
30
+ context = Context(InterfaceType.API_SERVER)
31
+
32
+
33
+ @app.get("/api/")
34
+ async def root():
35
+ return {"message": "Welcome to FastSD CPU API"}
36
+
37
+
38
+ @app.get(
39
+ "/api/info",
40
+ description="Get system information",
41
+ summary="Get system information",
42
+ )
43
+ async def info():
44
+ device_info = DeviceInfo(
45
+ device_type=DEVICE,
46
+ device_name=get_device_name(),
47
+ os=platform.system(),
48
+ platform=platform.platform(),
49
+ processor=platform.processor(),
50
+ )
51
+ return device_info.model_dump()
52
+
53
+
54
+ @app.get(
55
+ "/api/config",
56
+ description="Get current configuration",
57
+ summary="Get configurations",
58
+ )
59
+ async def config():
60
+ return app_settings.settings
61
+
62
+
63
+ @app.get(
64
+ "/api/models",
65
+ description="Get available models",
66
+ summary="Get available models",
67
+ )
68
+ async def models():
69
+ return {
70
+ "lcm_lora_models": app_settings.lcm_lora_models,
71
+ "stable_diffusion": app_settings.stable_diffsuion_models,
72
+ "openvino_models": app_settings.openvino_lcm_models,
73
+ "lcm_models": app_settings.lcm_models,
74
+ }
75
+
76
+
77
+ @app.post(
78
+ "/api/generate",
79
+ description="Generate image(Text to image,Image to Image)",
80
+ summary="Generate image(Text to image,Image to Image)",
81
+ )
82
+ async def generate(diffusion_config: LCMDiffusionSetting) -> StableDiffusionResponse:
83
+ app_settings.settings.lcm_diffusion_setting = diffusion_config
84
+ if diffusion_config.diffusion_task == DiffusionTask.image_to_image:
85
+ app_settings.settings.lcm_diffusion_setting.init_image = base64_image_to_pil(
86
+ diffusion_config.init_image
87
+ )
88
+
89
+ images = context.generate_text_to_image(app_settings.settings)
90
+
91
+ images_base64 = [pil_image_to_base64_str(img) for img in images]
92
+ return StableDiffusionResponse(
93
+ latency=round(context.latency, 2),
94
+ images=images_base64,
95
+ )
96
+
97
+
98
+ def start_web_server():
99
+ uvicorn.run(
100
+ app,
101
+ host="0.0.0.0",
102
+ port=8000,
103
+ )
backend/base64_image.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from io import BytesIO
2
+ from base64 import b64encode, b64decode
3
+ from PIL import Image
4
+
5
+
6
+ def pil_image_to_base64_str(
7
+ image: Image,
8
+ format: str = "JPEG",
9
+ ) -> str:
10
+ buffer = BytesIO()
11
+ image.save(buffer, format=format)
12
+ buffer.seek(0)
13
+ img_base64 = b64encode(buffer.getvalue()).decode("utf-8")
14
+ return img_base64
15
+
16
+
17
+ def base64_image_to_pil(base64_str) -> Image:
18
+ image_data = b64decode(base64_str)
19
+ image_buffer = BytesIO(image_data)
20
+ image = Image.open(image_buffer)
21
+ return image
backend/controlnet.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from PIL import Image
3
+ from diffusers import ControlNetModel
4
+ from backend.models.lcmdiffusion_setting import (
5
+ DiffusionTask,
6
+ ControlNetSetting,
7
+ )
8
+
9
+
10
+ # Prepares ControlNet adapters for use with FastSD CPU
11
+ #
12
+ # This function loads the ControlNet adapters defined by the
13
+ # _lcm_diffusion_setting.controlnet_ object and returns a dictionary
14
+ # with the pipeline arguments required to use the loaded adapters
15
+ def load_controlnet_adapters(lcm_diffusion_setting) -> dict:
16
+ controlnet_args = {}
17
+ if (
18
+ lcm_diffusion_setting.controlnet is None
19
+ or not lcm_diffusion_setting.controlnet.enabled
20
+ ):
21
+ return controlnet_args
22
+
23
+ logging.info("Loading ControlNet adapter")
24
+ controlnet_adapter = ControlNetModel.from_single_file(
25
+ lcm_diffusion_setting.controlnet.adapter_path,
26
+ local_files_only=True,
27
+ use_safetensors=True,
28
+ )
29
+ controlnet_args["controlnet"] = controlnet_adapter
30
+ return controlnet_args
31
+
32
+
33
+ # Updates the ControlNet pipeline arguments to use for image generation
34
+ #
35
+ # This function uses the contents of the _lcm_diffusion_setting.controlnet_
36
+ # object to generate a dictionary with the corresponding pipeline arguments
37
+ # to be used for image generation; in particular, it sets the ControlNet control
38
+ # image and conditioning scale
39
+ def update_controlnet_arguments(lcm_diffusion_setting) -> dict:
40
+ controlnet_args = {}
41
+ if (
42
+ lcm_diffusion_setting.controlnet is None
43
+ or not lcm_diffusion_setting.controlnet.enabled
44
+ ):
45
+ return controlnet_args
46
+
47
+ controlnet_args["controlnet_conditioning_scale"] = (
48
+ lcm_diffusion_setting.controlnet.conditioning_scale
49
+ )
50
+ if lcm_diffusion_setting.diffusion_task == DiffusionTask.text_to_image.value:
51
+ controlnet_args["image"] = lcm_diffusion_setting.controlnet._control_image
52
+ elif lcm_diffusion_setting.diffusion_task == DiffusionTask.image_to_image.value:
53
+ controlnet_args["control_image"] = (
54
+ lcm_diffusion_setting.controlnet._control_image
55
+ )
56
+ return controlnet_args
57
+
58
+
59
+ # Helper function to adjust ControlNet settings from a dictionary
60
+ def controlnet_settings_from_dict(
61
+ lcm_diffusion_setting,
62
+ dictionary,
63
+ ) -> None:
64
+ if lcm_diffusion_setting is None or dictionary is None:
65
+ logging.error("Invalid arguments!")
66
+ return
67
+ if (
68
+ "controlnet" not in dictionary
69
+ or dictionary["controlnet"] is None
70
+ or len(dictionary["controlnet"]) == 0
71
+ ):
72
+ logging.warning("ControlNet settings not found, ControlNet will be disabled")
73
+ lcm_diffusion_setting.controlnet = None
74
+ return
75
+
76
+ controlnet = ControlNetSetting()
77
+ controlnet.enabled = dictionary["controlnet"][0]["enabled"]
78
+ controlnet.conditioning_scale = dictionary["controlnet"][0]["conditioning_scale"]
79
+ controlnet.adapter_path = dictionary["controlnet"][0]["adapter_path"]
80
+ controlnet._control_image = None
81
+ image_path = dictionary["controlnet"][0]["control_image"]
82
+ if controlnet.enabled:
83
+ try:
84
+ controlnet._control_image = Image.open(image_path)
85
+ except (AttributeError, FileNotFoundError) as err:
86
+ print(err)
87
+ if controlnet._control_image is None:
88
+ logging.error("Wrong ControlNet control image! Disabling ControlNet")
89
+ controlnet.enabled = False
90
+ lcm_diffusion_setting.controlnet = controlnet
backend/device.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import platform
2
+ from constants import DEVICE
3
+ import torch
4
+ import openvino as ov
5
+
6
+ core = ov.Core()
7
+
8
+
9
+ def is_openvino_device() -> bool:
10
+ if DEVICE.lower() == "cpu" or DEVICE.lower()[0] == "g" or DEVICE.lower()[0] == "n":
11
+ return True
12
+ else:
13
+ return False
14
+
15
+
16
+ def get_device_name() -> str:
17
+ if DEVICE == "cuda" or DEVICE == "mps":
18
+ default_gpu_index = torch.cuda.current_device()
19
+ return torch.cuda.get_device_name(default_gpu_index)
20
+ elif platform.system().lower() == "darwin":
21
+ return platform.processor()
22
+ elif is_openvino_device():
23
+ return core.get_property(DEVICE.upper(), "FULL_DEVICE_NAME")
backend/image_saver.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from os import path, mkdir
3
+ from typing import Any
4
+ from uuid import uuid4
5
+ from backend.models.lcmdiffusion_setting import LCMDiffusionSetting
6
+ from utils import get_image_file_extension
7
+
8
+
9
+ def get_exclude_keys():
10
+ exclude_keys = {
11
+ "init_image": True,
12
+ "generated_images": True,
13
+ "lora": {
14
+ "models_dir": True,
15
+ "path": True,
16
+ },
17
+ "dirs": True,
18
+ "controlnet": {
19
+ "adapter_path": True,
20
+ },
21
+ }
22
+ return exclude_keys
23
+
24
+
25
+ class ImageSaver:
26
+ @staticmethod
27
+ def save_images(
28
+ output_path: str,
29
+ images: Any,
30
+ folder_name: str = "",
31
+ format: str = "PNG",
32
+ lcm_diffusion_setting: LCMDiffusionSetting = None,
33
+ ) -> None:
34
+ gen_id = uuid4()
35
+
36
+ for index, image in enumerate(images):
37
+ if not path.exists(output_path):
38
+ mkdir(output_path)
39
+
40
+ if folder_name:
41
+ out_path = path.join(
42
+ output_path,
43
+ folder_name,
44
+ )
45
+ else:
46
+ out_path = output_path
47
+
48
+ if not path.exists(out_path):
49
+ mkdir(out_path)
50
+ image_extension = get_image_file_extension(format)
51
+ image.save(path.join(out_path, f"{gen_id}-{index+1}{image_extension}"))
52
+ if lcm_diffusion_setting:
53
+ with open(path.join(out_path, f"{gen_id}.json"), "w") as json_file:
54
+ json.dump(
55
+ lcm_diffusion_setting.model_dump(
56
+ exclude=get_exclude_keys(),
57
+ ),
58
+ json_file,
59
+ indent=4,
60
+ )
backend/lcm_text_to_image.py ADDED
@@ -0,0 +1,386 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ from math import ceil
3
+ from typing import Any
4
+
5
+ import numpy as np
6
+ import torch
7
+ import logging
8
+ from backend.device import is_openvino_device
9
+ from backend.lora import load_lora_weight
10
+ from backend.controlnet import (
11
+ load_controlnet_adapters,
12
+ update_controlnet_arguments,
13
+ )
14
+ from backend.models.lcmdiffusion_setting import (
15
+ DiffusionTask,
16
+ LCMDiffusionSetting,
17
+ LCMLora,
18
+ )
19
+ from backend.openvino.pipelines import (
20
+ get_ov_image_to_image_pipeline,
21
+ get_ov_text_to_image_pipeline,
22
+ ov_load_taesd,
23
+ )
24
+ from backend.pipelines.lcm import (
25
+ get_image_to_image_pipeline,
26
+ get_lcm_model_pipeline,
27
+ load_taesd,
28
+ )
29
+ from backend.pipelines.lcm_lora import get_lcm_lora_pipeline
30
+ from constants import DEVICE
31
+ from diffusers import LCMScheduler
32
+ from image_ops import resize_pil_image
33
+
34
+
35
+ class LCMTextToImage:
36
+ def __init__(
37
+ self,
38
+ device: str = "cpu",
39
+ ) -> None:
40
+ self.pipeline = None
41
+ self.use_openvino = False
42
+ self.device = ""
43
+ self.previous_model_id = None
44
+ self.previous_use_tae_sd = False
45
+ self.previous_use_lcm_lora = False
46
+ self.previous_ov_model_id = ""
47
+ self.previous_safety_checker = False
48
+ self.previous_use_openvino = False
49
+ self.img_to_img_pipeline = None
50
+ self.is_openvino_init = False
51
+ self.previous_lora = None
52
+ self.task_type = DiffusionTask.text_to_image
53
+ self.torch_data_type = (
54
+ torch.float32 if is_openvino_device() or DEVICE == "mps" else torch.float16
55
+ )
56
+ print(f"Torch datatype : {self.torch_data_type}")
57
+
58
+ def _pipeline_to_device(self):
59
+ print(f"Pipeline device : {DEVICE}")
60
+ print(f"Pipeline dtype : {self.torch_data_type}")
61
+ self.pipeline.to(
62
+ torch_device=DEVICE,
63
+ torch_dtype=self.torch_data_type,
64
+ )
65
+
66
+ def _add_freeu(self):
67
+ pipeline_class = self.pipeline.__class__.__name__
68
+ if isinstance(self.pipeline.scheduler, LCMScheduler):
69
+ if pipeline_class == "StableDiffusionPipeline":
70
+ print("Add FreeU - SD")
71
+ self.pipeline.enable_freeu(
72
+ s1=0.9,
73
+ s2=0.2,
74
+ b1=1.2,
75
+ b2=1.4,
76
+ )
77
+ elif pipeline_class == "StableDiffusionXLPipeline":
78
+ print("Add FreeU - SDXL")
79
+ self.pipeline.enable_freeu(
80
+ s1=0.6,
81
+ s2=0.4,
82
+ b1=1.1,
83
+ b2=1.2,
84
+ )
85
+
86
+ def _enable_vae_tiling(self):
87
+ self.pipeline.vae.enable_tiling()
88
+
89
+ def _update_lcm_scheduler_params(self):
90
+ if isinstance(self.pipeline.scheduler, LCMScheduler):
91
+ self.pipeline.scheduler = LCMScheduler.from_config(
92
+ self.pipeline.scheduler.config,
93
+ beta_start=0.001,
94
+ beta_end=0.01,
95
+ )
96
+
97
+ def init(
98
+ self,
99
+ device: str = "cpu",
100
+ lcm_diffusion_setting: LCMDiffusionSetting = LCMDiffusionSetting(),
101
+ ) -> None:
102
+ self.device = device
103
+ self.use_openvino = lcm_diffusion_setting.use_openvino
104
+ model_id = lcm_diffusion_setting.lcm_model_id
105
+ use_local_model = lcm_diffusion_setting.use_offline_model
106
+ use_tiny_auto_encoder = lcm_diffusion_setting.use_tiny_auto_encoder
107
+ use_lora = lcm_diffusion_setting.use_lcm_lora
108
+ lcm_lora: LCMLora = lcm_diffusion_setting.lcm_lora
109
+ ov_model_id = lcm_diffusion_setting.openvino_lcm_model_id
110
+
111
+ if lcm_diffusion_setting.diffusion_task == DiffusionTask.image_to_image.value:
112
+ lcm_diffusion_setting.init_image = resize_pil_image(
113
+ lcm_diffusion_setting.init_image,
114
+ lcm_diffusion_setting.image_width,
115
+ lcm_diffusion_setting.image_height,
116
+ )
117
+
118
+ if (
119
+ self.pipeline is None
120
+ or self.previous_model_id != model_id
121
+ or self.previous_use_tae_sd != use_tiny_auto_encoder
122
+ or self.previous_lcm_lora_base_id != lcm_lora.base_model_id
123
+ or self.previous_lcm_lora_id != lcm_lora.lcm_lora_id
124
+ or self.previous_use_lcm_lora != use_lora
125
+ or self.previous_ov_model_id != ov_model_id
126
+ or self.previous_safety_checker != lcm_diffusion_setting.use_safety_checker
127
+ or self.previous_use_openvino != lcm_diffusion_setting.use_openvino
128
+ or (
129
+ self.use_openvino
130
+ and (
131
+ self.previous_task_type != lcm_diffusion_setting.diffusion_task
132
+ or self.previous_lora != lcm_diffusion_setting.lora
133
+ )
134
+ )
135
+ or lcm_diffusion_setting.rebuild_pipeline
136
+ ):
137
+ if self.use_openvino and is_openvino_device():
138
+ if self.pipeline:
139
+ del self.pipeline
140
+ self.pipeline = None
141
+ gc.collect()
142
+ self.is_openvino_init = True
143
+ if (
144
+ lcm_diffusion_setting.diffusion_task
145
+ == DiffusionTask.text_to_image.value
146
+ ):
147
+ print(f"***** Init Text to image (OpenVINO) - {ov_model_id} *****")
148
+ self.pipeline = get_ov_text_to_image_pipeline(
149
+ ov_model_id,
150
+ use_local_model,
151
+ )
152
+ elif (
153
+ lcm_diffusion_setting.diffusion_task
154
+ == DiffusionTask.image_to_image.value
155
+ ):
156
+ print(f"***** Image to image (OpenVINO) - {ov_model_id} *****")
157
+ self.pipeline = get_ov_image_to_image_pipeline(
158
+ ov_model_id,
159
+ use_local_model,
160
+ )
161
+ else:
162
+ if self.pipeline:
163
+ del self.pipeline
164
+ self.pipeline = None
165
+ if self.img_to_img_pipeline:
166
+ del self.img_to_img_pipeline
167
+ self.img_to_img_pipeline = None
168
+
169
+ controlnet_args = load_controlnet_adapters(lcm_diffusion_setting)
170
+ if use_lora:
171
+ print(
172
+ f"***** Init LCM-LoRA pipeline - {lcm_lora.base_model_id} *****"
173
+ )
174
+ self.pipeline = get_lcm_lora_pipeline(
175
+ lcm_lora.base_model_id,
176
+ lcm_lora.lcm_lora_id,
177
+ use_local_model,
178
+ torch_data_type=self.torch_data_type,
179
+ pipeline_args=controlnet_args,
180
+ )
181
+
182
+ else:
183
+ print(f"***** Init LCM Model pipeline - {model_id} *****")
184
+ self.pipeline = get_lcm_model_pipeline(
185
+ model_id,
186
+ use_local_model,
187
+ controlnet_args,
188
+ )
189
+
190
+ self.img_to_img_pipeline = get_image_to_image_pipeline(self.pipeline)
191
+
192
+ if use_tiny_auto_encoder:
193
+ if self.use_openvino and is_openvino_device():
194
+ print("Using Tiny Auto Encoder (OpenVINO)")
195
+ ov_load_taesd(
196
+ self.pipeline,
197
+ use_local_model,
198
+ )
199
+ else:
200
+ print("Using Tiny Auto Encoder")
201
+ load_taesd(
202
+ self.pipeline,
203
+ use_local_model,
204
+ self.torch_data_type,
205
+ )
206
+ load_taesd(
207
+ self.img_to_img_pipeline,
208
+ use_local_model,
209
+ self.torch_data_type,
210
+ )
211
+
212
+ if not self.use_openvino and not is_openvino_device():
213
+ self._pipeline_to_device()
214
+
215
+ if (
216
+ lcm_diffusion_setting.diffusion_task
217
+ == DiffusionTask.image_to_image.value
218
+ and lcm_diffusion_setting.use_openvino
219
+ ):
220
+ self.pipeline.scheduler = LCMScheduler.from_config(
221
+ self.pipeline.scheduler.config,
222
+ )
223
+ else:
224
+ self._update_lcm_scheduler_params()
225
+
226
+ if use_lora:
227
+ self._add_freeu()
228
+
229
+ self.previous_model_id = model_id
230
+ self.previous_ov_model_id = ov_model_id
231
+ self.previous_use_tae_sd = use_tiny_auto_encoder
232
+ self.previous_lcm_lora_base_id = lcm_lora.base_model_id
233
+ self.previous_lcm_lora_id = lcm_lora.lcm_lora_id
234
+ self.previous_use_lcm_lora = use_lora
235
+ self.previous_safety_checker = lcm_diffusion_setting.use_safety_checker
236
+ self.previous_use_openvino = lcm_diffusion_setting.use_openvino
237
+ self.previous_task_type = lcm_diffusion_setting.diffusion_task
238
+ self.previous_lora = lcm_diffusion_setting.lora.model_copy(deep=True)
239
+ lcm_diffusion_setting.rebuild_pipeline = False
240
+ if (
241
+ lcm_diffusion_setting.diffusion_task
242
+ == DiffusionTask.text_to_image.value
243
+ ):
244
+ print(f"Pipeline : {self.pipeline}")
245
+ elif (
246
+ lcm_diffusion_setting.diffusion_task
247
+ == DiffusionTask.image_to_image.value
248
+ ):
249
+ if self.use_openvino and is_openvino_device():
250
+ print(f"Pipeline : {self.pipeline}")
251
+ else:
252
+ print(f"Pipeline : {self.img_to_img_pipeline}")
253
+ if self.use_openvino:
254
+ if lcm_diffusion_setting.lora.enabled:
255
+ print("Warning: Lora models not supported on OpenVINO mode")
256
+ else:
257
+ adapters = self.pipeline.get_active_adapters()
258
+ print(f"Active adapters : {adapters}")
259
+
260
+ def _get_timesteps(self):
261
+ time_steps = self.pipeline.scheduler.config.get("timesteps")
262
+ time_steps_value = [int(time_steps)] if time_steps else None
263
+ return time_steps_value
264
+
265
+ def generate(
266
+ self,
267
+ lcm_diffusion_setting: LCMDiffusionSetting,
268
+ reshape: bool = False,
269
+ ) -> Any:
270
+ guidance_scale = lcm_diffusion_setting.guidance_scale
271
+ img_to_img_inference_steps = lcm_diffusion_setting.inference_steps
272
+ check_step_value = int(
273
+ lcm_diffusion_setting.inference_steps * lcm_diffusion_setting.strength
274
+ )
275
+ if (
276
+ lcm_diffusion_setting.diffusion_task == DiffusionTask.image_to_image.value
277
+ and check_step_value < 1
278
+ ):
279
+ img_to_img_inference_steps = ceil(1 / lcm_diffusion_setting.strength)
280
+ print(
281
+ f"Strength: {lcm_diffusion_setting.strength},{img_to_img_inference_steps}"
282
+ )
283
+
284
+ if lcm_diffusion_setting.use_seed:
285
+ cur_seed = lcm_diffusion_setting.seed
286
+ if self.use_openvino:
287
+ np.random.seed(cur_seed)
288
+ else:
289
+ torch.manual_seed(cur_seed)
290
+
291
+ is_openvino_pipe = lcm_diffusion_setting.use_openvino and is_openvino_device()
292
+ if is_openvino_pipe:
293
+ print("Using OpenVINO")
294
+ if reshape and not self.is_openvino_init:
295
+ print("Reshape and compile")
296
+ self.pipeline.reshape(
297
+ batch_size=-1,
298
+ height=lcm_diffusion_setting.image_height,
299
+ width=lcm_diffusion_setting.image_width,
300
+ num_images_per_prompt=lcm_diffusion_setting.number_of_images,
301
+ )
302
+ self.pipeline.compile()
303
+
304
+ if self.is_openvino_init:
305
+ self.is_openvino_init = False
306
+
307
+ if not lcm_diffusion_setting.use_safety_checker:
308
+ self.pipeline.safety_checker = None
309
+ if (
310
+ lcm_diffusion_setting.diffusion_task
311
+ == DiffusionTask.image_to_image.value
312
+ and not is_openvino_pipe
313
+ ):
314
+ self.img_to_img_pipeline.safety_checker = None
315
+
316
+ if (
317
+ not lcm_diffusion_setting.use_lcm_lora
318
+ and not lcm_diffusion_setting.use_openvino
319
+ and lcm_diffusion_setting.guidance_scale != 1.0
320
+ ):
321
+ print("Not using LCM-LoRA so setting guidance_scale 1.0")
322
+ guidance_scale = 1.0
323
+
324
+ controlnet_args = update_controlnet_arguments(lcm_diffusion_setting)
325
+ if lcm_diffusion_setting.use_openvino:
326
+ if (
327
+ lcm_diffusion_setting.diffusion_task
328
+ == DiffusionTask.text_to_image.value
329
+ ):
330
+ result_images = self.pipeline(
331
+ prompt=lcm_diffusion_setting.prompt,
332
+ negative_prompt=lcm_diffusion_setting.negative_prompt,
333
+ num_inference_steps=lcm_diffusion_setting.inference_steps,
334
+ guidance_scale=guidance_scale,
335
+ width=lcm_diffusion_setting.image_width,
336
+ height=lcm_diffusion_setting.image_height,
337
+ num_images_per_prompt=lcm_diffusion_setting.number_of_images,
338
+ ).images
339
+ elif (
340
+ lcm_diffusion_setting.diffusion_task
341
+ == DiffusionTask.image_to_image.value
342
+ ):
343
+ result_images = self.pipeline(
344
+ image=lcm_diffusion_setting.init_image,
345
+ strength=lcm_diffusion_setting.strength,
346
+ prompt=lcm_diffusion_setting.prompt,
347
+ negative_prompt=lcm_diffusion_setting.negative_prompt,
348
+ num_inference_steps=img_to_img_inference_steps * 3,
349
+ guidance_scale=guidance_scale,
350
+ num_images_per_prompt=lcm_diffusion_setting.number_of_images,
351
+ ).images
352
+
353
+ else:
354
+ if (
355
+ lcm_diffusion_setting.diffusion_task
356
+ == DiffusionTask.text_to_image.value
357
+ ):
358
+ result_images = self.pipeline(
359
+ prompt=lcm_diffusion_setting.prompt,
360
+ negative_prompt=lcm_diffusion_setting.negative_prompt,
361
+ num_inference_steps=lcm_diffusion_setting.inference_steps,
362
+ guidance_scale=guidance_scale,
363
+ width=lcm_diffusion_setting.image_width,
364
+ height=lcm_diffusion_setting.image_height,
365
+ num_images_per_prompt=lcm_diffusion_setting.number_of_images,
366
+ timesteps=self._get_timesteps(),
367
+ **controlnet_args,
368
+ ).images
369
+
370
+ elif (
371
+ lcm_diffusion_setting.diffusion_task
372
+ == DiffusionTask.image_to_image.value
373
+ ):
374
+ result_images = self.img_to_img_pipeline(
375
+ image=lcm_diffusion_setting.init_image,
376
+ strength=lcm_diffusion_setting.strength,
377
+ prompt=lcm_diffusion_setting.prompt,
378
+ negative_prompt=lcm_diffusion_setting.negative_prompt,
379
+ num_inference_steps=img_to_img_inference_steps,
380
+ guidance_scale=guidance_scale,
381
+ width=lcm_diffusion_setting.image_width,
382
+ height=lcm_diffusion_setting.image_height,
383
+ num_images_per_prompt=lcm_diffusion_setting.number_of_images,
384
+ **controlnet_args,
385
+ ).images
386
+ return result_images
backend/lora.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ from os import path
3
+ from paths import get_file_name, FastStableDiffusionPaths
4
+ from pathlib import Path
5
+
6
+
7
+ # A basic class to keep track of the currently loaded LoRAs and
8
+ # their weights; the diffusers function \c get_active_adapters()
9
+ # returns a list of adapter names but not their weights so we need
10
+ # a way to keep track of the current LoRA weights to set whenever
11
+ # a new LoRA is loaded
12
+ class _lora_info:
13
+ def __init__(
14
+ self,
15
+ path: str,
16
+ weight: float,
17
+ ):
18
+ self.path = path
19
+ self.adapter_name = get_file_name(path)
20
+ self.weight = weight
21
+
22
+ def __del__(self):
23
+ self.path = None
24
+ self.adapter_name = None
25
+
26
+
27
+ _loaded_loras = []
28
+ _current_pipeline = None
29
+
30
+
31
+ # This function loads a LoRA from the LoRA path setting, so it's
32
+ # possible to load multiple LoRAs by calling this function more than
33
+ # once with a different LoRA path setting; note that if you plan to
34
+ # load multiple LoRAs and dynamically change their weights, you
35
+ # might want to set the LoRA fuse option to False
36
+ def load_lora_weight(
37
+ pipeline,
38
+ lcm_diffusion_setting,
39
+ ):
40
+ if not lcm_diffusion_setting.lora.path:
41
+ raise Exception("Empty lora model path")
42
+
43
+ if not path.exists(lcm_diffusion_setting.lora.path):
44
+ raise Exception("Lora model path is invalid")
45
+
46
+ # If the pipeline has been rebuilt since the last call, remove all
47
+ # references to previously loaded LoRAs and store the new pipeline
48
+ global _loaded_loras
49
+ global _current_pipeline
50
+ if pipeline != _current_pipeline:
51
+ for lora in _loaded_loras:
52
+ del lora
53
+ del _loaded_loras
54
+ _loaded_loras = []
55
+ _current_pipeline = pipeline
56
+
57
+ current_lora = _lora_info(
58
+ lcm_diffusion_setting.lora.path,
59
+ lcm_diffusion_setting.lora.weight,
60
+ )
61
+ _loaded_loras.append(current_lora)
62
+
63
+ if lcm_diffusion_setting.lora.enabled:
64
+ print(f"LoRA adapter name : {current_lora.adapter_name}")
65
+ pipeline.load_lora_weights(
66
+ FastStableDiffusionPaths.get_lora_models_path(),
67
+ weight_name=Path(lcm_diffusion_setting.lora.path).name,
68
+ local_files_only=True,
69
+ adapter_name=current_lora.adapter_name,
70
+ )
71
+ update_lora_weights(
72
+ pipeline,
73
+ lcm_diffusion_setting,
74
+ )
75
+
76
+ if lcm_diffusion_setting.lora.fuse:
77
+ pipeline.fuse_lora()
78
+
79
+
80
+ def get_lora_models(root_dir: str):
81
+ lora_models = glob.glob(f"{root_dir}/**/*.safetensors", recursive=True)
82
+ lora_models_map = {}
83
+ for file_path in lora_models:
84
+ lora_name = get_file_name(file_path)
85
+ if lora_name is not None:
86
+ lora_models_map[lora_name] = file_path
87
+ return lora_models_map
88
+
89
+
90
+ # This function returns a list of (adapter_name, weight) tuples for the
91
+ # currently loaded LoRAs
92
+ def get_active_lora_weights():
93
+ active_loras = []
94
+ for lora_info in _loaded_loras:
95
+ active_loras.append(
96
+ (
97
+ lora_info.adapter_name,
98
+ lora_info.weight,
99
+ )
100
+ )
101
+ return active_loras
102
+
103
+
104
+ # This function receives a pipeline, an lcm_diffusion_setting object and
105
+ # an optional list of updated (adapter_name, weight) tuples
106
+ def update_lora_weights(
107
+ pipeline,
108
+ lcm_diffusion_setting,
109
+ lora_weights=None,
110
+ ):
111
+ global _loaded_loras
112
+ global _current_pipeline
113
+ if pipeline != _current_pipeline:
114
+ print("Wrong pipeline when trying to update LoRA weights")
115
+ return
116
+ if lora_weights:
117
+ for idx, lora in enumerate(lora_weights):
118
+ if _loaded_loras[idx].adapter_name != lora[0]:
119
+ print("Wrong adapter name in LoRA enumeration!")
120
+ continue
121
+ _loaded_loras[idx].weight = lora[1]
122
+
123
+ adapter_names = []
124
+ adapter_weights = []
125
+ if lcm_diffusion_setting.use_lcm_lora:
126
+ adapter_names.append("lcm")
127
+ adapter_weights.append(1.0)
128
+ for lora in _loaded_loras:
129
+ adapter_names.append(lora.adapter_name)
130
+ adapter_weights.append(lora.weight)
131
+ pipeline.set_adapters(
132
+ adapter_names,
133
+ adapter_weights=adapter_weights,
134
+ )
135
+ adapter_weights = zip(adapter_names, adapter_weights)
136
+ print(f"Adapters: {list(adapter_weights)}")
backend/models/device.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel
2
+
3
+
4
+ class DeviceInfo(BaseModel):
5
+ device_type: str
6
+ device_name: str
7
+ os: str
8
+ platform: str
9
+ processor: str
backend/models/gen_images.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel
2
+ from enum import Enum, auto
3
+ from paths import FastStableDiffusionPaths
4
+
5
+
6
+ class ImageFormat(str, Enum):
7
+ """Image format"""
8
+
9
+ JPEG = "jpeg"
10
+ PNG = "png"
11
+
12
+
13
+ class GeneratedImages(BaseModel):
14
+ path: str = FastStableDiffusionPaths.get_results_path()
15
+ format: str = ImageFormat.PNG.value.upper()
16
+ save_image: bool = True
backend/models/lcmdiffusion_setting.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+ from PIL import Image
3
+ from typing import Any, Optional, Union
4
+
5
+ from constants import LCM_DEFAULT_MODEL, LCM_DEFAULT_MODEL_OPENVINO
6
+ from paths import FastStableDiffusionPaths
7
+ from pydantic import BaseModel
8
+
9
+
10
+ class LCMLora(BaseModel):
11
+ base_model_id: str = "Lykon/dreamshaper-8"
12
+ lcm_lora_id: str = "latent-consistency/lcm-lora-sdv1-5"
13
+
14
+
15
+ class DiffusionTask(str, Enum):
16
+ """Diffusion task types"""
17
+
18
+ text_to_image = "text_to_image"
19
+ image_to_image = "image_to_image"
20
+
21
+
22
+ class Lora(BaseModel):
23
+ models_dir: str = FastStableDiffusionPaths.get_lora_models_path()
24
+ path: Optional[Any] = None
25
+ weight: Optional[float] = 0.5
26
+ fuse: bool = True
27
+ enabled: bool = False
28
+
29
+
30
+ class ControlNetSetting(BaseModel):
31
+ adapter_path: Optional[str] = None # ControlNet adapter path
32
+ conditioning_scale: float = 0.5
33
+ enabled: bool = False
34
+ _control_image: Image = None # Control image, PIL image
35
+
36
+
37
+ class LCMDiffusionSetting(BaseModel):
38
+ lcm_model_id: str = LCM_DEFAULT_MODEL
39
+ openvino_lcm_model_id: str = LCM_DEFAULT_MODEL_OPENVINO
40
+ use_offline_model: bool = False
41
+ use_lcm_lora: bool = False
42
+ lcm_lora: Optional[LCMLora] = LCMLora()
43
+ use_tiny_auto_encoder: bool = False
44
+ use_openvino: bool = False
45
+ prompt: str = ""
46
+ negative_prompt: str = ""
47
+ init_image: Any = None
48
+ strength: Optional[float] = 0.6
49
+ image_height: Optional[int] = 512
50
+ image_width: Optional[int] = 512
51
+ inference_steps: Optional[int] = 1
52
+ guidance_scale: Optional[float] = 1
53
+ number_of_images: Optional[int] = 1
54
+ seed: Optional[int] = 123123
55
+ use_seed: bool = False
56
+ use_safety_checker: bool = False
57
+ diffusion_task: str = DiffusionTask.text_to_image.value
58
+ lora: Optional[Lora] = Lora()
59
+ controlnet: Optional[Union[ControlNetSetting, list[ControlNetSetting]]] = None
60
+ dirs: dict = {
61
+ "controlnet": FastStableDiffusionPaths.get_controlnet_models_path(),
62
+ "lora": FastStableDiffusionPaths.get_lora_models_path(),
63
+ }
64
+ rebuild_pipeline: bool = False
backend/models/upscale.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+
3
+
4
+ class UpscaleMode(str, Enum):
5
+ """Diffusion task types"""
6
+
7
+ normal = "normal"
8
+ sd_upscale = "sd_upscale"
9
+ aura_sr = "aura_sr"
backend/openvino/custom_ov_model_vae_decoder.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from backend.device import is_openvino_device
2
+
3
+ if is_openvino_device():
4
+ from optimum.intel.openvino.modeling_diffusion import OVModelVaeDecoder
5
+
6
+
7
+ class CustomOVModelVaeDecoder(OVModelVaeDecoder):
8
+ def __init__(
9
+ self,
10
+ model,
11
+ parent_model,
12
+ ov_config=None,
13
+ model_dir=None,
14
+ ):
15
+ super(OVModelVaeDecoder, self).__init__(
16
+ model,
17
+ parent_model,
18
+ ov_config,
19
+ "vae_decoder",
20
+ model_dir,
21
+ )
backend/openvino/pipelines.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from constants import DEVICE, LCM_DEFAULT_MODEL_OPENVINO
2
+ from backend.tiny_decoder import get_tiny_decoder_vae_model
3
+ from typing import Any
4
+ from backend.device import is_openvino_device
5
+ from paths import get_base_folder_name
6
+
7
+ if is_openvino_device():
8
+ from huggingface_hub import snapshot_download
9
+ from optimum.intel.openvino.modeling_diffusion import OVBaseModel
10
+
11
+ from optimum.intel.openvino.modeling_diffusion import (
12
+ OVStableDiffusionPipeline,
13
+ OVStableDiffusionImg2ImgPipeline,
14
+ OVStableDiffusionXLPipeline,
15
+ OVStableDiffusionXLImg2ImgPipeline,
16
+ )
17
+ from backend.openvino.custom_ov_model_vae_decoder import CustomOVModelVaeDecoder
18
+
19
+
20
+ def ov_load_taesd(
21
+ pipeline: Any,
22
+ use_local_model: bool = False,
23
+ ):
24
+ taesd_dir = snapshot_download(
25
+ repo_id=get_tiny_decoder_vae_model(pipeline.__class__.__name__),
26
+ local_files_only=use_local_model,
27
+ )
28
+ pipeline.vae_decoder = CustomOVModelVaeDecoder(
29
+ model=OVBaseModel.load_model(f"{taesd_dir}/vae_decoder/openvino_model.xml"),
30
+ parent_model=pipeline,
31
+ model_dir=taesd_dir,
32
+ )
33
+
34
+
35
+ def get_ov_text_to_image_pipeline(
36
+ model_id: str = LCM_DEFAULT_MODEL_OPENVINO,
37
+ use_local_model: bool = False,
38
+ ) -> Any:
39
+ if "xl" in get_base_folder_name(model_id).lower():
40
+ pipeline = OVStableDiffusionXLPipeline.from_pretrained(
41
+ model_id,
42
+ local_files_only=use_local_model,
43
+ ov_config={"CACHE_DIR": ""},
44
+ device=DEVICE.upper(),
45
+ )
46
+ else:
47
+ pipeline = OVStableDiffusionPipeline.from_pretrained(
48
+ model_id,
49
+ local_files_only=use_local_model,
50
+ ov_config={"CACHE_DIR": ""},
51
+ device=DEVICE.upper(),
52
+ )
53
+
54
+ return pipeline
55
+
56
+
57
+ def get_ov_image_to_image_pipeline(
58
+ model_id: str = LCM_DEFAULT_MODEL_OPENVINO,
59
+ use_local_model: bool = False,
60
+ ) -> Any:
61
+ if "xl" in get_base_folder_name(model_id).lower():
62
+ pipeline = OVStableDiffusionXLImg2ImgPipeline.from_pretrained(
63
+ model_id,
64
+ local_files_only=use_local_model,
65
+ ov_config={"CACHE_DIR": ""},
66
+ device=DEVICE.upper(),
67
+ )
68
+ else:
69
+ pipeline = OVStableDiffusionImg2ImgPipeline.from_pretrained(
70
+ model_id,
71
+ local_files_only=use_local_model,
72
+ ov_config={"CACHE_DIR": ""},
73
+ device=DEVICE.upper(),
74
+ )
75
+ return pipeline
backend/pipelines/lcm.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from constants import LCM_DEFAULT_MODEL
2
+ from diffusers import (
3
+ DiffusionPipeline,
4
+ AutoencoderTiny,
5
+ UNet2DConditionModel,
6
+ LCMScheduler,
7
+ )
8
+ import torch
9
+ from backend.tiny_decoder import get_tiny_decoder_vae_model
10
+ from typing import Any
11
+ from diffusers import (
12
+ LCMScheduler,
13
+ StableDiffusionImg2ImgPipeline,
14
+ StableDiffusionXLImg2ImgPipeline,
15
+ AutoPipelineForText2Image,
16
+ AutoPipelineForImage2Image,
17
+ StableDiffusionControlNetPipeline,
18
+ )
19
+
20
+
21
+ def _get_lcm_pipeline_from_base_model(
22
+ lcm_model_id: str,
23
+ base_model_id: str,
24
+ use_local_model: bool,
25
+ ):
26
+ pipeline = None
27
+ unet = UNet2DConditionModel.from_pretrained(
28
+ lcm_model_id,
29
+ torch_dtype=torch.float32,
30
+ local_files_only=use_local_model,
31
+ resume_download=True,
32
+ )
33
+ pipeline = DiffusionPipeline.from_pretrained(
34
+ base_model_id,
35
+ unet=unet,
36
+ torch_dtype=torch.float32,
37
+ local_files_only=use_local_model,
38
+ resume_download=True,
39
+ )
40
+ pipeline.scheduler = LCMScheduler.from_config(pipeline.scheduler.config)
41
+ return pipeline
42
+
43
+
44
+ def load_taesd(
45
+ pipeline: Any,
46
+ use_local_model: bool = False,
47
+ torch_data_type: torch.dtype = torch.float32,
48
+ ):
49
+ vae_model = get_tiny_decoder_vae_model(pipeline.__class__.__name__)
50
+ pipeline.vae = AutoencoderTiny.from_pretrained(
51
+ vae_model,
52
+ torch_dtype=torch_data_type,
53
+ local_files_only=use_local_model,
54
+ )
55
+
56
+
57
+ def get_lcm_model_pipeline(
58
+ model_id: str = LCM_DEFAULT_MODEL,
59
+ use_local_model: bool = False,
60
+ pipeline_args={},
61
+ ):
62
+ pipeline = None
63
+ if model_id == "latent-consistency/lcm-sdxl":
64
+ pipeline = _get_lcm_pipeline_from_base_model(
65
+ model_id,
66
+ "stabilityai/stable-diffusion-xl-base-1.0",
67
+ use_local_model,
68
+ )
69
+
70
+ elif model_id == "latent-consistency/lcm-ssd-1b":
71
+ pipeline = _get_lcm_pipeline_from_base_model(
72
+ model_id,
73
+ "segmind/SSD-1B",
74
+ use_local_model,
75
+ )
76
+ else:
77
+ # pipeline = DiffusionPipeline.from_pretrained(
78
+ pipeline = AutoPipelineForText2Image.from_pretrained(
79
+ model_id,
80
+ local_files_only=use_local_model,
81
+ **pipeline_args,
82
+ )
83
+
84
+ return pipeline
85
+
86
+
87
+ def get_image_to_image_pipeline(pipeline: Any) -> Any:
88
+ components = pipeline.components
89
+ pipeline_class = pipeline.__class__.__name__
90
+ if (
91
+ pipeline_class == "LatentConsistencyModelPipeline"
92
+ or pipeline_class == "StableDiffusionPipeline"
93
+ ):
94
+ return StableDiffusionImg2ImgPipeline(**components)
95
+ elif pipeline_class == "StableDiffusionControlNetPipeline":
96
+ return AutoPipelineForImage2Image.from_pipe(pipeline)
97
+ elif pipeline_class == "StableDiffusionXLPipeline":
98
+ return StableDiffusionXLImg2ImgPipeline(**components)
99
+ else:
100
+ raise Exception(f"Unknown pipeline {pipeline_class}")
backend/pipelines/lcm_lora.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pathlib
2
+ from os import path
3
+
4
+ import torch
5
+ from diffusers import (
6
+ AutoPipelineForText2Image,
7
+ LCMScheduler,
8
+ StableDiffusionPipeline,
9
+ )
10
+
11
+
12
+ def load_lcm_weights(
13
+ pipeline,
14
+ use_local_model,
15
+ lcm_lora_id,
16
+ ):
17
+ kwargs = {
18
+ "local_files_only": use_local_model,
19
+ "weight_name": "pytorch_lora_weights.safetensors",
20
+ }
21
+ pipeline.load_lora_weights(
22
+ lcm_lora_id,
23
+ **kwargs,
24
+ adapter_name="lcm",
25
+ )
26
+
27
+
28
+ def get_lcm_lora_pipeline(
29
+ base_model_id: str,
30
+ lcm_lora_id: str,
31
+ use_local_model: bool,
32
+ torch_data_type: torch.dtype,
33
+ pipeline_args={},
34
+ ):
35
+ if pathlib.Path(base_model_id).suffix == ".safetensors":
36
+ # SD 1.5 models only
37
+ # When loading a .safetensors model, the pipeline has to be created
38
+ # with StableDiffusionPipeline() since it's the only class that
39
+ # defines the method from_single_file(); afterwards a new pipeline
40
+ # is created using AutoPipelineForText2Image() for ControlNet
41
+ # support, in case ControlNet is enabled
42
+ if not path.exists(base_model_id):
43
+ raise FileNotFoundError(
44
+ f"Model file not found,Please check your model path: {base_model_id}"
45
+ )
46
+ print("Using single file Safetensors model (Supported models - SD 1.5 models)")
47
+
48
+ dummy_pipeline = StableDiffusionPipeline.from_single_file(
49
+ base_model_id,
50
+ torch_dtype=torch_data_type,
51
+ safety_checker=None,
52
+ load_safety_checker=False,
53
+ local_files_only=use_local_model,
54
+ use_safetensors=True,
55
+ )
56
+ pipeline = AutoPipelineForText2Image.from_pipe(
57
+ dummy_pipeline,
58
+ **pipeline_args,
59
+ )
60
+ del dummy_pipeline
61
+ else:
62
+ pipeline = AutoPipelineForText2Image.from_pretrained(
63
+ base_model_id,
64
+ torch_dtype=torch_data_type,
65
+ local_files_only=use_local_model,
66
+ **pipeline_args,
67
+ )
68
+
69
+ load_lcm_weights(
70
+ pipeline,
71
+ use_local_model,
72
+ lcm_lora_id,
73
+ )
74
+ # Always fuse LCM-LoRA
75
+ pipeline.fuse_lora()
76
+
77
+ if "lcm" in lcm_lora_id.lower() or "hypersd" in lcm_lora_id.lower():
78
+ print("LCM LoRA model detected so using recommended LCMScheduler")
79
+ pipeline.scheduler = LCMScheduler.from_config(pipeline.scheduler.config)
80
+
81
+ # pipeline.unet.to(memory_format=torch.channels_last)
82
+ return pipeline
backend/tiny_decoder.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from constants import (
2
+ TAESD_MODEL,
3
+ TAESDXL_MODEL,
4
+ TAESD_MODEL_OPENVINO,
5
+ TAESDXL_MODEL_OPENVINO,
6
+ )
7
+
8
+
9
+ def get_tiny_decoder_vae_model(pipeline_class) -> str:
10
+ print(f"Pipeline class : {pipeline_class}")
11
+ if (
12
+ pipeline_class == "LatentConsistencyModelPipeline"
13
+ or pipeline_class == "StableDiffusionPipeline"
14
+ or pipeline_class == "StableDiffusionImg2ImgPipeline"
15
+ or pipeline_class == "StableDiffusionControlNetPipeline"
16
+ or pipeline_class == "StableDiffusionControlNetImg2ImgPipeline"
17
+ ):
18
+ return TAESD_MODEL
19
+ elif (
20
+ pipeline_class == "StableDiffusionXLPipeline"
21
+ or pipeline_class == "StableDiffusionXLImg2ImgPipeline"
22
+ ):
23
+ return TAESDXL_MODEL
24
+ elif (
25
+ pipeline_class == "OVStableDiffusionPipeline"
26
+ or pipeline_class == "OVStableDiffusionImg2ImgPipeline"
27
+ ):
28
+ return TAESD_MODEL_OPENVINO
29
+ elif pipeline_class == "OVStableDiffusionXLPipeline":
30
+ return TAESDXL_MODEL_OPENVINO
31
+ else:
32
+ raise Exception("No valid pipeline class found!")
backend/upscale/aura_sr.py ADDED
@@ -0,0 +1,834 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AuraSR: GAN-based Super-Resolution for real-world, a reproduction of the GigaGAN* paper. Implementation is
2
+ # based on the unofficial lucidrains/gigagan-pytorch repository. Heavily modified from there.
3
+ #
4
+ # https://mingukkang.github.io/GigaGAN/
5
+ from math import log2, ceil
6
+ from functools import partial
7
+ from typing import Any, Optional, List, Iterable
8
+
9
+ import torch
10
+ from torchvision import transforms
11
+ from PIL import Image
12
+ from torch import nn, einsum, Tensor
13
+ import torch.nn.functional as F
14
+
15
+ from einops import rearrange, repeat, reduce
16
+ from einops.layers.torch import Rearrange
17
+
18
+
19
+ def get_same_padding(size, kernel, dilation, stride):
20
+ return ((size - 1) * (stride - 1) + dilation * (kernel - 1)) // 2
21
+
22
+
23
+ class AdaptiveConv2DMod(nn.Module):
24
+ def __init__(
25
+ self,
26
+ dim,
27
+ dim_out,
28
+ kernel,
29
+ *,
30
+ demod=True,
31
+ stride=1,
32
+ dilation=1,
33
+ eps=1e-8,
34
+ num_conv_kernels=1, # set this to be greater than 1 for adaptive
35
+ ):
36
+ super().__init__()
37
+ self.eps = eps
38
+
39
+ self.dim_out = dim_out
40
+
41
+ self.kernel = kernel
42
+ self.stride = stride
43
+ self.dilation = dilation
44
+ self.adaptive = num_conv_kernels > 1
45
+
46
+ self.weights = nn.Parameter(
47
+ torch.randn((num_conv_kernels, dim_out, dim, kernel, kernel))
48
+ )
49
+
50
+ self.demod = demod
51
+
52
+ nn.init.kaiming_normal_(
53
+ self.weights, a=0, mode="fan_in", nonlinearity="leaky_relu"
54
+ )
55
+
56
+ def forward(
57
+ self, fmap, mod: Optional[Tensor] = None, kernel_mod: Optional[Tensor] = None
58
+ ):
59
+ """
60
+ notation
61
+
62
+ b - batch
63
+ n - convs
64
+ o - output
65
+ i - input
66
+ k - kernel
67
+ """
68
+
69
+ b, h = fmap.shape[0], fmap.shape[-2]
70
+
71
+ # account for feature map that has been expanded by the scale in the first dimension
72
+ # due to multiscale inputs and outputs
73
+
74
+ if mod.shape[0] != b:
75
+ mod = repeat(mod, "b ... -> (s b) ...", s=b // mod.shape[0])
76
+
77
+ if exists(kernel_mod):
78
+ kernel_mod_has_el = kernel_mod.numel() > 0
79
+
80
+ assert self.adaptive or not kernel_mod_has_el
81
+
82
+ if kernel_mod_has_el and kernel_mod.shape[0] != b:
83
+ kernel_mod = repeat(
84
+ kernel_mod, "b ... -> (s b) ...", s=b // kernel_mod.shape[0]
85
+ )
86
+
87
+ # prepare weights for modulation
88
+
89
+ weights = self.weights
90
+
91
+ if self.adaptive:
92
+ weights = repeat(weights, "... -> b ...", b=b)
93
+
94
+ # determine an adaptive weight and 'select' the kernel to use with softmax
95
+
96
+ assert exists(kernel_mod) and kernel_mod.numel() > 0
97
+
98
+ kernel_attn = kernel_mod.softmax(dim=-1)
99
+ kernel_attn = rearrange(kernel_attn, "b n -> b n 1 1 1 1")
100
+
101
+ weights = reduce(weights * kernel_attn, "b n ... -> b ...", "sum")
102
+
103
+ # do the modulation, demodulation, as done in stylegan2
104
+
105
+ mod = rearrange(mod, "b i -> b 1 i 1 1")
106
+
107
+ weights = weights * (mod + 1)
108
+
109
+ if self.demod:
110
+ inv_norm = (
111
+ reduce(weights**2, "b o i k1 k2 -> b o 1 1 1", "sum")
112
+ .clamp(min=self.eps)
113
+ .rsqrt()
114
+ )
115
+ weights = weights * inv_norm
116
+
117
+ fmap = rearrange(fmap, "b c h w -> 1 (b c) h w")
118
+
119
+ weights = rearrange(weights, "b o ... -> (b o) ...")
120
+
121
+ padding = get_same_padding(h, self.kernel, self.dilation, self.stride)
122
+ fmap = F.conv2d(fmap, weights, padding=padding, groups=b)
123
+
124
+ return rearrange(fmap, "1 (b o) ... -> b o ...", b=b)
125
+
126
+
127
+ class Attend(nn.Module):
128
+ def __init__(self, dropout=0.0, flash=False):
129
+ super().__init__()
130
+ self.dropout = dropout
131
+ self.attn_dropout = nn.Dropout(dropout)
132
+ self.scale = nn.Parameter(torch.randn(1))
133
+ self.flash = flash
134
+
135
+ def flash_attn(self, q, k, v):
136
+ q, k, v = map(lambda t: t.contiguous(), (q, k, v))
137
+ out = F.scaled_dot_product_attention(
138
+ q, k, v, dropout_p=self.dropout if self.training else 0.0
139
+ )
140
+ return out
141
+
142
+ def forward(self, q, k, v):
143
+ if self.flash:
144
+ return self.flash_attn(q, k, v)
145
+
146
+ scale = q.shape[-1] ** -0.5
147
+
148
+ # similarity
149
+ sim = einsum("b h i d, b h j d -> b h i j", q, k) * scale
150
+
151
+ # attention
152
+ attn = sim.softmax(dim=-1)
153
+ attn = self.attn_dropout(attn)
154
+
155
+ # aggregate values
156
+ out = einsum("b h i j, b h j d -> b h i d", attn, v)
157
+
158
+ return out
159
+
160
+
161
+ def exists(x):
162
+ return x is not None
163
+
164
+
165
+ def default(val, d):
166
+ if exists(val):
167
+ return val
168
+ return d() if callable(d) else d
169
+
170
+
171
+ def cast_tuple(t, length=1):
172
+ if isinstance(t, tuple):
173
+ return t
174
+ return (t,) * length
175
+
176
+
177
+ def identity(t, *args, **kwargs):
178
+ return t
179
+
180
+
181
+ def is_power_of_two(n):
182
+ return log2(n).is_integer()
183
+
184
+
185
+ def null_iterator():
186
+ while True:
187
+ yield None
188
+
189
+ def Downsample(dim, dim_out=None):
190
+ return nn.Sequential(
191
+ Rearrange("b c (h p1) (w p2) -> b (c p1 p2) h w", p1=2, p2=2),
192
+ nn.Conv2d(dim * 4, default(dim_out, dim), 1),
193
+ )
194
+
195
+
196
+ class RMSNorm(nn.Module):
197
+ def __init__(self, dim):
198
+ super().__init__()
199
+ self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
200
+ self.eps = 1e-4
201
+
202
+ def forward(self, x):
203
+ return F.normalize(x, dim=1) * self.g * (x.shape[1] ** 0.5)
204
+
205
+
206
+ # building block modules
207
+
208
+
209
+ class Block(nn.Module):
210
+ def __init__(self, dim, dim_out, groups=8, num_conv_kernels=0):
211
+ super().__init__()
212
+ self.proj = AdaptiveConv2DMod(
213
+ dim, dim_out, kernel=3, num_conv_kernels=num_conv_kernels
214
+ )
215
+ self.kernel = 3
216
+ self.dilation = 1
217
+ self.stride = 1
218
+
219
+ self.act = nn.SiLU()
220
+
221
+ def forward(self, x, conv_mods_iter: Optional[Iterable] = None):
222
+ conv_mods_iter = default(conv_mods_iter, null_iterator())
223
+
224
+ x = self.proj(x, mod=next(conv_mods_iter), kernel_mod=next(conv_mods_iter))
225
+
226
+ x = self.act(x)
227
+ return x
228
+
229
+
230
+ class ResnetBlock(nn.Module):
231
+ def __init__(
232
+ self, dim, dim_out, *, groups=8, num_conv_kernels=0, style_dims: List = []
233
+ ):
234
+ super().__init__()
235
+ style_dims.extend([dim, num_conv_kernels, dim_out, num_conv_kernels])
236
+
237
+ self.block1 = Block(
238
+ dim, dim_out, groups=groups, num_conv_kernels=num_conv_kernels
239
+ )
240
+ self.block2 = Block(
241
+ dim_out, dim_out, groups=groups, num_conv_kernels=num_conv_kernels
242
+ )
243
+ self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
244
+
245
+ def forward(self, x, conv_mods_iter: Optional[Iterable] = None):
246
+ h = self.block1(x, conv_mods_iter=conv_mods_iter)
247
+ h = self.block2(h, conv_mods_iter=conv_mods_iter)
248
+
249
+ return h + self.res_conv(x)
250
+
251
+
252
+ class LinearAttention(nn.Module):
253
+ def __init__(self, dim, heads=4, dim_head=32):
254
+ super().__init__()
255
+ self.scale = dim_head**-0.5
256
+ self.heads = heads
257
+ hidden_dim = dim_head * heads
258
+
259
+ self.norm = RMSNorm(dim)
260
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
261
+
262
+ self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1), RMSNorm(dim))
263
+
264
+ def forward(self, x):
265
+ b, c, h, w = x.shape
266
+
267
+ x = self.norm(x)
268
+
269
+ qkv = self.to_qkv(x).chunk(3, dim=1)
270
+ q, k, v = map(
271
+ lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
272
+ )
273
+
274
+ q = q.softmax(dim=-2)
275
+ k = k.softmax(dim=-1)
276
+
277
+ q = q * self.scale
278
+
279
+ context = torch.einsum("b h d n, b h e n -> b h d e", k, v)
280
+
281
+ out = torch.einsum("b h d e, b h d n -> b h e n", context, q)
282
+ out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w)
283
+ return self.to_out(out)
284
+
285
+
286
+ class Attention(nn.Module):
287
+ def __init__(self, dim, heads=4, dim_head=32, flash=False):
288
+ super().__init__()
289
+ self.heads = heads
290
+ hidden_dim = dim_head * heads
291
+
292
+ self.norm = RMSNorm(dim)
293
+
294
+ self.attend = Attend(flash=flash)
295
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
296
+ self.to_out = nn.Conv2d(hidden_dim, dim, 1)
297
+
298
+ def forward(self, x):
299
+ b, c, h, w = x.shape
300
+ x = self.norm(x)
301
+ qkv = self.to_qkv(x).chunk(3, dim=1)
302
+
303
+ q, k, v = map(
304
+ lambda t: rearrange(t, "b (h c) x y -> b h (x y) c", h=self.heads), qkv
305
+ )
306
+
307
+ out = self.attend(q, k, v)
308
+ out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w)
309
+
310
+ return self.to_out(out)
311
+
312
+
313
+ # feedforward
314
+ def FeedForward(dim, mult=4):
315
+ return nn.Sequential(
316
+ RMSNorm(dim),
317
+ nn.Conv2d(dim, dim * mult, 1),
318
+ nn.GELU(),
319
+ nn.Conv2d(dim * mult, dim, 1),
320
+ )
321
+
322
+
323
+ # transformers
324
+ class Transformer(nn.Module):
325
+ def __init__(self, dim, dim_head=64, heads=8, depth=1, flash_attn=True, ff_mult=4):
326
+ super().__init__()
327
+ self.layers = nn.ModuleList([])
328
+
329
+ for _ in range(depth):
330
+ self.layers.append(
331
+ nn.ModuleList(
332
+ [
333
+ Attention(
334
+ dim=dim, dim_head=dim_head, heads=heads, flash=flash_attn
335
+ ),
336
+ FeedForward(dim=dim, mult=ff_mult),
337
+ ]
338
+ )
339
+ )
340
+
341
+ def forward(self, x):
342
+ for attn, ff in self.layers:
343
+ x = attn(x) + x
344
+ x = ff(x) + x
345
+
346
+ return x
347
+
348
+
349
+ class LinearTransformer(nn.Module):
350
+ def __init__(self, dim, dim_head=64, heads=8, depth=1, ff_mult=4):
351
+ super().__init__()
352
+ self.layers = nn.ModuleList([])
353
+
354
+ for _ in range(depth):
355
+ self.layers.append(
356
+ nn.ModuleList(
357
+ [
358
+ LinearAttention(dim=dim, dim_head=dim_head, heads=heads),
359
+ FeedForward(dim=dim, mult=ff_mult),
360
+ ]
361
+ )
362
+ )
363
+
364
+ def forward(self, x):
365
+ for attn, ff in self.layers:
366
+ x = attn(x) + x
367
+ x = ff(x) + x
368
+
369
+ return x
370
+
371
+
372
+ class NearestNeighborhoodUpsample(nn.Module):
373
+ def __init__(self, dim, dim_out=None):
374
+ super().__init__()
375
+ dim_out = default(dim_out, dim)
376
+ self.conv = nn.Conv2d(dim, dim_out, kernel_size=3, stride=1, padding=1)
377
+
378
+ def forward(self, x):
379
+
380
+ if x.shape[0] >= 64:
381
+ x = x.contiguous()
382
+
383
+ x = F.interpolate(x, scale_factor=2.0, mode="nearest")
384
+ x = self.conv(x)
385
+
386
+ return x
387
+
388
+ class EqualLinear(nn.Module):
389
+ def __init__(self, dim, dim_out, lr_mul=1, bias=True):
390
+ super().__init__()
391
+ self.weight = nn.Parameter(torch.randn(dim_out, dim))
392
+ if bias:
393
+ self.bias = nn.Parameter(torch.zeros(dim_out))
394
+
395
+ self.lr_mul = lr_mul
396
+
397
+ def forward(self, input):
398
+ return F.linear(input, self.weight * self.lr_mul, bias=self.bias * self.lr_mul)
399
+
400
+
401
+ class StyleGanNetwork(nn.Module):
402
+ def __init__(self, dim_in=128, dim_out=512, depth=8, lr_mul=0.1, dim_text_latent=0):
403
+ super().__init__()
404
+ self.dim_in = dim_in
405
+ self.dim_out = dim_out
406
+ self.dim_text_latent = dim_text_latent
407
+
408
+ layers = []
409
+ for i in range(depth):
410
+ is_first = i == 0
411
+
412
+ if is_first:
413
+ dim_in_layer = dim_in + dim_text_latent
414
+ else:
415
+ dim_in_layer = dim_out
416
+
417
+ dim_out_layer = dim_out
418
+
419
+ layers.extend(
420
+ [EqualLinear(dim_in_layer, dim_out_layer, lr_mul), nn.LeakyReLU(0.2)]
421
+ )
422
+
423
+ self.net = nn.Sequential(*layers)
424
+
425
+ def forward(self, x, text_latent=None):
426
+ x = F.normalize(x, dim=1)
427
+ if self.dim_text_latent > 0:
428
+ assert exists(text_latent)
429
+ x = torch.cat((x, text_latent), dim=-1)
430
+ return self.net(x)
431
+
432
+
433
+ class UnetUpsampler(torch.nn.Module):
434
+
435
+ def __init__(
436
+ self,
437
+ dim: int,
438
+ *,
439
+ image_size: int,
440
+ input_image_size: int,
441
+ init_dim: Optional[int] = None,
442
+ out_dim: Optional[int] = None,
443
+ style_network: Optional[dict] = None,
444
+ up_dim_mults: tuple = (1, 2, 4, 8, 16),
445
+ down_dim_mults: tuple = (4, 8, 16),
446
+ channels: int = 3,
447
+ resnet_block_groups: int = 8,
448
+ full_attn: tuple = (False, False, False, True, True),
449
+ flash_attn: bool = True,
450
+ self_attn_dim_head: int = 64,
451
+ self_attn_heads: int = 8,
452
+ attn_depths: tuple = (2, 2, 2, 2, 4),
453
+ mid_attn_depth: int = 4,
454
+ num_conv_kernels: int = 4,
455
+ resize_mode: str = "bilinear",
456
+ unconditional: bool = True,
457
+ skip_connect_scale: Optional[float] = None,
458
+ ):
459
+ super().__init__()
460
+ self.style_network = style_network = StyleGanNetwork(**style_network)
461
+ self.unconditional = unconditional
462
+ assert not (
463
+ unconditional
464
+ and exists(style_network)
465
+ and style_network.dim_text_latent > 0
466
+ )
467
+
468
+ assert is_power_of_two(image_size) and is_power_of_two(
469
+ input_image_size
470
+ ), "both output image size and input image size must be power of 2"
471
+ assert (
472
+ input_image_size < image_size
473
+ ), "input image size must be smaller than the output image size, thus upsampling"
474
+
475
+ self.image_size = image_size
476
+ self.input_image_size = input_image_size
477
+
478
+ style_embed_split_dims = []
479
+
480
+ self.channels = channels
481
+ input_channels = channels
482
+
483
+ init_dim = default(init_dim, dim)
484
+
485
+ up_dims = [init_dim, *map(lambda m: dim * m, up_dim_mults)]
486
+ init_down_dim = up_dims[len(up_dim_mults) - len(down_dim_mults)]
487
+ down_dims = [init_down_dim, *map(lambda m: dim * m, down_dim_mults)]
488
+ self.init_conv = nn.Conv2d(input_channels, init_down_dim, 7, padding=3)
489
+
490
+ up_in_out = list(zip(up_dims[:-1], up_dims[1:]))
491
+ down_in_out = list(zip(down_dims[:-1], down_dims[1:]))
492
+
493
+ block_klass = partial(
494
+ ResnetBlock,
495
+ groups=resnet_block_groups,
496
+ num_conv_kernels=num_conv_kernels,
497
+ style_dims=style_embed_split_dims,
498
+ )
499
+
500
+ FullAttention = partial(Transformer, flash_attn=flash_attn)
501
+ *_, mid_dim = up_dims
502
+
503
+ self.skip_connect_scale = default(skip_connect_scale, 2**-0.5)
504
+
505
+ self.downs = nn.ModuleList([])
506
+ self.ups = nn.ModuleList([])
507
+
508
+ block_count = 6
509
+
510
+ for ind, (
511
+ (dim_in, dim_out),
512
+ layer_full_attn,
513
+ layer_attn_depth,
514
+ ) in enumerate(zip(down_in_out, full_attn, attn_depths)):
515
+ attn_klass = FullAttention if layer_full_attn else LinearTransformer
516
+
517
+ blocks = []
518
+ for i in range(block_count):
519
+ blocks.append(block_klass(dim_in, dim_in))
520
+
521
+ self.downs.append(
522
+ nn.ModuleList(
523
+ [
524
+ nn.ModuleList(blocks),
525
+ nn.ModuleList(
526
+ [
527
+ (
528
+ attn_klass(
529
+ dim_in,
530
+ dim_head=self_attn_dim_head,
531
+ heads=self_attn_heads,
532
+ depth=layer_attn_depth,
533
+ )
534
+ if layer_full_attn
535
+ else None
536
+ ),
537
+ nn.Conv2d(
538
+ dim_in, dim_out, kernel_size=3, stride=2, padding=1
539
+ ),
540
+ ]
541
+ ),
542
+ ]
543
+ )
544
+ )
545
+
546
+ self.mid_block1 = block_klass(mid_dim, mid_dim)
547
+ self.mid_attn = FullAttention(
548
+ mid_dim,
549
+ dim_head=self_attn_dim_head,
550
+ heads=self_attn_heads,
551
+ depth=mid_attn_depth,
552
+ )
553
+ self.mid_block2 = block_klass(mid_dim, mid_dim)
554
+
555
+ *_, last_dim = up_dims
556
+
557
+ for ind, (
558
+ (dim_in, dim_out),
559
+ layer_full_attn,
560
+ layer_attn_depth,
561
+ ) in enumerate(
562
+ zip(
563
+ reversed(up_in_out),
564
+ reversed(full_attn),
565
+ reversed(attn_depths),
566
+ )
567
+ ):
568
+ attn_klass = FullAttention if layer_full_attn else LinearTransformer
569
+
570
+ blocks = []
571
+ input_dim = dim_in * 2 if ind < len(down_in_out) else dim_in
572
+ for i in range(block_count):
573
+ blocks.append(block_klass(input_dim, dim_in))
574
+
575
+ self.ups.append(
576
+ nn.ModuleList(
577
+ [
578
+ nn.ModuleList(blocks),
579
+ nn.ModuleList(
580
+ [
581
+ NearestNeighborhoodUpsample(
582
+ last_dim if ind == 0 else dim_out,
583
+ dim_in,
584
+ ),
585
+ (
586
+ attn_klass(
587
+ dim_in,
588
+ dim_head=self_attn_dim_head,
589
+ heads=self_attn_heads,
590
+ depth=layer_attn_depth,
591
+ )
592
+ if layer_full_attn
593
+ else None
594
+ ),
595
+ ]
596
+ ),
597
+ ]
598
+ )
599
+ )
600
+
601
+ self.out_dim = default(out_dim, channels)
602
+ self.final_res_block = block_klass(dim, dim)
603
+ self.final_to_rgb = nn.Conv2d(dim, channels, 1)
604
+ self.resize_mode = resize_mode
605
+ self.style_to_conv_modulations = nn.Linear(
606
+ style_network.dim_out, sum(style_embed_split_dims)
607
+ )
608
+ self.style_embed_split_dims = style_embed_split_dims
609
+
610
+ @property
611
+ def allowable_rgb_resolutions(self):
612
+ input_res_base = int(log2(self.input_image_size))
613
+ output_res_base = int(log2(self.image_size))
614
+ allowed_rgb_res_base = list(range(input_res_base, output_res_base))
615
+ return [*map(lambda p: 2**p, allowed_rgb_res_base)]
616
+
617
+ @property
618
+ def device(self):
619
+ return next(self.parameters()).device
620
+
621
+ @property
622
+ def total_params(self):
623
+ return sum([p.numel() for p in self.parameters()])
624
+
625
+ def resize_image_to(self, x, size):
626
+ return F.interpolate(x, (size, size), mode=self.resize_mode)
627
+
628
+ def forward(
629
+ self,
630
+ lowres_image: torch.Tensor,
631
+ styles: Optional[torch.Tensor] = None,
632
+ noise: Optional[torch.Tensor] = None,
633
+ global_text_tokens: Optional[torch.Tensor] = None,
634
+ return_all_rgbs: bool = False,
635
+ ):
636
+ x = lowres_image
637
+
638
+ noise_scale = 0.001 # Adjust the scale of the noise as needed
639
+ noise_aug = torch.randn_like(x) * noise_scale
640
+ x = x + noise_aug
641
+ x = x.clamp(0, 1)
642
+
643
+ shape = x.shape
644
+ batch_size = shape[0]
645
+
646
+ assert shape[-2:] == ((self.input_image_size,) * 2)
647
+
648
+ # styles
649
+ if not exists(styles):
650
+ assert exists(self.style_network)
651
+
652
+ noise = default(
653
+ noise,
654
+ torch.randn(
655
+ (batch_size, self.style_network.dim_in), device=self.device
656
+ ),
657
+ )
658
+ styles = self.style_network(noise, global_text_tokens)
659
+
660
+ # project styles to conv modulations
661
+ conv_mods = self.style_to_conv_modulations(styles)
662
+ conv_mods = conv_mods.split(self.style_embed_split_dims, dim=-1)
663
+ conv_mods = iter(conv_mods)
664
+
665
+ x = self.init_conv(x)
666
+
667
+ h = []
668
+ for blocks, (attn, downsample) in self.downs:
669
+ for block in blocks:
670
+ x = block(x, conv_mods_iter=conv_mods)
671
+ h.append(x)
672
+
673
+ if attn is not None:
674
+ x = attn(x)
675
+
676
+ x = downsample(x)
677
+
678
+ x = self.mid_block1(x, conv_mods_iter=conv_mods)
679
+ x = self.mid_attn(x)
680
+ x = self.mid_block2(x, conv_mods_iter=conv_mods)
681
+
682
+ for (
683
+ blocks,
684
+ (
685
+ upsample,
686
+ attn,
687
+ ),
688
+ ) in self.ups:
689
+ x = upsample(x)
690
+ for block in blocks:
691
+ if h != []:
692
+ res = h.pop()
693
+ res = res * self.skip_connect_scale
694
+ x = torch.cat((x, res), dim=1)
695
+
696
+ x = block(x, conv_mods_iter=conv_mods)
697
+
698
+ if attn is not None:
699
+ x = attn(x)
700
+
701
+ x = self.final_res_block(x, conv_mods_iter=conv_mods)
702
+ rgb = self.final_to_rgb(x)
703
+
704
+ if not return_all_rgbs:
705
+ return rgb
706
+
707
+ return rgb, []
708
+
709
+
710
+ def tile_image(image, chunk_size=64):
711
+ c, h, w = image.shape
712
+ h_chunks = ceil(h / chunk_size)
713
+ w_chunks = ceil(w / chunk_size)
714
+ tiles = []
715
+ for i in range(h_chunks):
716
+ for j in range(w_chunks):
717
+ tile = image[:, i * chunk_size:(i + 1) * chunk_size, j * chunk_size:(j + 1) * chunk_size]
718
+ tiles.append(tile)
719
+ return tiles, h_chunks, w_chunks
720
+
721
+
722
+ def merge_tiles(tiles, h_chunks, w_chunks, chunk_size=64):
723
+ # Determine the shape of the output tensor
724
+ c = tiles[0].shape[0]
725
+ h = h_chunks * chunk_size
726
+ w = w_chunks * chunk_size
727
+
728
+ # Create an empty tensor to hold the merged image
729
+ merged = torch.zeros((c, h, w), dtype=tiles[0].dtype)
730
+
731
+ # Iterate over the tiles and place them in the correct position
732
+ for idx, tile in enumerate(tiles):
733
+ i = idx // w_chunks
734
+ j = idx % w_chunks
735
+
736
+ h_start = i * chunk_size
737
+ w_start = j * chunk_size
738
+
739
+ tile_h, tile_w = tile.shape[1:]
740
+ merged[:, h_start:h_start+tile_h, w_start:w_start+tile_w] = tile
741
+
742
+ return merged
743
+
744
+
745
+ class AuraSR:
746
+ def __init__(self, config: dict[str, Any], device: str = "cuda"):
747
+ self.upsampler = UnetUpsampler(**config).to(device)
748
+ self.input_image_size = config["input_image_size"]
749
+
750
+ @classmethod
751
+ def from_pretrained(cls, model_id: str = "fal-ai/AuraSR",device: str="cuda",use_safetensors: bool = True):
752
+ import json
753
+ import torch
754
+ from pathlib import Path
755
+ from huggingface_hub import snapshot_download
756
+
757
+ # Check if model_id is a local file
758
+ if Path(model_id).is_file():
759
+ local_file = Path(model_id)
760
+ if local_file.suffix == '.safetensors':
761
+ use_safetensors = True
762
+ elif local_file.suffix == '.ckpt':
763
+ use_safetensors = False
764
+ else:
765
+ raise ValueError(f"Unsupported file format: {local_file.suffix}. Please use .safetensors or .ckpt files.")
766
+
767
+ # For local files, we need to provide the config separately
768
+ config_path = local_file.with_name('config.json')
769
+ if not config_path.exists():
770
+ raise FileNotFoundError(
771
+ f"Config file not found: {config_path}. "
772
+ f"When loading from a local file, ensure that 'config.json' "
773
+ f"is present in the same directory as '{local_file.name}'. "
774
+ f"If you're trying to load a model from Hugging Face, "
775
+ f"please provide the model ID instead of a file path."
776
+ )
777
+
778
+ config = json.loads(config_path.read_text())
779
+ hf_model_path = local_file.parent
780
+ else:
781
+ hf_model_path = Path(snapshot_download(model_id))
782
+ config = json.loads((hf_model_path / "config.json").read_text())
783
+
784
+ model = cls(config,device)
785
+
786
+ if use_safetensors:
787
+ try:
788
+ from safetensors.torch import load_file
789
+ checkpoint = load_file(hf_model_path / "model.safetensors" if not Path(model_id).is_file() else model_id)
790
+ except ImportError:
791
+ raise ImportError(
792
+ "The safetensors library is not installed. "
793
+ "Please install it with `pip install safetensors` "
794
+ "or use `use_safetensors=False` to load the model with PyTorch."
795
+ )
796
+ else:
797
+ checkpoint = torch.load(hf_model_path / "model.ckpt" if not Path(model_id).is_file() else model_id)
798
+
799
+ model.upsampler.load_state_dict(checkpoint, strict=True)
800
+ return model
801
+
802
+ @torch.no_grad()
803
+ def upscale_4x(self, image: Image.Image, max_batch_size=8) -> Image.Image:
804
+ tensor_transform = transforms.ToTensor()
805
+ device = self.upsampler.device
806
+
807
+ image_tensor = tensor_transform(image).unsqueeze(0)
808
+ _, _, h, w = image_tensor.shape
809
+ pad_h = (self.input_image_size - h % self.input_image_size) % self.input_image_size
810
+ pad_w = (self.input_image_size - w % self.input_image_size) % self.input_image_size
811
+
812
+ # Pad the image
813
+ image_tensor = torch.nn.functional.pad(image_tensor, (0, pad_w, 0, pad_h), mode='reflect').squeeze(0)
814
+ tiles, h_chunks, w_chunks = tile_image(image_tensor, self.input_image_size)
815
+
816
+ # Batch processing of tiles
817
+ num_tiles = len(tiles)
818
+ batches = [tiles[i:i + max_batch_size] for i in range(0, num_tiles, max_batch_size)]
819
+ reconstructed_tiles = []
820
+
821
+ for batch in batches:
822
+ model_input = torch.stack(batch).to(device)
823
+ generator_output = self.upsampler(
824
+ lowres_image=model_input,
825
+ noise=torch.randn(model_input.shape[0], 128, device=device)
826
+ )
827
+ reconstructed_tiles.extend(list(generator_output.clamp_(0, 1).detach().cpu()))
828
+
829
+ merged_tensor = merge_tiles(reconstructed_tiles, h_chunks, w_chunks, self.input_image_size * 4)
830
+ unpadded = merged_tensor[:, :h * 4, :w * 4]
831
+
832
+ to_pil = transforms.ToPILImage()
833
+ return to_pil(unpadded)
834
+
backend/upscale/aura_sr_upscale.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from backend.upscale.aura_sr import AuraSR
2
+ from PIL import Image
3
+
4
+
5
+ def upscale_aura_sr(image_path: str):
6
+
7
+ aura_sr = AuraSR.from_pretrained("fal-ai/AuraSR", device="cpu")
8
+ image_in = Image.open(image_path) # .resize((256, 256))
9
+ return aura_sr.upscale_4x(image_in)
backend/upscale/edsr_upscale_onnx.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import onnxruntime
3
+ from huggingface_hub import hf_hub_download
4
+ from PIL import Image
5
+
6
+
7
+ def upscale_edsr_2x(image_path: str):
8
+ input_image = Image.open(image_path).convert("RGB")
9
+ input_image = np.array(input_image).astype("float32")
10
+ input_image = np.transpose(input_image, (2, 0, 1))
11
+ img_arr = np.expand_dims(input_image, axis=0)
12
+
13
+ if np.max(img_arr) > 256: # 16-bit image
14
+ max_range = 65535
15
+ else:
16
+ max_range = 255.0
17
+ img = img_arr / max_range
18
+
19
+ model_path = hf_hub_download(
20
+ repo_id="rupeshs/edsr-onnx",
21
+ filename="edsr_onnxsim_2x.onnx",
22
+ )
23
+ sess = onnxruntime.InferenceSession(model_path)
24
+
25
+ input_name = sess.get_inputs()[0].name
26
+ output_name = sess.get_outputs()[0].name
27
+ output = sess.run(
28
+ [output_name],
29
+ {input_name: img},
30
+ )[0]
31
+
32
+ result = output.squeeze()
33
+ result = result.clip(0, 1)
34
+ image_array = np.transpose(result, (1, 2, 0))
35
+ image_array = np.uint8(image_array * 255)
36
+ upscaled_image = Image.fromarray(image_array)
37
+ return upscaled_image
backend/upscale/tiled_upscale.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import math
3
+ import logging
4
+ from PIL import Image, ImageDraw, ImageFilter
5
+ from backend.models.lcmdiffusion_setting import DiffusionTask
6
+ from context import Context
7
+ from constants import DEVICE
8
+
9
+
10
+ def generate_upscaled_image(
11
+ config,
12
+ input_path=None,
13
+ strength=0.3,
14
+ scale_factor=2.0,
15
+ tile_overlap=16,
16
+ upscale_settings=None,
17
+ context: Context = None,
18
+ output_path=None,
19
+ image_format="PNG",
20
+ ):
21
+ if config == None or (
22
+ input_path == None or input_path == "" and upscale_settings == None
23
+ ):
24
+ logging.error("Wrong arguments in tiled upscale function call!")
25
+ return
26
+
27
+ # Use the upscale_settings dict if provided; otherwise, build the
28
+ # upscale_settings dict using the function arguments and default values
29
+ if upscale_settings == None:
30
+ upscale_settings = {
31
+ "source_file": input_path,
32
+ "target_file": None,
33
+ "output_format": image_format,
34
+ "strength": strength,
35
+ "scale_factor": scale_factor,
36
+ "prompt": config.lcm_diffusion_setting.prompt,
37
+ "tile_overlap": tile_overlap,
38
+ "tile_size": 256,
39
+ "tiles": [],
40
+ }
41
+ source_image = Image.open(input_path) # PIL image
42
+ else:
43
+ source_image = Image.open(upscale_settings["source_file"])
44
+
45
+ upscale_settings["source_image"] = source_image
46
+
47
+ if upscale_settings["target_file"]:
48
+ result = Image.open(upscale_settings["target_file"])
49
+ else:
50
+ result = Image.new(
51
+ mode="RGBA",
52
+ size=(
53
+ source_image.size[0] * int(upscale_settings["scale_factor"]),
54
+ source_image.size[1] * int(upscale_settings["scale_factor"]),
55
+ ),
56
+ color=(0, 0, 0, 0),
57
+ )
58
+ upscale_settings["target_image"] = result
59
+
60
+ # If the custom tile definition array 'tiles' is empty, proceed with the
61
+ # default tiled upscale task by defining all the possible image tiles; note
62
+ # that the actual tile size is 'tile_size' + 'tile_overlap' and the target
63
+ # image width and height are no longer constrained to multiples of 256 but
64
+ # are instead multiples of the actual tile size
65
+ if len(upscale_settings["tiles"]) == 0:
66
+ tile_size = upscale_settings["tile_size"]
67
+ scale_factor = upscale_settings["scale_factor"]
68
+ tile_overlap = upscale_settings["tile_overlap"]
69
+ total_cols = math.ceil(
70
+ source_image.size[0] / tile_size
71
+ ) # Image width / tile size
72
+ total_rows = math.ceil(
73
+ source_image.size[1] / tile_size
74
+ ) # Image height / tile size
75
+ for y in range(0, total_rows):
76
+ y_offset = tile_overlap if y > 0 else 0 # Tile mask offset
77
+ for x in range(0, total_cols):
78
+ x_offset = tile_overlap if x > 0 else 0 # Tile mask offset
79
+ x1 = x * tile_size
80
+ y1 = y * tile_size
81
+ w = tile_size + (tile_overlap if x < total_cols - 1 else 0)
82
+ h = tile_size + (tile_overlap if y < total_rows - 1 else 0)
83
+ mask_box = ( # Default tile mask box definition
84
+ x_offset,
85
+ y_offset,
86
+ int(w * scale_factor),
87
+ int(h * scale_factor),
88
+ )
89
+ upscale_settings["tiles"].append(
90
+ {
91
+ "x": x1,
92
+ "y": y1,
93
+ "w": w,
94
+ "h": h,
95
+ "mask_box": mask_box,
96
+ "prompt": upscale_settings["prompt"], # Use top level prompt if available
97
+ "scale_factor": scale_factor,
98
+ }
99
+ )
100
+
101
+ # Generate the output image tiles
102
+ for i in range(0, len(upscale_settings["tiles"])):
103
+ generate_upscaled_tile(
104
+ config,
105
+ i,
106
+ upscale_settings,
107
+ context=context,
108
+ )
109
+
110
+ # Save completed upscaled image
111
+ if upscale_settings["output_format"].upper() == "JPEG":
112
+ result_rgb = result.convert("RGB")
113
+ result.close()
114
+ result = result_rgb
115
+ result.save(output_path)
116
+ result.close()
117
+ source_image.close()
118
+ return
119
+
120
+
121
+ def get_current_tile(
122
+ config,
123
+ context,
124
+ strength,
125
+ ):
126
+ config.lcm_diffusion_setting.strength = strength
127
+ config.lcm_diffusion_setting.diffusion_task = DiffusionTask.image_to_image.value
128
+ if (
129
+ config.lcm_diffusion_setting.use_tiny_auto_encoder
130
+ and config.lcm_diffusion_setting.use_openvino
131
+ ):
132
+ config.lcm_diffusion_setting.use_tiny_auto_encoder = False
133
+ current_tile = context.generate_text_to_image(
134
+ settings=config,
135
+ reshape=True,
136
+ device=DEVICE,
137
+ save_images=False,
138
+ save_config=False,
139
+ )[0]
140
+ return current_tile
141
+
142
+
143
+ # Generates a single tile from the source image as defined in the
144
+ # upscale_settings["tiles"] array with the corresponding index and pastes the
145
+ # generated tile into the target image using the corresponding mask and scale
146
+ # factor; note that scale factor for the target image and the individual tiles
147
+ # can be different, this function will adjust scale factors as needed
148
+ def generate_upscaled_tile(
149
+ config,
150
+ index,
151
+ upscale_settings,
152
+ context: Context = None,
153
+ ):
154
+ if config == None or upscale_settings == None:
155
+ logging.error("Wrong arguments in tile creation function call!")
156
+ return
157
+
158
+ x = upscale_settings["tiles"][index]["x"]
159
+ y = upscale_settings["tiles"][index]["y"]
160
+ w = upscale_settings["tiles"][index]["w"]
161
+ h = upscale_settings["tiles"][index]["h"]
162
+ tile_prompt = upscale_settings["tiles"][index]["prompt"]
163
+ scale_factor = upscale_settings["scale_factor"]
164
+ tile_scale_factor = upscale_settings["tiles"][index]["scale_factor"]
165
+ target_width = int(w * tile_scale_factor)
166
+ target_height = int(h * tile_scale_factor)
167
+ strength = upscale_settings["strength"]
168
+ source_image = upscale_settings["source_image"]
169
+ target_image = upscale_settings["target_image"]
170
+ mask_image = generate_tile_mask(config, index, upscale_settings)
171
+
172
+ config.lcm_diffusion_setting.number_of_images = 1
173
+ config.lcm_diffusion_setting.prompt = tile_prompt
174
+ config.lcm_diffusion_setting.image_width = target_width
175
+ config.lcm_diffusion_setting.image_height = target_height
176
+ config.lcm_diffusion_setting.init_image = source_image.crop((x, y, x + w, y + h))
177
+
178
+ current_tile = None
179
+ print(f"[SD Upscale] Generating tile {index + 1}/{len(upscale_settings['tiles'])} ")
180
+ if tile_prompt == None or tile_prompt == "":
181
+ config.lcm_diffusion_setting.prompt = ""
182
+ config.lcm_diffusion_setting.negative_prompt = ""
183
+ current_tile = get_current_tile(config, context, strength)
184
+ else:
185
+ # Attempt to use img2img with low denoising strength to
186
+ # generate the tiles with the extra aid of a prompt
187
+ # context = get_context(InterfaceType.CLI)
188
+ current_tile = get_current_tile(config, context, strength)
189
+
190
+ if math.isclose(scale_factor, tile_scale_factor):
191
+ target_image.paste(
192
+ current_tile, (int(x * scale_factor), int(y * scale_factor)), mask_image
193
+ )
194
+ else:
195
+ target_image.paste(
196
+ current_tile.resize((int(w * scale_factor), int(h * scale_factor))),
197
+ (int(x * scale_factor), int(y * scale_factor)),
198
+ mask_image.resize((int(w * scale_factor), int(h * scale_factor))),
199
+ )
200
+ mask_image.close()
201
+ current_tile.close()
202
+ config.lcm_diffusion_setting.init_image.close()
203
+
204
+
205
+ # Generate tile mask using the box definition in the upscale_settings["tiles"]
206
+ # array with the corresponding index; note that tile masks for the default
207
+ # tiled upscale task can be reused but that would complicate the code, so
208
+ # new tile masks are instead created for each tile
209
+ def generate_tile_mask(
210
+ config,
211
+ index,
212
+ upscale_settings,
213
+ ):
214
+ scale_factor = upscale_settings["scale_factor"]
215
+ tile_overlap = upscale_settings["tile_overlap"]
216
+ tile_scale_factor = upscale_settings["tiles"][index]["scale_factor"]
217
+ w = int(upscale_settings["tiles"][index]["w"] * tile_scale_factor)
218
+ h = int(upscale_settings["tiles"][index]["h"] * tile_scale_factor)
219
+ # The Stable Diffusion pipeline automatically adjusts the output size
220
+ # to multiples of 8 pixels; the mask must be created with the same
221
+ # size as the output tile
222
+ w = w - (w % 8)
223
+ h = h - (h % 8)
224
+ mask_box = upscale_settings["tiles"][index]["mask_box"]
225
+ if mask_box == None:
226
+ # Build a default solid mask with soft/transparent edges
227
+ mask_box = (
228
+ tile_overlap,
229
+ tile_overlap,
230
+ w - tile_overlap,
231
+ h - tile_overlap,
232
+ )
233
+ mask_image = Image.new(mode="RGBA", size=(w, h), color=(0, 0, 0, 0))
234
+ mask_draw = ImageDraw.Draw(mask_image)
235
+ mask_draw.rectangle(tuple(mask_box), fill=(0, 0, 0))
236
+ mask_blur = mask_image.filter(ImageFilter.BoxBlur(tile_overlap - 1))
237
+ mask_image.close()
238
+ return mask_blur
backend/upscale/upscaler.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from backend.models.lcmdiffusion_setting import DiffusionTask
2
+ from backend.models.upscale import UpscaleMode
3
+ from backend.upscale.edsr_upscale_onnx import upscale_edsr_2x
4
+ from backend.upscale.aura_sr_upscale import upscale_aura_sr
5
+ from backend.upscale.tiled_upscale import generate_upscaled_image
6
+ from context import Context
7
+ from PIL import Image
8
+ from state import get_settings
9
+
10
+
11
+ config = get_settings()
12
+
13
+
14
+ def upscale_image(
15
+ context: Context,
16
+ src_image_path: str,
17
+ dst_image_path: str,
18
+ scale_factor: int = 2,
19
+ upscale_mode: UpscaleMode = UpscaleMode.normal.value,
20
+ ):
21
+ if upscale_mode == UpscaleMode.normal.value:
22
+
23
+ upscaled_img = upscale_edsr_2x(src_image_path)
24
+ upscaled_img.save(dst_image_path)
25
+ print(f"Upscaled image saved {dst_image_path}")
26
+ elif upscale_mode == UpscaleMode.aura_sr.value:
27
+ upscaled_img = upscale_aura_sr(src_image_path)
28
+ upscaled_img.save(dst_image_path)
29
+ print(f"Upscaled image saved {dst_image_path}")
30
+ else:
31
+ config.settings.lcm_diffusion_setting.strength = (
32
+ 0.3 if config.settings.lcm_diffusion_setting.use_openvino else 0.1
33
+ )
34
+ config.settings.lcm_diffusion_setting.diffusion_task = (
35
+ DiffusionTask.image_to_image.value
36
+ )
37
+
38
+ generate_upscaled_image(
39
+ config.settings,
40
+ src_image_path,
41
+ config.settings.lcm_diffusion_setting.strength,
42
+ upscale_settings=None,
43
+ context=context,
44
+ tile_overlap=(
45
+ 32 if config.settings.lcm_diffusion_setting.use_openvino else 16
46
+ ),
47
+ output_path=dst_image_path,
48
+ image_format=config.settings.generated_images.format,
49
+ )
50
+ print(f"Upscaled image saved {dst_image_path}")
51
+
52
+ return [Image.open(dst_image_path)]
constants.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from os import environ
2
+
3
+ APP_VERSION = "v1.0.0 beta 33"
4
+ LCM_DEFAULT_MODEL = "stabilityai/sd-turbo"
5
+ LCM_DEFAULT_MODEL_OPENVINO = "rupeshs/sd-turbo-openvino"
6
+ APP_NAME = "FastSD CPU"
7
+ APP_SETTINGS_FILE = "settings.yaml"
8
+ RESULTS_DIRECTORY = "results"
9
+ CONFIG_DIRECTORY = "configs"
10
+ DEVICE = environ.get("DEVICE", "cpu")
11
+ SD_MODELS_FILE = "stable-diffusion-models.txt"
12
+ LCM_LORA_MODELS_FILE = "lcm-lora-models.txt"
13
+ OPENVINO_LCM_MODELS_FILE = "openvino-lcm-models.txt"
14
+ TAESD_MODEL = "madebyollin/taesd"
15
+ TAESDXL_MODEL = "madebyollin/taesdxl"
16
+ TAESD_MODEL_OPENVINO = "deinferno/taesd-openvino"
17
+ LCM_MODELS_FILE = "lcm-models.txt"
18
+ TAESDXL_MODEL_OPENVINO = "rupeshs/taesdxl-openvino"
19
+ LORA_DIRECTORY = "lora_models"
20
+ CONTROLNET_DIRECTORY = "controlnet_models"
context.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+ from app_settings import Settings
3
+ from models.interface_types import InterfaceType
4
+ from backend.models.lcmdiffusion_setting import DiffusionTask
5
+ from backend.lcm_text_to_image import LCMTextToImage
6
+ from time import perf_counter
7
+ from backend.image_saver import ImageSaver
8
+ from pprint import pprint
9
+
10
+
11
+ class Context:
12
+ def __init__(
13
+ self,
14
+ interface_type: InterfaceType,
15
+ device="cpu",
16
+ ):
17
+ self.interface_type = interface_type.value
18
+ self.lcm_text_to_image = LCMTextToImage(device)
19
+ self._latency = 0
20
+
21
+ @property
22
+ def latency(self):
23
+ return self._latency
24
+
25
+ def generate_text_to_image(
26
+ self,
27
+ settings: Settings,
28
+ reshape: bool = False,
29
+ device: str = "cpu",
30
+ save_images=True,
31
+ save_config=True,
32
+ ) -> Any:
33
+ if (
34
+ settings.lcm_diffusion_setting.use_tiny_auto_encoder
35
+ and settings.lcm_diffusion_setting.use_openvino
36
+ ):
37
+ print(
38
+ "WARNING: Tiny AutoEncoder is not supported in Image to image mode (OpenVINO)"
39
+ )
40
+ tick = perf_counter()
41
+ from state import get_settings
42
+
43
+ if (
44
+ settings.lcm_diffusion_setting.diffusion_task
45
+ == DiffusionTask.text_to_image.value
46
+ ):
47
+ settings.lcm_diffusion_setting.init_image = None
48
+
49
+ if save_config:
50
+ get_settings().save()
51
+
52
+ pprint(settings.lcm_diffusion_setting.model_dump())
53
+ if not settings.lcm_diffusion_setting.lcm_lora:
54
+ return None
55
+ self.lcm_text_to_image.init(
56
+ device,
57
+ settings.lcm_diffusion_setting,
58
+ )
59
+ images = self.lcm_text_to_image.generate(
60
+ settings.lcm_diffusion_setting,
61
+ reshape,
62
+ )
63
+ elapsed = perf_counter() - tick
64
+
65
+ if save_images and settings.generated_images.save_image:
66
+ ImageSaver.save_images(
67
+ settings.generated_images.path,
68
+ images=images,
69
+ lcm_diffusion_setting=settings.lcm_diffusion_setting,
70
+ format=settings.generated_images.format,
71
+ )
72
+ self._latency = elapsed
73
+ print(f"Latency : {elapsed:.2f} seconds")
74
+ if settings.lcm_diffusion_setting.controlnet:
75
+ if settings.lcm_diffusion_setting.controlnet.enabled:
76
+ images.append(settings.lcm_diffusion_setting.controlnet._control_image)
77
+ return images
frontend/cli_interactive.py ADDED
@@ -0,0 +1,655 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from os import path
2
+ from PIL import Image
3
+ from typing import Any
4
+
5
+ from constants import DEVICE
6
+ from paths import FastStableDiffusionPaths
7
+ from backend.upscale.upscaler import upscale_image
8
+ from backend.controlnet import controlnet_settings_from_dict
9
+ from backend.upscale.tiled_upscale import generate_upscaled_image
10
+ from frontend.webui.image_variations_ui import generate_image_variations
11
+ from backend.lora import (
12
+ get_active_lora_weights,
13
+ update_lora_weights,
14
+ load_lora_weight,
15
+ )
16
+ from backend.models.lcmdiffusion_setting import (
17
+ DiffusionTask,
18
+ LCMDiffusionSetting,
19
+ ControlNetSetting,
20
+ )
21
+
22
+
23
+ _batch_count = 1
24
+ _edit_lora_settings = False
25
+
26
+
27
+ def user_value(
28
+ value_type: type,
29
+ message: str,
30
+ default_value: Any,
31
+ ) -> Any:
32
+ try:
33
+ value = value_type(input(message))
34
+ except:
35
+ value = default_value
36
+ return value
37
+
38
+
39
+ def interactive_mode(
40
+ config,
41
+ context,
42
+ ):
43
+ print("=============================================")
44
+ print("Welcome to FastSD CPU Interactive CLI")
45
+ print("=============================================")
46
+ while True:
47
+ print("> 1. Text to Image")
48
+ print("> 2. Image to Image")
49
+ print("> 3. Image Variations")
50
+ print("> 4. EDSR Upscale")
51
+ print("> 5. SD Upscale")
52
+ print("> 6. Edit default generation settings")
53
+ print("> 7. Edit LoRA settings")
54
+ print("> 8. Edit ControlNet settings")
55
+ print("> 9. Edit negative prompt")
56
+ print("> 10. Quit")
57
+ option = user_value(
58
+ int,
59
+ "Enter a Diffusion Task number (1): ",
60
+ 1,
61
+ )
62
+ if option not in range(1, 11):
63
+ print("Wrong Diffusion Task number!")
64
+ exit()
65
+
66
+ if option == 1:
67
+ interactive_txt2img(
68
+ config,
69
+ context,
70
+ )
71
+ elif option == 2:
72
+ interactive_img2img(
73
+ config,
74
+ context,
75
+ )
76
+ elif option == 3:
77
+ interactive_variations(
78
+ config,
79
+ context,
80
+ )
81
+ elif option == 4:
82
+ interactive_edsr(
83
+ config,
84
+ context,
85
+ )
86
+ elif option == 5:
87
+ interactive_sdupscale(
88
+ config,
89
+ context,
90
+ )
91
+ elif option == 6:
92
+ interactive_settings(
93
+ config,
94
+ context,
95
+ )
96
+ elif option == 7:
97
+ interactive_lora(
98
+ config,
99
+ context,
100
+ True,
101
+ )
102
+ elif option == 8:
103
+ interactive_controlnet(
104
+ config,
105
+ context,
106
+ True,
107
+ )
108
+ elif option == 9:
109
+ interactive_negative(
110
+ config,
111
+ context,
112
+ )
113
+ elif option == 10:
114
+ exit()
115
+
116
+
117
+ def interactive_negative(
118
+ config,
119
+ context,
120
+ ):
121
+ settings = config.lcm_diffusion_setting
122
+ print(f"Current negative prompt: '{settings.negative_prompt}'")
123
+ user_input = input("Write a negative prompt (set guidance > 1.0): ")
124
+ if user_input == "":
125
+ return
126
+ else:
127
+ settings.negative_prompt = user_input
128
+
129
+
130
+ def interactive_controlnet(
131
+ config,
132
+ context,
133
+ menu_flag=False,
134
+ ):
135
+ """
136
+ @param menu_flag: Indicates whether this function was called from the main
137
+ interactive CLI menu; _True_ if called from the main menu, _False_ otherwise
138
+ """
139
+ settings = config.lcm_diffusion_setting
140
+ if not settings.controlnet:
141
+ settings.controlnet = ControlNetSetting()
142
+
143
+ current_enabled = settings.controlnet.enabled
144
+ current_adapter_path = settings.controlnet.adapter_path
145
+ current_conditioning_scale = settings.controlnet.conditioning_scale
146
+ current_control_image = settings.controlnet._control_image
147
+
148
+ option = input("Enable ControlNet? (y/N): ")
149
+ settings.controlnet.enabled = True if option.upper() == "Y" else False
150
+ if settings.controlnet.enabled:
151
+ option = input(
152
+ f"Enter ControlNet adapter path ({settings.controlnet.adapter_path}): "
153
+ )
154
+ if option != "":
155
+ settings.controlnet.adapter_path = option
156
+ settings.controlnet.conditioning_scale = user_value(
157
+ float,
158
+ f"Enter ControlNet conditioning scale ({settings.controlnet.conditioning_scale}): ",
159
+ settings.controlnet.conditioning_scale,
160
+ )
161
+ option = input(
162
+ f"Enter ControlNet control image path (Leave empty to reuse current): "
163
+ )
164
+ if option != "":
165
+ try:
166
+ new_image = Image.open(option)
167
+ settings.controlnet._control_image = new_image
168
+ except (AttributeError, FileNotFoundError) as e:
169
+ settings.controlnet._control_image = None
170
+ if (
171
+ not settings.controlnet.adapter_path
172
+ or not path.exists(settings.controlnet.adapter_path)
173
+ or not settings.controlnet._control_image
174
+ ):
175
+ print("Invalid ControlNet settings! Disabling ControlNet")
176
+ settings.controlnet.enabled = False
177
+
178
+ if (
179
+ settings.controlnet.enabled != current_enabled
180
+ or settings.controlnet.adapter_path != current_adapter_path
181
+ ):
182
+ settings.rebuild_pipeline = True
183
+
184
+
185
+ def interactive_lora(
186
+ config,
187
+ context,
188
+ menu_flag=False,
189
+ ):
190
+ """
191
+ @param menu_flag: Indicates whether this function was called from the main
192
+ interactive CLI menu; _True_ if called from the main menu, _False_ otherwise
193
+ """
194
+ if context == None or context.lcm_text_to_image.pipeline == None:
195
+ print("Diffusion pipeline not initialized, please run a generation task first!")
196
+ return
197
+
198
+ print("> 1. Change LoRA weights")
199
+ print("> 2. Load new LoRA model")
200
+ option = user_value(
201
+ int,
202
+ "Enter a LoRA option (1): ",
203
+ 1,
204
+ )
205
+ if option not in range(1, 3):
206
+ print("Wrong LoRA option!")
207
+ return
208
+
209
+ if option == 1:
210
+ update_weights = []
211
+ active_weights = get_active_lora_weights()
212
+ for lora in active_weights:
213
+ weight = user_value(
214
+ float,
215
+ f"Enter a new LoRA weight for {lora[0]} ({lora[1]}): ",
216
+ lora[1],
217
+ )
218
+ update_weights.append(
219
+ (
220
+ lora[0],
221
+ weight,
222
+ )
223
+ )
224
+ if len(update_weights) > 0:
225
+ update_lora_weights(
226
+ context.lcm_text_to_image.pipeline,
227
+ config.lcm_diffusion_setting,
228
+ update_weights,
229
+ )
230
+ elif option == 2:
231
+ # Load a new LoRA
232
+ settings = config.lcm_diffusion_setting
233
+ settings.lora.fuse = False
234
+ settings.lora.enabled = False
235
+ settings.lora.path = input("Enter LoRA model path: ")
236
+ settings.lora.weight = user_value(
237
+ float,
238
+ "Enter a LoRA weight (0.5): ",
239
+ 0.5,
240
+ )
241
+ if not path.exists(settings.lora.path):
242
+ print("Invalid LoRA model path!")
243
+ return
244
+ settings.lora.enabled = True
245
+ load_lora_weight(context.lcm_text_to_image.pipeline, settings)
246
+
247
+ if menu_flag:
248
+ global _edit_lora_settings
249
+ _edit_lora_settings = False
250
+ option = input("Edit LoRA settings after every generation? (y/N): ")
251
+ if option.upper() == "Y":
252
+ _edit_lora_settings = True
253
+
254
+
255
+ def interactive_settings(
256
+ config,
257
+ context,
258
+ ):
259
+ global _batch_count
260
+ settings = config.lcm_diffusion_setting
261
+ print("Enter generation settings (leave empty to use current value)")
262
+ print("> 1. Use LCM")
263
+ print("> 2. Use LCM-Lora")
264
+ print("> 3. Use OpenVINO")
265
+ option = user_value(
266
+ int,
267
+ "Select inference model option (1): ",
268
+ 1,
269
+ )
270
+ if option not in range(1, 4):
271
+ print("Wrong inference model option! Falling back to defaults")
272
+ return
273
+
274
+ settings.use_lcm_lora = False
275
+ settings.use_openvino = False
276
+ if option == 1:
277
+ lcm_model_id = input(f"Enter LCM model ID ({settings.lcm_model_id}): ")
278
+ if lcm_model_id != "":
279
+ settings.lcm_model_id = lcm_model_id
280
+ elif option == 2:
281
+ settings.use_lcm_lora = True
282
+ lcm_lora_id = input(
283
+ f"Enter LCM-Lora model ID ({settings.lcm_lora.lcm_lora_id}): "
284
+ )
285
+ if lcm_lora_id != "":
286
+ settings.lcm_lora.lcm_lora_id = lcm_lora_id
287
+ base_model_id = input(
288
+ f"Enter Base model ID ({settings.lcm_lora.base_model_id}): "
289
+ )
290
+ if base_model_id != "":
291
+ settings.lcm_lora.base_model_id = base_model_id
292
+ elif option == 3:
293
+ settings.use_openvino = True
294
+ openvino_lcm_model_id = input(
295
+ f"Enter OpenVINO model ID ({settings.openvino_lcm_model_id}): "
296
+ )
297
+ if openvino_lcm_model_id != "":
298
+ settings.openvino_lcm_model_id = openvino_lcm_model_id
299
+
300
+ settings.use_offline_model = True
301
+ settings.use_tiny_auto_encoder = True
302
+ option = input("Work offline? (Y/n): ")
303
+ if option.upper() == "N":
304
+ settings.use_offline_model = False
305
+ option = input("Use Tiny Auto Encoder? (Y/n): ")
306
+ if option.upper() == "N":
307
+ settings.use_tiny_auto_encoder = False
308
+
309
+ settings.image_width = user_value(
310
+ int,
311
+ f"Image width ({settings.image_width}): ",
312
+ settings.image_width,
313
+ )
314
+ settings.image_height = user_value(
315
+ int,
316
+ f"Image height ({settings.image_height}): ",
317
+ settings.image_height,
318
+ )
319
+ settings.inference_steps = user_value(
320
+ int,
321
+ f"Inference steps ({settings.inference_steps}): ",
322
+ settings.inference_steps,
323
+ )
324
+ settings.guidance_scale = user_value(
325
+ float,
326
+ f"Guidance scale ({settings.guidance_scale}): ",
327
+ settings.guidance_scale,
328
+ )
329
+ settings.number_of_images = user_value(
330
+ int,
331
+ f"Number of images per batch ({settings.number_of_images}): ",
332
+ settings.number_of_images,
333
+ )
334
+ _batch_count = user_value(
335
+ int,
336
+ f"Batch count ({_batch_count}): ",
337
+ _batch_count,
338
+ )
339
+ # output_format = user_value(int, f"Output format (PNG)", 1)
340
+ print(config.lcm_diffusion_setting)
341
+
342
+
343
+ def interactive_txt2img(
344
+ config,
345
+ context,
346
+ ):
347
+ global _batch_count
348
+ config.lcm_diffusion_setting.diffusion_task = DiffusionTask.text_to_image.value
349
+ user_input = input("Write a prompt (write 'exit' to quit): ")
350
+ while True:
351
+ if user_input == "exit":
352
+ return
353
+ elif user_input == "":
354
+ user_input = config.lcm_diffusion_setting.prompt
355
+ config.lcm_diffusion_setting.prompt = user_input
356
+ for i in range(0, _batch_count):
357
+ context.generate_text_to_image(
358
+ settings=config,
359
+ device=DEVICE,
360
+ )
361
+ if _edit_lora_settings:
362
+ interactive_lora(
363
+ config,
364
+ context,
365
+ )
366
+ user_input = input("Write a prompt: ")
367
+
368
+
369
+ def interactive_img2img(
370
+ config,
371
+ context,
372
+ ):
373
+ global _batch_count
374
+ settings = config.lcm_diffusion_setting
375
+ settings.diffusion_task = DiffusionTask.image_to_image.value
376
+ steps = settings.inference_steps
377
+ source_path = input("Image path: ")
378
+ if source_path == "":
379
+ print("Error : You need to provide a file in img2img mode")
380
+ return
381
+ settings.strength = user_value(
382
+ float,
383
+ f"img2img strength ({settings.strength}): ",
384
+ settings.strength,
385
+ )
386
+ settings.inference_steps = int(steps / settings.strength + 1)
387
+ user_input = input("Write a prompt (write 'exit' to quit): ")
388
+ while True:
389
+ if user_input == "exit":
390
+ settings.inference_steps = steps
391
+ return
392
+ settings.init_image = Image.open(source_path)
393
+ settings.prompt = user_input
394
+ for i in range(0, _batch_count):
395
+ context.generate_text_to_image(
396
+ settings=config,
397
+ device=DEVICE,
398
+ )
399
+ new_path = input(f"Image path ({source_path}): ")
400
+ if new_path != "":
401
+ source_path = new_path
402
+ settings.strength = user_value(
403
+ float,
404
+ f"img2img strength ({settings.strength}): ",
405
+ settings.strength,
406
+ )
407
+ if _edit_lora_settings:
408
+ interactive_lora(
409
+ config,
410
+ context,
411
+ )
412
+ settings.inference_steps = int(steps / settings.strength + 1)
413
+ user_input = input("Write a prompt: ")
414
+
415
+
416
+ def interactive_variations(
417
+ config,
418
+ context,
419
+ ):
420
+ global _batch_count
421
+ settings = config.lcm_diffusion_setting
422
+ settings.diffusion_task = DiffusionTask.image_to_image.value
423
+ steps = settings.inference_steps
424
+ source_path = input("Image path: ")
425
+ if source_path == "":
426
+ print("Error : You need to provide a file in Image variations mode")
427
+ return
428
+ settings.strength = user_value(
429
+ float,
430
+ f"Image variations strength ({settings.strength}): ",
431
+ settings.strength,
432
+ )
433
+ settings.inference_steps = int(steps / settings.strength + 1)
434
+ while True:
435
+ settings.init_image = Image.open(source_path)
436
+ settings.prompt = ""
437
+ for i in range(0, _batch_count):
438
+ generate_image_variations(
439
+ settings.init_image,
440
+ settings.strength,
441
+ )
442
+ if _edit_lora_settings:
443
+ interactive_lora(
444
+ config,
445
+ context,
446
+ )
447
+ user_input = input("Continue in Image variations mode? (Y/n): ")
448
+ if user_input.upper() == "N":
449
+ settings.inference_steps = steps
450
+ return
451
+ new_path = input(f"Image path ({source_path}): ")
452
+ if new_path != "":
453
+ source_path = new_path
454
+ settings.strength = user_value(
455
+ float,
456
+ f"Image variations strength ({settings.strength}): ",
457
+ settings.strength,
458
+ )
459
+ settings.inference_steps = int(steps / settings.strength + 1)
460
+
461
+
462
+ def interactive_edsr(
463
+ config,
464
+ context,
465
+ ):
466
+ source_path = input("Image path: ")
467
+ if source_path == "":
468
+ print("Error : You need to provide a file in EDSR mode")
469
+ return
470
+ while True:
471
+ output_path = FastStableDiffusionPaths.get_upscale_filepath(
472
+ source_path,
473
+ 2,
474
+ config.generated_images.format,
475
+ )
476
+ result = upscale_image(
477
+ context,
478
+ source_path,
479
+ output_path,
480
+ 2,
481
+ )
482
+ user_input = input("Continue in EDSR upscale mode? (Y/n): ")
483
+ if user_input.upper() == "N":
484
+ return
485
+ new_path = input(f"Image path ({source_path}): ")
486
+ if new_path != "":
487
+ source_path = new_path
488
+
489
+
490
+ def interactive_sdupscale_settings(config):
491
+ steps = config.lcm_diffusion_setting.inference_steps
492
+ custom_settings = {}
493
+ print("> 1. Upscale whole image")
494
+ print("> 2. Define custom tiles (advanced)")
495
+ option = user_value(
496
+ int,
497
+ "Select an SD Upscale option (1): ",
498
+ 1,
499
+ )
500
+ if option not in range(1, 3):
501
+ print("Wrong SD Upscale option!")
502
+ return
503
+
504
+ # custom_settings["source_file"] = args.file
505
+ custom_settings["source_file"] = ""
506
+ new_path = input(f"Input image path ({custom_settings['source_file']}): ")
507
+ if new_path != "":
508
+ custom_settings["source_file"] = new_path
509
+ if custom_settings["source_file"] == "":
510
+ print("Error : You need to provide a file in SD Upscale mode")
511
+ return
512
+ custom_settings["target_file"] = None
513
+ if option == 2:
514
+ custom_settings["target_file"] = input("Image to patch: ")
515
+ if custom_settings["target_file"] == "":
516
+ print("No target file provided, upscaling whole input image instead!")
517
+ custom_settings["target_file"] = None
518
+ option = 1
519
+ custom_settings["output_format"] = config.generated_images.format
520
+ custom_settings["strength"] = user_value(
521
+ float,
522
+ f"SD Upscale strength ({config.lcm_diffusion_setting.strength}): ",
523
+ config.lcm_diffusion_setting.strength,
524
+ )
525
+ config.lcm_diffusion_setting.inference_steps = int(
526
+ steps / custom_settings["strength"] + 1
527
+ )
528
+ if option == 1:
529
+ custom_settings["scale_factor"] = user_value(
530
+ float,
531
+ f"Scale factor (2.0): ",
532
+ 2.0,
533
+ )
534
+ custom_settings["tile_size"] = user_value(
535
+ int,
536
+ f"Split input image into tiles of the following size, in pixels (256): ",
537
+ 256,
538
+ )
539
+ custom_settings["tile_overlap"] = user_value(
540
+ int,
541
+ f"Tile overlap, in pixels (16): ",
542
+ 16,
543
+ )
544
+ elif option == 2:
545
+ custom_settings["scale_factor"] = user_value(
546
+ float,
547
+ "Input image to Image-to-patch scale_factor (2.0): ",
548
+ 2.0,
549
+ )
550
+ custom_settings["tile_size"] = 256
551
+ custom_settings["tile_overlap"] = 16
552
+ custom_settings["prompt"] = input(
553
+ "Write a prompt describing the input image (optional): "
554
+ )
555
+ custom_settings["tiles"] = []
556
+ if option == 2:
557
+ add_tile = True
558
+ while add_tile:
559
+ print("=== Define custom SD Upscale tile ===")
560
+ tile_x = user_value(
561
+ int,
562
+ "Enter tile's X position: ",
563
+ 0,
564
+ )
565
+ tile_y = user_value(
566
+ int,
567
+ "Enter tile's Y position: ",
568
+ 0,
569
+ )
570
+ tile_w = user_value(
571
+ int,
572
+ "Enter tile's width (256): ",
573
+ 256,
574
+ )
575
+ tile_h = user_value(
576
+ int,
577
+ "Enter tile's height (256): ",
578
+ 256,
579
+ )
580
+ tile_scale = user_value(
581
+ float,
582
+ "Enter tile's scale factor (2.0): ",
583
+ 2.0,
584
+ )
585
+ tile_prompt = input("Enter tile's prompt (optional): ")
586
+ custom_settings["tiles"].append(
587
+ {
588
+ "x": tile_x,
589
+ "y": tile_y,
590
+ "w": tile_w,
591
+ "h": tile_h,
592
+ "mask_box": None,
593
+ "prompt": tile_prompt,
594
+ "scale_factor": tile_scale,
595
+ }
596
+ )
597
+ tile_option = input("Do you want to define another tile? (y/N): ")
598
+ if tile_option == "" or tile_option.upper() == "N":
599
+ add_tile = False
600
+
601
+ return custom_settings
602
+
603
+
604
+ def interactive_sdupscale(
605
+ config,
606
+ context,
607
+ ):
608
+ settings = config.lcm_diffusion_setting
609
+ settings.diffusion_task = DiffusionTask.image_to_image.value
610
+ settings.init_image = ""
611
+ source_path = ""
612
+ steps = settings.inference_steps
613
+
614
+ while True:
615
+ custom_upscale_settings = None
616
+ option = input("Edit custom SD Upscale settings? (y/N): ")
617
+ if option.upper() == "Y":
618
+ config.lcm_diffusion_setting.inference_steps = steps
619
+ custom_upscale_settings = interactive_sdupscale_settings(config)
620
+ if not custom_upscale_settings:
621
+ return
622
+ source_path = custom_upscale_settings["source_file"]
623
+ else:
624
+ new_path = input(f"Image path ({source_path}): ")
625
+ if new_path != "":
626
+ source_path = new_path
627
+ if source_path == "":
628
+ print("Error : You need to provide a file in SD Upscale mode")
629
+ return
630
+ settings.strength = user_value(
631
+ float,
632
+ f"SD Upscale strength ({settings.strength}): ",
633
+ settings.strength,
634
+ )
635
+ settings.inference_steps = int(steps / settings.strength + 1)
636
+
637
+ output_path = FastStableDiffusionPaths.get_upscale_filepath(
638
+ source_path,
639
+ 2,
640
+ config.generated_images.format,
641
+ )
642
+ generate_upscaled_image(
643
+ config,
644
+ source_path,
645
+ settings.strength,
646
+ upscale_settings=custom_upscale_settings,
647
+ context=context,
648
+ tile_overlap=32 if settings.use_openvino else 16,
649
+ output_path=output_path,
650
+ image_format=config.generated_images.format,
651
+ )
652
+ user_input = input("Continue in SD Upscale mode? (Y/n): ")
653
+ if user_input.upper() == "N":
654
+ settings.inference_steps = steps
655
+ return
frontend/gui/app_window.py ADDED
@@ -0,0 +1,612 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PyQt5.QtWidgets import (
2
+ QWidget,
3
+ QPushButton,
4
+ QHBoxLayout,
5
+ QVBoxLayout,
6
+ QLabel,
7
+ QLineEdit,
8
+ QMainWindow,
9
+ QSlider,
10
+ QTabWidget,
11
+ QSpacerItem,
12
+ QSizePolicy,
13
+ QComboBox,
14
+ QCheckBox,
15
+ QTextEdit,
16
+ QToolButton,
17
+ QFileDialog,
18
+ )
19
+ from PyQt5 import QtWidgets, QtCore
20
+ from PyQt5.QtGui import QPixmap, QDesktopServices
21
+ from PyQt5.QtCore import QSize, QThreadPool, Qt, QUrl
22
+
23
+ from PIL.ImageQt import ImageQt
24
+ from constants import (
25
+ LCM_DEFAULT_MODEL,
26
+ LCM_DEFAULT_MODEL_OPENVINO,
27
+ APP_NAME,
28
+ APP_VERSION,
29
+ )
30
+ from frontend.gui.image_generator_worker import ImageGeneratorWorker
31
+ from app_settings import AppSettings
32
+ from paths import FastStableDiffusionPaths
33
+ from frontend.utils import is_reshape_required
34
+ from context import Context
35
+ from models.interface_types import InterfaceType
36
+ from constants import DEVICE
37
+ from frontend.utils import enable_openvino_controls, get_valid_model_id
38
+ from backend.models.lcmdiffusion_setting import DiffusionTask
39
+
40
+ # DPI scale fix
41
+ QtWidgets.QApplication.setAttribute(QtCore.Qt.AA_EnableHighDpiScaling, True)
42
+ QtWidgets.QApplication.setAttribute(QtCore.Qt.AA_UseHighDpiPixmaps, True)
43
+
44
+
45
+ class MainWindow(QMainWindow):
46
+ def __init__(self, config: AppSettings):
47
+ super().__init__()
48
+ self.config = config
49
+ # Prevent saved LoRA and ControlNet settings from being used by
50
+ # default; in GUI mode, the user must explicitly enable those
51
+ if self.config.settings.lcm_diffusion_setting.lora:
52
+ self.config.settings.lcm_diffusion_setting.lora.enabled = False
53
+ if self.config.settings.lcm_diffusion_setting.controlnet:
54
+ self.config.settings.lcm_diffusion_setting.controlnet.enabled = False
55
+ self.setWindowTitle(APP_NAME)
56
+ self.setFixedSize(QSize(600, 670))
57
+ self.init_ui()
58
+ self.pipeline = None
59
+ self.threadpool = QThreadPool()
60
+ self.device = "cpu"
61
+ self.previous_width = 0
62
+ self.previous_height = 0
63
+ self.previous_model = ""
64
+ self.previous_num_of_images = 0
65
+ self.context = Context(InterfaceType.GUI)
66
+ self.init_ui_values()
67
+ self.gen_images = []
68
+ self.image_index = 0
69
+ print(f"Output path : { self.config.settings.generated_images.path}")
70
+
71
+ def init_ui_values(self):
72
+ self.lcm_model.setEnabled(
73
+ not self.config.settings.lcm_diffusion_setting.use_openvino
74
+ )
75
+ self.guidance.setValue(
76
+ int(self.config.settings.lcm_diffusion_setting.guidance_scale * 10)
77
+ )
78
+ self.seed_value.setEnabled(self.config.settings.lcm_diffusion_setting.use_seed)
79
+ self.safety_checker.setChecked(
80
+ self.config.settings.lcm_diffusion_setting.use_safety_checker
81
+ )
82
+ self.use_openvino_check.setChecked(
83
+ self.config.settings.lcm_diffusion_setting.use_openvino
84
+ )
85
+ self.width.setCurrentText(
86
+ str(self.config.settings.lcm_diffusion_setting.image_width)
87
+ )
88
+ self.height.setCurrentText(
89
+ str(self.config.settings.lcm_diffusion_setting.image_height)
90
+ )
91
+ self.inference_steps.setValue(
92
+ int(self.config.settings.lcm_diffusion_setting.inference_steps)
93
+ )
94
+ self.seed_check.setChecked(self.config.settings.lcm_diffusion_setting.use_seed)
95
+ self.seed_value.setText(str(self.config.settings.lcm_diffusion_setting.seed))
96
+ self.use_local_model_folder.setChecked(
97
+ self.config.settings.lcm_diffusion_setting.use_offline_model
98
+ )
99
+ self.results_path.setText(self.config.settings.generated_images.path)
100
+ self.num_images.setValue(
101
+ self.config.settings.lcm_diffusion_setting.number_of_images
102
+ )
103
+ self.use_tae_sd.setChecked(
104
+ self.config.settings.lcm_diffusion_setting.use_tiny_auto_encoder
105
+ )
106
+ self.use_lcm_lora.setChecked(
107
+ self.config.settings.lcm_diffusion_setting.use_lcm_lora
108
+ )
109
+ self.lcm_model.setCurrentText(
110
+ get_valid_model_id(
111
+ self.config.lcm_models,
112
+ self.config.settings.lcm_diffusion_setting.lcm_model_id,
113
+ LCM_DEFAULT_MODEL,
114
+ )
115
+ )
116
+ self.base_model_id.setCurrentText(
117
+ get_valid_model_id(
118
+ self.config.stable_diffsuion_models,
119
+ self.config.settings.lcm_diffusion_setting.lcm_lora.base_model_id,
120
+ )
121
+ )
122
+ self.lcm_lora_id.setCurrentText(
123
+ get_valid_model_id(
124
+ self.config.lcm_lora_models,
125
+ self.config.settings.lcm_diffusion_setting.lcm_lora.lcm_lora_id,
126
+ )
127
+ )
128
+ self.openvino_lcm_model_id.setCurrentText(
129
+ get_valid_model_id(
130
+ self.config.openvino_lcm_models,
131
+ self.config.settings.lcm_diffusion_setting.openvino_lcm_model_id,
132
+ LCM_DEFAULT_MODEL_OPENVINO,
133
+ )
134
+ )
135
+ self.neg_prompt.setEnabled(
136
+ self.config.settings.lcm_diffusion_setting.use_lcm_lora
137
+ or self.config.settings.lcm_diffusion_setting.use_openvino
138
+ )
139
+ self.openvino_lcm_model_id.setEnabled(
140
+ self.config.settings.lcm_diffusion_setting.use_openvino
141
+ )
142
+
143
+ def init_ui(self):
144
+ self.create_main_tab()
145
+ self.create_settings_tab()
146
+ self.create_about_tab()
147
+ self.show()
148
+
149
+ def create_main_tab(self):
150
+ self.img = QLabel("<<Image>>")
151
+ self.img.setAlignment(Qt.AlignCenter)
152
+ self.img.setFixedSize(QSize(512, 512))
153
+ self.vspacer = QSpacerItem(20, 40, QSizePolicy.Minimum, QSizePolicy.Expanding)
154
+
155
+ self.prompt = QTextEdit()
156
+ self.prompt.setPlaceholderText("A fantasy landscape")
157
+ self.prompt.setAcceptRichText(False)
158
+ self.neg_prompt = QTextEdit()
159
+ self.neg_prompt.setPlaceholderText("")
160
+ self.neg_prompt.setAcceptRichText(False)
161
+ self.neg_prompt_label = QLabel("Negative prompt (Set guidance scale > 1.0):")
162
+ self.generate = QPushButton("Generate")
163
+ self.generate.clicked.connect(self.text_to_image)
164
+ self.prompt.setFixedHeight(40)
165
+ self.neg_prompt.setFixedHeight(35)
166
+ self.browse_results = QPushButton("...")
167
+ self.browse_results.setFixedWidth(30)
168
+ self.browse_results.clicked.connect(self.on_open_results_folder)
169
+ self.browse_results.setToolTip("Open output folder")
170
+
171
+ hlayout = QHBoxLayout()
172
+ hlayout.addWidget(self.neg_prompt)
173
+ hlayout.addWidget(self.generate)
174
+ hlayout.addWidget(self.browse_results)
175
+
176
+ self.previous_img_btn = QToolButton()
177
+ self.previous_img_btn.setText("<")
178
+ self.previous_img_btn.clicked.connect(self.on_show_previous_image)
179
+ self.next_img_btn = QToolButton()
180
+ self.next_img_btn.setText(">")
181
+ self.next_img_btn.clicked.connect(self.on_show_next_image)
182
+ hlayout_nav = QHBoxLayout()
183
+ hlayout_nav.addWidget(self.previous_img_btn)
184
+ hlayout_nav.addWidget(self.img)
185
+ hlayout_nav.addWidget(self.next_img_btn)
186
+
187
+ vlayout = QVBoxLayout()
188
+ vlayout.addLayout(hlayout_nav)
189
+ vlayout.addItem(self.vspacer)
190
+ vlayout.addWidget(self.prompt)
191
+ vlayout.addWidget(self.neg_prompt_label)
192
+ vlayout.addLayout(hlayout)
193
+
194
+ self.tab_widget = QTabWidget(self)
195
+ self.tab_main = QWidget()
196
+ self.tab_settings = QWidget()
197
+ self.tab_about = QWidget()
198
+ self.tab_main.setLayout(vlayout)
199
+
200
+ self.tab_widget.addTab(self.tab_main, "Text to Image")
201
+ self.tab_widget.addTab(self.tab_settings, "Settings")
202
+ self.tab_widget.addTab(self.tab_about, "About")
203
+
204
+ self.setCentralWidget(self.tab_widget)
205
+ self.use_seed = False
206
+
207
+ def create_settings_tab(self):
208
+ self.lcm_model_label = QLabel("Latent Consistency Model:")
209
+ # self.lcm_model = QLineEdit(LCM_DEFAULT_MODEL)
210
+ self.lcm_model = QComboBox(self)
211
+ self.lcm_model.addItems(self.config.lcm_models)
212
+ self.lcm_model.currentIndexChanged.connect(self.on_lcm_model_changed)
213
+
214
+ self.use_lcm_lora = QCheckBox("Use LCM LoRA")
215
+ self.use_lcm_lora.setChecked(False)
216
+ self.use_lcm_lora.stateChanged.connect(self.use_lcm_lora_changed)
217
+
218
+ self.lora_base_model_id_label = QLabel("Lora base model ID :")
219
+ self.base_model_id = QComboBox(self)
220
+ self.base_model_id.addItems(self.config.stable_diffsuion_models)
221
+ self.base_model_id.currentIndexChanged.connect(self.on_base_model_id_changed)
222
+
223
+ self.lcm_lora_model_id_label = QLabel("LCM LoRA model ID :")
224
+ self.lcm_lora_id = QComboBox(self)
225
+ self.lcm_lora_id.addItems(self.config.lcm_lora_models)
226
+ self.lcm_lora_id.currentIndexChanged.connect(self.on_lcm_lora_id_changed)
227
+
228
+ self.inference_steps_value = QLabel("Number of inference steps: 4")
229
+ self.inference_steps = QSlider(orientation=Qt.Orientation.Horizontal)
230
+ self.inference_steps.setMaximum(25)
231
+ self.inference_steps.setMinimum(1)
232
+ self.inference_steps.setValue(4)
233
+ self.inference_steps.valueChanged.connect(self.update_steps_label)
234
+
235
+ self.num_images_value = QLabel("Number of images: 1")
236
+ self.num_images = QSlider(orientation=Qt.Orientation.Horizontal)
237
+ self.num_images.setMaximum(100)
238
+ self.num_images.setMinimum(1)
239
+ self.num_images.setValue(1)
240
+ self.num_images.valueChanged.connect(self.update_num_images_label)
241
+
242
+ self.guidance_value = QLabel("Guidance scale: 1")
243
+ self.guidance = QSlider(orientation=Qt.Orientation.Horizontal)
244
+ self.guidance.setMaximum(20)
245
+ self.guidance.setMinimum(10)
246
+ self.guidance.setValue(10)
247
+ self.guidance.valueChanged.connect(self.update_guidance_label)
248
+
249
+ self.width_value = QLabel("Width :")
250
+ self.width = QComboBox(self)
251
+ self.width.addItem("256")
252
+ self.width.addItem("512")
253
+ self.width.addItem("768")
254
+ self.width.addItem("1024")
255
+ self.width.setCurrentText("512")
256
+ self.width.currentIndexChanged.connect(self.on_width_changed)
257
+
258
+ self.height_value = QLabel("Height :")
259
+ self.height = QComboBox(self)
260
+ self.height.addItem("256")
261
+ self.height.addItem("512")
262
+ self.height.addItem("768")
263
+ self.height.addItem("1024")
264
+ self.height.setCurrentText("512")
265
+ self.height.currentIndexChanged.connect(self.on_height_changed)
266
+
267
+ self.seed_check = QCheckBox("Use seed")
268
+ self.seed_value = QLineEdit()
269
+ self.seed_value.setInputMask("9999999999")
270
+ self.seed_value.setText("123123")
271
+ self.seed_check.stateChanged.connect(self.seed_changed)
272
+
273
+ self.safety_checker = QCheckBox("Use safety checker")
274
+ self.safety_checker.setChecked(True)
275
+ self.safety_checker.stateChanged.connect(self.use_safety_checker_changed)
276
+
277
+ self.use_openvino_check = QCheckBox("Use OpenVINO")
278
+ self.use_openvino_check.setChecked(False)
279
+ self.openvino_model_label = QLabel("OpenVINO LCM model:")
280
+ self.use_local_model_folder = QCheckBox(
281
+ "Use locally cached model or downloaded model folder(offline)"
282
+ )
283
+ self.openvino_lcm_model_id = QComboBox(self)
284
+ self.openvino_lcm_model_id.addItems(self.config.openvino_lcm_models)
285
+ self.openvino_lcm_model_id.currentIndexChanged.connect(
286
+ self.on_openvino_lcm_model_id_changed
287
+ )
288
+
289
+ self.use_openvino_check.setEnabled(enable_openvino_controls())
290
+ self.use_local_model_folder.setChecked(False)
291
+ self.use_local_model_folder.stateChanged.connect(self.use_offline_model_changed)
292
+ self.use_openvino_check.stateChanged.connect(self.use_openvino_changed)
293
+
294
+ self.use_tae_sd = QCheckBox(
295
+ "Use Tiny Auto Encoder - TAESD (Fast, moderate quality)"
296
+ )
297
+ self.use_tae_sd.setChecked(False)
298
+ self.use_tae_sd.stateChanged.connect(self.use_tae_sd_changed)
299
+
300
+ hlayout = QHBoxLayout()
301
+ hlayout.addWidget(self.seed_check)
302
+ hlayout.addWidget(self.seed_value)
303
+ hspacer = QSpacerItem(20, 10, QSizePolicy.Expanding, QSizePolicy.Minimum)
304
+ slider_hspacer = QSpacerItem(20, 10, QSizePolicy.Expanding, QSizePolicy.Minimum)
305
+
306
+ self.results_path_label = QLabel("Output path:")
307
+ self.results_path = QLineEdit()
308
+ self.results_path.textChanged.connect(self.on_path_changed)
309
+ self.browse_folder_btn = QToolButton()
310
+ self.browse_folder_btn.setText("...")
311
+ self.browse_folder_btn.clicked.connect(self.on_browse_folder)
312
+
313
+ self.reset = QPushButton("Reset All")
314
+ self.reset.clicked.connect(self.reset_all_settings)
315
+
316
+ vlayout = QVBoxLayout()
317
+ vspacer = QSpacerItem(20, 20, QSizePolicy.Minimum, QSizePolicy.Expanding)
318
+ vlayout.addItem(hspacer)
319
+ vlayout.setSpacing(3)
320
+ vlayout.addWidget(self.lcm_model_label)
321
+ vlayout.addWidget(self.lcm_model)
322
+ vlayout.addWidget(self.use_local_model_folder)
323
+ vlayout.addWidget(self.use_lcm_lora)
324
+ vlayout.addWidget(self.lora_base_model_id_label)
325
+ vlayout.addWidget(self.base_model_id)
326
+ vlayout.addWidget(self.lcm_lora_model_id_label)
327
+ vlayout.addWidget(self.lcm_lora_id)
328
+ vlayout.addWidget(self.use_openvino_check)
329
+ vlayout.addWidget(self.openvino_model_label)
330
+ vlayout.addWidget(self.openvino_lcm_model_id)
331
+ vlayout.addWidget(self.use_tae_sd)
332
+ vlayout.addItem(slider_hspacer)
333
+ vlayout.addWidget(self.inference_steps_value)
334
+ vlayout.addWidget(self.inference_steps)
335
+ vlayout.addWidget(self.num_images_value)
336
+ vlayout.addWidget(self.num_images)
337
+ vlayout.addWidget(self.width_value)
338
+ vlayout.addWidget(self.width)
339
+ vlayout.addWidget(self.height_value)
340
+ vlayout.addWidget(self.height)
341
+ vlayout.addWidget(self.guidance_value)
342
+ vlayout.addWidget(self.guidance)
343
+ vlayout.addLayout(hlayout)
344
+ vlayout.addWidget(self.safety_checker)
345
+
346
+ vlayout.addWidget(self.results_path_label)
347
+ hlayout_path = QHBoxLayout()
348
+ hlayout_path.addWidget(self.results_path)
349
+ hlayout_path.addWidget(self.browse_folder_btn)
350
+ vlayout.addLayout(hlayout_path)
351
+ self.tab_settings.setLayout(vlayout)
352
+ hlayout_reset = QHBoxLayout()
353
+ hspacer = QSpacerItem(20, 20, QSizePolicy.Expanding, QSizePolicy.Minimum)
354
+ hlayout_reset.addItem(hspacer)
355
+ hlayout_reset.addWidget(self.reset)
356
+ vlayout.addLayout(hlayout_reset)
357
+ vlayout.addItem(vspacer)
358
+
359
+ def create_about_tab(self):
360
+ self.label = QLabel()
361
+ self.label.setAlignment(Qt.AlignCenter)
362
+ self.label.setText(
363
+ f"""<h1>FastSD CPU {APP_VERSION}</h1>
364
+ <h3>(c)2023 - 2024 Rupesh Sreeraman</h3>
365
+ <h3>Faster stable diffusion on CPU</h3>
366
+ <h3>Based on Latent Consistency Models</h3>
367
+ <h3>GitHub : https://github.com/rupeshs/fastsdcpu/</h3>"""
368
+ )
369
+
370
+ vlayout = QVBoxLayout()
371
+ vlayout.addWidget(self.label)
372
+ self.tab_about.setLayout(vlayout)
373
+
374
+ def show_image(self, pixmap):
375
+ image_width = self.config.settings.lcm_diffusion_setting.image_width
376
+ image_height = self.config.settings.lcm_diffusion_setting.image_height
377
+ if image_width > 512 or image_height > 512:
378
+ new_width = 512 if image_width > 512 else image_width
379
+ new_height = 512 if image_height > 512 else image_height
380
+ self.img.setPixmap(
381
+ pixmap.scaled(
382
+ new_width,
383
+ new_height,
384
+ Qt.KeepAspectRatio,
385
+ )
386
+ )
387
+ else:
388
+ self.img.setPixmap(pixmap)
389
+
390
+ def on_show_next_image(self):
391
+ if self.image_index != len(self.gen_images) - 1 and len(self.gen_images) > 0:
392
+ self.previous_img_btn.setEnabled(True)
393
+ self.image_index += 1
394
+ self.show_image(self.gen_images[self.image_index])
395
+ if self.image_index == len(self.gen_images) - 1:
396
+ self.next_img_btn.setEnabled(False)
397
+
398
+ def on_open_results_folder(self):
399
+ QDesktopServices.openUrl(
400
+ QUrl.fromLocalFile(self.config.settings.generated_images.path)
401
+ )
402
+
403
+ def on_show_previous_image(self):
404
+ if self.image_index != 0:
405
+ self.next_img_btn.setEnabled(True)
406
+ self.image_index -= 1
407
+ self.show_image(self.gen_images[self.image_index])
408
+ if self.image_index == 0:
409
+ self.previous_img_btn.setEnabled(False)
410
+
411
+ def on_path_changed(self, text):
412
+ self.config.settings.generated_images.path = text
413
+
414
+ def on_browse_folder(self):
415
+ options = QFileDialog.Options()
416
+ options |= QFileDialog.ShowDirsOnly
417
+
418
+ folder_path = QFileDialog.getExistingDirectory(
419
+ self, "Select a Folder", "", options=options
420
+ )
421
+
422
+ if folder_path:
423
+ self.config.settings.generated_images.path = folder_path
424
+ self.results_path.setText(folder_path)
425
+
426
+ def on_width_changed(self, index):
427
+ width_txt = self.width.itemText(index)
428
+ self.config.settings.lcm_diffusion_setting.image_width = int(width_txt)
429
+
430
+ def on_height_changed(self, index):
431
+ height_txt = self.height.itemText(index)
432
+ self.config.settings.lcm_diffusion_setting.image_height = int(height_txt)
433
+
434
+ def on_lcm_model_changed(self, index):
435
+ model_id = self.lcm_model.itemText(index)
436
+ self.config.settings.lcm_diffusion_setting.lcm_model_id = model_id
437
+
438
+ def on_base_model_id_changed(self, index):
439
+ model_id = self.base_model_id.itemText(index)
440
+ self.config.settings.lcm_diffusion_setting.lcm_lora.base_model_id = model_id
441
+
442
+ def on_lcm_lora_id_changed(self, index):
443
+ model_id = self.lcm_lora_id.itemText(index)
444
+ self.config.settings.lcm_diffusion_setting.lcm_lora.lcm_lora_id = model_id
445
+
446
+ def on_openvino_lcm_model_id_changed(self, index):
447
+ model_id = self.openvino_lcm_model_id.itemText(index)
448
+ self.config.settings.lcm_diffusion_setting.openvino_lcm_model_id = model_id
449
+
450
+ def use_openvino_changed(self, state):
451
+ if state == 2:
452
+ self.lcm_model.setEnabled(False)
453
+ self.use_lcm_lora.setEnabled(False)
454
+ self.lcm_lora_id.setEnabled(False)
455
+ self.base_model_id.setEnabled(False)
456
+ self.neg_prompt.setEnabled(True)
457
+ self.openvino_lcm_model_id.setEnabled(True)
458
+ self.config.settings.lcm_diffusion_setting.use_openvino = True
459
+ else:
460
+ self.lcm_model.setEnabled(True)
461
+ self.use_lcm_lora.setEnabled(True)
462
+ self.lcm_lora_id.setEnabled(True)
463
+ self.base_model_id.setEnabled(True)
464
+ self.neg_prompt.setEnabled(False)
465
+ self.openvino_lcm_model_id.setEnabled(False)
466
+ self.config.settings.lcm_diffusion_setting.use_openvino = False
467
+
468
+ def use_tae_sd_changed(self, state):
469
+ if state == 2:
470
+ self.config.settings.lcm_diffusion_setting.use_tiny_auto_encoder = True
471
+ else:
472
+ self.config.settings.lcm_diffusion_setting.use_tiny_auto_encoder = False
473
+
474
+ def use_offline_model_changed(self, state):
475
+ if state == 2:
476
+ self.config.settings.lcm_diffusion_setting.use_offline_model = True
477
+ else:
478
+ self.config.settings.lcm_diffusion_setting.use_offline_model = False
479
+
480
+ def use_lcm_lora_changed(self, state):
481
+ if state == 2:
482
+ self.lcm_model.setEnabled(False)
483
+ self.lcm_lora_id.setEnabled(True)
484
+ self.base_model_id.setEnabled(True)
485
+ self.neg_prompt.setEnabled(True)
486
+ self.config.settings.lcm_diffusion_setting.use_lcm_lora = True
487
+ else:
488
+ self.lcm_model.setEnabled(True)
489
+ self.lcm_lora_id.setEnabled(False)
490
+ self.base_model_id.setEnabled(False)
491
+ self.neg_prompt.setEnabled(False)
492
+ self.config.settings.lcm_diffusion_setting.use_lcm_lora = False
493
+
494
+ def use_safety_checker_changed(self, state):
495
+ if state == 2:
496
+ self.config.settings.lcm_diffusion_setting.use_safety_checker = True
497
+ else:
498
+ self.config.settings.lcm_diffusion_setting.use_safety_checker = False
499
+
500
+ def update_steps_label(self, value):
501
+ self.inference_steps_value.setText(f"Number of inference steps: {value}")
502
+ self.config.settings.lcm_diffusion_setting.inference_steps = value
503
+
504
+ def update_num_images_label(self, value):
505
+ self.num_images_value.setText(f"Number of images: {value}")
506
+ self.config.settings.lcm_diffusion_setting.number_of_images = value
507
+
508
+ def update_guidance_label(self, value):
509
+ val = round(int(value) / 10, 1)
510
+ self.guidance_value.setText(f"Guidance scale: {val}")
511
+ self.config.settings.lcm_diffusion_setting.guidance_scale = val
512
+
513
+ def seed_changed(self, state):
514
+ if state == 2:
515
+ self.seed_value.setEnabled(True)
516
+ self.config.settings.lcm_diffusion_setting.use_seed = True
517
+ else:
518
+ self.seed_value.setEnabled(False)
519
+ self.config.settings.lcm_diffusion_setting.use_seed = False
520
+
521
+ def get_seed_value(self) -> int:
522
+ use_seed = self.config.settings.lcm_diffusion_setting.use_seed
523
+ seed_value = int(self.seed_value.text()) if use_seed else -1
524
+ return seed_value
525
+
526
+ def generate_image(self):
527
+ self.config.settings.lcm_diffusion_setting.seed = self.get_seed_value()
528
+ self.config.settings.lcm_diffusion_setting.prompt = self.prompt.toPlainText()
529
+ self.config.settings.lcm_diffusion_setting.negative_prompt = (
530
+ self.neg_prompt.toPlainText()
531
+ )
532
+ self.config.settings.lcm_diffusion_setting.lcm_lora.lcm_lora_id = (
533
+ self.lcm_lora_id.currentText()
534
+ )
535
+ self.config.settings.lcm_diffusion_setting.lcm_lora.base_model_id = (
536
+ self.base_model_id.currentText()
537
+ )
538
+
539
+ if self.config.settings.lcm_diffusion_setting.use_openvino:
540
+ model_id = self.openvino_lcm_model_id.currentText()
541
+ self.config.settings.lcm_diffusion_setting.openvino_lcm_model_id = model_id
542
+ else:
543
+ model_id = self.lcm_model.currentText()
544
+ self.config.settings.lcm_diffusion_setting.lcm_model_id = model_id
545
+
546
+ reshape_required = False
547
+ if self.config.settings.lcm_diffusion_setting.use_openvino:
548
+ # Detect dimension change
549
+ reshape_required = is_reshape_required(
550
+ self.previous_width,
551
+ self.config.settings.lcm_diffusion_setting.image_width,
552
+ self.previous_height,
553
+ self.config.settings.lcm_diffusion_setting.image_height,
554
+ self.previous_model,
555
+ model_id,
556
+ self.previous_num_of_images,
557
+ self.config.settings.lcm_diffusion_setting.number_of_images,
558
+ )
559
+ self.config.settings.lcm_diffusion_setting.diffusion_task = (
560
+ DiffusionTask.text_to_image.value
561
+ )
562
+ images = self.context.generate_text_to_image(
563
+ self.config.settings,
564
+ reshape_required,
565
+ DEVICE,
566
+ )
567
+ self.image_index = 0
568
+ self.gen_images = []
569
+ for img in images:
570
+ im = ImageQt(img).copy()
571
+ pixmap = QPixmap.fromImage(im)
572
+ self.gen_images.append(pixmap)
573
+
574
+ if len(self.gen_images) > 1:
575
+ self.next_img_btn.setEnabled(True)
576
+ self.previous_img_btn.setEnabled(False)
577
+ else:
578
+ self.next_img_btn.setEnabled(False)
579
+ self.previous_img_btn.setEnabled(False)
580
+
581
+ self.show_image(self.gen_images[0])
582
+
583
+ self.previous_width = self.config.settings.lcm_diffusion_setting.image_width
584
+ self.previous_height = self.config.settings.lcm_diffusion_setting.image_height
585
+ self.previous_model = model_id
586
+ self.previous_num_of_images = (
587
+ self.config.settings.lcm_diffusion_setting.number_of_images
588
+ )
589
+
590
+ def text_to_image(self):
591
+ self.img.setText("Please wait...")
592
+ worker = ImageGeneratorWorker(self.generate_image)
593
+ self.threadpool.start(worker)
594
+
595
+ def closeEvent(self, event):
596
+ self.config.settings.lcm_diffusion_setting.seed = self.get_seed_value()
597
+ print(self.config.settings.lcm_diffusion_setting)
598
+ print("Saving settings")
599
+ self.config.save()
600
+
601
+ def reset_all_settings(self):
602
+ self.use_local_model_folder.setChecked(False)
603
+ self.width.setCurrentText("512")
604
+ self.height.setCurrentText("512")
605
+ self.inference_steps.setValue(4)
606
+ self.guidance.setValue(10)
607
+ self.use_openvino_check.setChecked(False)
608
+ self.seed_check.setChecked(False)
609
+ self.safety_checker.setChecked(False)
610
+ self.results_path.setText(FastStableDiffusionPaths().get_results_path())
611
+ self.use_tae_sd.setChecked(False)
612
+ self.use_lcm_lora.setChecked(False)
frontend/gui/image_generator_worker.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PyQt5.QtCore import (
2
+ pyqtSlot,
3
+ QRunnable,
4
+ pyqtSignal,
5
+ pyqtSlot,
6
+ )
7
+ from PyQt5.QtCore import QObject
8
+ import traceback
9
+ import sys
10
+
11
+
12
+ class WorkerSignals(QObject):
13
+ finished = pyqtSignal()
14
+ error = pyqtSignal(tuple)
15
+ result = pyqtSignal(object)
16
+
17
+
18
+ class ImageGeneratorWorker(QRunnable):
19
+ def __init__(self, fn, *args, **kwargs):
20
+ super(ImageGeneratorWorker, self).__init__()
21
+ self.fn = fn
22
+ self.args = args
23
+ self.kwargs = kwargs
24
+ self.signals = WorkerSignals()
25
+
26
+ @pyqtSlot()
27
+ def run(self):
28
+ try:
29
+ result = self.fn(*self.args, **self.kwargs)
30
+ except:
31
+ traceback.print_exc()
32
+ exctype, value = sys.exc_info()[:2]
33
+ self.signals.error.emit((exctype, value, traceback.format_exc()))
34
+ else:
35
+ self.signals.result.emit(result)
36
+ finally:
37
+ self.signals.finished.emit()
frontend/gui/ui.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ from frontend.gui.app_window import MainWindow
3
+ from PyQt5.QtWidgets import QApplication
4
+ import sys
5
+ from app_settings import AppSettings
6
+
7
+
8
+ def start_gui(
9
+ argv: List[str],
10
+ app_settings: AppSettings,
11
+ ):
12
+ app = QApplication(sys.argv)
13
+ window = MainWindow(app_settings)
14
+ window.show()
15
+ app.exec()
frontend/utils.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import platform
2
+ from os import path
3
+ from typing import List
4
+
5
+ from backend.device import is_openvino_device
6
+ from constants import DEVICE
7
+ from paths import get_file_name
8
+
9
+
10
+ def is_reshape_required(
11
+ prev_width: int,
12
+ cur_width: int,
13
+ prev_height: int,
14
+ cur_height: int,
15
+ prev_model: int,
16
+ cur_model: int,
17
+ prev_num_of_images: int,
18
+ cur_num_of_images: int,
19
+ ) -> bool:
20
+ reshape_required = False
21
+ if (
22
+ prev_width != cur_width
23
+ or prev_height != cur_height
24
+ or prev_model != cur_model
25
+ or prev_num_of_images != cur_num_of_images
26
+ ):
27
+ print("Reshape and compile")
28
+ reshape_required = True
29
+
30
+ return reshape_required
31
+
32
+
33
+ def enable_openvino_controls() -> bool:
34
+ return is_openvino_device() and platform.system().lower() != "darwin" and platform.processor().lower() != 'arm'
35
+
36
+
37
+
38
+ def get_valid_model_id(
39
+ models: List,
40
+ model_id: str,
41
+ default_model: str = "",
42
+ ) -> str:
43
+ if len(models) == 0:
44
+ print("Error: model configuration file is empty,please add some models.")
45
+ return ""
46
+ if model_id == "":
47
+ if default_model:
48
+ return default_model
49
+ else:
50
+ return models[0]
51
+
52
+ if model_id in models:
53
+ return model_id
54
+ else:
55
+ print(
56
+ f"Error:{model_id} Model not found in configuration file,so using first model : {models[0]}"
57
+ )
58
+ return models[0]
59
+
60
+
61
+ def get_valid_lora_model(
62
+ models: List,
63
+ cur_model: str,
64
+ lora_models_dir: str,
65
+ ) -> str:
66
+ if cur_model == "" or cur_model is None:
67
+ print(
68
+ f"No lora models found, please add lora models to {lora_models_dir} directory"
69
+ )
70
+ return ""
71
+ else:
72
+ if path.exists(cur_model):
73
+ return get_file_name(cur_model)
74
+ else:
75
+ print(f"Lora model {cur_model} not found")
76
+ if len(models) > 0:
77
+ print(f"Fallback model - {models[0]}")
78
+ return get_file_name(models[0])
79
+ else:
80
+ print(
81
+ f"No lora models found, please add lora models to {lora_models_dir} directory"
82
+ )
83
+ return ""
frontend/webui/controlnet_ui.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ from backend.lora import get_lora_models
4
+ from state import get_settings
5
+ from backend.models.lcmdiffusion_setting import ControlNetSetting
6
+ from backend.annotators.image_control_factory import ImageControlFactory
7
+
8
+ _controlnet_models_map = None
9
+ _controlnet_enabled = False
10
+ _adapter_path = None
11
+
12
+ app_settings = get_settings()
13
+
14
+
15
+ def on_user_input(
16
+ enable: bool,
17
+ adapter_name: str,
18
+ conditioning_scale: float,
19
+ control_image: Image,
20
+ preprocessor: str,
21
+ ):
22
+ if not isinstance(adapter_name, str):
23
+ gr.Warning("Please select a valid ControlNet model")
24
+ return gr.Checkbox(value=False)
25
+
26
+ settings = app_settings.settings.lcm_diffusion_setting
27
+ if settings.controlnet is None:
28
+ settings.controlnet = ControlNetSetting()
29
+
30
+ if enable and (adapter_name is None or adapter_name == ""):
31
+ gr.Warning("Please select a valid ControlNet adapter")
32
+ return gr.Checkbox(value=False)
33
+ elif enable and not control_image:
34
+ gr.Warning("Please provide a ControlNet control image")
35
+ return gr.Checkbox(value=False)
36
+
37
+ if control_image is None:
38
+ return gr.Checkbox(value=enable)
39
+
40
+ if preprocessor == "None":
41
+ processed_control_image = control_image
42
+ else:
43
+ image_control_factory = ImageControlFactory()
44
+ control = image_control_factory.create_control(preprocessor)
45
+ processed_control_image = control.get_control_image(control_image)
46
+
47
+ if not enable:
48
+ settings.controlnet.enabled = False
49
+ else:
50
+ settings.controlnet.enabled = True
51
+ settings.controlnet.adapter_path = _controlnet_models_map[adapter_name]
52
+ settings.controlnet.conditioning_scale = float(conditioning_scale)
53
+ settings.controlnet._control_image = processed_control_image
54
+
55
+ # This code can be improved; currently, if the user clicks the
56
+ # "Enable ControlNet" checkbox or changes the currently selected
57
+ # ControlNet model, it will trigger a pipeline rebuild even if, in
58
+ # the end, the user leaves the same ControlNet settings
59
+ global _controlnet_enabled
60
+ global _adapter_path
61
+ if settings.controlnet.enabled != _controlnet_enabled or (
62
+ settings.controlnet.enabled
63
+ and settings.controlnet.adapter_path != _adapter_path
64
+ ):
65
+ settings.rebuild_pipeline = True
66
+ _controlnet_enabled = settings.controlnet.enabled
67
+ _adapter_path = settings.controlnet.adapter_path
68
+ return gr.Checkbox(value=enable)
69
+
70
+
71
+ def on_change_conditioning_scale(cond_scale):
72
+ print(cond_scale)
73
+ app_settings.settings.lcm_diffusion_setting.controlnet.conditioning_scale = (
74
+ cond_scale
75
+ )
76
+
77
+
78
+ def get_controlnet_ui() -> None:
79
+ with gr.Blocks() as ui:
80
+ gr.HTML(
81
+ 'Download ControlNet v1.1 model from <a href="https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/tree/main">ControlNet v1.1 </a> (723 MB files) and place it in <b>controlnet_models</b> folder,restart the app'
82
+ )
83
+ with gr.Row():
84
+ with gr.Column():
85
+ with gr.Row():
86
+ global _controlnet_models_map
87
+ _controlnet_models_map = get_lora_models(
88
+ app_settings.settings.lcm_diffusion_setting.dirs["controlnet"]
89
+ )
90
+ controlnet_models = list(_controlnet_models_map.keys())
91
+ default_model = (
92
+ controlnet_models[0] if len(controlnet_models) else None
93
+ )
94
+
95
+ enabled_checkbox = gr.Checkbox(
96
+ label="Enable ControlNet",
97
+ info="Enable ControlNet",
98
+ show_label=True,
99
+ )
100
+ model_dropdown = gr.Dropdown(
101
+ _controlnet_models_map.keys(),
102
+ label="ControlNet model",
103
+ info="ControlNet model to load (.safetensors format)",
104
+ value=default_model,
105
+ interactive=True,
106
+ )
107
+ conditioning_scale_slider = gr.Slider(
108
+ 0.0,
109
+ 1.0,
110
+ value=0.5,
111
+ step=0.05,
112
+ label="ControlNet conditioning scale",
113
+ interactive=True,
114
+ )
115
+ control_image = gr.Image(
116
+ label="Control image",
117
+ type="pil",
118
+ )
119
+ preprocessor_radio = gr.Radio(
120
+ [
121
+ "Canny",
122
+ "Depth",
123
+ "LineArt",
124
+ "MLSD",
125
+ "NormalBAE",
126
+ "Pose",
127
+ "SoftEdge",
128
+ "Shuffle",
129
+ "None",
130
+ ],
131
+ label="Preprocessor",
132
+ info="Select the preprocessor for the control image",
133
+ value="Canny",
134
+ interactive=True,
135
+ )
136
+
137
+ enabled_checkbox.input(
138
+ fn=on_user_input,
139
+ inputs=[
140
+ enabled_checkbox,
141
+ model_dropdown,
142
+ conditioning_scale_slider,
143
+ control_image,
144
+ preprocessor_radio,
145
+ ],
146
+ outputs=[enabled_checkbox],
147
+ )
148
+ model_dropdown.input(
149
+ fn=on_user_input,
150
+ inputs=[
151
+ enabled_checkbox,
152
+ model_dropdown,
153
+ conditioning_scale_slider,
154
+ control_image,
155
+ preprocessor_radio,
156
+ ],
157
+ outputs=[enabled_checkbox],
158
+ )
159
+ conditioning_scale_slider.input(
160
+ fn=on_user_input,
161
+ inputs=[
162
+ enabled_checkbox,
163
+ model_dropdown,
164
+ conditioning_scale_slider,
165
+ control_image,
166
+ preprocessor_radio,
167
+ ],
168
+ outputs=[enabled_checkbox],
169
+ )
170
+ control_image.change(
171
+ fn=on_user_input,
172
+ inputs=[
173
+ enabled_checkbox,
174
+ model_dropdown,
175
+ conditioning_scale_slider,
176
+ control_image,
177
+ preprocessor_radio,
178
+ ],
179
+ outputs=[enabled_checkbox],
180
+ )
181
+ preprocessor_radio.change(
182
+ fn=on_user_input,
183
+ inputs=[
184
+ enabled_checkbox,
185
+ model_dropdown,
186
+ conditioning_scale_slider,
187
+ control_image,
188
+ preprocessor_radio,
189
+ ],
190
+ outputs=[enabled_checkbox],
191
+ )
192
+ conditioning_scale_slider.change(
193
+ on_change_conditioning_scale, conditioning_scale_slider
194
+ )
frontend/webui/css/style.css ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ footer {
2
+ visibility: hidden
3
+ }
4
+
5
+ #generate_button {
6
+ color: white;
7
+ border-color: #007bff;
8
+ background: #2563eb;
9
+
10
+ }
11
+
12
+ #save_button {
13
+ color: white;
14
+ border-color: #028b40;
15
+ background: #01b97c;
16
+ width: 200px;
17
+ }
18
+
19
+ #settings_header {
20
+ background: rgb(245, 105, 105);
21
+
22
+ }
frontend/webui/generation_settings_ui.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from state import get_settings
3
+ from backend.models.gen_images import ImageFormat
4
+
5
+ app_settings = get_settings()
6
+
7
+
8
+ def on_change_inference_steps(steps):
9
+ app_settings.settings.lcm_diffusion_setting.inference_steps = steps
10
+
11
+
12
+ def on_change_image_width(img_width):
13
+ app_settings.settings.lcm_diffusion_setting.image_width = img_width
14
+
15
+
16
+ def on_change_image_height(img_height):
17
+ app_settings.settings.lcm_diffusion_setting.image_height = img_height
18
+
19
+
20
+ def on_change_num_images(num_images):
21
+ app_settings.settings.lcm_diffusion_setting.number_of_images = num_images
22
+
23
+
24
+ def on_change_guidance_scale(guidance_scale):
25
+ app_settings.settings.lcm_diffusion_setting.guidance_scale = guidance_scale
26
+
27
+
28
+ def on_change_seed_value(seed):
29
+ app_settings.settings.lcm_diffusion_setting.seed = seed
30
+
31
+
32
+ def on_change_seed_checkbox(seed_checkbox):
33
+ app_settings.settings.lcm_diffusion_setting.use_seed = seed_checkbox
34
+
35
+
36
+ def on_change_safety_checker_checkbox(safety_checker_checkbox):
37
+ app_settings.settings.lcm_diffusion_setting.use_safety_checker = (
38
+ safety_checker_checkbox
39
+ )
40
+
41
+
42
+ def on_change_tiny_auto_encoder_checkbox(tiny_auto_encoder_checkbox):
43
+ app_settings.settings.lcm_diffusion_setting.use_tiny_auto_encoder = (
44
+ tiny_auto_encoder_checkbox
45
+ )
46
+
47
+
48
+ def on_offline_checkbox(offline_checkbox):
49
+ app_settings.settings.lcm_diffusion_setting.use_offline_model = offline_checkbox
50
+
51
+
52
+ def on_change_image_format(image_format):
53
+ if image_format == "PNG":
54
+ app_settings.settings.generated_images.format = ImageFormat.PNG.value.upper()
55
+ else:
56
+ app_settings.settings.generated_images.format = ImageFormat.JPEG.value.upper()
57
+
58
+ app_settings.save()
59
+
60
+
61
+ def get_generation_settings_ui() -> None:
62
+ with gr.Blocks():
63
+ with gr.Row():
64
+ with gr.Column():
65
+ num_inference_steps = gr.Slider(
66
+ 1,
67
+ 25,
68
+ value=app_settings.settings.lcm_diffusion_setting.inference_steps,
69
+ step=1,
70
+ label="Inference Steps",
71
+ interactive=True,
72
+ )
73
+
74
+ image_height = gr.Slider(
75
+ 256,
76
+ 1024,
77
+ value=app_settings.settings.lcm_diffusion_setting.image_height,
78
+ step=256,
79
+ label="Image Height",
80
+ interactive=True,
81
+ )
82
+ image_width = gr.Slider(
83
+ 256,
84
+ 1024,
85
+ value=app_settings.settings.lcm_diffusion_setting.image_width,
86
+ step=256,
87
+ label="Image Width",
88
+ interactive=True,
89
+ )
90
+ num_images = gr.Slider(
91
+ 1,
92
+ 50,
93
+ value=app_settings.settings.lcm_diffusion_setting.number_of_images,
94
+ step=1,
95
+ label="Number of images to generate",
96
+ interactive=True,
97
+ )
98
+ guidance_scale = gr.Slider(
99
+ 1.0,
100
+ 10.0,
101
+ value=app_settings.settings.lcm_diffusion_setting.guidance_scale,
102
+ step=0.1,
103
+ label="Guidance Scale",
104
+ interactive=True,
105
+ )
106
+
107
+ seed = gr.Slider(
108
+ value=app_settings.settings.lcm_diffusion_setting.seed,
109
+ minimum=0,
110
+ maximum=999999999,
111
+ label="Seed",
112
+ step=1,
113
+ interactive=True,
114
+ )
115
+ seed_checkbox = gr.Checkbox(
116
+ label="Use seed",
117
+ value=app_settings.settings.lcm_diffusion_setting.use_seed,
118
+ interactive=True,
119
+ )
120
+
121
+ safety_checker_checkbox = gr.Checkbox(
122
+ label="Use Safety Checker",
123
+ value=app_settings.settings.lcm_diffusion_setting.use_safety_checker,
124
+ interactive=True,
125
+ )
126
+ tiny_auto_encoder_checkbox = gr.Checkbox(
127
+ label="Use tiny auto encoder for SD",
128
+ value=app_settings.settings.lcm_diffusion_setting.use_tiny_auto_encoder,
129
+ interactive=True,
130
+ )
131
+ offline_checkbox = gr.Checkbox(
132
+ label="Use locally cached model or downloaded model folder(offline)",
133
+ value=app_settings.settings.lcm_diffusion_setting.use_offline_model,
134
+ interactive=True,
135
+ )
136
+ img_format = gr.Radio(
137
+ label="Output image format",
138
+ choices=["PNG", "JPEG"],
139
+ value=app_settings.settings.generated_images.format,
140
+ interactive=True,
141
+ )
142
+
143
+ num_inference_steps.change(on_change_inference_steps, num_inference_steps)
144
+ image_height.change(on_change_image_height, image_height)
145
+ image_width.change(on_change_image_width, image_width)
146
+ num_images.change(on_change_num_images, num_images)
147
+ guidance_scale.change(on_change_guidance_scale, guidance_scale)
148
+ seed.change(on_change_seed_value, seed)
149
+ seed_checkbox.change(on_change_seed_checkbox, seed_checkbox)
150
+ safety_checker_checkbox.change(
151
+ on_change_safety_checker_checkbox, safety_checker_checkbox
152
+ )
153
+ tiny_auto_encoder_checkbox.change(
154
+ on_change_tiny_auto_encoder_checkbox, tiny_auto_encoder_checkbox
155
+ )
156
+ offline_checkbox.change(on_offline_checkbox, offline_checkbox)
157
+ img_format.change(on_change_image_format, img_format)
frontend/webui/image_to_image_ui.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+ import gradio as gr
3
+ from backend.models.lcmdiffusion_setting import DiffusionTask
4
+ from models.interface_types import InterfaceType
5
+ from frontend.utils import is_reshape_required
6
+ from constants import DEVICE
7
+ from state import get_settings, get_context
8
+ from concurrent.futures import ThreadPoolExecutor
9
+
10
+
11
+ app_settings = get_settings()
12
+
13
+ previous_width = 0
14
+ previous_height = 0
15
+ previous_model_id = ""
16
+ previous_num_of_images = 0
17
+
18
+
19
+ def generate_image_to_image(
20
+ prompt,
21
+ negative_prompt,
22
+ init_image,
23
+ strength,
24
+ ) -> Any:
25
+ context = get_context(InterfaceType.WEBUI)
26
+ global previous_height, previous_width, previous_model_id, previous_num_of_images, app_settings
27
+
28
+ app_settings.settings.lcm_diffusion_setting.prompt = prompt
29
+ app_settings.settings.lcm_diffusion_setting.negative_prompt = negative_prompt
30
+ app_settings.settings.lcm_diffusion_setting.init_image = init_image
31
+ app_settings.settings.lcm_diffusion_setting.strength = strength
32
+
33
+ app_settings.settings.lcm_diffusion_setting.diffusion_task = (
34
+ DiffusionTask.image_to_image.value
35
+ )
36
+ model_id = app_settings.settings.lcm_diffusion_setting.openvino_lcm_model_id
37
+ reshape = False
38
+ image_width = app_settings.settings.lcm_diffusion_setting.image_width
39
+ image_height = app_settings.settings.lcm_diffusion_setting.image_height
40
+ num_images = app_settings.settings.lcm_diffusion_setting.number_of_images
41
+ if app_settings.settings.lcm_diffusion_setting.use_openvino:
42
+ reshape = is_reshape_required(
43
+ previous_width,
44
+ image_width,
45
+ previous_height,
46
+ image_height,
47
+ previous_model_id,
48
+ model_id,
49
+ previous_num_of_images,
50
+ num_images,
51
+ )
52
+
53
+ with ThreadPoolExecutor(max_workers=1) as executor:
54
+ future = executor.submit(
55
+ context.generate_text_to_image,
56
+ app_settings.settings,
57
+ reshape,
58
+ DEVICE,
59
+ )
60
+ images = future.result()
61
+
62
+ previous_width = image_width
63
+ previous_height = image_height
64
+ previous_model_id = model_id
65
+ previous_num_of_images = num_images
66
+ return images
67
+
68
+
69
+ def get_image_to_image_ui() -> None:
70
+ with gr.Blocks():
71
+ with gr.Row():
72
+ with gr.Column():
73
+ input_image = gr.Image(label="Init image", type="pil")
74
+ with gr.Row():
75
+ prompt = gr.Textbox(
76
+ show_label=False,
77
+ lines=3,
78
+ placeholder="A fantasy landscape",
79
+ container=False,
80
+ )
81
+
82
+ generate_btn = gr.Button(
83
+ "Generate",
84
+ elem_id="generate_button",
85
+ scale=0,
86
+ )
87
+ negative_prompt = gr.Textbox(
88
+ label="Negative prompt (Works in LCM-LoRA mode, set guidance > 1.0):",
89
+ lines=1,
90
+ placeholder="",
91
+ )
92
+ strength = gr.Slider(
93
+ 0.1,
94
+ 1,
95
+ value=app_settings.settings.lcm_diffusion_setting.strength,
96
+ step=0.01,
97
+ label="Strength",
98
+ )
99
+
100
+ input_params = [
101
+ prompt,
102
+ negative_prompt,
103
+ input_image,
104
+ strength,
105
+ ]
106
+
107
+ with gr.Column():
108
+ output = gr.Gallery(
109
+ label="Generated images",
110
+ show_label=True,
111
+ elem_id="gallery",
112
+ columns=2,
113
+ height=512,
114
+ )
115
+
116
+ generate_btn.click(
117
+ fn=generate_image_to_image,
118
+ inputs=input_params,
119
+ outputs=output,
120
+ )
frontend/webui/image_variations_ui.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+ import gradio as gr
3
+ from backend.models.lcmdiffusion_setting import DiffusionTask
4
+ from context import Context
5
+ from models.interface_types import InterfaceType
6
+ from frontend.utils import is_reshape_required
7
+ from constants import DEVICE
8
+ from state import get_settings, get_context
9
+ from concurrent.futures import ThreadPoolExecutor
10
+
11
+ app_settings = get_settings()
12
+
13
+
14
+ previous_width = 0
15
+ previous_height = 0
16
+ previous_model_id = ""
17
+ previous_num_of_images = 0
18
+
19
+
20
+ def generate_image_variations(
21
+ init_image,
22
+ variation_strength,
23
+ ) -> Any:
24
+ context = get_context(InterfaceType.WEBUI)
25
+ global previous_height, previous_width, previous_model_id, previous_num_of_images, app_settings
26
+
27
+ app_settings.settings.lcm_diffusion_setting.init_image = init_image
28
+ app_settings.settings.lcm_diffusion_setting.strength = variation_strength
29
+ app_settings.settings.lcm_diffusion_setting.prompt = ""
30
+ app_settings.settings.lcm_diffusion_setting.negative_prompt = ""
31
+
32
+ app_settings.settings.lcm_diffusion_setting.diffusion_task = (
33
+ DiffusionTask.image_to_image.value
34
+ )
35
+ model_id = app_settings.settings.lcm_diffusion_setting.openvino_lcm_model_id
36
+ reshape = False
37
+ image_width = app_settings.settings.lcm_diffusion_setting.image_width
38
+ image_height = app_settings.settings.lcm_diffusion_setting.image_height
39
+ num_images = app_settings.settings.lcm_diffusion_setting.number_of_images
40
+ if app_settings.settings.lcm_diffusion_setting.use_openvino:
41
+ reshape = is_reshape_required(
42
+ previous_width,
43
+ image_width,
44
+ previous_height,
45
+ image_height,
46
+ previous_model_id,
47
+ model_id,
48
+ previous_num_of_images,
49
+ num_images,
50
+ )
51
+
52
+ with ThreadPoolExecutor(max_workers=1) as executor:
53
+ future = executor.submit(
54
+ context.generate_text_to_image,
55
+ app_settings.settings,
56
+ reshape,
57
+ DEVICE,
58
+ )
59
+ images = future.result()
60
+
61
+ previous_width = image_width
62
+ previous_height = image_height
63
+ previous_model_id = model_id
64
+ previous_num_of_images = num_images
65
+ return images
66
+
67
+
68
+ def get_image_variations_ui() -> None:
69
+ with gr.Blocks():
70
+ with gr.Row():
71
+ with gr.Column():
72
+ input_image = gr.Image(label="Init image", type="pil")
73
+ with gr.Row():
74
+ generate_btn = gr.Button(
75
+ "Generate",
76
+ elem_id="generate_button",
77
+ scale=0,
78
+ )
79
+
80
+ variation_strength = gr.Slider(
81
+ 0.1,
82
+ 1,
83
+ value=0.4,
84
+ step=0.01,
85
+ label="Variations Strength",
86
+ )
87
+
88
+ input_params = [
89
+ input_image,
90
+ variation_strength,
91
+ ]
92
+
93
+ with gr.Column():
94
+ output = gr.Gallery(
95
+ label="Generated images",
96
+ show_label=True,
97
+ elem_id="gallery",
98
+ columns=2,
99
+ height=512,
100
+ )
101
+
102
+ generate_btn.click(
103
+ fn=generate_image_variations,
104
+ inputs=input_params,
105
+ outputs=output,
106
+ )
frontend/webui/lora_models_ui.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from os import path
3
+ from backend.lora import (
4
+ get_lora_models,
5
+ get_active_lora_weights,
6
+ update_lora_weights,
7
+ load_lora_weight,
8
+ )
9
+ from state import get_settings, get_context
10
+ from frontend.utils import get_valid_lora_model
11
+ from models.interface_types import InterfaceType
12
+ from backend.models.lcmdiffusion_setting import LCMDiffusionSetting
13
+
14
+
15
+ _MAX_LORA_WEIGHTS = 5
16
+
17
+ _custom_lora_sliders = []
18
+ _custom_lora_names = []
19
+ _custom_lora_columns = []
20
+
21
+ app_settings = get_settings()
22
+
23
+
24
+ def on_click_update_weight(*lora_weights):
25
+ update_weights = []
26
+ active_weights = get_active_lora_weights()
27
+ if not len(active_weights):
28
+ gr.Warning("No active LoRAs, first you need to load LoRA model")
29
+ return
30
+ for idx, lora in enumerate(active_weights):
31
+ update_weights.append(
32
+ (
33
+ lora[0],
34
+ lora_weights[idx],
35
+ )
36
+ )
37
+ if len(update_weights) > 0:
38
+ update_lora_weights(
39
+ get_context(InterfaceType.WEBUI).lcm_text_to_image.pipeline,
40
+ app_settings.settings.lcm_diffusion_setting,
41
+ update_weights,
42
+ )
43
+
44
+
45
+ def on_click_load_lora(lora_name, lora_weight):
46
+ if app_settings.settings.lcm_diffusion_setting.use_openvino:
47
+ gr.Warning("Currently LoRA is not supported in OpenVINO.")
48
+ return
49
+ lora_models_map = get_lora_models(
50
+ app_settings.settings.lcm_diffusion_setting.lora.models_dir
51
+ )
52
+
53
+ # Load a new LoRA
54
+ settings = app_settings.settings.lcm_diffusion_setting
55
+ settings.lora.fuse = False
56
+ settings.lora.enabled = False
57
+ settings.lora.path = lora_models_map[lora_name]
58
+ settings.lora.weight = lora_weight
59
+ if not path.exists(settings.lora.path):
60
+ gr.Warning("Invalid LoRA model path!")
61
+ return
62
+ pipeline = get_context(InterfaceType.WEBUI).lcm_text_to_image.pipeline
63
+ if not pipeline:
64
+ gr.Warning("Pipeline not initialized. Please generate an image first.")
65
+ return
66
+ settings.lora.enabled = True
67
+ load_lora_weight(
68
+ get_context(InterfaceType.WEBUI).lcm_text_to_image.pipeline,
69
+ settings,
70
+ )
71
+
72
+ # Update Gradio LoRA UI
73
+ global _MAX_LORA_WEIGHTS
74
+ values = []
75
+ labels = []
76
+ rows = []
77
+ active_weights = get_active_lora_weights()
78
+ for idx, lora in enumerate(active_weights):
79
+ labels.append(f"{lora[0]}: ")
80
+ values.append(lora[1])
81
+ rows.append(gr.Row.update(visible=True))
82
+ for i in range(len(active_weights), _MAX_LORA_WEIGHTS):
83
+ labels.append(f"Update weight")
84
+ values.append(0.0)
85
+ rows.append(gr.Row.update(visible=False))
86
+ return labels + values + rows
87
+
88
+
89
+ def get_lora_models_ui() -> None:
90
+ with gr.Blocks() as ui:
91
+ gr.HTML(
92
+ "Download and place your LoRA model weights in <b>lora_models</b> folders and restart App"
93
+ )
94
+ with gr.Row():
95
+
96
+ with gr.Column():
97
+ with gr.Row():
98
+ lora_models_map = get_lora_models(
99
+ app_settings.settings.lcm_diffusion_setting.lora.models_dir
100
+ )
101
+ valid_model = get_valid_lora_model(
102
+ list(lora_models_map.values()),
103
+ app_settings.settings.lcm_diffusion_setting.lora.path,
104
+ app_settings.settings.lcm_diffusion_setting.lora.models_dir,
105
+ )
106
+ if valid_model != "":
107
+ valid_model_path = lora_models_map[valid_model]
108
+ app_settings.settings.lcm_diffusion_setting.lora.path = (
109
+ valid_model_path
110
+ )
111
+ else:
112
+ app_settings.settings.lcm_diffusion_setting.lora.path = ""
113
+
114
+ lora_model = gr.Dropdown(
115
+ lora_models_map.keys(),
116
+ label="LoRA model",
117
+ info="LoRA model weight to load (You can use Lora models from Civitai or Hugging Face .safetensors format)",
118
+ value=valid_model,
119
+ interactive=True,
120
+ )
121
+
122
+ lora_weight = gr.Slider(
123
+ 0.0,
124
+ 1.0,
125
+ value=app_settings.settings.lcm_diffusion_setting.lora.weight,
126
+ step=0.05,
127
+ label="Initial Lora weight",
128
+ interactive=True,
129
+ )
130
+ load_lora_btn = gr.Button(
131
+ "Load selected LoRA",
132
+ elem_id="load_lora_button",
133
+ scale=0,
134
+ )
135
+
136
+ with gr.Row():
137
+ gr.Markdown(
138
+ "## Loaded LoRA models",
139
+ show_label=False,
140
+ )
141
+ update_lora_weights_btn = gr.Button(
142
+ "Update LoRA weights",
143
+ elem_id="load_lora_button",
144
+ scale=0,
145
+ )
146
+
147
+ global _MAX_LORA_WEIGHTS
148
+ global _custom_lora_sliders
149
+ global _custom_lora_names
150
+ global _custom_lora_columns
151
+ for i in range(0, _MAX_LORA_WEIGHTS):
152
+ new_row = gr.Column(visible=False)
153
+ _custom_lora_columns.append(new_row)
154
+ with new_row:
155
+ lora_name = gr.Markdown(
156
+ "Lora Name",
157
+ show_label=True,
158
+ )
159
+ lora_slider = gr.Slider(
160
+ 0.0,
161
+ 1.0,
162
+ step=0.05,
163
+ label="LoRA weight",
164
+ interactive=True,
165
+ visible=True,
166
+ )
167
+
168
+ _custom_lora_names.append(lora_name)
169
+ _custom_lora_sliders.append(lora_slider)
170
+
171
+ load_lora_btn.click(
172
+ fn=on_click_load_lora,
173
+ inputs=[lora_model, lora_weight],
174
+ outputs=[
175
+ *_custom_lora_names,
176
+ *_custom_lora_sliders,
177
+ *_custom_lora_columns,
178
+ ],
179
+ )
180
+
181
+ update_lora_weights_btn.click(
182
+ fn=on_click_update_weight,
183
+ inputs=[*_custom_lora_sliders],
184
+ outputs=None,
185
+ )
frontend/webui/models_ui.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from app_settings import AppSettings
2
+ from typing import Any
3
+ import gradio as gr
4
+ from constants import LCM_DEFAULT_MODEL, LCM_DEFAULT_MODEL_OPENVINO
5
+ from state import get_settings
6
+ from frontend.utils import get_valid_model_id
7
+
8
+ app_settings = get_settings()
9
+ app_settings.settings.lcm_diffusion_setting.openvino_lcm_model_id = get_valid_model_id(
10
+ app_settings.openvino_lcm_models,
11
+ app_settings.settings.lcm_diffusion_setting.openvino_lcm_model_id,
12
+ )
13
+
14
+
15
+ def change_lcm_model_id(model_id):
16
+ app_settings.settings.lcm_diffusion_setting.lcm_model_id = model_id
17
+
18
+
19
+ def change_lcm_lora_model_id(model_id):
20
+ app_settings.settings.lcm_diffusion_setting.lcm_lora.lcm_lora_id = model_id
21
+
22
+
23
+ def change_lcm_lora_base_model_id(model_id):
24
+ app_settings.settings.lcm_diffusion_setting.lcm_lora.base_model_id = model_id
25
+
26
+
27
+ def change_openvino_lcm_model_id(model_id):
28
+ app_settings.settings.lcm_diffusion_setting.openvino_lcm_model_id = model_id
29
+
30
+
31
+ def get_models_ui() -> None:
32
+ with gr.Blocks():
33
+ with gr.Row():
34
+ lcm_model_id = gr.Dropdown(
35
+ app_settings.lcm_models,
36
+ label="LCM model",
37
+ info="Diffusers LCM model ID",
38
+ value=get_valid_model_id(
39
+ app_settings.lcm_models,
40
+ app_settings.settings.lcm_diffusion_setting.lcm_model_id,
41
+ LCM_DEFAULT_MODEL,
42
+ ),
43
+ interactive=True,
44
+ )
45
+ with gr.Row():
46
+ lcm_lora_model_id = gr.Dropdown(
47
+ app_settings.lcm_lora_models,
48
+ label="LCM LoRA model",
49
+ info="Diffusers LCM LoRA model ID",
50
+ value=get_valid_model_id(
51
+ app_settings.lcm_lora_models,
52
+ app_settings.settings.lcm_diffusion_setting.lcm_lora.lcm_lora_id,
53
+ ),
54
+ interactive=True,
55
+ )
56
+ lcm_lora_base_model_id = gr.Dropdown(
57
+ app_settings.stable_diffsuion_models,
58
+ label="LCM LoRA base model",
59
+ info="Diffusers LCM LoRA base model ID",
60
+ value=get_valid_model_id(
61
+ app_settings.stable_diffsuion_models,
62
+ app_settings.settings.lcm_diffusion_setting.lcm_lora.base_model_id,
63
+ ),
64
+ interactive=True,
65
+ )
66
+ with gr.Row():
67
+ lcm_openvino_model_id = gr.Dropdown(
68
+ app_settings.openvino_lcm_models,
69
+ label="LCM OpenVINO model",
70
+ info="OpenVINO LCM-LoRA fused model ID",
71
+ value=get_valid_model_id(
72
+ app_settings.openvino_lcm_models,
73
+ app_settings.settings.lcm_diffusion_setting.openvino_lcm_model_id,
74
+ ),
75
+ interactive=True,
76
+ )
77
+
78
+ lcm_model_id.change(change_lcm_model_id, lcm_model_id)
79
+ lcm_lora_model_id.change(change_lcm_lora_model_id, lcm_lora_model_id)
80
+ lcm_lora_base_model_id.change(
81
+ change_lcm_lora_base_model_id, lcm_lora_base_model_id
82
+ )
83
+ lcm_openvino_model_id.change(
84
+ change_openvino_lcm_model_id, lcm_openvino_model_id
85
+ )