lnyan commited on
Commit
5d56ba2
1 Parent(s): 65ade76

Update vae, remove autocast

Browse files
Files changed (1) hide show
  1. app.py +18 -11
app.py CHANGED
@@ -16,6 +16,7 @@ from diffusers import (
16
  DDIMScheduler,
17
  LMSDiscreteScheduler,
18
  )
 
19
  from PIL import Image
20
  from PIL import ImageOps
21
  import gradio as gr
@@ -73,10 +74,9 @@ finally:
73
  else:
74
  device = "cpu"
75
 
76
- if device != "cuda":
77
- import contextlib
78
 
79
- autocast = contextlib.nullcontext
80
 
81
  with open("config.yaml", "r") as yaml_in:
82
  yaml_object = yaml.safe_load(yaml_in)
@@ -258,15 +258,17 @@ class StableDiffusionInpaint:
258
  model_name = os.path.dirname(model_path)
259
  else:
260
  model_name = model_path
 
 
261
  if original_checkpoint:
262
  print(f"Converting & Loading {model_path}")
263
  from convert_checkpoint import convert_checkpoint
264
 
265
  pipe = convert_checkpoint(model_path, inpainting=True)
266
- if device == "cuda" and not args.fp32:
267
  pipe.to(torch.float16)
268
  inpaint = StableDiffusionInpaintPipeline(
269
- vae=pipe.vae,
270
  text_encoder=pipe.text_encoder,
271
  tokenizer=pipe.tokenizer,
272
  unet=pipe.unet,
@@ -276,12 +278,13 @@ class StableDiffusionInpaint:
276
  )
277
  else:
278
  print(f"Loading {model_name}")
279
- if device == "cuda" and not args.fp32:
280
  inpaint = StableDiffusionInpaintPipeline.from_pretrained(
281
  model_name,
282
  revision="fp16",
283
  torch_dtype=torch.float16,
284
  use_auth_token=token,
 
285
  )
286
  else:
