prithivMLmods commited on
Commit
ab6b5e5
1 Parent(s): 6204f5f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +291 -24
app.py CHANGED
@@ -8,10 +8,10 @@ from PIL import Image
8
  import spaces
9
  import torch
10
  from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
11
- from diffusers import AuraFlowPipeline
12
-
13
 
14
  DESCRIPTIONx = """
 
 
15
  """
16
 
17
  css = '''
@@ -22,10 +22,17 @@ footer {
22
  }
23
  '''
24
 
 
 
 
 
 
 
 
 
25
  MODEL_OPTIONS = {
26
  "Lightning": "SG161222/RealVisXL_V4.0_Lightning",
27
  "Realvision": "SG161222/RealVisXL_V4.0",
28
- "AuraFlow": "fal/AuraFlow",
29
  }
30
 
31
  MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
@@ -36,29 +43,23 @@ BATCH_SIZE = int(os.getenv("BATCH_SIZE", "1"))
36
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
37
 
38
  def load_and_prepare_model(model_id):
39
- if model_id == "fal/AuraFlow":
40
- pipe = AuraFlowPipeline.from_pretrained(
41
- model_id,
42
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
43
- ).to(device)
44
- else:
45
- pipe = StableDiffusionXLPipeline.from_pretrained(
46
- model_id,
47
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
48
- use_safetensors=True,
49
- add_watermarker=False,
50
- ).to(device)
51
- pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
52
-
53
  if USE_TORCH_COMPILE:
54
  pipe.compile()
55
-
56
  if ENABLE_CPU_OFFLOAD:
57
  pipe.enable_model_cpu_offload()
58
-
59
  return pipe
60
 
61
- # Preload and compile all models
62
  models = {key: load_and_prepare_model(value) for key, value in MODEL_OPTIONS.items()}
63
 
64
  MAX_SEED = np.iinfo(np.int32).max
@@ -91,7 +92,7 @@ def generate(
91
  ):
92
  global models
93
  pipe = models[model_choice]
94
-
95
  seed = int(randomize_seed_fn(seed, randomize_seed))
96
  generator = torch.Generator(device=device).manual_seed(seed)
97
 
@@ -138,7 +139,7 @@ def load_predefined_images():
138
  return predefined_images
139
 
140
  with gr.Blocks(css=css) as demo:
141
- gr.Markdown(DESCRIPTIONx)
142
  with gr.Row():
143
  prompt = gr.Text(
144
  label="Prompt",
@@ -149,7 +150,7 @@ with gr.Blocks(css=css) as demo:
149
  container=False,
150
  )
151
  run_button = gr.Button("Run⚡", scale=0)
152
- result = gr.Gallery(label="Result", columns=1, show_label=False)
153
 
154
  with gr.Row():
155
  model_choice = gr.Dropdown(
@@ -216,13 +217,19 @@ with gr.Blocks(css=css) as demo:
216
  value=20,
217
  )
218
 
 
 
 
 
 
 
219
  use_negative_prompt.change(
220
  fn=lambda x: gr.update(visible=x),
221
  inputs=use_negative_prompt,
222
  outputs=negative_prompt,
223
  api_name=False,
224
  )
