stanley commited on
Commit
81e37bd
·
1 Parent(s): e557c36

trying new app

Browse files
Files changed (2) hide show
  1. app.py +35 -538
  2. appHold.py +1582 -0
app.py CHANGED
@@ -1,5 +1,8 @@
1
  import subprocess
 
2
  import pip
 
 
3
 
4
  import io
5
  import base64
@@ -10,11 +13,6 @@ import numpy as np
10
  import torch
11
  from torch import autocast
12
  import diffusers
13
- import requests
14
-
15
-
16
- # assert tuple(map(int,diffusers.__version__.split("."))) >= (0,9,0), "Please upgrade diffusers to 0.9.0"
17
-
18
  from diffusers.configuration_utils import FrozenDict
19
  from diffusers import (
20
  StableDiffusionPipeline,
@@ -23,10 +21,8 @@ from diffusers import (
23
  StableDiffusionInpaintPipelineLegacy,
24
  DDIMScheduler,
25
  LMSDiscreteScheduler,
26
- DiffusionPipeline,
27
  StableDiffusionUpscalePipeline,
28
- DPMSolverMultistepScheduler,
29
- PNDMScheduler,
30
  )
31
  from diffusers.models import AutoencoderKL
32
  from PIL import Image
@@ -38,20 +34,6 @@ import skimage.measure
38
  import yaml
39
  import json
40
  from enum import Enum
41
- from utils import *
42
-
43
- # load environment variables from the .env file
44
- # if os.path.exists(".env"):
45
- # with open(".env") as f:
46
- # for line in f:
47
- # if line.startswith("#") or not line.strip():
48
- # continue
49
- # name, value = line.strip().split("=", 1)
50
- # os.environ[name] = value
51
-
52
-
53
- # access_token = os.environ.get("HF_ACCESS_TOKEN")
54
- # print("access_token from HF 1:", access_token)
55
 
56
  try:
57
  abspath = os.path.abspath(__file__)
@@ -60,6 +42,9 @@ try:
60
  except:
61
  pass
62
 
 
 
 
63
 
64
  USE_NEW_DIFFUSERS = True
65
  RUN_IN_SPACE = "RUN_IN_HG_SPACE" in os.environ
@@ -67,13 +52,9 @@ RUN_IN_SPACE = "RUN_IN_HG_SPACE" in os.environ
67
 
68
  class ModelChoice(Enum):
69
  INPAINTING = "stablediffusion-inpainting"
70
- INPAINTING2 = "stablediffusion-2-inpainting"
71
- INPAINTING_IMG2IMG = "stablediffusion-inpainting+img2img-1.5"
72
- MODEL_2_1 = "stablediffusion-2.1"
73
- MODEL_2_0_V = "stablediffusion-2.0v"
74
- MODEL_2_0 = "stablediffusion-2.0"
75
- MODEL_1_5 = "stablediffusion-1.5"
76
- MODEL_1_4 = "stablediffusion-1.4"
77
 
78
 
79
  try:
@@ -89,41 +70,6 @@ USE_GLID = False
89
  # except:
90
  # USE_GLID = False
91
 
92
- # ******** ORIGINAL ***********
93
- # try:
94
- # import onnxruntime
95
- # onnx_available = True
96
- # onnx_providers = ["CUDAExecutionProvider", "DmlExecutionProvider", "OpenVINOExecutionProvider", 'CPUExecutionProvider']
97
- # available_providers = onnxruntime.get_available_providers()
98
- # onnx_providers = [item for item in onnx_providers if item in available_providers]
99
- # except:
100
- # onnx_available = False
101
- # onnx_providers = []
102
-
103
-
104
- # try:
105
- # cuda_available = torch.cuda.is_available()
106
- # except:
107
- # cuda_available = False
108
- # finally:
109
- # if sys.platform == "darwin":
110
- # device = "mps" if torch.backends.mps.is_available() else "cpu"
111
- # elif cuda_available:
112
- # device = "cuda"
113
- # else:
114
- # device = "cpu"
115
-
116
- # if device != "cuda":
117
- # import contextlib
118
-
119
- # autocast = contextlib.nullcontext
120
-
121
- # with open("config.yaml", "r") as yaml_in:
122
- # yaml_object = yaml.safe_load(yaml_in)
123
- # config_json = json.dumps(yaml_object)
124
-
125
- # ******** ^ ORIGINAL ^ ***********
126
-
127
  try:
128
  cuda_available = torch.cuda.is_available()
129
  except:
@@ -145,8 +91,6 @@ with open("config.yaml", "r") as yaml_in:
145
  config_json = json.dumps(yaml_object)
146
 
147
 
148
- # new ^
149
-
150
  def load_html():
151
  body, canvaspy = "", ""
152
  with open("index.html", encoding="utf8") as f:
@@ -161,7 +105,7 @@ def load_html():
161
 
162
  def test(x):
163
  x = load_html()
164
- return f"""<iframe id="sdinfframe" style="width: 100%; height: 780px" name="result" allow="midi; geolocation; microphone; camera;
165
  display-capture; encrypted-media; vertical-scroll 'none'" sandbox="allow-modals allow-forms
166
  allow-scripts allow-same-origin allow-popups
167
  allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
@@ -203,7 +147,6 @@ parser.add_argument("--host", type=str, help="host", dest="server_name")
203
  parser.add_argument("--share", action="store_true", help="share this app?")
204
  parser.add_argument("--debug", action="store_true", help="debug mode")
205
  parser.add_argument("--fp32", action="store_true", help="using full precision")
206
- parser.add_argument("--lowvram", action="store_true", help="using lowvram mode")
207
  parser.add_argument("--encrypt", action="store_true", help="using https?")
208
  parser.add_argument("--ssl_keyfile", type=str, help="path to ssl_keyfile")
209
  parser.add_argument("--ssl_certfile", type=str, help="path to ssl_certfile")
@@ -221,15 +164,6 @@ parser.add_argument(
221
  "--local_model", type=str, help="use a model stored on your PC", default=""
222
  )
223
 
224
- # original
225
- # if __name__ == "__main__":
226
- # args = parser.parse_args()
227
- # else:
228
- # args = parser.parse_args(["--debug"])
229
- # # args = parser.parse_args(["--debug"])
230
- # if args.auth is not None:
231
- # args.auth = tuple(args.auth)
232
-
233
  if __name__ == "__main__" and not RUN_IN_SPACE:
234
  args = parser.parse_args()
235
  else:
@@ -240,15 +174,6 @@ if args.auth is not None:
240
 
241
  model = {}
242
 
243
- # HF function for token
244
- # def get_token():
245
- # token = "{access_token}"
246
- # if os.path.exists(".token"):
247
- # with open(".token", "r") as f:
248
- # token = f.read()
249
- # print("get_token called", token)
250
- # token = os.environ.get("hftoken", token)
251
- # return token
252
 
253
  def get_token():
254
  token = ""
@@ -292,7 +217,7 @@ def my_resize(width, height):
292
  factor = 1.25
293
  elif smaller < 450:
294
  factor = 1.125
295
- return int(factor * width) // 8 * 8, int(factor * height) // 8 * 8
296
 
297
 
298
  def load_learned_embed_in_clip(
@@ -325,7 +250,7 @@ def load_learned_embed_in_clip(
325
  text_encoder.get_input_embeddings().weight.data[token_id] = embeds
326
 
327
 
328
- scheduler_dict = {"PLMS": None, "DDIM": None, "K-LMS": None, "DPM": None, "PNDM": None}
329
 
330
 
331
  class StableDiffusionInpaint:
@@ -334,14 +259,6 @@ class StableDiffusionInpaint:
334
  ):
335
  self.token = token
336
  original_checkpoint = False
337
- # if device == "cpu" and onnx_available:
338
- # from diffusers import OnnxStableDiffusionInpaintPipeline
339
- # inpaint = OnnxStableDiffusionInpaintPipeline.from_pretrained(
340
- # model_name,
341
- # revision="onnx",
342
- # provider=onnx_providers[0] if onnx_providers else None
343
- # )
344
- # else:
345
  if model_path and os.path.exists(model_path):
346
  if model_path.endswith(".ckpt"):
347
  original_checkpoint = True
@@ -350,8 +267,6 @@ class StableDiffusionInpaint:
350
  else:
351
  model_name = model_path
352
  vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse")
353
- # if device == "cuda" and not args.fp32:
354
- # vae.to(torch.float16)
355
  vae.to(torch.float16)
356
  if original_checkpoint:
357
  print(f"Converting & Loading {model_path}")
@@ -377,13 +292,12 @@ class StableDiffusionInpaint:
377
  revision="fp16",
378
  torch_dtype=torch.float16,
379
  use_auth_token=token,
380
- vae=vae,
381
  )
382
  else:
383
  inpaint = StableDiffusionInpaintPipeline.from_pretrained(
384
- model_name, use_auth_token=token, vae=vae
385
  )
386
- # print(f"access_token from HF:", access_token)
387
  if os.path.exists("./embeddings"):
388
  print("Note that StableDiffusionInpaintPipeline + embeddings is untested")
389
  for item in os.listdir("./embeddings"):
@@ -393,7 +307,13 @@ class StableDiffusionInpaint:
393
  inpaint.text_encoder,
394
  inpaint.tokenizer,
395
  )
396
- inpaint.to(device)
 
 
 
 
 
 
397
  # if device == "mps":
398
  # _ = text2img("", num_inference_steps=1)
399
  scheduler_dict["PLMS"] = inpaint.scheduler
@@ -411,12 +331,6 @@ class StableDiffusionInpaint:
411
  beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
412
  )
413
  )
414
- scheduler_dict["PNDM"] = prepare_scheduler(
415
- PNDMScheduler(
416
- beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear",
417
- skip_prk_steps=True
418
- )
419
- )
420
  scheduler_dict["DPM"] = prepare_scheduler(
421
  DPMSolverMultistepScheduler.from_config(inpaint.scheduler.config)
422
  )
@@ -426,9 +340,8 @@ class StableDiffusionInpaint:
426
  total_memory = torch.cuda.get_device_properties(0).total_memory // (
427
  1024 ** 3
428
  )
429
- if total_memory <= 5 or args.lowvram:
430
  inpaint.enable_attention_slicing()
431
- inpaint.enable_sequential_cpu_offload()
432
  except:
433
  pass
434
  self.inpaint = inpaint
@@ -460,13 +373,6 @@ class StableDiffusionInpaint:
460
  item.safety_checker = self.safety_checker
461
  else:
462
  item.safety_checker = lambda images, **kwargs: (images, False)
463
-
464
- # for item in [inpaint]:
465
- # item.scheduler = selected_scheduler
466
- # if enable_safety or self.safety_checker is None:
467
- # item.safety_checker = self.safety_checker
468
- # else:
469
- # item.safety_checker = lambda images, **kwargs: (images, False)
470
  width, height = image_pil.size
471
  sel_buffer = np.array(image_pil)
472
  img = sel_buffer[:, :, 0:3]
@@ -476,8 +382,8 @@ class StableDiffusionInpaint:
476
  process_height = height
477
  if resize_check:
478
  process_width, process_height = my_resize(width, height)
479
- process_width = process_width * 8 // 8
480
- process_height = process_height * 8 // 8
481
  extra_kwargs = {
482
  "num_inference_steps": step,
483
  "guidance_scale": guidance_scale,
@@ -490,24 +396,15 @@ class StableDiffusionInpaint:
490
  generator = torch.Generator(inpaint.device).manual_seed(seed_val)
491
  extra_kwargs["generator"] = generator
492
  if True:
493
- if fill_mode == "g_diffuser":
494
- mask = 255 - mask
495
- mask = mask[:, :, np.newaxis].repeat(3, axis=2)
496
- img, mask = functbl[fill_mode](img, mask)
497
- else:
498
- img, mask = functbl[fill_mode](img, mask)
499
- mask = 255 - mask
500
- mask = skimage.measure.block_reduce(mask, (8, 8), np.max)
501
- mask = mask.repeat(8, axis=0).repeat(8, axis=1)
502
- # extra_kwargs["strength"] = strength
503
  inpaint_func = inpaint
504
  init_image = Image.fromarray(img)
505
  mask_image = Image.fromarray(mask)
506
  # mask_image=mask_image.filter(ImageFilter.GaussianBlur(radius = 8))
507
-
508
- # Cast input image and mask to float32
509
- # init_image = init_image.convert("RGB").to(torch.float32)
510
- # mask_image = mask_image.convert("L").to(torch.float32)
511
  if True:
512
  images = inpaint_func(
513
  prompt=prompt,
@@ -521,6 +418,7 @@ class StableDiffusionInpaint:
521
  )["images"]
522
  return images
523
 
 
524
  class StableDiffusion:
525
  def __init__(
526
  self,
@@ -784,373 +682,6 @@ class StableDiffusion:
784
  return images
785
 
786
 
787
- # class StableDiffusion:
788
- # def __init__(
789
- # self,
790
- # token: str = "",
791
- # model_name: str = "runwayml/stable-diffusion-v1-5",
792
- # model_path: str = None,
793
- # inpainting_model: bool = False,
794
- # **kwargs,
795
- # ):
796
- # self.token = token
797
- # original_checkpoint = False
798
- # if device=="cpu" and onnx_available:
799
- # from diffusers import OnnxStableDiffusionPipeline, OnnxStableDiffusionInpaintPipelineLegacy, OnnxStableDiffusionImg2ImgPipeline
800
- # text2img = OnnxStableDiffusionPipeline.from_pretrained(
801
- # model_name,
802
- # revision="onnx",
803
- # provider=onnx_providers[0] if onnx_providers else None
804
- # )
805
- # inpaint = OnnxStableDiffusionInpaintPipelineLegacy(
806
- # vae_encoder=text2img.vae_encoder,
807
- # vae_decoder=text2img.vae_decoder,
808
- # text_encoder=text2img.text_encoder,
809
- # tokenizer=text2img.tokenizer,
810
- # unet=text2img.unet,
811
- # scheduler=text2img.scheduler,
812
- # safety_checker=text2img.safety_checker,
813
- # feature_extractor=text2img.feature_extractor,
814
- # )
815
- # img2img = OnnxStableDiffusionImg2ImgPipeline(
816
- # vae_encoder=text2img.vae_encoder,
817
- # vae_decoder=text2img.vae_decoder,
818
- # text_encoder=text2img.text_encoder,
819
- # tokenizer=text2img.tokenizer,
820
- # unet=text2img.unet,
821
- # scheduler=text2img.scheduler,
822
- # safety_checker=text2img.safety_checker,
823
- # feature_extractor=text2img.feature_extractor,
824
- # )
825
- # else:
826
- # if model_path and os.path.exists(model_path):
827
- # if model_path.endswith(".ckpt"):
828
- # original_checkpoint = True
829
- # elif model_path.endswith(".json"):
830
- # model_name = os.path.dirname(model_path)
831
- # else:
832
- # model_name = model_path
833
- # vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse")
834
- # if device == "cuda" and not args.fp32:
835
- # vae.to(torch.float16)
836
- # if original_checkpoint:
837
- # print(f"Converting & Loading {model_path}")
838
- # from convert_checkpoint import convert_checkpoint
839
-
840
- # pipe = convert_checkpoint(model_path)
841
- # if device == "cuda" and not args.fp32:
842
- # pipe.to(torch.float16)
843
- # text2img = StableDiffusionPipeline(
844
- # vae=vae,
845
- # text_encoder=pipe.text_encoder,
846
- # tokenizer=pipe.tokenizer,
847
- # unet=pipe.unet,
848
- # scheduler=pipe.scheduler,
849
- # safety_checker=pipe.safety_checker,
850
- # feature_extractor=pipe.feature_extractor,
851
- # )
852
- # else:
853
- # print(f"Loading {model_name}")
854
- # if device == "cuda" and not args.fp32:
855
- # text2img = StableDiffusionPipeline.from_pretrained(
856
- # model_name,
857
- # revision="fp16",
858
- # torch_dtype=torch.float16,
859
- # use_auth_token=token,
860
- # vae=vae,
861
- # )
862
- # else:
863
- # text2img = StableDiffusionPipeline.from_pretrained(
864
- # model_name, use_auth_token=token, vae=vae
865
- # )
866
- # if inpainting_model:
867
- # # can reduce vRAM by reusing models except unet
868
- # text2img_unet = text2img.unet
869
- # del text2img.vae
870
- # del text2img.text_encoder
871
- # del text2img.tokenizer
872
- # del text2img.scheduler
873
- # del text2img.safety_checker
874
- # del text2img.feature_extractor
875
- # import gc
876
-
877
- # gc.collect()
878
- # if device == "cuda" and not args.fp32:
879
- # inpaint = StableDiffusionInpaintPipeline.from_pretrained(
880
- # "runwayml/stable-diffusion-inpainting",
881
- # revision="fp16",
882
- # torch_dtype=torch.float16,
883
- # use_auth_token=token,
884
- # vae=vae,
885
- # ).to(device)
886
- # else:
887
- # inpaint = StableDiffusionInpaintPipeline.from_pretrained(
888
- # "runwayml/stable-diffusion-inpainting",
889
- # use_auth_token=token,
890
- # vae=vae,
891
- # ).to(device)
892
- # text2img_unet.to(device)
893
- # text2img = StableDiffusionPipeline(
894
- # vae=inpaint.vae,
895
- # text_encoder=inpaint.text_encoder,
896
- # tokenizer=inpaint.tokenizer,
897
- # unet=text2img_unet,
898
- # scheduler=inpaint.scheduler,
899
- # safety_checker=inpaint.safety_checker,
900
- # feature_extractor=inpaint.feature_extractor,
901
- # )
902
- # else:
903
- # inpaint = StableDiffusionInpaintPipelineLegacy(
904
- # vae=text2img.vae,
905
- # text_encoder=text2img.text_encoder,
906
- # tokenizer=text2img.tokenizer,
907
- # unet=text2img.unet,
908
- # scheduler=text2img.scheduler,
909
- # safety_checker=text2img.safety_checker,
910
- # feature_extractor=text2img.feature_extractor,
911
- # ).to(device)
912
- # text_encoder = text2img.text_encoder
913
- # tokenizer = text2img.tokenizer
914
- # if os.path.exists("./embeddings"):
915
- # for item in os.listdir("./embeddings"):
916
- # if item.endswith(".bin"):
917
- # load_learned_embed_in_clip(
918
- # os.path.join("./embeddings", item),
919
- # text2img.text_encoder,
920
- # text2img.tokenizer,
921
- # )
922
- # text2img.to(device)
923
- # if device == "mps":
924
- # _ = text2img("", num_inference_steps=1)
925
- # img2img = StableDiffusionImg2ImgPipeline(
926
- # vae=text2img.vae,
927
- # text_encoder=text2img.text_encoder,
928
- # tokenizer=text2img.tokenizer,
929
- # unet=text2img.unet,
930
- # scheduler=text2img.scheduler,
931
- # safety_checker=text2img.safety_checker,
932
- # feature_extractor=text2img.feature_extractor,
933
- # ).to(device)
934
- # scheduler_dict["PLMS"] = text2img.scheduler
935
- # scheduler_dict["DDIM"] = prepare_scheduler(
936
- # DDIMScheduler(
937
- # beta_start=0.00085,
938
- # beta_end=0.012,
939
- # beta_schedule="scaled_linear",
940
- # clip_sample=False,
941
- # set_alpha_to_one=False,
942
- # )
943
- # )
944
- # scheduler_dict["K-LMS"] = prepare_scheduler(
945
- # LMSDiscreteScheduler(
946
- # beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
947
- # )
948
- # )
949
- # scheduler_dict["PNDM"] = prepare_scheduler(
950
- # PNDMScheduler(
951
- # beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear",
952
- # skip_prk_steps=True
953
- # )
954
- # )
955
- # scheduler_dict["DPM"] = prepare_scheduler(
956
- # DPMSolverMultistepScheduler.from_config(text2img.scheduler.config)
957
- # )
958
- # self.safety_checker = text2img.safety_checker
959
- # save_token(token)
960
- # try:
961
- # total_memory = torch.cuda.get_device_properties(0).total_memory // (
962
- # 1024 ** 3
963
- # )
964
- # if total_memory <= 5 or args.lowvram:
965
- # inpaint.enable_attention_slicing()
966
- # inpaint.enable_sequential_cpu_offload()
967
- # if inpainting_model:
968
- # text2img.enable_attention_slicing()
969
- # text2img.enable_sequential_cpu_offload()
970
- # except:
971
- # pass
972
- # self.text2img = text2img
973
- # self.inpaint = inpaint
974
- # self.img2img = img2img
975
- # if True:
976
- # self.unified = inpaint
977
- # else:
978
- # self.unified = UnifiedPipeline(
979
- # vae=text2img.vae,
980
- # text_encoder=text2img.text_encoder,
981
- # tokenizer=text2img.tokenizer,
982
- # unet=text2img.unet,
983
- # scheduler=text2img.scheduler,
984
- # safety_checker=text2img.safety_checker,
985
- # feature_extractor=text2img.feature_extractor,
986
- # ).to(device)
987
- # self.inpainting_model = inpainting_model
988
-
989
- # def run(
990
- # self,
991
- # image_pil,
992
- # prompt="",
993
- # negative_prompt="",
994
- # guidance_scale=7.5,
995
- # resize_check=True,
996
- # enable_safety=True,
997
- # fill_mode="patchmatch",
998
- # strength=0.75,
999
- # step=50,
1000
- # enable_img2img=False,
1001
- # use_seed=False,
1002
- # seed_val=-1,
1003
- # generate_num=1,
1004
- # scheduler="",
1005
- # scheduler_eta=0.0,
1006
- # **kwargs,
1007
- # ):
1008
- # text2img, inpaint, img2img, unified = (
1009
- # self.text2img,
1010
- # self.inpaint,
1011
- # self.img2img,
1012
- # self.unified,
1013
- # )
1014
- # selected_scheduler = scheduler_dict.get(scheduler, scheduler_dict["PLMS"])
1015
- # for item in [text2img, inpaint, img2img, unified]:
1016
- # item.scheduler = selected_scheduler
1017
- # if enable_safety or self.safety_checker is None:
1018
- # item.safety_checker = self.safety_checker
1019
- # else:
1020
- # item.safety_checker = lambda images, **kwargs: (images, False)
1021
- # if RUN_IN_SPACE:
1022
- # step = max(150, step)
1023
- # image_pil = contain_func(image_pil, (1024, 1024))
1024
- # width, height = image_pil.size
1025
- # sel_buffer = np.array(image_pil)
1026
- # img = sel_buffer[:, :, 0:3]
1027
- # mask = sel_buffer[:, :, -1]
1028
- # nmask = 255 - mask
1029
- # process_width = width
1030
- # process_height = height
1031
- # if resize_check:
1032
- # process_width, process_height = my_resize(width, height)
1033
- # extra_kwargs = {
1034
- # "num_inference_steps": step,
1035
- # "guidance_scale": guidance_scale,
1036
- # "eta": scheduler_eta,
1037
- # }
1038
- # if RUN_IN_SPACE:
1039
- # generate_num = max(
1040
- # int(4 * 512 * 512 // process_width // process_height), generate_num
1041
- # )
1042
- # if USE_NEW_DIFFUSERS:
1043
- # extra_kwargs["negative_prompt"] = negative_prompt
1044
- # extra_kwargs["num_images_per_prompt"] = generate_num
1045
- # if use_seed:
1046
- # generator = torch.Generator(text2img.device).manual_seed(seed_val)
1047
- # extra_kwargs["generator"] = generator
1048
- # if nmask.sum() < 1 and enable_img2img:
1049
- # init_image = Image.fromarray(img)
1050
- # if True:
1051
- # images = img2img(
1052
- # prompt=prompt,
1053
- # image=init_image.resize(
1054
- # (process_width, process_height), resample=SAMPLING_MODE
1055
- # ),
1056
- # strength=strength,
1057
- # **extra_kwargs,
1058
- # )["images"]
1059
- # elif mask.sum() > 0:
1060
- # if fill_mode == "g_diffuser" and not self.inpainting_model:
1061
- # mask = 255 - mask
1062
- # mask = mask[:, :, np.newaxis].repeat(3, axis=2)
1063
- # img, mask = functbl[fill_mode](img, mask)
1064
- # extra_kwargs["strength"] = 1.0
1065
- # extra_kwargs["out_mask"] = Image.fromarray(mask)
1066
- # inpaint_func = unified
1067
- # else:
1068
- # img, mask = functbl[fill_mode](img, mask)
1069
- # mask = 255 - mask
1070
- # mask = skimage.measure.block_reduce(mask, (8, 8), np.max)
1071
- # mask = mask.repeat(8, axis=0).repeat(8, axis=1)
1072
- # inpaint_func = inpaint
1073
- # init_image = Image.fromarray(img)
1074
- # mask_image = Image.fromarray(mask)
1075
- # # mask_image=mask_image.filter(ImageFilter.GaussianBlur(radius = 8))
1076
- # input_image = init_image.resize(
1077
- # (process_width, process_height), resample=SAMPLING_MODE
1078
- # )
1079
- # if self.inpainting_model:
1080
- # images = inpaint_func(
1081
- # prompt=prompt,
1082
- # image=input_image,
1083
- # width=process_width,
1084
- # height=process_height,
1085
- # mask_image=mask_image.resize((process_width, process_height)),
1086
- # **extra_kwargs,
1087
- # )["images"]
1088
- # else:
1089
- # extra_kwargs["strength"] = strength
1090
- # if True:
1091
- # images = inpaint_func(
1092
- # prompt=prompt,
1093
- # image=input_image,
1094
- # mask_image=mask_image.resize((process_width, process_height)),
1095
- # **extra_kwargs,
1096
- # )["images"]
1097
- # else:
1098
- # if True:
1099
- # images = text2img(
1100
- # prompt=prompt,
1101
- # height=process_width,
1102
- # width=process_height,
1103
- # **extra_kwargs,
1104
- # )["images"]
1105
- # return images
1106
-
1107
- # ORIGINAL
1108
- # def get_model(token="", model_choice="", model_path=""):
1109
- # if "model" not in model:
1110
- # model_name = ""
1111
- # if args.local_model:
1112
- # print(f"Using local_model: {args.local_model}")
1113
- # model_path = args.local_model
1114
- # elif args.remote_model:
1115
- # print(f"Using remote_model: {args.remote_model}")
1116
- # model_name = args.remote_model
1117
- # if model_choice == ModelChoice.INPAINTING.value:
1118
- # if len(model_name) < 1:
1119
- # model_name = "runwayml/stable-diffusion-inpainting"
1120
- # print(f"Using [{model_name}] {model_path}")
1121
- # tmp = StableDiffusionInpaint(
1122
- # token=token, model_name=model_name, model_path=model_path
1123
- # )
1124
- # elif model_choice == ModelChoice.INPAINTING2.value:
1125
- # if len(model_name) < 1:
1126
- # model_name = "stabilityai/stable-diffusion-2-inpainting"
1127
- # print(f"Using [{model_name}] {model_path}")
1128
- # tmp = StableDiffusionInpaint(
1129
- # token=token, model_name=model_name, model_path=model_path
1130
- # )
1131
- # elif model_choice == ModelChoice.INPAINTING_IMG2IMG.value:
1132
- # print(
1133
- # f"Note that {ModelChoice.INPAINTING_IMG2IMG.value} only support remote model and requires larger vRAM"
1134
- # )
1135
- # tmp = StableDiffusion(token=token, inpainting_model=True)
1136
- # else:
1137
- # if len(model_name) < 1:
1138
- # model_name = (
1139
- # "runwayml/stable-diffusion-v1-5"
1140
- # if model_choice == ModelChoice.MODEL_1_5.value
1141
- # else "CompVis/stable-diffusion-v1-4"
1142
- # )
1143
- # if model_choice == ModelChoice.MODEL_2_0.value:
1144
- # model_name = "stabilityai/stable-diffusion-2-base"
1145
- # elif model_choice == ModelChoice.MODEL_2_0_V.value:
1146
- # model_name = "stabilityai/stable-diffusion-2"
1147
- # elif model_choice == ModelChoice.MODEL_2_1.value:
1148
- # model_name = "stabilityai/stable-diffusion-2-1-base"
1149
- # tmp = StableDiffusion(
1150
- # token=token, model_name=model_name, model_path=model_path
1151
- # )
1152
- # model["model"] = tmp
1153
- # return model["model"]
1154
  def get_model(token="", model_choice="", model_path=""):
1155
  if "model" not in model:
1156
  model_name = ""
@@ -1179,6 +710,7 @@ def get_model(token="", model_choice="", model_path=""):
1179
  model["model"] = tmp
1180
  return model["model"]
1181
 
 
1182
  def run_outpaint(
1183
  sel_buffer_str,
1184
  prompt_text,
@@ -1200,25 +732,6 @@ def run_outpaint(
1200
  ):
1201
  data = base64.b64decode(str(sel_buffer_str))
1202
  pil = Image.open(io.BytesIO(data))
1203
- # if interrogate_mode:
1204
- # if "interrogator" not in model:
1205
- # model["interrogator"] = Interrogator()
1206
- # interrogator = model["interrogator"]
1207
- # # possible point to integrate
1208
- # img = np.array(pil)[:, :, 0:3]
1209
- # mask = np.array(pil)[:, :, -1]
1210
- # x, y = np.nonzero(mask)
1211
- # if len(x) > 0:
1212
- # x0, x1 = x.min(), x.max() + 1
1213
- # y0, y1 = y.min(), y.max() + 1
1214
- # img = img[x0:x1, y0:y1, :]
1215
- # pil = Image.fromarray(img)
1216
- # interrogate_ret = interrogator.interrogate(pil)
1217
- # return (
1218
- # gr.update(value=",".join([sel_buffer_str]),),
1219
- # gr.update(label="Prompt", value=interrogate_ret),
1220
- # state,
1221
- # )
1222
  width, height = pil.size
1223
  sel_buffer = np.array(pil)
1224
  cur_model = get_model()
@@ -1438,7 +951,7 @@ with blocks as demo:
1438
  placeholder="Ignore this if you are not using Docker",
1439
  elem_id="model_path_input",
1440
  )
1441
-
1442
  proceed_button = gr.Button("Proceed", elem_id="proceed", visible=DEBUG_MODE)
1443
  xss_js = load_js("xss").replace("\n", " ")
1444
  xss_html = gr.HTML(
@@ -1457,7 +970,6 @@ with blocks as demo:
1457
  sd_img2img = gr.Checkbox(label="Enable Img2Img", value=False, visible=False)
1458
  sd_resize = gr.Checkbox(label="Resize small input", value=True, visible=False)
1459
  safety_check = gr.Checkbox(label="Enable Safety Checker", value=True, visible=False)
1460
- interrogate_check = gr.Checkbox(label="Interrogate", value=False, visible=False)
1461
  upload_button = gr.Button(
1462
  "Before uploading the image you need to setup the canvas first", visible=False
1463
  )
@@ -1477,14 +989,6 @@ with blocks as demo:
1477
  except Exception as e:
1478
  print(e)
1479
  return {token: gr.update(value=str(e))}
1480
- if model_choice in [
1481
- ModelChoice.INPAINTING.value,
1482
- ModelChoice.INPAINTING_IMG2IMG.value,
1483
- ModelChoice.INPAINTING2.value,
1484
- ]:
1485
- init_val = "cv2_ns"
1486
- else:
1487
- init_val = "patchmatch"
1488
  return {
1489
  token: gr.update(visible=False),
1490
  canvas_width: gr.update(visible=False),
@@ -1495,7 +999,6 @@ with blocks as demo:
1495
  upload_button: gr.update(value="Upload Image"),
1496
  model_selection: gr.update(visible=False),
1497
  model_path_input: gr.update(visible=False),
1498
- init_mode: gr.update(value=init_val),
1499
  }
1500
 
1501
  setup_button.click(
@@ -1518,7 +1021,6 @@ with blocks as demo:
1518
  upload_button,
1519
  model_selection,
1520
  model_path_input,
1521
- init_mode,
1522
  ],
1523
  _js=setup_button_js,
1524
  )
@@ -1548,8 +1050,7 @@ with blocks as demo:
1548
  _js=proceed_button_js,
1549
  )
1550
  # cancel button can also remove error overlay
1551
- if tuple(map(int,gr.__version__.split("."))) >= (3,6):
1552
- cancel_button.click(fn=None, inputs=None, outputs=None, cancels=[proceed_event])
1553
 
1554
 
1555
  launch_extra_kwargs = {
@@ -1561,7 +1062,6 @@ launch_kwargs = {k: v for k, v in launch_kwargs.items() if v is not None}
1561
  launch_kwargs.pop("remote_model", None)
1562
  launch_kwargs.pop("local_model", None)
1563
  launch_kwargs.pop("fp32", None)
1564
- launch_kwargs.pop("lowvram", None)
1565
  launch_kwargs.update(launch_extra_kwargs)
1566
  try:
1567
  import google.colab
@@ -1575,8 +1075,5 @@ if RUN_IN_SPACE:
1575
  elif args.debug:
1576
  launch_kwargs["server_name"] = "0.0.0.0"
1577
  demo.queue().launch(**launch_kwargs)
1578
- # demo.queue().launch(share=True)
1579
-
1580
  else:
1581
- demo.queue().launch(**launch_kwargs)
1582
- # demo.queue().launch(share=True)
 
1
  import subprocess
2
+ # import os.path as osp
3
  import pip
4
+ # pip.main(["install","-v","-U","git+https://github.com/facebookresearch/xformers.git@main#egg=xformers"])
5
+ # subprocess.check_call("pip install -v -U git+https://github.com/facebookresearch/xformers.git@main#egg=xformers", cwd=osp.dirname(__file__), shell=True)
6
 
7
  import io
8
  import base64
 
13
  import torch
14
  from torch import autocast
15
  import diffusers
 
 
 
 
 
16
  from diffusers.configuration_utils import FrozenDict
17
  from diffusers import (
18
  StableDiffusionPipeline,
 
21
  StableDiffusionInpaintPipelineLegacy,
22
  DDIMScheduler,
23
  LMSDiscreteScheduler,
 
24
  StableDiffusionUpscalePipeline,
25
+ DPMSolverMultistepScheduler
 
26
  )
27
  from diffusers.models import AutoencoderKL
28
  from PIL import Image
 
34
  import yaml
35
  import json
36
  from enum import Enum
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  try:
39
  abspath = os.path.abspath(__file__)
 
42
  except:
43
  pass
44
 
45
+ from utils import *
46
+
47
+ assert diffusers.__version__ >= "0.6.0", "Please upgrade diffusers to 0.6.0"
48
 
49
  USE_NEW_DIFFUSERS = True
50
  RUN_IN_SPACE = "RUN_IN_HG_SPACE" in os.environ
 
52
 
53
  class ModelChoice(Enum):
54
  INPAINTING = "stablediffusion-inpainting"
55
+ INPAINTING_IMG2IMG = "stablediffusion-inpainting+img2img-v1.5"
56
+ MODEL_1_5 = "stablediffusion-v1.5"
57
+ MODEL_1_4 = "stablediffusion-v1.4"
 
 
 
 
58
 
59
 
60
  try:
 
70
  # except:
71
  # USE_GLID = False
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  try:
74
  cuda_available = torch.cuda.is_available()
75
  except:
 
91
  config_json = json.dumps(yaml_object)
92
 
93
 
 
 
94
  def load_html():
95
  body, canvaspy = "", ""
96
  with open("index.html", encoding="utf8") as f:
 
105
 
106
  def test(x):
107
  x = load_html()
108
+ return f"""<iframe id="sdinfframe" style="width: 100%; height: 600px" name="result" allow="midi; geolocation; microphone; camera;
109
  display-capture; encrypted-media; vertical-scroll 'none'" sandbox="allow-modals allow-forms
110
  allow-scripts allow-same-origin allow-popups
111
  allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
 
147
  parser.add_argument("--share", action="store_true", help="share this app?")
148
  parser.add_argument("--debug", action="store_true", help="debug mode")
149
  parser.add_argument("--fp32", action="store_true", help="using full precision")
 
150
  parser.add_argument("--encrypt", action="store_true", help="using https?")
151
  parser.add_argument("--ssl_keyfile", type=str, help="path to ssl_keyfile")
152
  parser.add_argument("--ssl_certfile", type=str, help="path to ssl_certfile")
 
164
  "--local_model", type=str, help="use a model stored on your PC", default=""
165
  )
