stanley commited on
Commit
6b83218
·
1 Parent(s): 028f63a

debuggin huggin

Browse files
Files changed (1) hide show
  1. app.py +319 -319
app.py CHANGED
@@ -477,325 +477,325 @@ class StableDiffusionInpaint:
477
  return images
478
 
479
 
480
- # class StableDiffusion:
481
- # def __init__(
482
- # self,
483
- # token: str = "",
484
- # model_name: str = "runwayml/stable-diffusion-v1-5",
485
- # model_path: str = None,
486
- # inpainting_model: bool = False,
487
- # **kwargs,
488
- # ):
489
- # self.token = token
490
- # original_checkpoint = False
491
- # if device=="cpu" and onnx_available:
492
- # from diffusers import OnnxStableDiffusionPipeline, OnnxStableDiffusionInpaintPipelineLegacy, OnnxStableDiffusionImg2ImgPipeline
493
- # text2img = OnnxStableDiffusionPipeline.from_pretrained(
494
- # model_name,
495
- # revision="onnx",
496
- # provider=onnx_providers[0] if onnx_providers else None
497
- # )
498
- # inpaint = OnnxStableDiffusionInpaintPipelineLegacy(
499
- # vae_encoder=text2img.vae_encoder,
500
- # vae_decoder=text2img.vae_decoder,
501
- # text_encoder=text2img.text_encoder,
502
- # tokenizer=text2img.tokenizer,
503
- # unet=text2img.unet,
504
- # scheduler=text2img.scheduler,
505
- # safety_checker=text2img.safety_checker,
506
- # feature_extractor=text2img.feature_extractor,
507
- # )
508
- # img2img = OnnxStableDiffusionImg2ImgPipeline(
509
- # vae_encoder=text2img.vae_encoder,
510
- # vae_decoder=text2img.vae_decoder,
511
- # text_encoder=text2img.text_encoder,
512
- # tokenizer=text2img.tokenizer,
513
- # unet=text2img.unet,
514
- # scheduler=text2img.scheduler,
515
- # safety_checker=text2img.safety_checker,
516
- # feature_extractor=text2img.feature_extractor,
517
- # )
518
- # else:
519
- # if model_path and os.path.exists(model_path):
520
- # if model_path.endswith(".ckpt"):
521
- # original_checkpoint = True
522
- # elif model_path.endswith(".json"):
523
- # model_name = os.path.dirname(model_path)
524
- # else:
525
- # model_name = model_path
526
- # vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse")
527
- # if device == "cuda" and not args.fp32:
528
- # vae.to(torch.float16)
529
- # if original_checkpoint:
530
- # print(f"Converting & Loading {model_path}")
531
- # from convert_checkpoint import convert_checkpoint
532
-
533
- # pipe = convert_checkpoint(model_path)
534
- # if device == "cuda" and not args.fp32:
535
- # pipe.to(torch.float16)
536
- # text2img = StableDiffusionPipeline(
537
- # vae=vae,
538
- # text_encoder=pipe.text_encoder,
539
- # tokenizer=pipe.tokenizer,
540
- # unet=pipe.unet,
541
- # scheduler=pipe.scheduler,
542
- # safety_checker=pipe.safety_checker,
543
- # feature_extractor=pipe.feature_extractor,
544
- # )
545
- # else:
546
- # print(f"Loading {model_name}")
547
- # if device == "cuda" and not args.fp32:
548
- # text2img = StableDiffusionPipeline.from_pretrained(
549
- # model_name,
550
- # revision="fp16",
551
- # torch_dtype=torch.float16,
552
- # use_auth_token=token,
553
- # vae=vae,
554
- # )
555
- # else:
556
- # text2img = StableDiffusionPipeline.from_pretrained(
557
- # model_name, use_auth_token=token, vae=vae
558
- # )
559
- # if inpainting_model:
560
- # # can reduce vRAM by reusing models except unet
561
- # text2img_unet = text2img.unet
562
- # del text2img.vae
563
- # del text2img.text_encoder
564
- # del text2img.tokenizer
565
- # del text2img.scheduler
566
- # del text2img.safety_checker
567
- # del text2img.feature_extractor
568
- # import gc
569
-
570
- # gc.collect()
571
- # if device == "cuda" and not args.fp32:
572
- # inpaint = StableDiffusionInpaintPipeline.from_pretrained(
573
- # "runwayml/stable-diffusion-inpainting",
574
- # revision="fp16",
575
- # torch_dtype=torch.float16,
576
- # use_auth_token=token,
577
- # vae=vae,
578
- # ).to(device)
579
- # else:
580
- # inpaint = StableDiffusionInpaintPipeline.from_pretrained(
581
- # "runwayml/stable-diffusion-inpainting",
582
- # use_auth_token=token,
583
- # vae=vae,
584
- # ).to(device)
585
- # text2img_unet.to(device)
586
- # text2img = StableDiffusionPipeline(
587
- # vae=inpaint.vae,
588
- # text_encoder=inpaint.text_encoder,
589
- # tokenizer=inpaint.tokenizer,
590
- # unet=text2img_unet,
591
- # scheduler=inpaint.scheduler,
592
- # safety_checker=inpaint.safety_checker,
593
- # feature_extractor=inpaint.feature_extractor,
594
- # )
595
- # else:
596
- # inpaint = StableDiffusionInpaintPipelineLegacy(
597
- # vae=text2img.vae,
598
- # text_encoder=text2img.text_encoder,
599
- # tokenizer=text2img.tokenizer,
600
- # unet=text2img.unet,
601
- # scheduler=text2img.scheduler,
602
- # safety_checker=text2img.safety_checker,
603
- # feature_extractor=text2img.feature_extractor,
604
- # ).to(device)
605
- # text_encoder = text2img.text_encoder
606
- # tokenizer = text2img.tokenizer
607
- # if os.path.exists("./embeddings"):
608
- # for item in os.listdir("./embeddings"):
609
- # if item.endswith(".bin"):
610
- # load_learned_embed_in_clip(
611
- # os.path.join("./embeddings", item),
612
- # text2img.text_encoder,
613
- # text2img.tokenizer,
614
- # )
615
- # text2img.to(device)
616
- # if device == "mps":
617
- # _ = text2img("", num_inference_steps=1)
618
- # img2img = StableDiffusionImg2ImgPipeline(
619
- # vae=text2img.vae,
620
- # text_encoder=text2img.text_encoder,
621
- # tokenizer=text2img.tokenizer,
622
- # unet=text2img.unet,
623
- # scheduler=text2img.scheduler,
624
- # safety_checker=text2img.safety_checker,
625
- # feature_extractor=text2img.feature_extractor,
626
- # ).to(device)
627
- # scheduler_dict["PLMS"] = text2img.scheduler
628
- # scheduler_dict["DDIM"] = prepare_scheduler(
629
- # DDIMScheduler(
630
- # beta_start=0.00085,
631
- # beta_end=0.012,
632
- # beta_schedule="scaled_linear",
633
- # clip_sample=False,
634
- # set_alpha_to_one=False,
635
- # )
636
- # )
637
- # scheduler_dict["K-LMS"] = prepare_scheduler(
638
- # LMSDiscreteScheduler(
639
- # beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
640
- # )
641
- # )
642
- # scheduler_dict["PNDM"] = prepare_scheduler(
643
- # PNDMScheduler(
644
- # beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear",
645
- # skip_prk_steps=True
646
- # )
647
- # )
648
- # scheduler_dict["DPM"] = prepare_scheduler(
649
- # DPMSolverMultistepScheduler.from_config(text2img.scheduler.config)
650
- # )
651
- # self.safety_checker = text2img.safety_checker
652
- # save_token(token)
653
- # try:
654
- # total_memory = torch.cuda.get_device_properties(0).total_memory // (
655
- # 1024 ** 3
656
- # )
657
- # if total_memory <= 5 or args.lowvram:
658
- # inpaint.enable_attention_slicing()
659
- # inpaint.enable_sequential_cpu_offload()
660
- # if inpainting_model:
661
- # text2img.enable_attention_slicing()
662
- # text2img.enable_sequential_cpu_offload()
663
- # except:
664
- # pass
665
- # self.text2img = text2img
666
- # self.inpaint = inpaint
667
- # self.img2img = img2img
668
- # if True:
669
- # self.unified = inpaint
670
- # else:
671
- # self.unified = UnifiedPipeline(
672
- # vae=text2img.vae,
673
- # text_encoder=text2img.text_encoder,
674
- # tokenizer=text2img.tokenizer,
675
- # unet=text2img.unet,
676
- # scheduler=text2img.scheduler,
677
- # safety_checker=text2img.safety_checker,
678
- # feature_extractor=text2img.feature_extractor,
679
- # ).to(device)
680
- # self.inpainting_model = inpainting_model
681
-
682
- # def run(
683
- # self,
684
- # image_pil,
685
- # prompt="",
686
- # negative_prompt="",
687
- # guidance_scale=7.5,
688
- # resize_check=True,
689
- # enable_safety=True,
690
- # fill_mode="patchmatch",
691
- # strength=0.75,
692
- # step=50,
693
- # enable_img2img=False,
694
- # use_seed=False,
695
- # seed_val=-1,
696
- # generate_num=1,
697
- # scheduler="",
698
- # scheduler_eta=0.0,
699
- # **kwargs,
700
- # ):
701
- # text2img, inpaint, img2img, unified = (
702
- # self.text2img,
703
- # self.inpaint,
704
- # self.img2img,
705
- # self.unified,
706
- # )
707
- # selected_scheduler = scheduler_dict.get(scheduler, scheduler_dict["PLMS"])
708
- # for item in [text2img, inpaint, img2img, unified]:
709
- # item.scheduler = selected_scheduler
710
- # if enable_safety or self.safety_checker is None:
711
- # item.safety_checker = self.safety_checker
712
- # else:
713
- # item.safety_checker = lambda images, **kwargs: (images, False)
714
- # if RUN_IN_SPACE:
715
- # step = max(150, step)
716
- # image_pil = contain_func(image_pil, (1024, 1024))
717
- # width, height = image_pil.size
718
- # sel_buffer = np.array(image_pil)
719
- # img = sel_buffer[:, :, 0:3]
720
- # mask = sel_buffer[:, :, -1]
721
- # nmask = 255 - mask
722
- # process_width = width
723
- # process_height = height
724
- # if resize_check:
725
- # process_width, process_height = my_resize(width, height)
726
- # extra_kwargs = {
727
- # "num_inference_steps": step,
728
- # "guidance_scale": guidance_scale,
729
- # "eta": scheduler_eta,
730
- # }
731
- # if RUN_IN_SPACE:
732
- # generate_num = max(
733
- # int(4 * 512 * 512 // process_width // process_height), generate_num
734
- # )
735
- # if USE_NEW_DIFFUSERS:
736
- # extra_kwargs["negative_prompt"] = negative_prompt
737
- # extra_kwargs["num_images_per_prompt"] = generate_num
738
- # if use_seed:
739
- # generator = torch.Generator(text2img.device).manual_seed(seed_val)
740
- # extra_kwargs["generator"] = generator
741
- # if nmask.sum() < 1 and enable_img2img:
742
- # init_image = Image.fromarray(img)
743
- # if True:
744
- # images = img2img(
745
- # prompt=prompt,
746
- # image=init_image.resize(
747
- # (process_width, process_height), resample=SAMPLING_MODE
748
- # ),
749
- # strength=strength,
750
- # **extra_kwargs,
751
- # )["images"]
752
- # elif mask.sum() > 0:
753
- # if fill_mode == "g_diffuser" and not self.inpainting_model:
754
- # mask = 255 - mask
755
- # mask = mask[:, :, np.newaxis].repeat(3, axis=2)
756
- # img, mask = functbl[fill_mode](img, mask)
757
- # extra_kwargs["strength"] = 1.0
758
- # extra_kwargs["out_mask"] = Image.fromarray(mask)
759
- # inpaint_func = unified
760
- # else:
761
- # img, mask = functbl[fill_mode](img, mask)
762
- # mask = 255 - mask
763
- # mask = skimage.measure.block_reduce(mask, (8, 8), np.max)
764
- # mask = mask.repeat(8, axis=0).repeat(8, axis=1)
765
- # inpaint_func = inpaint
766
- # init_image = Image.fromarray(img)
767
- # mask_image = Image.fromarray(mask)
768
- # # mask_image=mask_image.filter(ImageFilter.GaussianBlur(radius = 8))
769
- # input_image = init_image.resize(
770
- # (process_width, process_height), resample=SAMPLING_MODE
771
- # )
772
- # if self.inpainting_model:
773
- # images = inpaint_func(
774
- # prompt=prompt,
775
- # image=input_image,
776
- # width=process_width,
777
- # height=process_height,
778
- # mask_image=mask_image.resize((process_width, process_height)),
779
- # **extra_kwargs,
780
- # )["images"]
781
- # else:
782
- # extra_kwargs["strength"] = strength
783
- # if True:
784
- # images = inpaint_func(
785
- # prompt=prompt,
786
- # image=input_image,
787
- # mask_image=mask_image.resize((process_width, process_height)),
788
- # **extra_kwargs,
789
- # )["images"]
790
- # else:
791
- # if True:
792
- # images = text2img(
793
- # prompt=prompt,
794
- # height=process_width,
795
- # width=process_height,
796
- # **extra_kwargs,
797
- # )["images"]
798
- # return images
799
 