225
-
226
  gr.on(
227
  triggers=[
228
  prompt.submit,
@@ -246,6 +253,266 @@ with gr.Blocks(css=css) as demo:
246
  outputs=[result, seed],
247
  api_name="run",
248
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
  if __name__ == "__main__":
251
  demo.queue(max_size=40).launch(show_api=False)
 
8
  import spaces
9
  import torch
10
  from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
 
 
11
 
12
  DESCRIPTIONx = """
13
+
14
+
15
  """
16
 
17
  css = '''
 
22
  }
23
  '''
24
 
25
+ #examples = [
26
+ # "3d image, cute girl, in the style of Pixar --ar 1:2 --stylize 750, 4K resolution highlights, Sharp focus, octane render, ray tracing, Ultra-High-Definition, 8k, UHD, HDR, (Masterpiece:1.5), (best quality:1.5)",
27
+ # "Chocolate dripping from a donut against a yellow background, in the style of brocore, hyper-realistic oil --ar 2:3 --q 2 --s 750 --v 5 --ar 2:3 --q 2 --s 750 --v 5",
28
+ # "Illustration of A starry night camp in the mountains. Low-angle view, Minimal background, Geometric shapes theme, Pottery, Split-complementary colors, Bicolored light, UHD",
29
+ # "Man in brown leather jacket posing for camera, in the style of sleek and stylized, clockpunk, subtle shades, exacting precision, ferrania p30 --ar 67:101 --v 5",
30
+ # "Commercial photography, giant burger, white lighting, studio light, 8k octane rendering, high resolution photography, insanely detailed, fine details, on white isolated plain, 8k, commercial photography, stock photo, professional color grading, --v 4 --ar 9:16 "
31
+ #]
32
+
33
  MODEL_OPTIONS = {
34
  "Lightning": "SG161222/RealVisXL_V4.0_Lightning",
35
  "Realvision": "SG161222/RealVisXL_V4.0",
 
36
  }
37
 
38
  MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
 
43
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
44
 
45
  def load_and_prepare_model(model_id):
46
+ pipe = StableDiffusionXLPipeline.from_pretrained(
47
+ model_id,
48
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
49
+ use_safetensors=True,
50
+ add_watermarker=False,
51
+ ).to(device)
52
+ pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
53
+
 
 
 
 
 
 
54
  if USE_TORCH_COMPILE:
55
  pipe.compile()
56
+
57
  if ENABLE_CPU_OFFLOAD:
58
  pipe.enable_model_cpu_offload()
59
+
60
  return pipe
61
 
62
+ # Preload and compile both models
63
  models = {key: load_and_prepare_model(value) for key, value in MODEL_OPTIONS.items()}
64
 
65
  MAX_SEED = np.iinfo(np.int32).max
 
92
  ):
93
  global models
94
  pipe = models[model_choice]
95
+
96
  seed = int(randomize_seed_fn(seed, randomize_seed))
97
  generator = torch.Generator(device=device).manual_seed(seed)
98
 
 
139
  return predefined_images
140
 
141
  with gr.Blocks(css=css) as demo:
142
+ gr.Markdown(DESCRIPTIONx)
143
  with gr.Row():
144
  prompt = gr.Text(
145
  label="Prompt",
 
150
  container=False,
151
  )
152
  run_button = gr.Button("Run⚡", scale=0)
153
+ result = gr.Gallery(label="Result", columns=1, show_label=False)
154
 
155
  with gr.Row():
156
  model_choice = gr.Dropdown(
 
217
  value=20,
218
  )
219
 
220
+ # gr.Examples(
221
+ # examples=examples,
222
+ # inputs=prompt,
223
+ # cache_examples=False
224
+ #)
225
+
226
  use_negative_prompt.change(
227
  fn=lambda x: gr.update(visible=x),
228
  inputs=use_negative_prompt,
229
  outputs=negative_prompt,
230
  api_name=False,
231
  )
232
+
233
  gr.on(
234
  triggers=[
235
  prompt.submit,
 
253
  outputs=[result, seed],
254
  api_name="run",
255
  )
256
+ #!/usr/bin/env python
257
+ import os
258
+ import random
259
+ import uuid
260
+ import gradio as gr
261
+ import numpy as np
262
+ from PIL import Image
263
+ import spaces
264
+ import torch
265
+ from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
266
+
267
+ DESCRIPTIONx = """
268
+
269
+
270
+ """
271
+
272
+ css = '''
273
+ .gradio-container{max-width: 570px !important}
274
+ h1{text-align:center}
275
+ footer {
276
+ visibility: hidden
277
+ }
278
+ '''
279
+
280
+ examples = [
281
+ "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
282
+ "Chocolate dripping from a donut against a yellow background, 8k",
283
+ "Illustration of A starry night camp in the mountains, 4k",
284
+ "A photo of a lavender cat, hdr, 4k",
285
+ "A delicious ceviche cheesecake slice, 4k"
286
+ ]
287
+
288
+ MODEL_OPTIONS = {
289
+ "Lightning": "SG161222/RealVisXL_V4.0_Lightning",
290
+ "Realvision": "SG161222/RealVisXL_V4.0",
291
+ }
292
+
293
+ MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
294
+ USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
295
+ ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
296
+ BATCH_SIZE = int(os.getenv("BATCH_SIZE", "1"))
297
+
298
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
299
+
300
+ def load_and_prepare_model(model_id):
301
+ pipe = StableDiffusionXLPipeline.from_pretrained(
302
+ model_id,
303
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
304
+ use_safetensors=True,
305
+ add_watermarker=False,
306
+ ).to(device)
307
+ pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
308
+
309
+ if USE_TORCH_COMPILE:
310
+ pipe.compile()
311
+
312
+ if ENABLE_CPU_OFFLOAD:
313
+ pipe.enable_model_cpu_offload()
314
+
315
+ return pipe
316
+
317
+ # Preload and compile both models
318
+ models = {key: load_and_prepare_model(value) for key, value in MODEL_OPTIONS.items()}
319
+
320
+ MAX_SEED = np.iinfo(np.int32).max
321
 
322
+ def save_image(img):
323
+ unique_name = str(uuid.uuid4()) + ".png"
324
+ img.save(unique_name)
325
+ return unique_name
326
+
327
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
328
+ if randomize_seed:
329
+ seed = random.randint(0, MAX_SEED)
330
+ return seed
331
+
332
+ @spaces.GPU(duration=60, enable_queue=True)
333
+ def generate(
334
+ model_choice: str,
335
+ prompt: str,
336
+ negative_prompt: str = "",
337
+ use_negative_prompt: bool = False,
338
+ seed: int = 1,
339
+ width: int = 1024,
340
+ height: int = 1024,
341
+ guidance_scale: float = 3,
342
+ num_inference_steps: int = 25,
343
+ randomize_seed: bool = False,
344
+ use_resolution_binning: bool = True,
345
+ num_images: int = 1,
346
+ progress=gr.Progress(track_tqdm=True),
347
+ ):
348
+ global models
349
+ pipe = models[model_choice]
350
+
351
+ seed = int(randomize_seed_fn(seed, randomize_seed))
352
+ generator = torch.Generator(device=device).manual_seed(seed)
353
+
354
+ options = {
355
+ "prompt": [prompt] * num_images,
356
+ "negative_prompt": [negative_prompt] * num_images if use_negative_prompt else None,
357
+ "width": width,
358
+ "height": height,
359
+ "guidance_scale": guidance_scale,
360
+ "num_inference_steps": num_inference_steps,
361
+ "generator": generator,
362
+ "output_type": "pil",
363
+ }
364
+
365
+ if use_resolution_binning:
366
+ options["use_resolution_binning"] = True
367
+
368
+ images = []
369
+ for i in range(0, num_images, BATCH_SIZE):
370
+ batch_options = options.copy()
371
+ batch_options["prompt"] = options["prompt"][i:i+BATCH_SIZE]
372
+ if "negative_prompt" in batch_options:
373
+ batch_options["negative_prompt"] = options["negative_prompt"][i:i+BATCH_SIZE]
374
+ images.extend(pipe(**batch_options).images)
375
+
376
+ image_paths = [save_image(img) for img in images]
377
+ return image_paths, seed
378
+
379
+ #def load_predefined_images():
380
+ # predefined_images = [
381
+ # "assets/1.png",
382
+ # "assets/2.png",
383
+ # "assets/3.png",
384
+ # "assets/4.png",
385
+ # "assets/5.png",
386
+ # "assets/6.png",
387
+ # "assets/7.png",
388
+ # "assets/8.png",
389
+ # "assets/9.png",
390
+ # "assets/10.png",
391
+ # "assets/11.png",
392
+ # "assets/12.png",
393
+ # ]
394
+ # return predefined_images
395
+
396
+ with gr.Blocks(css=css) as demo:
397
+ gr.Markdown(DESCRIPTIONx)
398
+ with gr.Row():
399
+ prompt = gr.Text(
400
+ label="Prompt",
401
+ show_label=False,
402
+ max_lines=1,
403
+ placeholder="Enter your prompt",
404
+ value="Chocolate dripping from a donut against a yellow background, 8k",
405
+ container=False,
406
+ )
407
+ run_button = gr.Button("Run⚡", scale=0)
408
+ result = gr.Gallery(label="Result", columns=1, show_label=False)
409
+
410
+ with gr.Row():
411
+ model_choice = gr.Dropdown(
412
+ label="Model Selection",
413
+ choices=list(MODEL_OPTIONS.keys()),
414
+ value="Lightning"
415
+ )
416
+
417
+ with gr.Accordion("Advanced options", open=True, visible=False):
418
+ num_images = gr.Slider(
419
+ label="Number of Images",
420
+ minimum=1,
421
+ maximum=1,
422
+ step=1,
423
+ value=1,
424
+ )
425
+ with gr.Row():
426
+ with gr.Column(scale=1):
427
+ use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=True)
428
+ negative_prompt = gr.Text(
429
+ label="Negative prompt",
430
+ max_lines=5,
431
+ lines=4,
432
+ placeholder="Enter a negative prompt",
433
+ value="(deformed, distorted, disfigured:1.3), poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, (mutated hands and fingers:1.4), disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation",
434
+ visible=True,
435
+ )
436
+ seed = gr.Slider(
437
+ label="Seed",
438
+ minimum=0,
439
+ maximum=MAX_SEED,
440
+ step=1,
441
+ value=0,
442
+ )
443
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
444
+ with gr.Row():
445
+ width = gr.Slider(
446
+ label="Width",
447
+ minimum=512,
448
+ maximum=MAX_IMAGE_SIZE,
449
+ step=64,
450
+ value=1024,
451
+ )
452
+ height = gr.Slider(
453
+ label="Height",
454
+ minimum=512,
455
+ maximum=MAX_IMAGE_SIZE,
456
+ step=64,
457
+ value=1024,
458
+ )
459
+ with gr.Row():
460
+ guidance_scale = gr.Slider(
461
+ label="Guidance Scale",
462
+ minimum=0.1,
463
+ maximum=6,
464
+ step=0.1,
465
+ value=3.0,
466
+ )
467
+ num_inference_steps = gr.Slider(
468
+ label="Number of inference steps",
469
+ minimum=1,
470
+ maximum=35,
471
+ step=1,
472
+ value=20,
473
+ )
474
+
475
+ gr.Examples(
476
+ examples=examples,
477
+ inputs=prompt,
478
+ cache_examples=False
479
+ )
480
+
481
+ use_negative_prompt.change(
482
+ fn=lambda x: gr.update(visible=x),
483
+ inputs=use_negative_prompt,
484
+ outputs=negative_prompt,
485
+ api_name=False,
486
+ )
487
+
488
+ gr.on(
489
+ triggers=[
490
+ prompt.submit,
491
+ negative_prompt.submit,
492
+ run_button.click,
493
+ ],
494
+ fn=generate,
495
+ inputs=[
496
+ model_choice,
497
+ prompt,
498
+ negative_prompt,
499
+ use_negative_prompt,
500
+ seed,
501
+ width,
502
+ height,
503
+ guidance_scale,
504
+ num_inference_steps,
505
+ randomize_seed,
506
+ num_images
507
+ ],
508
+ outputs=[result, seed],
509
+ api_name="run",
510
+ )
511
+ # with gr.Column(scale=3):
512
+ # gr.Markdown("### Image Gallery")
513
+ # predefined_gallery = gr.Gallery(label="Image Gallery", columns=4, show_label=False, value=load_predefined_images())
514
+ if __name__ == "__main__":
515
+ demo.queue(max_size=40).launch(show_api=False)
516
+
517
  if __name__ == "__main__":
518
  demo.queue(max_size=40).launch(show_api=False)