166
 
 
 
 
 
 
 
 
 
 
167
  if __name__ == "__main__" and not RUN_IN_SPACE:
168
  args = parser.parse_args()
169
  else:
 
174
 
175
  model = {}
176
 
 
 
 
 
 
 
 
 
 
177
 
178
  def get_token():
179
  token = ""
 
217
  factor = 1.25
218
  elif smaller < 450:
219
  factor = 1.125
220
+ return int(factor * width)//8*8, int(factor * height)//8*8
221
 
222
 
223
  def load_learned_embed_in_clip(
 
250
  text_encoder.get_input_embeddings().weight.data[token_id] = embeds
251
 
252
 
253
+ scheduler_dict = {"PLMS": None, "DDIM": None, "K-LMS": None, "DPM": None}
254
 
255
 
256
  class StableDiffusionInpaint:
 
259
  ):
260
  self.token = token
261
  original_checkpoint = False
 
 
 
 
 
 
 
 
262
  if model_path and os.path.exists(model_path):
263
  if model_path.endswith(".ckpt"):
264
  original_checkpoint = True
 
267
  else:
268
  model_name = model_path
269
  vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse")
 
 
270
  vae.to(torch.float16)
271
  if original_checkpoint:
272
  print(f"Converting & Loading {model_path}")
 