287
  inpaint = StableDiffusionInpaintPipeline.from_pretrained(
@@ -385,7 +388,7 @@ class StableDiffusionInpaint:
385
  init_image = Image.fromarray(img)
386
  mask_image = Image.fromarray(mask)
387
  # mask_image=mask_image.filter(ImageFilter.GaussianBlur(radius = 8))
388
- with autocast("cuda"):
389
  images = inpaint_func(
390
  prompt=prompt,
391
  image=init_image.resize(
@@ -410,6 +413,8 @@ class StableDiffusion:
410
  ):
411
  self.token = token
412
  original_checkpoint = False
 
 
413
  if model_path and os.path.exists(model_path):
414
  if model_path.endswith(".ckpt"):
415
  original_checkpoint = True
@@ -432,6 +437,7 @@ class StableDiffusion:
432
  revision="fp16",
433
  torch_dtype=torch.float16,
434
  use_auth_token=token,
 
435
  )
436
  else:
437
  text2img = StableDiffusionPipeline.from_pretrained(
@@ -449,12 +455,13 @@ class StableDiffusion:
449
  import gc
450
 
451
  gc.collect()
452
- if device == "cuda" and not args.fp32:
453
  inpaint = StableDiffusionInpaintPipeline.from_pretrained(
454
  "runwayml/stable-diffusion-inpainting",
455
  revision="fp16",
456
  torch_dtype=torch.float16,
457
  use_auth_token=token,
 
458
  ).to(device)
459
  else:
460
  inpaint = StableDiffusionInpaintPipeline.from_pretrained(
@@ -604,7 +611,7 @@ class StableDiffusion:
604
  extra_kwargs["generator"] = generator
605
  if nmask.sum() < 1 and enable_img2img:
606
  init_image = Image.fromarray(img)
607
- with autocast("cuda"):
608
  images = img2img(
609
  prompt=prompt,
610
  init_image=init_image.resize(
@@ -631,7 +638,7 @@ class StableDiffusion:
631
  init_image = Image.fromarray(img)
632
  mask_image = Image.fromarray(mask)
633
  # mask_image=mask_image.filter(ImageFilter.GaussianBlur(radius = 8))
634
- with autocast("cuda"):
635
  input_image = init_image.resize(
636
  (process_width, process_height), resample=SAMPLING_MODE
637
  )
@@ -645,7 +652,7 @@ class StableDiffusion:
645
  **extra_kwargs,
646
  )["images"]
647
  else:
648
- with autocast("cuda"):
649
  images = text2img(
650
  prompt=prompt,
651
  height=process_width,
 
16
  DDIMScheduler,
17
  LMSDiscreteScheduler,
18
  )
19
+ from diffusers.models import AutoencoderKL
20
  from PIL import Image
21
  from PIL import ImageOps
22
  import gradio as gr
 
74
  else:
75
  device = "cpu"
76
 
77
+ import contextlib
 
78
 
79
+ autocast = contextlib.nullcontext
80
 
81
  with open("config.yaml", "r") as yaml_in:
82
  yaml_object = yaml.safe_load(yaml_in)
 
258
  model_name = os.path.dirname(model_path)
259
  else:
260
  model_name = model_path
261
+ vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse")
262
+ vae.to(torch.float16)
263
  if original_checkpoint:
264
  print(f"Converting & Loading {model_path}")
265
  from convert_checkpoint import convert_checkpoint
266
 
267
  pipe = convert_checkpoint(model_path, inpainting=True)
268
+ if device == "cuda":
269
  pipe.to(torch.float16)
270
  inpaint = StableDiffusionInpaintPipeline(
271
+ vae=vae,
272
  text_encoder=pipe.text_encoder,
273
  tokenizer=pipe.tokenizer,
274
  unet=pipe.unet,
 
278
  )
279
  else:
280
  print(f"Loading {model_name}")
281
+ if device == "cuda":
282
  inpaint = StableDiffusionInpaintPipeline.from_pretrained(
283
  model_name,
284
  revision="fp16",
285
  torch_dtype=torch.float16,
286
  use_auth_token=token,
287
+ vae=vae
288
  )
289
  else:
290
  inpaint = StableDiffusionInpaintPipeline.from_pretrained(
 
388
  init_image = Image.fromarray(img)
389
  mask_image = Image.fromarray(mask)
390
  # mask_image=mask_image.filter(ImageFilter.GaussianBlur(radius = 8))
391
+ if True:
392
  images = inpaint_func(
393
  prompt=prompt,
394
  image=init_image.resize(
 
413
  ):
414
  self.token = token
415
  original_checkpoint = False
416
+ vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse")
417
+ vae.to(torch.float16)
418
  if model_path and os.path.exists(model_path):
419
  if model_path.endswith(".ckpt"):
420
  original_checkpoint = True
 
437
  revision="fp16",
438
  torch_dtype=torch.float16,
439
  use_auth_token=token,
440
+ vae=vae
441
  )
442
  else:
443
  text2img = StableDiffusionPipeline.from_pretrained(
 
455
  import gc
456
 
457
  gc.collect()
458
+ if device == "cuda":
459
  inpaint = StableDiffusionInpaintPipeline.from_pretrained(
460
  "runwayml/stable-diffusion-inpainting",
461
  revision="fp16",
462
  torch_dtype=torch.float16,
463
  use_auth_token=token,
464
+ vae=vae
465
  ).to(device)
466
  else:
467
  inpaint = StableDiffusionInpaintPipeline.from_pretrained(
 
611
  extra_kwargs["generator"] = generator
612
  if nmask.sum() < 1 and enable_img2img:
613
  init_image = Image.fromarray(img)
614
+ if True:
615
  images = img2img(
616
  prompt=prompt,
617
  init_image=init_image.resize(
 
638
  init_image = Image.fromarray(img)
639
  mask_image = Image.fromarray(mask)
640
  # mask_image=mask_image.filter(ImageFilter.GaussianBlur(radius = 8))
641
+ if True:
642
  input_image = init_image.resize(
643
  (process_width, process_height), resample=SAMPLING_MODE
644
  )
 
652
  **extra_kwargs,
653
  )["images"]
654
  else:
655
+ if True:
656
  images = text2img(
657
  prompt=prompt,
658
  height=process_width,