Spaces:
Paused
Paused
Create inpainting.py
Browse files- utils/inpainting.py +177 -0
utils/inpainting.py
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import cv2
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
import gradio as gr
|
6 |
+
from PIL import Image
|
7 |
+
from omegaconf import OmegaConf
|
8 |
+
from einops import repeat
|
9 |
+
from imwatermark import WatermarkEncoder
|
10 |
+
from pathlib import Path
|
11 |
+
|
12 |
+
from .ddim import DDIMSampler
|
13 |
+
from .util import instantiate_from_config
|
14 |
+
|
15 |
+
|
16 |
+
torch.set_grad_enabled(False)
|
17 |
+
|
18 |
+
|
19 |
+
def put_watermark(img, wm_encoder=None):
|
20 |
+
if wm_encoder is not None:
|
21 |
+
img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
|
22 |
+
img = wm_encoder.encode(img, 'dwtDct')
|
23 |
+
img = Image.fromarray(img[:, :, ::-1])
|
24 |
+
return img
|
25 |
+
|
26 |
+
|
27 |
+
def initialize_model(config, ckpt):
|
28 |
+
config = OmegaConf.load(config)
|
29 |
+
model = instantiate_from_config(config.model)
|
30 |
+
|
31 |
+
model.load_state_dict(torch.load(ckpt)["state_dict"], strict=False)
|
32 |
+
|
33 |
+
device = torch.device(
|
34 |
+
"cuda") if torch.cuda.is_available() else torch.device("cpu")
|
35 |
+
model = model.to(device)
|
36 |
+
sampler = DDIMSampler(model)
|
37 |
+
|
38 |
+
return sampler
|
39 |
+
|
40 |
+
|
41 |
+
def make_batch_sd(
|
42 |
+
image,
|
43 |
+
mask,
|
44 |
+
txt,
|
45 |
+
device,
|
46 |
+
num_samples=1):
|
47 |
+
image = np.array(image.convert("RGB"))
|
48 |
+
image = image[None].transpose(0, 3, 1, 2)
|
49 |
+
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
|
50 |
+
|
51 |
+
mask = np.array(mask.convert("L"))
|
52 |
+
mask = mask.astype(np.float32) / 255.0
|
53 |
+
mask = mask[None, None]
|
54 |
+
mask[mask < 0.5] = 0
|
55 |
+
mask[mask >= 0.5] = 1
|
56 |
+
mask = torch.from_numpy(mask)
|
57 |
+
|
58 |
+
masked_image = image * (mask < 0.5)
|
59 |
+
|
60 |
+
batch = {
|
61 |
+
"image": repeat(image.to(device=device), "1 ... -> n ...", n=num_samples),
|
62 |
+
"txt": num_samples * [txt],
|
63 |
+
"mask": repeat(mask.to(device=device), "1 ... -> n ...", n=num_samples),
|
64 |
+
"masked_image": repeat(masked_image.to(device=device), "1 ... -> n ...", n=num_samples),
|
65 |
+
}
|
66 |
+
return batch
|
67 |
+
|
68 |
+
@torch.no_grad()
|
69 |
+
def inpaint(sampler, image, mask, prompt, seed, scale, ddim_steps, num_samples=1, w=512, h=512):
|
70 |
+
device = torch.device(
|
71 |
+
"cuda") if torch.cuda.is_available() else torch.device("cpu")
|
72 |
+
model = sampler.model
|
73 |
+
|
74 |
+
print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...")
|
75 |
+
wm = "SDV2"
|
76 |
+
wm_encoder = WatermarkEncoder()
|
77 |
+
wm_encoder.set_watermark('bytes', wm.encode('utf-8'))
|
78 |
+
|
79 |
+
prng = np.random.RandomState(seed)
|
80 |
+
start_code = prng.randn(num_samples, 4, h // 8, w // 8)
|
81 |
+
start_code = torch.from_numpy(start_code).to(
|
82 |
+
device=device, dtype=torch.float32)
|
83 |
+
|
84 |
+
with torch.no_grad(), \
|
85 |
+
torch.autocast("cuda"):
|
86 |
+
batch = make_batch_sd(image, mask, txt=prompt,
|
87 |
+
device=device, num_samples=num_samples)
|
88 |
+
|
89 |
+
c = model.cond_stage_model.encode(batch["txt"])
|
90 |
+
|
91 |
+
c_cat = list()
|
92 |
+
for ck in model.concat_keys:
|
93 |
+
cc = batch[ck].float()
|
94 |
+
if ck != model.masked_image_key:
|
95 |
+
bchw = [num_samples, 4, h // 8, w // 8]
|
96 |
+
cc = torch.nn.functional.interpolate(cc, size=bchw[-2:])
|
97 |
+
else:
|
98 |
+
cc = model.get_first_stage_encoding(
|
99 |
+
model.encode_first_stage(cc))
|
100 |
+
c_cat.append(cc)
|
101 |
+
c_cat = torch.cat(c_cat, dim=1)
|
102 |
+
|
103 |
+
# cond
|
104 |
+
cond = {"c_concat": [c_cat], "c_crossattn": [c]}
|
105 |
+
|
106 |
+
# uncond cond
|
107 |
+
uc_cross = model.get_unconditional_conditioning(num_samples, "")
|
108 |
+
uc_full = {"c_concat": [c_cat], "c_crossattn": [uc_cross]}
|
109 |
+
|
110 |
+
shape = [model.channels, h // 8, w // 8]
|
111 |
+
samples_cfg, intermediates = sampler.sample(
|
112 |
+
ddim_steps,
|
113 |
+
num_samples,
|
114 |
+
shape,
|
115 |
+
cond,
|
116 |
+
verbose=False,
|
117 |
+
eta=1.0,
|
118 |
+
unconditional_guidance_scale=scale,
|
119 |
+
unconditional_conditioning=uc_full,
|
120 |
+
x_T=start_code,
|
121 |
+
)
|
122 |
+
x_samples_ddim = model.decode_first_stage(samples_cfg)
|
123 |
+
|
124 |
+
result = torch.clamp((x_samples_ddim + 1.0) / 2.0,
|
125 |
+
min=0.0, max=1.0)
|
126 |
+
|
127 |
+
result = result.cpu().numpy().transpose(0, 2, 3, 1) * 255
|
128 |
+
return [put_watermark(Image.fromarray(img.astype(np.uint8)), wm_encoder) for img in result]
|
129 |
+
|
130 |
+
def pad_image(input_image):
|
131 |
+
pad_w, pad_h = np.max(((2, 2), np.ceil(
|
132 |
+
np.array(input_image.size) / 64).astype(int)), axis=0) * 64 - input_image.size
|
133 |
+
im_padded = Image.fromarray(
|
134 |
+
np.pad(np.array(input_image), ((0, pad_h), (0, pad_w), (0, 0)), mode='edge'))
|
135 |
+
return im_padded
|
136 |
+
|
137 |
+
def crop_image(input_image):
|
138 |
+
crop_w, crop_h = np.floor(np.array(input_image.size) / 64).astype(int) * 64
|
139 |
+
im_cropped = Image.fromarray(np.array(input_image)[:crop_h, :crop_w])
|
140 |
+
return im_cropped
|
141 |
+
|
142 |
+
# sampler = initialize_model(sys.argv[1], sys.argv[2])
|
143 |
+
@torch.no_grad()
|
144 |
+
def predict(model, input_image, prompt, ddim_steps, num_samples, scale, seed):
|
145 |
+
"""_summary_
|
146 |
+
|
147 |
+
Args:
|
148 |
+
input_image (_type_): dict
|
149 |
+
- image: PIL.Image. Input image.
|
150 |
+
- mask: PIL.Image. Mask image.
|
151 |
+
prompt (_type_): string to be used as prompt.
|
152 |
+
ddim_steps (_type_): typical 45
|
153 |
+
num_samples (_type_): typical 4
|
154 |
+
scale (_type_): typical 10.0 Guidance Scale.
|
155 |
+
seed (_type_): typical 1529160519
|
156 |
+
|
157 |
+
"""
|
158 |
+
init_image = input_image["image"].convert("RGB")
|
159 |
+
init_mask = input_image["mask"].convert("RGB")
|
160 |
+
image = pad_image(init_image) # resize to integer multiple of 32
|
161 |
+
mask = pad_image(init_mask) # resize to integer multiple of 32
|
162 |
+
width, height = image.size
|
163 |
+
print("Inpainting...", width, height)
|
164 |
+
|
165 |
+
result = inpaint(
|
166 |
+
sampler=model,
|
167 |
+
image=image,
|
168 |
+
mask=mask,
|
169 |
+
prompt=prompt,
|
170 |
+
seed=seed,
|
171 |
+
scale=scale,
|
172 |
+
ddim_steps=ddim_steps,
|
173 |
+
num_samples=num_samples,
|
174 |
+
h=height, w=width
|
175 |
+
)
|
176 |
+
|
177 |
+
return result
|