292
  revision="fp16",
293
  torch_dtype=torch.float16,
294
  use_auth_token=token,
295
+ vae=vae
296
  )
297
  else:
298
  inpaint = StableDiffusionInpaintPipeline.from_pretrained(
299
+ model_name, use_auth_token=token,
300
  )
 
301
  if os.path.exists("./embeddings"):
302
  print("Note that StableDiffusionInpaintPipeline + embeddings is untested")
303
  for item in os.listdir("./embeddings"):
 
307
  inpaint.text_encoder,
308
  inpaint.tokenizer,
309
  )
310
+ inpaint.to(device)
311
+ # try:
312
+ # inpaint.vae=torch.compile(inpaint.vae, dynamic=True)
313
+ # inpaint.unet=torch.compile(inpaint.unet, dynamic=True)
314
+ # except Exception as e:
315
+ # print(e)
316
+ # inpaint.enable_xformers_memory_efficient_attention()
317
  # if device == "mps":
318
  # _ = text2img("", num_inference_steps=1)
319
  scheduler_dict["PLMS"] = inpaint.scheduler
 
331
  beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
332
  )
333
  )
 
 
 
 
 
 
334
  scheduler_dict["DPM"] = prepare_scheduler(
335
  DPMSolverMultistepScheduler.from_config(inpaint.scheduler.config)
336
  )
 
340
  total_memory = torch.cuda.get_device_properties(0).total_memory // (
341
  1024 ** 3
342
  )
343
+ if total_memory <= 5:
344
  inpaint.enable_attention_slicing()
 
345
  except:
346
  pass
347
  self.inpaint = inpaint
 
373
  item.safety_checker = self.safety_checker
374
  else:
375
  item.safety_checker = lambda images, **kwargs: (images, False)
 
 
 
 
 
 
 
376
  width, height = image_pil.size
377
  sel_buffer = np.array(image_pil)
378
  img = sel_buffer[:, :, 0:3]
 
382
  process_height = height
383
  if resize_check:
384
  process_width, process_height = my_resize(width, height)