800
 
801
  def get_model(token="", model_choice="", model_path=""):
 
477
  return images
478
 
479
 
480
+ class StableDiffusion:
481
+ def __init__(
482
+ self,
483
+ token: str = "",
484
+ model_name: str = "runwayml/stable-diffusion-v1-5",
485
+ model_path: str = None,
486
+ inpainting_model: bool = False,
487
+ **kwargs,
488
+ ):
489
+ self.token = token
490
+ original_checkpoint = False
491
+ if device=="cpu" and onnx_available:
492
+ from diffusers import OnnxStableDiffusionPipeline, OnnxStableDiffusionInpaintPipelineLegacy, OnnxStableDiffusionImg2ImgPipeline
493
+ text2img = OnnxStableDiffusionPipeline.from_pretrained(
494
+ model_name,
495
+ revision="onnx",
496
+ provider=onnx_providers[0] if onnx_providers else None
497
+ )
498
+ inpaint = OnnxStableDiffusionInpaintPipelineLegacy(
499
+ vae_encoder=text2img.vae_encoder,
500
+ vae_decoder=text2img.vae_decoder,
501
+ text_encoder=text2img.text_encoder,
502
+ tokenizer=text2img.tokenizer,
503
+ unet=text2img.unet,
504
+ scheduler=text2img.scheduler,
505
+ safety_checker=text2img.safety_checker,
506
+ feature_extractor=text2img.feature_extractor,
507
+ )
508
+ img2img = OnnxStableDiffusionImg2ImgPipeline(
509
+ vae_encoder=text2img.vae_encoder,
510
+ vae_decoder=text2img.vae_decoder,
511
+ text_encoder=text2img.text_encoder,
512
+ tokenizer=text2img.tokenizer,
513
+ unet=text2img.unet,
514
+ scheduler=text2img.scheduler,
515
+ safety_checker=text2img.safety_checker,
516
+ feature_extractor=text2img.feature_extractor,
517
+ )
518
+ else:
519
+ if model_path and os.path.exists(model_path):
520
+ if model_path.endswith(".ckpt"):
521
+ original_checkpoint = True
522
+ elif model_path.endswith(".json"):
523
+ model_name = os.path.dirname(model_path)
524
+ else:
525
+ model_name = model_path
526
+ vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse")
527
+ if device == "cuda" and not args.fp32:
528
+ vae.to(torch.float16)
529
+ if original_checkpoint:
530
+ print(f"Converting & Loading {model_path}")
531
+ from convert_checkpoint import convert_checkpoint
532
+
533
+ pipe = convert_checkpoint(model_path)
534
+ if device == "cuda" and not args.fp32:
535
+ pipe.to(torch.float16)
536
+ text2img = StableDiffusionPipeline(
537
+ vae=vae,
538
+ text_encoder=pipe.text_encoder,
539
+ tokenizer=pipe.tokenizer,
540
+ unet=pipe.unet,
541
+ scheduler=pipe.scheduler,
542
+ safety_checker=pipe.safety_checker,
543
+ feature_extractor=pipe.feature_extractor,
544
+ )
545
+ else:
546
+ print(f"Loading {model_name}")
547
+ if device == "cuda" and not args.fp32:
548
+ text2img = StableDiffusionPipeline.from_pretrained(
549
+ model_name,
550
+ revision="fp16",
551
+ torch_dtype=torch.float16,
552
+ use_auth_token=token,
553
+ vae=vae,
554
+ )
555
+ else:
556
+ text2img = StableDiffusionPipeline.from_pretrained(
557
+ model_name, use_auth_token=token, vae=vae
558
+ )
559
+ if inpainting_model:
560
+ # can reduce vRAM by reusing models except unet
561
+ text2img_unet = text2img.unet
562
+ del text2img.vae
563
+ del text2img.text_encoder
564
+ del text2img.tokenizer
565
+ del text2img.scheduler
566
+ del text2img.safety_checker
567
+ del text2img.feature_extractor
568
+ import gc
569
+
570
+ gc.collect()
571
+ if device == "cuda" and not args.fp32:
572
+ inpaint = StableDiffusionInpaintPipeline.from_pretrained(
573
+ "runwayml/stable-diffusion-inpainting",
574
+ revision="fp16",
575
+ torch_dtype=torch.float16,
576
+ use_auth_token=token,
577
+ vae=vae,
578
+ ).to(device)
579
+ else:
580
+ inpaint = StableDiffusionInpaintPipeline.from_pretrained(
581
+ "runwayml/stable-diffusion-inpainting",
582
+ use_auth_token=token,
583
+ vae=vae,
584
+ ).to(device)
585
+ text2img_unet.to(device)
586
+ text2img = StableDiffusionPipeline(
587
+ vae=inpaint.vae,
588
+ text_encoder=inpaint.text_encoder,
589
+ tokenizer=inpaint.tokenizer,
590
+ unet=text2img_unet,
591
+ scheduler=inpaint.scheduler,
592
+ safety_checker=inpaint.safety_checker,
593
+ feature_extractor=inpaint.feature_extractor,
594
+ )
595
+ else:
596
+ inpaint = StableDiffusionInpaintPipelineLegacy(
597
+ vae=text2img.vae,
598
+ text_encoder=text2img.text_encoder,
599
+ tokenizer=text2img.tokenizer,
600
+ unet=text2img.unet,
601
+ scheduler=text2img.scheduler,
602
+ safety_checker=text2img.safety_checker,
603
+ feature_extractor=text2img.feature_extractor,
604
+ ).to(device)
605
+ text_encoder = text2img.text_encoder
606
+ tokenizer = text2img.tokenizer
607
+ if os.path.exists("./embeddings"):
608
+ for item in os.listdir("./embeddings"):
609
+ if item.endswith(".bin"):
610
+ load_learned_embed_in_clip(
611
+ os.path.join("./embeddings", item),
612
+ text2img.text_encoder,
613
+ text2img.tokenizer,
614
+ )
615
+ text2img.to(device)
616
+ if device == "mps":
617
+ _ = text2img("", num_inference_steps=1)
618
+ img2img = StableDiffusionImg2ImgPipeline(
619
+ vae=text2img.vae,
620
+ text_encoder=text2img.text_encoder,
621
+ tokenizer=text2img.tokenizer,
622
+ unet=text2img.unet,
623
+ scheduler=text2img.scheduler,
624
+ safety_checker=text2img.safety_checker,
625
+ feature_extractor=text2img.feature_extractor,
626
+ ).to(device)
627
+ scheduler_dict["PLMS"] = text2img.scheduler
628
+ scheduler_dict["DDIM"] = prepare_scheduler(
629
+ DDIMScheduler(
630
+ beta_start=0.00085,
631
+ beta_end=0.012,
632
+ beta_schedule="scaled_linear",
633
+ clip_sample=False,
634
+ set_alpha_to_one=False,
635
+ )
636
+ )
637
+ scheduler_dict["K-LMS"] = prepare_scheduler(
638
+ LMSDiscreteScheduler(
639
+ beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
640
+ )
641
+ )
642
+ scheduler_dict["PNDM"] = prepare_scheduler(
643
+ PNDMScheduler(
644
+ beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear",
645
+ skip_prk_steps=True
646
+ )
647
+ )
648
+ scheduler_dict["DPM"] = prepare_scheduler(
649
+ DPMSolverMultistepScheduler.from_config(text2img.scheduler.config)
650
+ )
651
+ self.safety_checker = text2img.safety_checker
652
+ save_token(token)
653
+ try:
654
+ total_memory = torch.cuda.get_device_properties(0).total_memory // (
655
+ 1024 ** 3
656
+ )
657
+ if total_memory <= 5 or args.lowvram:
658
+ inpaint.enable_attention_slicing()
659
+ inpaint.enable_sequential_cpu_offload()
660
+ if inpainting_model:
661
+ text2img.enable_attention_slicing()
662
+ text2img.enable_sequential_cpu_offload()
663
+ except:
664
+ pass
665
+ self.text2img = text2img
666
+ self.inpaint = inpaint
667
+ self.img2img = img2img
668
+ if True:
669
+ self.unified = inpaint
670
+ else:
671
+ self.unified = UnifiedPipeline(
672
+ vae=text2img.vae,
673
+ text_encoder=text2img.text_encoder,
674
+ tokenizer=text2img.tokenizer,
675
+ unet=text2img.unet,
676
+ scheduler=text2img.scheduler,
677
+ safety_checker=text2img.safety_checker,
678
+ feature_extractor=text2img.feature_extractor,
679
+ ).to(device)
680
+ self.inpainting_model = inpainting_model
681
+
682
+ def run(
683
+ self,
684
+ image_pil,
685
+ prompt="",
686
+ negative_prompt="",
687
+ guidance_scale=7.5,
688
+ resize_check=True,
689
+ enable_safety=True,
690
+ fill_mode="patchmatch",
691
+ strength=0.75,
692
+ step=50,
693
+ enable_img2img=False,
694
+ use_seed=False,
695
+ seed_val=-1,
696
+ generate_num=1,
697
+ scheduler="",
698
+ scheduler_eta=0.0,
699
+ **kwargs,
700
+ ):
701
+ text2img, inpaint, img2img, unified = (
702
+ self.text2img,
703
+ self.inpaint,
704
+ self.img2img,
705
+ self.unified,
706
+ )
707
+ selected_scheduler = scheduler_dict.get(scheduler, scheduler_dict["PLMS"])
708
+ for item in [text2img, inpaint, img2img, unified]:
709
+ item.scheduler = selected_scheduler
710
+ if enable_safety or self.safety_checker is None:
711
+ item.safety_checker = self.safety_checker
712
+ else:
713
+ item.safety_checker = lambda images, **kwargs: (images, False)
714
+ if RUN_IN_SPACE:
715
+ step = max(150, step)
716
+ image_pil = contain_func(image_pil, (1024, 1024))
717
+ width, height = image_pil.size
718
+ sel_buffer = np.array(image_pil)
719
+ img = sel_buffer[:, :, 0:3]
720
+ mask = sel_buffer[:, :, -1]
721
+ nmask = 255 - mask
722
+ process_width = width
723
+ process_height = height
724
+ if resize_check:
725
+ process_width, process_height = my_resize(width, height)
726
+ extra_kwargs = {
727
+ "num_inference_steps": step,
728
+ "guidance_scale": guidance_scale,
729
+ "eta": scheduler_eta,
730
+ }
731
+ if RUN_IN_SPACE:
732
+ generate_num = max(
733
+ int(4 * 512 * 512 // process_width // process_height), generate_num
734
+ )
735
+ if USE_NEW_DIFFUSERS:
736
+ extra_kwargs["negative_prompt"] = negative_prompt
737
+ extra_kwargs["num_images_per_prompt"] = generate_num
738
+ if use_seed:
739
+ generator = torch.Generator(text2img.device).manual_seed(seed_val)
740
+ extra_kwargs["generator"] = generator
741
+ if nmask.sum() < 1 and enable_img2img:
742
+ init_image = Image.fromarray(img)
743
+ if True:
744
+ images = img2img(
745
+ prompt=prompt,
746
+ image=init_image.resize(
747
+ (process_width, process_height), resample=SAMPLING_MODE
748
+ ),
749
+ strength=strength,
750
+ **extra_kwargs,
751
+ )["images"]
752
+ elif mask.sum() > 0:
753
+ if fill_mode == "g_diffuser" and not self.inpainting_model:
754
+ mask = 255 - mask
755
+ mask = mask[:, :, np.newaxis].repeat(3, axis=2)
756
+ img, mask = functbl[fill_mode](img, mask)
757
+ extra_kwargs["strength"] = 1.0
758
+ extra_kwargs["out_mask"] = Image.fromarray(mask)
759
+ inpaint_func = unified
760
+ else:
761
+ img, mask = functbl[fill_mode](img, mask)
762
+ mask = 255 - mask
763
+ mask = skimage.measure.block_reduce(mask, (8, 8), np.max)
764
+ mask = mask.repeat(8, axis=0).repeat(8, axis=1)
765
+ inpaint_func = inpaint
766
+ init_image = Image.fromarray(img)
767
+ mask_image = Image.fromarray(mask)
768
+ # mask_image=mask_image.filter(ImageFilter.GaussianBlur(radius = 8))
769
+ input_image = init_image.resize(
770
+ (process_width, process_height), resample=SAMPLING_MODE
771
+ )
772
+ if self.inpainting_model:
773
+ images = inpaint_func(
774
+ prompt=prompt,
775
+ image=input_image,
776
+ width=process_width,
777
+ height=process_height,
778
+ mask_image=mask_image.resize((process_width, process_height)),
779
+ **extra_kwargs,
780
+ )["images"]
781
+ else:
782
+ extra_kwargs["strength"] = strength
783
+ if True:
784
+ images = inpaint_func(
785
+ prompt=prompt,
786
+ image=input_image,
787
+ mask_image=mask_image.resize((process_width, process_height)),
788
+ **extra_kwargs,
789
+ )["images"]
790
+ else:
791
+ if True:
792
+ images = text2img(
793
+ prompt=prompt,
794
+ height=process_width,
795
+ width=process_height,
796
+ **extra_kwargs,
797
+ )["images"]
798
+ return images
799
 
800
 
801
  def get_model(token="", model_choice="", model_path=""):