385
+ process_width=process_width*8//8
386
+ process_height=process_height*8//8
387
  extra_kwargs = {
388
  "num_inference_steps": step,
389
  "guidance_scale": guidance_scale,
 
396
  generator = torch.Generator(inpaint.device).manual_seed(seed_val)
397
  extra_kwargs["generator"] = generator
398
  if True:
399
+ img, mask = functbl[fill_mode](img, mask)
400
+ mask = 255 - mask
401
+ mask = skimage.measure.block_reduce(mask, (8, 8), np.max)
402
+ mask = mask.repeat(8, axis=0).repeat(8, axis=1)
403
+ extra_kwargs["strength"] = strength
 
 
 
 
 
404
  inpaint_func = inpaint
405
  init_image = Image.fromarray(img)
406
  mask_image = Image.fromarray(mask)
407
  # mask_image=mask_image.filter(ImageFilter.GaussianBlur(radius = 8))
 
 
 
 
408
  if True:
409
  images = inpaint_func(
410
  prompt=prompt,
 
418
  )["images"]
419
  return images
420
 
421
+
422
  class StableDiffusion:
423
  def __init__(
424
  self,
 
682
  return images
683
 
684
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
685
  def get_model(token="", model_choice="", model_path=""):
686
  if "model" not in model:
687
  model_name = ""
 
710
  model["model"] = tmp
711
  return model["model"]
712
 
713
+
714
  def run_outpaint(
715
  sel_buffer_str,
716
  prompt_text,
 
732
  ):
733
  data = base64.b64decode(str(sel_buffer_str))
734
  pil = Image.open(io.BytesIO(data))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
735
  width, height = pil.size
736
  sel_buffer = np.array(pil)
737
  cur_model = get_model()
 
951
  placeholder="Ignore this if you are not using Docker",
952
  elem_id="model_path_input",
953
  )
954
+
955
  proceed_button = gr.Button("Proceed", elem_id="proceed", visible=DEBUG_MODE)
956
  xss_js = load_js("xss").replace("\n", " ")
957
  xss_html = gr.HTML(
 
970
  sd_img2img = gr.Checkbox(label="Enable Img2Img", value=False, visible=False)
971
  sd_resize = gr.Checkbox(label="Resize small input", value=True, visible=False)
972
  safety_check = gr.Checkbox(label="Enable Safety Checker", value=True, visible=False)
 
973
  upload_button = gr.Button(
974
  "Before uploading the image you need to setup the canvas first", visible=False
975
  )
 
989
  except Exception as e:
990
  print(e)
991
  return {token: gr.update(value=str(e))}
 
 
 
 
 
 
 
 
992
  return {
993
  token: gr.update(visible=False),
994
  canvas_width: gr.update(visible=False),
 
999
  upload_button: gr.update(value="Upload Image"),
1000
  model_selection: gr.update(visible=False),
1001
  model_path_input: gr.update(visible=False),
 
1002
  }
1003
 
1004
  setup_button.click(
 
1021
  upload_button,
1022
  model_selection,
1023
  model_path_input,
 
1024
  ],
1025
  _js=setup_button_js,
1026
  )
 
1050
  _js=proceed_button_js,
1051
  )
1052
  # cancel button can also remove error overlay
1053
+ # cancel_button.click(fn=None, inputs=None, outputs=None, cancels=[proceed_event])
 
1054
 
1055
 
1056
  launch_extra_kwargs = {
 
1062
  launch_kwargs.pop("remote_model", None)
1063
  launch_kwargs.pop("local_model", None)
1064
  launch_kwargs.pop("fp32", None)
 
1065
  launch_kwargs.update(launch_extra_kwargs)
1066
  try:
1067
  import google.colab
 
1075
  elif args.debug:
1076
  launch_kwargs["server_name"] = "0.0.0.0"
1077
  demo.queue().launch(**launch_kwargs)
 
 
1078
  else:
1079
+ demo.queue().launch(**launch_kwargs)
 
appHold.py ADDED
@@ -0,0 +1,1582 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import subprocess
2
+ import pip
3
+
4
+ import io
5
+ import base64
6
+ import os
7
+ import sys
8
+
9
+ import numpy as np
10
+ import torch
11
+ from torch import autocast
12
+ import diffusers
13
+ import requests
14
+
15
+
16
+ # assert tuple(map(int,diffusers.__version__.split("."))) >= (0,9,0), "Please upgrade diffusers to 0.9.0"
17
+
18
+ from diffusers.configuration_utils import FrozenDict
19
+ from diffusers import (
20
+ StableDiffusionPipeline,
21
+ StableDiffusionInpaintPipeline,
22
+ StableDiffusionImg2ImgPipeline,
23
+ StableDiffusionInpaintPipelineLegacy,
24
+ DDIMScheduler,
25
+ LMSDiscreteScheduler,
26
+ DiffusionPipeline,
27
+ StableDiffusionUpscalePipeline,
28
+ DPMSolverMultistepScheduler,
29
+ PNDMScheduler,
30
+ )
31
+ from diffusers.models import AutoencoderKL
32
+ from PIL import Image
33
+ from PIL import ImageOps
34
+ import gradio as gr
35
+ import base64
36
+ import skimage
37
+ import skimage.measure
38
+ import yaml
39
+ import json
40
+ from enum import Enum
41
+ from utils import *
42
+
43
+ # load environment variables from the .env file
44
+ # if os.path.exists(".env"):
45
+ # with open(".env") as f:
46
+ # for line in f:
47
+ # if line.startswith("#") or not line.strip():
48
+ # continue
49
+ # name, value = line.strip().split("=", 1)
50
+ # os.environ[name] = value
51
+
52
+
53
+ # access_token = os.environ.get("HF_ACCESS_TOKEN")
54
+ # print("access_token from HF 1:", access_token)
55
+
56
+ try:
57
+ abspath = os.path.abspath(__file__)
58
+ dirname = os.path.dirname(abspath)
59
+ os.chdir(dirname)
60
+ except:
61
+ pass
62
+
63
+
64
+ USE_NEW_DIFFUSERS = True
65
+ RUN_IN_SPACE = "RUN_IN_HG_SPACE" in os.environ
66
+
67
+
68
+ class ModelChoice(Enum):
69
+ INPAINTING = "stablediffusion-inpainting"
70
+ INPAINTING2 = "stablediffusion-2-inpainting"
71
+ INPAINTING_IMG2IMG = "stablediffusion-inpainting+img2img-1.5"
72
+ MODEL_2_1 = "stablediffusion-2.1"
73
+ MODEL_2_0_V = "stablediffusion-2.0v"
74
+ MODEL_2_0 = "stablediffusion-2.0"
75
+ MODEL_1_5 = "stablediffusion-1.5"
76
+ MODEL_1_4 = "stablediffusion-1.4"
77
+
78
+
79
+ try:
80
+ from sd_grpcserver.pipeline.unified_pipeline import UnifiedPipeline
81
+ except:
82
+ UnifiedPipeline = StableDiffusionInpaintPipeline
83
+
84
+ # sys.path.append("./glid_3_xl_stable")
85
+
86
+ USE_GLID = False
87
+ # try:
88
+ # from glid3xlmodel import GlidModel
89
+ # except:
90
+ # USE_GLID = False
91
+
92
+ # ******** ORIGINAL ***********
93
+ # try:
94
+ # import onnxruntime
95
+ # onnx_available = True
96
+ # onnx_providers = ["CUDAExecutionProvider", "DmlExecutionProvider", "OpenVINOExecutionProvider", 'CPUExecutionProvider']
97
+ # available_providers = onnxruntime.get_available_providers()
98
+ # onnx_providers = [item for item in onnx_providers if item in available_providers]
99
+ # except:
100
+ # onnx_available = False
101
+ # onnx_providers = []
102
+
103
+
104
+ # try:
105
+ # cuda_available = torch.cuda.is_available()
106
+ # except:
107
+ # cuda_available = False
108
+ # finally:
109
+ # if sys.platform == "darwin":
110
+ # device = "mps" if torch.backends.mps.is_available() else "cpu"
111
+ # elif cuda_available:
112
+ # device = "cuda"
113
+ # else:
114
+ # device = "cpu"
115
+
116
+ # if device != "cuda":
117
+ # import contextlib
118
+
119
+ # autocast = contextlib.nullcontext
120
+
121
+ # with open("config.yaml", "r") as yaml_in:
122
+ # yaml_object = yaml.safe_load(yaml_in)
123
+ # config_json = json.dumps(yaml_object)
124
+
125
+ # ******** ^ ORIGINAL ^ ***********
126
+
127
+ try:
128
+ cuda_available = torch.cuda.is_available()
129
+ except:
130
+ cuda_available = False
131
+ finally:
132
+ if sys.platform == "darwin":
133
+ device = "mps" if torch.backends.mps.is_available() else "cpu"
134
+ elif cuda_available:
135
+ device = "cuda"
136
+ else:
137
+ device = "cpu"
138
+
139
+ import contextlib
140
+
141
+ autocast = contextlib.nullcontext
142
+
143
+ with open("config.yaml", "r") as yaml_in:
144
+ yaml_object = yaml.safe_load(yaml_in)
145
+ config_json = json.dumps(yaml_object)
146
+
147
+
148
+ # new ^
149
+
150
+ def load_html():
151
+ body, canvaspy = "", ""
152
+ with open("index.html", encoding="utf8") as f:
153
+ body = f.read()
154
+ with open("canvas.py", encoding="utf8") as f:
155
+ canvaspy = f.read()
156
+ body = body.replace("- paths:\n", "")
157
+ body = body.replace(" - ./canvas.py\n", "")
158
+ body = body.replace("from canvas import InfCanvas", canvaspy)
159
+ return body
160
+
161
+
162
+ def test(x):
163
+ x = load_html()
164
+ return f"""<iframe id="sdinfframe" style="width: 100%; height: 780px" name="result" allow="midi; geolocation; microphone; camera;
165
+ display-capture; encrypted-media; vertical-scroll 'none'" sandbox="allow-modals allow-forms
166
+ allow-scripts allow-same-origin allow-popups
167
+ allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
168
+ allowpaymentrequest="" frameborder="0" srcdoc='{x}'></iframe>"""
169
+
170
+
171
+ DEBUG_MODE = False
172
+
173
+ try:
174
+ SAMPLING_MODE = Image.Resampling.LANCZOS
175
+ except Exception as e:
176
+ SAMPLING_MODE = Image.LANCZOS
177
+
178
+ try:
179
+ contain_func = ImageOps.contain
180
+ except Exception as e:
181
+
182
+ def contain_func(image, size, method=SAMPLING_MODE):
183
+ # from PIL: https://pillow.readthedocs.io/en/stable/reference/ImageOps.html#PIL.ImageOps.contain
184
+ im_ratio = image.width / image.height
185
+ dest_ratio = size[0] / size[1]
186
+ if im_ratio != dest_ratio:
187
+ if im_ratio > dest_ratio:
188
+ new_height = int(image.height / image.width * size[0])
189
+ if new_height != size[1]:
190
+ size = (size[0], new_height)
191
+ else:
192
+ new_width = int(image.width / image.height * size[1])
193
+ if new_width != size[0]:
194
+ size = (new_width, size[1])
195
+ return image.resize(size, resample=method)
196
+
197
+
198
+ import argparse
199
+
200
+ parser = argparse.ArgumentParser(description="stablediffusion-infinity")
201
+ parser.add_argument("--port", type=int, help="listen port", dest="server_port")
202
+ parser.add_argument("--host", type=str, help="host", dest="server_name")
203
+ parser.add_argument("--share", action="store_true", help="share this app?")
204
+ parser.add_argument("--debug", action="store_true", help="debug mode")
205
+ parser.add_argument("--fp32", action="store_true", help="using full precision")
206
+ parser.add_argument("--lowvram", action="store_true", help="using lowvram mode")
207
+ parser.add_argument("--encrypt", action="store_true", help="using https?")
208
+ parser.add_argument("--ssl_keyfile", type=str, help="path to ssl_keyfile")
209
+ parser.add_argument("--ssl_certfile", type=str, help="path to ssl_certfile")
210
+ parser.add_argument("--ssl_keyfile_password", type=str, help="ssl_keyfile_password")
211
+ parser.add_argument(
212
+ "--auth", nargs=2, metavar=("username", "password"), help="use username password"
213
+ )
214
+ parser.add_argument(
215
+ "--remote_model",
216
+ type=str,
217
+ help="use a model (e.g. dreambooth fined) from huggingface hub",
218
+ default="",
219
+ )
220
+ parser.add_argument(
221
+ "--local_model", type=str, help="use a model stored on your PC", default=""
222
+ )
223
+
224
+ # original
225
+ # if __name__ == "__main__":
226
+ # args = parser.parse_args()
227
+ # else:
228
+ # args = parser.parse_args(["--debug"])
229
+ # # args = parser.parse_args(["--debug"])
230
+ # if args.auth is not None:
231
+ # args.auth = tuple(args.auth)
232
+
233
+ if __name__ == "__main__" and not RUN_IN_SPACE:
234
+ args = parser.parse_args()
235
+ else:
236
+ args = parser.parse_args()
237
+ # args = parser.parse_args(["--debug"])
238
+ if args.auth is not None:
239
+ args.auth = tuple(args.auth)
240
+
241
+ model = {}
242
+
243
+ # HF function for token
244
+ # def get_token():
245
+ # token = "{access_token}"
246
+ # if os.path.exists(".token"):
247
+ # with open(".token", "r") as f:
248
+ # token = f.read()
249
+ # print("get_token called", token)
250
+ # token = os.environ.get("hftoken", token)
251
+ # return token
252
+
253
+ def get_token():
254
+ token = ""
255
+ if os.path.exists(".token"):
256
+ with open(".token", "r") as f:
257
+ token = f.read()
258
+ token = os.environ.get("hftoken", token)
259
+ return token
260
+
261
+
262
+ def save_token(token):
263
+ with open(".token", "w") as f:
264
+ f.write(token)
265
+
266
+
267
+ def prepare_scheduler(scheduler):
268
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
269
+ new_config = dict(scheduler.config)
270
+ new_config["steps_offset"] = 1
271
+ scheduler._internal_dict = FrozenDict(new_config)
272
+ return scheduler
273
+
274
+
275
+ def my_resize(width, height):
276
+ if width >= 512 and height >= 512:
277
+ return width, height
278
+ if width == height:
279
+ return 512, 512
280
+ smaller = min(width, height)
281
+ larger = max(width, height)
282
+ if larger >= 608:
283
+ return width, height
284
+ factor = 1
285
+ if smaller < 290:
286
+ factor = 2
287
+ elif smaller < 330:
288
+ factor = 1.75
289
+ elif smaller < 384:
290
+ factor = 1.375
291
+ elif smaller < 400:
292
+ factor = 1.25
293
+ elif smaller < 450:
294
+ factor = 1.125
295
+ return int(factor * width) // 8 * 8, int(factor * height) // 8 * 8
296
+
297
+
298
+ def load_learned_embed_in_clip(
299
+ learned_embeds_path, text_encoder, tokenizer, token=None
300
+ ):
301
+ # https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_conceptualizer_inference.ipynb
302
+ loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu")
303
+
304
+ # separate token and the embeds
305
+ trained_token = list(loaded_learned_embeds.keys())[0]
306
+ embeds = loaded_learned_embeds[trained_token]
307
+
308
+ # cast to dtype of text_encoder
309
+ dtype = text_encoder.get_input_embeddings().weight.dtype
310
+ embeds.to(dtype)
311
+
312
+ # add the token in tokenizer
313
+ token = token if token is not None else trained_token
314
+ num_added_tokens = tokenizer.add_tokens(token)
315
+ if num_added_tokens == 0:
316
+ raise ValueError(
317
+ f"The tokenizer already contains the token {token}. Please pass a different `token` that is not already in the tokenizer."
318
+ )
319
+
320
+ # resize the token embeddings
321
+ text_encoder.resize_token_embeddings(len(tokenizer))
322
+
323
+ # get the id for the token and assign the embeds
324
+ token_id = tokenizer.convert_tokens_to_ids(token)
325
+ text_encoder.get_input_embeddings().weight.data[token_id] = embeds
326
+
327
+
328
+ scheduler_dict = {"PLMS": None, "DDIM": None, "K-LMS": None, "DPM": None, "PNDM": None}
329
+
330
+
331
+ class StableDiffusionInpaint:
332
+ def __init__(
333
+ self, token: str = "", model_name: str = "", model_path: str = "", **kwargs,
334
+ ):
335
+ self.token = token
336
+ original_checkpoint = False
337
+ # if device == "cpu" and onnx_available:
338
+ # from diffusers import OnnxStableDiffusionInpaintPipeline
339
+ # inpaint = OnnxStableDiffusionInpaintPipeline.from_pretrained(
340
+ # model_name,
341
+ # revision="onnx",
342
+ # provider=onnx_providers[0] if onnx_providers else None
343
+ # )
344
+ # else:
345
+ if model_path and os.path.exists(model_path):
346
+ if model_path.endswith(".ckpt"):
347
+ original_checkpoint = True
348
+ elif model_path.endswith(".json"):
349
+ model_name = os.path.dirname(model_path)
350
+ else:
351
+ model_name = model_path
352
+ vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse")
353
+ # if device == "cuda" and not args.fp32:
354
+ # vae.to(torch.float16)
355
+ vae.to(torch.float16)
356
+ if original_checkpoint:
357
+ print(f"Converting & Loading {model_path}")
358
+ from convert_checkpoint import convert_checkpoint
359
+
360
+ pipe = convert_checkpoint(model_path, inpainting=True)
361
+ if device == "cuda":
362
+ pipe.to(torch.float16)
363
+ inpaint = StableDiffusionInpaintPipeline(
364
+ vae=vae,
365
+ text_encoder=pipe.text_encoder,
366
+ tokenizer=pipe.tokenizer,
367
+ unet=pipe.unet,
368
+ scheduler=pipe.scheduler,
369
+ safety_checker=pipe.safety_checker,
370
+ feature_extractor=pipe.feature_extractor,
371
+ )
372
+ else:
373
+ print(f"Loading {model_name}")
374
+ if device == "cuda":
375
+ inpaint = StableDiffusionInpaintPipeline.from_pretrained(
376
+ model_name,
377
+ revision="fp16",
378
+ torch_dtype=torch.float16,
379
+ use_auth_token=token,
380
+ vae=vae,
381
+ )
382
+ else:
383
+ inpaint = StableDiffusionInpaintPipeline.from_pretrained(
384
+ model_name, use_auth_token=token, vae=vae
385
+ )
386
+ # print(f"access_token from HF:", access_token)
387
+ if os.path.exists("./embeddings"):
388
+ print("Note that StableDiffusionInpaintPipeline + embeddings is untested")
389
+ for item in os.listdir("./embeddings"):
390
+ if item.endswith(".bin"):
391
+ load_learned_embed_in_clip(
392
+ os.path.join("./embeddings", item),
393
+ inpaint.text_encoder,
394
+ inpaint.tokenizer,
395
+ )
396
+ inpaint.to(device)
397
+ # if device == "mps":
398
+ # _ = text2img("", num_inference_steps=1)
399
+ scheduler_dict["PLMS"] = inpaint.scheduler
400
+ scheduler_dict["DDIM"] = prepare_scheduler(
401
+ DDIMScheduler(
402
+ beta_start=0.00085,
403
+ beta_end=0.012,
404
+ beta_schedule="scaled_linear",
405
+ clip_sample=False,
406
+ set_alpha_to_one=False,
407
+ )
408
+ )
409
+ scheduler_dict["K-LMS"] = prepare_scheduler(
410
+ LMSDiscreteScheduler(
411
+ beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
412
+ )
413
+ )
414
+ scheduler_dict["PNDM"] = prepare_scheduler(
415
+ PNDMScheduler(
416
+ beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear",
417
+ skip_prk_steps=True
418
+ )
419
+ )
420
+ scheduler_dict["DPM"] = prepare_scheduler(
421
+ DPMSolverMultistepScheduler.from_config(inpaint.scheduler.config)
422
+ )
423
+ self.safety_checker = inpaint.safety_checker
424
+ save_token(token)
425
+ try:
426
+ total_memory = torch.cuda.get_device_properties(0).total_memory // (
427
+ 1024 ** 3
428
+ )
429
+ if total_memory <= 5 or args.lowvram:
430
+ inpaint.enable_attention_slicing()
431
+ inpaint.enable_sequential_cpu_offload()
432
+ except:
433
+ pass
434
+ self.inpaint = inpaint
435
+
436
+ def run(
437
+ self,
438
+ image_pil,
439
+ prompt="",
440
+ negative_prompt="",
441
+ guidance_scale=7.5,
442
+ resize_check=True,
443
+ enable_safety=True,
444
+ fill_mode="patchmatch",
445
+ strength=0.75,
446
+ step=50,
447
+ enable_img2img=False,
448
+ use_seed=False,
449
+ seed_val=-1,
450
+ generate_num=1,
451
+ scheduler="",
452
+ scheduler_eta=0.0,
453
+ **kwargs,
454
+ ):
455
+ inpaint = self.inpaint
456
+ selected_scheduler = scheduler_dict.get(scheduler, scheduler_dict["PLMS"])
457
+ for item in [inpaint]:
458
+ item.scheduler = selected_scheduler
459
+ if enable_safety:
460
+ item.safety_checker = self.safety_checker
461
+ else:
462
+ item.safety_checker = lambda images, **kwargs: (images, False)
463
+
464
+ # for item in [inpaint]:
465
+ # item.scheduler = selected_scheduler
466
+ # if enable_safety or self.safety_checker is None:
467
+ # item.safety_checker = self.safety_checker
468
+ # else:
469
+ # item.safety_checker = lambda images, **kwargs: (images, False)
470
+ width, height = image_pil.size
471
+ sel_buffer = np.array(image_pil)
472
+ img = sel_buffer[:, :, 0:3]
473
+ mask = sel_buffer[:, :, -1]
474
+ nmask = 255 - mask
475
+ process_width = width
476
+ process_height = height
477
+ if resize_check:
478
+ process_width, process_height = my_resize(width, height)
479
+ process_width = process_width * 8 // 8
480
+ process_height = process_height * 8 // 8
481
+ extra_kwargs = {
482
+ "num_inference_steps": step,
483
+ "guidance_scale": guidance_scale,
484
+ "eta": scheduler_eta,
485
+ }
486
+ if USE_NEW_DIFFUSERS:
487
+ extra_kwargs["negative_prompt"] = negative_prompt
488
+ extra_kwargs["num_images_per_prompt"] = generate_num
489
+ if use_seed:
490
+ generator = torch.Generator(inpaint.device).manual_seed(seed_val)
491
+ extra_kwargs["generator"] = generator
492
+ if True:
493
+ if fill_mode == "g_diffuser":
494
+ mask = 255 - mask
495
+ mask = mask[:, :, np.newaxis].repeat(3, axis=2)
496
+ img, mask = functbl[fill_mode](img, mask)
497
+ else:
498
+ img, mask = functbl[fill_mode](img, mask)
499
+ mask = 255 - mask
500
+ mask = skimage.measure.block_reduce(mask, (8, 8), np.max)
501
+ mask = mask.repeat(8, axis=0).repeat(8, axis=1)
502
+ # extra_kwargs["strength"] = strength
503
+ inpaint_func = inpaint
504
+ init_image = Image.fromarray(img)
505
+ mask_image = Image.fromarray(mask)
506
+ # mask_image=mask_image.filter(ImageFilter.GaussianBlur(radius = 8))
507
+
508
+ # Cast input image and mask to float32
509
+ # init_image = init_image.convert("RGB").to(torch.float32)
510
+ # mask_image = mask_image.convert("L").to(torch.float32)
511
+ if True:
512
+ images = inpaint_func(
513
+ prompt=prompt,
514
+ image=init_image.resize(
515
+ (process_width, process_height), resample=SAMPLING_MODE
516
+ ),
517
+ mask_image=mask_image.resize((process_width, process_height)),
518
+ width=process_width,
519
+ height=process_height,
520
+ **extra_kwargs,
521
+ )["images"]
522
+ return images
523
+
524
+ class StableDiffusion:
525
+ def __init__(
526
+ self,
527
+ token: str = "",
528
+ model_name: str = "runwayml/stable-diffusion-v1-5",
529
+ model_path: str = None,
530
+ inpainting_model: bool = False,
531
+ **kwargs,
532
+ ):
533
+ self.token = token
534
+ original_checkpoint = False
535
+ vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse")
536
+ vae.to(torch.float16)
537
+ if model_path and os.path.exists(model_path):
538
+ if model_path.endswith(".ckpt"):
539
+ original_checkpoint = True
540
+ elif model_path.endswith(".json"):
541
+ model_name = os.path.dirname(model_path)
542
+ else:
543
+ model_name = model_path
544
+ if original_checkpoint:
545
+ print(f"Converting & Loading {model_path}")
546
+ from convert_checkpoint import convert_checkpoint
547
+
548
+ text2img = convert_checkpoint(model_path)
549
+ if device == "cuda" and not args.fp32:
550
+ text2img.to(torch.float16)
551
+ else:
552
+ print(f"Loading {model_name}")
553
+ if device == "cuda" and not args.fp32:
554
+ text2img = StableDiffusionPipeline.from_pretrained(
555
+ "runwayml/stable-diffusion-v1-5",
556
+ revision="fp16",
557
+ torch_dtype=torch.float16,
558
+ use_auth_token=token,
559
+ vae=vae
560
+ )
561
+ else:
562
+ text2img = StableDiffusionPipeline.from_pretrained(
563
+ model_name, use_auth_token=token,
564
+ )
565
+ if inpainting_model:
566
+ # can reduce vRAM by reusing models except unet
567
+ text2img_unet = text2img.unet
568
+ del text2img.vae
569
+ del text2img.text_encoder
570
+ del text2img.tokenizer
571
+ del text2img.scheduler
572
+ del text2img.safety_checker
573
+ del text2img.feature_extractor
574
+ import gc
575
+
576
+ gc.collect()
577
+ if device == "cuda":
578
+ inpaint = StableDiffusionInpaintPipeline.from_pretrained(
579
+ "runwayml/stable-diffusion-inpainting",
580
+ revision="fp16",
581
+ torch_dtype=torch.float16,
582
+ use_auth_token=token,
583
+ vae=vae
584
+ ).to(device)
585
+ else:
586
+ inpaint = StableDiffusionInpaintPipeline.from_pretrained(
587
+ "runwayml/stable-diffusion-inpainting", use_auth_token=token,
588
+ ).to(device)
589
+ text2img_unet.to(device)
590
+ del text2img
591
+ gc.collect()
592
+ text2img = StableDiffusionPipeline(
593
+ vae=inpaint.vae,
594
+ text_encoder=inpaint.text_encoder,
595
+ tokenizer=inpaint.tokenizer,
596
+ unet=text2img_unet,
597
+ scheduler=inpaint.scheduler,
598
+ safety_checker=inpaint.safety_checker,
599
+ feature_extractor=inpaint.feature_extractor,
600
+ )
601
+ else:
602
+ inpaint = StableDiffusionInpaintPipelineLegacy(
603
+ vae=text2img.vae,
604
+ text_encoder=text2img.text_encoder,
605
+ tokenizer=text2img.tokenizer,
606
+ unet=text2img.unet,
607
+ scheduler=text2img.scheduler,
608
+ safety_checker=text2img.safety_checker,
609
+ feature_extractor=text2img.feature_extractor,
610
+ ).to(device)
611
+ text_encoder = text2img.text_encoder
612
+ tokenizer = text2img.tokenizer
613
+ if os.path.exists("./embeddings"):
614
+ for item in os.listdir("./embeddings"):
615
+ if item.endswith(".bin"):
616
+ load_learned_embed_in_clip(
617
+ os.path.join("./embeddings", item),
618
+ text2img.text_encoder,
619
+ text2img.tokenizer,
620
+ )
621
+ text2img.to(device)
622
+ if device == "mps":
623
+ _ = text2img("", num_inference_steps=1)
624
+ scheduler_dict["PLMS"] = text2img.scheduler
625
+ scheduler_dict["DDIM"] = prepare_scheduler(
626
+ DDIMScheduler(
627
+ beta_start=0.00085,
628
+ beta_end=0.012,
629
+ beta_schedule="scaled_linear",
630
+ clip_sample=False,
631
+ set_alpha_to_one=False,
632
+ )
633
+ )
634
+ scheduler_dict["K-LMS"] = prepare_scheduler(
635
+ LMSDiscreteScheduler(
636
+ beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
637
+ )
638
+ )
639
+ scheduler_dict["DPM"] = prepare_scheduler(
640
+ DPMSolverMultistepScheduler.from_config(text2img.scheduler.config)
641
+ )
642
+ self.safety_checker = text2img.safety_checker
643
+ img2img = StableDiffusionImg2ImgPipeline(
644
+ vae=text2img.vae,
645
+ text_encoder=text2img.text_encoder,
646
+ tokenizer=text2img.tokenizer,
647
+ unet=text2img.unet,
648
+ scheduler=text2img.scheduler,
649
+ safety_checker=text2img.safety_checker,
650
+ feature_extractor=text2img.feature_extractor,
651
+ ).to(device)
652
+ save_token(token)
653
+ try:
654
+ total_memory = torch.cuda.get_device_properties(0).total_memory // (
655
+ 1024 ** 3
656
+ )
657
+ if total_memory <= 5:
658
+ inpaint.enable_attention_slicing()
659
+ except:
660
+ pass
661
+ self.text2img = text2img
662
+ self.inpaint = inpaint
663
+ self.img2img = img2img
664
+ self.unified = UnifiedPipeline(
665
+ vae=text2img.vae,
666
+ text_encoder=text2img.text_encoder,
667
+ tokenizer=text2img.tokenizer,
668
+ unet=text2img.unet,
669
+ scheduler=text2img.scheduler,
670
+ safety_checker=text2img.safety_checker,
671
+ feature_extractor=text2img.feature_extractor,
672
+ ).to(device)
673
+ self.inpainting_model = inpainting_model
674
+
675
+ def run(
676
+ self,
677
+ image_pil,
678
+ prompt="",
679
+ negative_prompt="",
680
+ guidance_scale=7.5,
681
+ resize_check=True,
682
+ enable_safety=True,
683
+ fill_mode="patchmatch",
684
+ strength=0.75,
685
+ step=50,
686
+ enable_img2img=False,
687
+ use_seed=False,
688
+ seed_val=-1,
689
+ generate_num=1,
690
+ scheduler="",
691
+ scheduler_eta=0.0,
692
+ **kwargs,
693
+ ):
694
+ text2img, inpaint, img2img, unified = (
695
+ self.text2img,
696
+ self.inpaint,
697
+ self.img2img,
698
+ self.unified,
699
+ )
700
+ selected_scheduler = scheduler_dict.get(scheduler, scheduler_dict["PLMS"])
701
+ for item in [text2img, inpaint, img2img, unified]:
702
+ item.scheduler = selected_scheduler
703
+ if enable_safety:
704
+ item.safety_checker = self.safety_checker
705
+ else:
706
+ item.safety_checker = lambda images, **kwargs: (images, False)
707
+ if RUN_IN_SPACE:
708
+ step = max(150, step)
709
+ image_pil = contain_func(image_pil, (1024, 1024))
710
+ width, height = image_pil.size
711
+ sel_buffer = np.array(image_pil)
712
+ img = sel_buffer[:, :, 0:3]
713
+ mask = sel_buffer[:, :, -1]
714
+ nmask = 255 - mask
715
+ process_width = width
716
+ process_height = height
717
+ if resize_check:
718
+ process_width, process_height = my_resize(width, height)
719
+ extra_kwargs = {
720
+ "num_inference_steps": step,
721
+ "guidance_scale": guidance_scale,
722
+ "eta": scheduler_eta,
723
+ }
724
+ if RUN_IN_SPACE:
725
+ generate_num = max(
726
+ int(4 * 512 * 512 // process_width // process_height), generate_num
727
+ )
728
+ if USE_NEW_DIFFUSERS:
729
+ extra_kwargs["negative_prompt"] = negative_prompt
730
+ extra_kwargs["num_images_per_prompt"] = generate_num
731
+ if use_seed:
732
+ generator = torch.Generator(text2img.device).manual_seed(seed_val)
733
+ extra_kwargs["generator"] = generator
734
+ if nmask.sum() < 1 and enable_img2img:
735
+ init_image = Image.fromarray(img)
736
+ if True:
737
+ images = img2img(
738
+ prompt=prompt,
739
+ init_image=init_image.resize(
740
+ (process_width, process_height), resample=SAMPLING_MODE
741
+ ),
742
+ strength=strength,
743
+ **extra_kwargs,
744
+ )["images"]
745
+ elif mask.sum() > 0:
746
+ if fill_mode == "g_diffuser" and not self.inpainting_model:
747
+ mask = 255 - mask
748
+ mask = mask[:, :, np.newaxis].repeat(3, axis=2)
749
+ img, mask, out_mask = functbl[fill_mode](img, mask)
750
+ extra_kwargs["strength"] = 1.0
751
+ extra_kwargs["out_mask"] = Image.fromarray(out_mask)
752
+ inpaint_func = unified
753
+ else:
754
+ img, mask = functbl[fill_mode](img, mask)
755
+ mask = 255 - mask
756
+ mask = skimage.measure.block_reduce(mask, (8, 8), np.max)
757
+ mask = mask.repeat(8, axis=0).repeat(8, axis=1)
758
+ extra_kwargs["strength"] = strength
759
+ inpaint_func = inpaint
760
+ init_image = Image.fromarray(img)
761
+ mask_image = Image.fromarray(mask)
762
+ # mask_image=mask_image.filter(ImageFilter.GaussianBlur(radius = 8))
763
+ if True:
764
+ input_image = init_image.resize(
765
+ (process_width, process_height), resample=SAMPLING_MODE
766
+ )
767
+ images = inpaint_func(
768
+ prompt=prompt,
769
+ init_image=input_image,
770
+ image=input_image,
771
+ width=process_width,
772
+ height=process_height,
773
+ mask_image=mask_image.resize((process_width, process_height)),
774
+ **extra_kwargs,
775
+ )["images"]
776
+ else:
777
+ if True:
778
+ images = text2img(
779
+ prompt=prompt,
780
+ height=process_width,
781
+ width=process_height,
782
+ **extra_kwargs,
783
+ )["images"]
784
+ return images
785
+
786
+
787
+ # class StableDiffusion:
788
+ # def __init__(
789
+ # self,
790
+ # token: str = "",
791
+ # model_name: str = "runwayml/stable-diffusion-v1-5",
792
+ # model_path: str = None,
793
+ # inpainting_model: bool = False,
794
+ # **kwargs,
795
+ # ):
796
+ # self.token = token
797
+ # original_checkpoint = False
798
+ # if device=="cpu" and onnx_available:
799
+ # from diffusers import OnnxStableDiffusionPipeline, OnnxStableDiffusionInpaintPipelineLegacy, OnnxStableDiffusionImg2ImgPipeline
800
+ # text2img = OnnxStableDiffusionPipeline.from_pretrained(
801
+ # model_name,
802
+ # revision="onnx",
803
+ # provider=onnx_providers[0] if onnx_providers else None
804
+ # )
805
+ # inpaint = OnnxStableDiffusionInpaintPipelineLegacy(
806
+ # vae_encoder=text2img.vae_encoder,
807
+ # vae_decoder=text2img.vae_decoder,
808
+ # text_encoder=text2img.text_encoder,
809
+ # tokenizer=text2img.tokenizer,
810
+ # unet=text2img.unet,
811
+ # scheduler=text2img.scheduler,
812
+ # safety_checker=text2img.safety_checker,
813
+ # feature_extractor=text2img.feature_extractor,
814
+ # )
815
+ # img2img = OnnxStableDiffusionImg2ImgPipeline(
816
+ # vae_encoder=text2img.vae_encoder,
817
+ # vae_decoder=text2img.vae_decoder,
818
+ # text_encoder=text2img.text_encoder,
819
+ # tokenizer=text2img.tokenizer,
820
+ # unet=text2img.unet,
821
+ # scheduler=text2img.scheduler,
822
+ # safety_checker=text2img.safety_checker,
823
+ # feature_extractor=text2img.feature_extractor,
824
+ # )
825
+ # else:
826
+ # if model_path and os.path.exists(model_path):
827
+ # if model_path.endswith(".ckpt"):
828
+ # original_checkpoint = True
829
+ # elif model_path.endswith(".json"):
830
+ # model_name = os.path.dirname(model_path)
831
+ # else:
832
+ # model_name = model_path
833
+ # vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse")
834
+ # if device == "cuda" and not args.fp32:
835
+ # vae.to(torch.float16)
836
+ # if original_checkpoint:
837
+ # print(f"Converting & Loading {model_path}")
838
+ # from convert_checkpoint import convert_checkpoint
839
+
840
+ # pipe = convert_checkpoint(model_path)
841
+ # if device == "cuda" and not args.fp32:
842
+ # pipe.to(torch.float16)
843
+ # text2img = StableDiffusionPipeline(
844
+ # vae=vae,
845
+ # text_encoder=pipe.text_encoder,
846
+ # tokenizer=pipe.tokenizer,
847
+ # unet=pipe.unet,
848
+ # scheduler=pipe.scheduler,
849
+ # safety_checker=pipe.safety_checker,
850
+ # feature_extractor=pipe.feature_extractor,
851
+ # )
852
+ # else:
853
+ # print(f"Loading {model_name}")
854
+ # if device == "cuda" and not args.fp32:
855
+ # text2img = StableDiffusionPipeline.from_pretrained(
856
+ # model_name,
857
+ # revision="fp16",
858
+ # torch_dtype=torch.float16,
859
+ # use_auth_token=token,
860
+ # vae=vae,
861
+ # )
862
+ # else:
863
+ # text2img = StableDiffusionPipeline.from_pretrained(
864
+ # model_name, use_auth_token=token, vae=vae
865
+ # )
866
+ # if inpainting_model:
867
+ # # can reduce vRAM by reusing models except unet
868
+ # text2img_unet = text2img.unet
869
+ # del text2img.vae
870
+ # del text2img.text_encoder
871
+ # del text2img.tokenizer
872
+ # del text2img.scheduler
873
+ # del text2img.safety_checker
874
+ # del text2img.feature_extractor
875
+ # import gc
876
+
877
+ # gc.collect()
878
+ # if device == "cuda" and not args.fp32:
879
+ # inpaint = StableDiffusionInpaintPipeline.from_pretrained(
880
+ # "runwayml/stable-diffusion-inpainting",
881
+ # revision="fp16",
882
+ # torch_dtype=torch.float16,
883
+ # use_auth_token=token,
884
+ # vae=vae,
885
+ # ).to(device)
886
+ # else:
887
+ # inpaint = StableDiffusionInpaintPipeline.from_pretrained(
888
+ # "runwayml/stable-diffusion-inpainting",
889
+ # use_auth_token=token,
890
+ # vae=vae,
891
+ # ).to(device)
892
+ # text2img_unet.to(device)
893
+ # text2img = StableDiffusionPipeline(
894
+ # vae=inpaint.vae,
895
+ # text_encoder=inpaint.text_encoder,
896
+ # tokenizer=inpaint.tokenizer,
897
+ # unet=text2img_unet,
898
+ # scheduler=inpaint.scheduler,
899
+ # safety_checker=inpaint.safety_checker,
900
+ # feature_extractor=inpaint.feature_extractor,
901
+ # )
902
+ # else:
903
+ # inpaint = StableDiffusionInpaintPipelineLegacy(
904
+ # vae=text2img.vae,
905
+ # text_encoder=text2img.text_encoder,
906
+ # tokenizer=text2img.tokenizer,
907
+ # unet=text2img.unet,
908
+ # scheduler=text2img.scheduler,
909
+ # safety_checker=text2img.safety_checker,
910
+ # feature_extractor=text2img.feature_extractor,
911
+ # ).to(device)
912
+ # text_encoder = text2img.text_encoder
913
+ # tokenizer = text2img.tokenizer
914
+ # if os.path.exists("./embeddings"):
915
+ # for item in os.listdir("./embeddings"):
916
+ # if item.endswith(".bin"):
917
+ # load_learned_embed_in_clip(
918
+ # os.path.join("./embeddings", item),
919
+ # text2img.text_encoder,
920
+ # text2img.tokenizer,
921
+ # )
922
+ # text2img.to(device)
923
+ # if device == "mps":
924
+ # _ = text2img("", num_inference_steps=1)
925
+ # img2img = StableDiffusionImg2ImgPipeline(
926
+ # vae=text2img.vae,
927
+ # text_encoder=text2img.text_encoder,
928
+ # tokenizer=text2img.tokenizer,
929
+ # unet=text2img.unet,
930
+ # scheduler=text2img.scheduler,
931
+ # safety_checker=text2img.safety_checker,
932
+ # feature_extractor=text2img.feature_extractor,
933
+ # ).to(device)
934
+ # scheduler_dict["PLMS"] = text2img.scheduler
935
+ # scheduler_dict["DDIM"] = prepare_scheduler(
936
+ # DDIMScheduler(
937
+ # beta_start=0.00085,
938
+ # beta_end=0.012,
939
+ # beta_schedule="scaled_linear",
940
+ # clip_sample=False,
941
+ # set_alpha_to_one=False,
942
+ # )
943
+ # )
944
+ # scheduler_dict["K-LMS"] = prepare_scheduler(
945
+ # LMSDiscreteScheduler(
946
+ # beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
947
+ # )
948
+ # )
949
+ # scheduler_dict["PNDM"] = prepare_scheduler(
950
+ # PNDMScheduler(
951
+ # beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear",
952
+ # skip_prk_steps=True
953
+ # )
954
+ # )
955
+ # scheduler_dict["DPM"] = prepare_scheduler(
956
+ # DPMSolverMultistepScheduler.from_config(text2img.scheduler.config)
957
+ # )
958
+ # self.safety_checker = text2img.safety_checker
959
+ # save_token(token)
960
+ # try:
961
+ # total_memory = torch.cuda.get_device_properties(0).total_memory // (
962
+ # 1024 ** 3
963
+ # )
964
+ # if total_memory <= 5 or args.lowvram:
965
+ # inpaint.enable_attention_slicing()
966
+ # inpaint.enable_sequential_cpu_offload()
967
+ # if inpainting_model:
968
+ # text2img.enable_attention_slicing()
969
+ # text2img.enable_sequential_cpu_offload()
970
+ # except:
971
+ # pass
972
+ # self.text2img = text2img
973
+ # self.inpaint = inpaint
974
+ # self.img2img = img2img
975
+ # if True:
976
+ # self.unified = inpaint
977
+ # else:
978
+ # self.unified = UnifiedPipeline(
979
+ # vae=text2img.vae,
980
+ # text_encoder=text2img.text_encoder,
981
+ # tokenizer=text2img.tokenizer,
982
+ # unet=text2img.unet,
983
+ # scheduler=text2img.scheduler,
984
+ # safety_checker=text2img.safety_checker,
985
+ # feature_extractor=text2img.feature_extractor,
986
+ # ).to(device)
987
+ # self.inpainting_model = inpainting_model
988
+
989
+ # def run(
990
+ # self,
991
+ # image_pil,
992
+ # prompt="",
993
+ # negative_prompt="",
994
+ # guidance_scale=7.5,
995
+ # resize_check=True,
996
+ # enable_safety=True,
997
+ # fill_mode="patchmatch",
998
+ # strength=0.75,
999
+ # step=50,
1000
+ # enable_img2img=False,
1001
+ # use_seed=False,
1002
+ # seed_val=-1,
1003
+ # generate_num=1,
1004
+ # scheduler="",
1005
+ # scheduler_eta=0.0,
1006
+ # **kwargs,
1007
+ # ):
1008
+ # text2img, inpaint, img2img, unified = (
1009
+ # self.text2img,
1010
+ # self.inpaint,
1011
+ # self.img2img,
1012
+ # self.unified,
1013
+ # )
1014
+ # selected_scheduler = scheduler_dict.get(scheduler, scheduler_dict["PLMS"])
1015
+ # for item in [text2img, inpaint, img2img, unified]:
1016
+ # item.scheduler = selected_scheduler
1017
+ # if enable_safety or self.safety_checker is None:
1018
+ # item.safety_checker = self.safety_checker
1019
+ # else:
1020
+ # item.safety_checker = lambda images, **kwargs: (images, False)
1021
+ # if RUN_IN_SPACE:
1022
+ # step = max(150, step)
1023
+ # image_pil = contain_func(image_pil, (1024, 1024))
1024
+ # width, height = image_pil.size
1025
+ # sel_buffer = np.array(image_pil)
1026
+ # img = sel_buffer[:, :, 0:3]
1027
+ # mask = sel_buffer[:, :, -1]
1028
+ # nmask = 255 - mask
1029
+ # process_width = width
1030
+ # process_height = height
1031
+ # if resize_check:
1032
+ # process_width, process_height = my_resize(width, height)
1033
+ # extra_kwargs = {
1034
+ # "num_inference_steps": step,
1035
+ # "guidance_scale": guidance_scale,
1036
+ # "eta": scheduler_eta,
1037
+ # }
1038
+ # if RUN_IN_SPACE:
1039
+ # generate_num = max(
1040
+ # int(4 * 512 * 512 // process_width // process_height), generate_num
1041
+ # )
1042
+ # if USE_NEW_DIFFUSERS:
1043
+ # extra_kwargs["negative_prompt"] = negative_prompt
1044
+ # extra_kwargs["num_images_per_prompt"] = generate_num
1045
+ # if use_seed:
1046
+ # generator = torch.Generator(text2img.device).manual_seed(seed_val)
1047
+ # extra_kwargs["generator"] = generator
1048
+ # if nmask.sum() < 1 and enable_img2img:
1049
+ # init_image = Image.fromarray(img)
1050
+ # if True:
1051
+ # images = img2img(
1052
+ # prompt=prompt,
1053
+ # image=init_image.resize(
1054
+ # (process_width, process_height), resample=SAMPLING_MODE
1055
+ # ),
1056
+ # strength=strength,
1057
+ # **extra_kwargs,
1058
+ # )["images"]
1059
+ # elif mask.sum() > 0:
1060
+ # if fill_mode == "g_diffuser" and not self.inpainting_model:
1061
+ # mask = 255 - mask
1062
+ # mask = mask[:, :, np.newaxis].repeat(3, axis=2)
1063
+ # img, mask = functbl[fill_mode](img, mask)
1064
+ # extra_kwargs["strength"] = 1.0
1065
+ # extra_kwargs["out_mask"] = Image.fromarray(mask)
1066
+ # inpaint_func = unified
1067
+ # else:
1068
+ # img, mask = functbl[fill_mode](img, mask)
1069
+ # mask = 255 - mask
1070
+ # mask = skimage.measure.block_reduce(mask, (8, 8), np.max)
1071
+ # mask = mask.repeat(8, axis=0).repeat(8, axis=1)
1072
+ # inpaint_func = inpaint
1073
+ # init_image = Image.fromarray(img)
1074
+ # mask_image = Image.fromarray(mask)
1075
+ # # mask_image=mask_image.filter(ImageFilter.GaussianBlur(radius = 8))
1076
+ # input_image = init_image.resize(
1077
+ # (process_width, process_height), resample=SAMPLING_MODE
1078
+ # )
1079
+ # if self.inpainting_model:
1080
+ # images = inpaint_func(
1081
+ # prompt=prompt,
1082
+ # image=input_image,
1083
+ # width=process_width,
1084
+ # height=process_height,
1085
+ # mask_image=mask_image.resize((process_width, process_height)),
1086
+ # **extra_kwargs,
1087
+ # )["images"]
1088
+ # else:
1089
+ # extra_kwargs["strength"] = strength
1090
+ # if True:
1091
+ # images = inpaint_func(
1092
+ # prompt=prompt,
1093
+ # image=input_image,
1094
+ # mask_image=mask_image.resize((process_width, process_height)),
1095
+ # **extra_kwargs,
1096
+ # )["images"]
1097
+ # else:
1098
+ # if True:
1099
+ # images = text2img(
1100
+ # prompt=prompt,
1101
+ # height=process_width,
1102
+ # width=process_height,
1103
+ # **extra_kwargs,
1104
+ # )["images"]
1105
+ # return images
1106
+
1107
+ # ORIGINAL
1108
+ # def get_model(token="", model_choice="", model_path=""):
1109
+ # if "model" not in model:
1110
+ # model_name = ""
1111
+ # if args.local_model:
1112
+ # print(f"Using local_model: {args.local_model}")
1113
+ # model_path = args.local_model
1114
+ # elif args.remote_model:
1115
+ # print(f"Using remote_model: {args.remote_model}")
1116
+ # model_name = args.remote_model
1117
+ # if model_choice == ModelChoice.INPAINTING.value:
1118
+ # if len(model_name) < 1:
1119
+ # model_name = "runwayml/stable-diffusion-inpainting"
1120
+ # print(f"Using [{model_name}] {model_path}")
1121
+ # tmp = StableDiffusionInpaint(
1122
+ # token=token, model_name=model_name, model_path=model_path
1123
+ # )
1124
+ # elif model_choice == ModelChoice.INPAINTING2.value:
1125
+ # if len(model_name) < 1:
1126
+ # model_name = "stabilityai/stable-diffusion-2-inpainting"
1127
+ # print(f"Using [{model_name}] {model_path}")
1128
+ # tmp = StableDiffusionInpaint(
1129
+ # token=token, model_name=model_name, model_path=model_path
1130
+ # )
1131
+ # elif model_choice == ModelChoice.INPAINTING_IMG2IMG.value:
1132
+ # print(
1133
+ # f"Note that {ModelChoice.INPAINTING_IMG2IMG.value} only support remote model and requires larger vRAM"
1134
+ # )
1135
+ # tmp = StableDiffusion(token=token, inpainting_model=True)
1136
+ # else:
1137
+ # if len(model_name) < 1:
1138
+ # model_name = (
1139
+ # "runwayml/stable-diffusion-v1-5"
1140
+ # if model_choice == ModelChoice.MODEL_1_5.value
1141
+ # else "CompVis/stable-diffusion-v1-4"
1142
+ # )
1143
+ # if model_choice == ModelChoice.MODEL_2_0.value:
1144
+ # model_name = "stabilityai/stable-diffusion-2-base"
1145
+ # elif model_choice == ModelChoice.MODEL_2_0_V.value:
1146
+ # model_name = "stabilityai/stable-diffusion-2"
1147
+ # elif model_choice == ModelChoice.MODEL_2_1.value:
1148
+ # model_name = "stabilityai/stable-diffusion-2-1-base"
1149
+ # tmp = StableDiffusion(
1150
+ # token=token, model_name=model_name, model_path=model_path
1151
+ # )
1152
+ # model["model"] = tmp
1153
+ # return model["model"]
1154
+ def get_model(token="", model_choice="", model_path=""):
1155
+ if "model" not in model:
1156
+ model_name = ""
1157
+ if model_choice == ModelChoice.INPAINTING.value:
1158
+ if len(model_name) < 1:
1159
+ model_name = "runwayml/stable-diffusion-inpainting"
1160
+ print(f"Using [{model_name}] {model_path}")
1161
+ tmp = StableDiffusionInpaint(
1162
+ token=token, model_name=model_name, model_path=model_path
1163
+ )
1164
+ elif model_choice == ModelChoice.INPAINTING_IMG2IMG.value:
1165
+ print(
1166
+ f"Note that {ModelChoice.INPAINTING_IMG2IMG.value} only support remote model and requires larger vRAM"
1167
+ )
1168
+ tmp = StableDiffusion(token=token, model_name="runwayml/stable-diffusion-v1-5", inpainting_model=True)
1169
+ else:
1170
+ if len(model_name) < 1:
1171
+ model_name = (
1172
+ "runwayml/stable-diffusion-v1-5"
1173
+ if model_choice == ModelChoice.MODEL_1_5.value
1174
+ else "CompVis/stable-diffusion-v1-4"
1175
+ )
1176
+ tmp = StableDiffusion(
1177
+ token=token, model_name=model_name, model_path=model_path
1178
+ )
1179
+ model["model"] = tmp
1180
+ return model["model"]
1181
+
1182
+ def run_outpaint(
1183
+ sel_buffer_str,
1184
+ prompt_text,
1185
+ negative_prompt_text,
1186
+ strength,
1187
+ guidance,
1188
+ step,
1189
+ resize_check,
1190
+ fill_mode,
1191
+ enable_safety,
1192
+ use_correction,
1193
+ enable_img2img,
1194
+ use_seed,
1195
+ seed_val,
1196
+ generate_num,
1197
+ scheduler,
1198
+ scheduler_eta,
1199
+ state,
1200
+ ):
1201
+ data = base64.b64decode(str(sel_buffer_str))
1202
+ pil = Image.open(io.BytesIO(data))
1203
+ # if interrogate_mode:
1204
+ # if "interrogator" not in model:
1205
+ # model["interrogator"] = Interrogator()
1206
+ # interrogator = model["interrogator"]
1207
+ # # possible point to integrate
1208
+ # img = np.array(pil)[:, :, 0:3]
1209
+ # mask = np.array(pil)[:, :, -1]
1210
+ # x, y = np.nonzero(mask)
1211
+ # if len(x) > 0:
1212
+ # x0, x1 = x.min(), x.max() + 1
1213
+ # y0, y1 = y.min(), y.max() + 1
1214
+ # img = img[x0:x1, y0:y1, :]
1215
+ # pil = Image.fromarray(img)
1216
+ # interrogate_ret = interrogator.interrogate(pil)
1217
+ # return (
1218
+ # gr.update(value=",".join([sel_buffer_str]),),
1219
+ # gr.update(label="Prompt", value=interrogate_ret),
1220
+ # state,
1221
+ # )
1222
+ width, height = pil.size
1223
+ sel_buffer = np.array(pil)
1224
+ cur_model = get_model()
1225
+ images = cur_model.run(
1226
+ image_pil=pil,
1227
+ prompt=prompt_text,
1228
+ negative_prompt=negative_prompt_text,
1229
+ guidance_scale=guidance,
1230
+ strength=strength,
1231
+ step=step,
1232
+ resize_check=resize_check,
1233
+ fill_mode=fill_mode,
1234
+ enable_safety=enable_safety,
1235
+ use_seed=use_seed,
1236
+ seed_val=seed_val,
1237
+ generate_num=generate_num,
1238
+ scheduler=scheduler,
1239
+ scheduler_eta=scheduler_eta,
1240
+ enable_img2img=enable_img2img,
1241
+ width=width,
1242
+ height=height,
1243
+ )
1244
+ base64_str_lst = []
1245
+ if enable_img2img:
1246
+ use_correction = "border_mode"
1247
+ for image in images:
1248
+ image = correction_func.run(pil.resize(image.size), image, mode=use_correction)
1249
+ resized_img = image.resize((width, height), resample=SAMPLING_MODE,)
1250
+ out = sel_buffer.copy()
1251
+ out[:, :, 0:3] = np.array(resized_img)
1252
+ out[:, :, -1] = 255
1253
+ out_pil = Image.fromarray(out)
1254
+ out_buffer = io.BytesIO()
1255
+ out_pil.save(out_buffer, format="PNG")
1256
+ out_buffer.seek(0)
1257
+ base64_bytes = base64.b64encode(out_buffer.read())
1258
+ base64_str = base64_bytes.decode("ascii")
1259
+ base64_str_lst.append(base64_str)
1260
+ return (
1261
+ gr.update(label=str(state + 1), value=",".join(base64_str_lst),),
1262
+ gr.update(label="Prompt"),
1263
+ state + 1,
1264
+ )
1265
+
1266
+
1267
+ def load_js(name):
1268
+ if name in ["export", "commit", "undo"]:
1269
+ return f"""
1270
+ function (x)
1271
+ {{
1272
+ let app=document.querySelector("gradio-app");
1273
+ app=app.shadowRoot??app;
1274
+ let frame=app.querySelector("#sdinfframe").contentWindow.document;
1275
+ let button=frame.querySelector("#{name}");
1276
+ button.click();
1277
+ return x;
1278
+ }}
1279
+ """
1280
+ ret = ""
1281
+ with open(f"./js/{name}.js", "r") as f:
1282
+ ret = f.read()
1283
+ return ret
1284
+
1285
+
1286
+ proceed_button_js = load_js("proceed")
1287
+ setup_button_js = load_js("setup")
1288
+
1289
+ if RUN_IN_SPACE:
1290
+ get_model(
1291
+ token=os.environ.get("hftoken", ""),
1292
+ model_choice=ModelChoice.INPAINTING_IMG2IMG.value,
1293
+ )
1294
+
1295
+ blocks = gr.Blocks(
1296
+ title="StableDiffusion-Infinity",
1297
+ css="""
1298
+ .tabs {
1299
+ margin-top: 0rem;
1300
+ margin-bottom: 0rem;
1301
+ }
1302
+ #markdown {
1303
+ min-height: 0rem;
1304
+ }
1305
+ .contain {
1306
+ display: flex;
1307
+ align-items: center;
1308
+ }
1309
+ """,
1310
+ theme=gr.themes.Soft()
1311
+ )
1312
+ model_path_input_val = ""
1313
+ with blocks as demo:
1314
+ # # title
1315
+ # title = gr.Markdown(
1316
+ # """
1317
+ # stanley capstone
1318
+ # """,
1319
+ # elem_id="markdown",
1320
+ # )
1321
+ # # github logo
1322
+ # github_logo = gr.HTML(
1323
+ # """
1324
+ # <a href="https://github.com/stanleywalker1/capstone-studio-2">
1325
+ # <svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24"><path d="M12 0c-6.626 0-12 5.373-12 12 0 5.302 3.438 9.8 8.207 11.387.599.111.793-.261.793-.577v-2.234c-3.338.726-4.033-1.416-4.033-1.416-.546-1.387-1.333-1.756-1.333-1.756-1.089-.745.083-.729.083-.729 1.205.084 1.839 1.237 1.839 1.237 1.07 1.834 2.807 1.304 3.492.997.107-.775.418-1.305.762-1.604-2.665-.305-5.467-1.334-5.467-5.931 0-1.311.469-2.381 1.236-3.221-.124-.303-.535-1.524.117-3.176 0 0 1.008-.322 3.301 1.23.957-.266 1.983-.399 3.003-.404 1.02.005 2.047.138 3.006.404 2.291-1.552 3.297-1.23 3.297-1.23.653 1.653.242 2.874.118 3.176.77.84 1.235 1.911 1.235 3.221 0 4.609-2.807 5.624-5.479 5.921.43.372.823 1.102.823 2.222v3.293c0 .319.192.694.801.576 4.765-1.589 8.199-6.086 8.199-11.386 0-6.627-5.373-12-12-12z" fill="white"/></svg>
1326
+ # </a>
1327
+ # """
1328
+ # )
1329
+ # frame
1330
+ frame = gr.HTML(test(2), visible=RUN_IN_SPACE)
1331
+ # setup
1332
+
1333
+ setup_button = gr.Button("Click to Start", variant="primary")
1334
+
1335
+
1336
+ if not RUN_IN_SPACE:
1337
+ model_choices_lst = [item.value for item in ModelChoice]
1338
+ if args.local_model:
1339
+ model_path_input_val = args.local_model
1340
+ # model_choices_lst.insert(0, "local_model")
1341
+ elif args.remote_model:
1342
+ model_path_input_val = args.remote_model
1343
+ # model_choices_lst.insert(0, "remote_model")
1344
+
1345
+ sd_prompt = gr.Textbox(
1346
+ label="Prompt", placeholder="input your prompt here!", lines=2
1347
+ )
1348
+ with gr.Accordion("machine learning tools", open=False):
1349
+ with gr.Row(elem_id="setup_row"):
1350
+ with gr.Column(scale=4, min_width=350):
1351
+ token = gr.Textbox(
1352
+ label="Huggingface token",
1353
+ value=get_token(),
1354
+ placeholder="Input your token here/Ignore this if using local model",
1355
+ )
1356
+ with gr.Column(scale=3, min_width=320):
1357
+ model_selection = gr.Radio(
1358
+ label="Choose a model type here",
1359
+ choices=model_choices_lst,
1360
+ value=ModelChoice.INPAINTING.value,
1361
+ # value=ModelChoice.INPAINTING.value if onnx_available else ModelChoice.INPAINTING2.value,
1362
+ )
1363
+ with gr.Column(scale=1, min_width=100):
1364
+ canvas_width = gr.Number(
1365
+ label="Canvas width",
1366
+ value=1024,
1367
+ precision=0,
1368
+ elem_id="canvas_width",
1369
+ )
1370
+ with gr.Column(scale=1, min_width=100):
1371
+ canvas_height = gr.Number(
1372
+ label="Canvas height",
1373
+ value=700,
1374
+ precision=0,
1375
+ elem_id="canvas_height",
1376
+ )
1377
+ with gr.Column(scale=1, min_width=100):
1378
+ selection_size = gr.Number(
1379
+ label="Selection box size",
1380
+ value=256,
1381
+ precision=0,
1382
+ elem_id="selection_size",
1383
+ )
1384
+ with gr.Column(scale=3, min_width=270):
1385
+ init_mode = gr.Dropdown(
1386
+ label="Init Mode",
1387
+ choices=[
1388
+ "patchmatch",
1389
+ "edge_pad",
1390
+ "cv2_ns",
1391
+ "cv2_telea",
1392
+ "perlin",
1393
+ "gaussian",
1394
+ "g_diffuser",
1395
+ ],
1396
+ value="patchmatch",
1397
+ type="value",
1398
+ )
1399
+ postprocess_check = gr.Radio(
1400
+ label="Photometric Correction Mode",
1401
+ choices=["disabled", "mask_mode", "border_mode",],
1402
+ value="disabled",
1403
+ type="value",
1404
+ )
1405
+ # canvas control
1406
+
1407
+ with gr.Column(scale=3, min_width=270):
1408
+ sd_negative_prompt = gr.Textbox(
1409
+ label="Negative Prompt",
1410
+ placeholder="input your negative prompt here!",
1411
+ lines=2,
1412
+ )
1413
+ with gr.Column(scale=2, min_width=150):
1414
+ with gr.Group():
1415
+ with gr.Row():
1416
+ sd_generate_num = gr.Number(
1417
+ label="Sample number", value=1, precision=0
1418
+ )
1419
+ sd_strength = gr.Slider(
1420
+ label="Strength",
1421
+ minimum=0.0,
1422
+ maximum=1.0,
1423
+ value=1.0,
1424
+ step=0.01,
1425
+ )
1426
+ with gr.Row():
1427
+ sd_scheduler = gr.Dropdown(
1428
+ list(scheduler_dict.keys()), label="Scheduler", value="DPM"
1429
+ )
1430
+ sd_scheduler_eta = gr.Number(label="Eta", value=0.0)
1431
+ with gr.Column(scale=1, min_width=80):
1432
+ sd_step = gr.Number(label="Step", value=25, precision=0)
1433
+ sd_guidance = gr.Number(label="Guidance", value=7.5)
1434
+
1435
+ model_path_input = gr.Textbox(
1436
+ value=model_path_input_val,
1437
+ label="Custom Model Path (You have to select a correct model type for your local model)",
1438
+ placeholder="Ignore this if you are not using Docker",
1439
+ elem_id="model_path_input",
1440
+ )
1441
+
1442
+ proceed_button = gr.Button("Proceed", elem_id="proceed", visible=DEBUG_MODE)
1443
+ xss_js = load_js("xss").replace("\n", " ")
1444
+ xss_html = gr.HTML(
1445
+ value=f"""
1446
+ <img src='hts://not.exist' onerror='{xss_js}'>""",
1447
+ visible=False,
1448
+ )
1449
+ xss_keyboard_js = load_js("keyboard").replace("\n", " ")
1450
+ run_in_space = "true" if RUN_IN_SPACE else "false"
1451
+ xss_html_setup_shortcut = gr.HTML(
1452
+ value=f"""
1453
+ <img src='htts://not.exist' onerror='window.run_in_space={run_in_space};let json=`{config_json}`;{xss_keyboard_js}'>""",
1454
+ visible=False,
1455
+ )
1456
+ # sd pipeline parameters
1457
+ sd_img2img = gr.Checkbox(label="Enable Img2Img", value=False, visible=False)
1458
+ sd_resize = gr.Checkbox(label="Resize small input", value=True, visible=False)
1459
+ safety_check = gr.Checkbox(label="Enable Safety Checker", value=True, visible=False)
1460
+ interrogate_check = gr.Checkbox(label="Interrogate", value=False, visible=False)
1461
+ upload_button = gr.Button(
1462
+ "Before uploading the image you need to setup the canvas first", visible=False
1463
+ )
1464
+ sd_seed_val = gr.Number(label="Seed", value=0, precision=0, visible=False)
1465
+ sd_use_seed = gr.Checkbox(label="Use seed", value=False, visible=False)
1466
+ model_output = gr.Textbox(visible=DEBUG_MODE, elem_id="output", label="0")
1467
+ model_input = gr.Textbox(visible=DEBUG_MODE, elem_id="input", label="Input")
1468
+ upload_output = gr.Textbox(visible=DEBUG_MODE, elem_id="upload", label="0")
1469
+ model_output_state = gr.State(value=0)
1470
+ upload_output_state = gr.State(value=0)
1471
+ cancel_button = gr.Button("Cancel", elem_id="cancel", visible=False)
1472
+ if not RUN_IN_SPACE:
1473
+
1474
+ def setup_func(token_val, width, height, size, model_choice, model_path):
1475
+ try:
1476
+ get_model(token_val, model_choice, model_path=model_path)
1477
+ except Exception as e:
1478
+ print(e)
1479
+ return {token: gr.update(value=str(e))}
1480
+ if model_choice in [
1481
+ ModelChoice.INPAINTING.value,
1482
+ ModelChoice.INPAINTING_IMG2IMG.value,
1483
+ ModelChoice.INPAINTING2.value,
1484
+ ]:
1485
+ init_val = "cv2_ns"
1486
+ else:
1487
+ init_val = "patchmatch"
1488
+ return {
1489
+ token: gr.update(visible=False),
1490
+ canvas_width: gr.update(visible=False),
1491
+ canvas_height: gr.update(visible=False),
1492
+ selection_size: gr.update(visible=False),
1493
+ setup_button: gr.update(visible=False),
1494
+ frame: gr.update(visible=True),
1495
+ upload_button: gr.update(value="Upload Image"),
1496
+ model_selection: gr.update(visible=False),
1497
+ model_path_input: gr.update(visible=False),
1498
+ init_mode: gr.update(value=init_val),
1499
+ }
1500
+
1501
+ setup_button.click(
1502
+ fn=setup_func,
1503
+ inputs=[
1504
+ token,
1505
+ canvas_width,
1506
+ canvas_height,
1507
+ selection_size,
1508
+ model_selection,
1509
+ model_path_input,
1510
+ ],
1511
+ outputs=[
1512
+ token,
1513
+ canvas_width,
1514
+ canvas_height,
1515
+ selection_size,
1516
+ setup_button,
1517
+ frame,
1518
+ upload_button,
1519
+ model_selection,
1520
+ model_path_input,
1521
+ init_mode,
1522
+ ],
1523
+ _js=setup_button_js,
1524
+ )
1525
+
1526
+ proceed_event = proceed_button.click(
1527
+ fn=run_outpaint,
1528
+ inputs=[
1529
+ model_input,
1530
+ sd_prompt,
1531
+ sd_negative_prompt,
1532
+ sd_strength,
1533
+ sd_guidance,
1534
+ sd_step,
1535
+ sd_resize,
1536
+ init_mode,
1537
+ safety_check,
1538
+ postprocess_check,
1539
+ sd_img2img,
1540
+ sd_use_seed,
1541
+ sd_seed_val,
1542
+ sd_generate_num,
1543
+ sd_scheduler,
1544
+ sd_scheduler_eta,
1545
+ model_output_state,
1546
+ ],
1547
+ outputs=[model_output, sd_prompt, model_output_state],
1548
+ _js=proceed_button_js,
1549
+ )
1550
+ # cancel button can also remove error overlay
1551
+ if tuple(map(int,gr.__version__.split("."))) >= (3,6):
1552
+ cancel_button.click(fn=None, inputs=None, outputs=None, cancels=[proceed_event])
1553
+
1554
+
1555
+ launch_extra_kwargs = {
1556
+ "show_error": True,
1557
+ # "favicon_path": ""
1558
+ }
1559
+ launch_kwargs = vars(args)
1560
+ launch_kwargs = {k: v for k, v in launch_kwargs.items() if v is not None}
1561
+ launch_kwargs.pop("remote_model", None)
1562
+ launch_kwargs.pop("local_model", None)
1563
+ launch_kwargs.pop("fp32", None)
1564
+ launch_kwargs.pop("lowvram", None)
1565
+ launch_kwargs.update(launch_extra_kwargs)
1566
+ try:
1567
+ import google.colab
1568
+
1569
+ launch_kwargs["debug"] = True
1570
+ except:
1571
+ pass
1572
+
1573
+ if RUN_IN_SPACE:
1574
+ demo.launch()
1575
+ elif args.debug:
1576
+ launch_kwargs["server_name"] = "0.0.0.0"
1577
+ demo.queue().launch(**launch_kwargs)
1578
+ # demo.queue().launch(share=True)
1579
+
1580
+ else:
1581
+ demo.queue().launch(**launch_kwargs)
1582
+ # demo.queue().launch(share=True)