Spaces:
Runtime error
Runtime error
update code
Browse files- .gitattributes +34 -0
- app.py +393 -338
- assets/GIF.gif +0 -0
- assets/Teaser_Small.png +3 -0
- assets/examples/Lancia.webp +3 -0
- assets/examples/car.jpeg +3 -0
- assets/examples/car1.webp +3 -0
- assets/examples/carpet2.webp +3 -0
- assets/examples/chair.jpeg +3 -0
- assets/examples/chair1.jpeg +3 -0
- assets/examples/dog.jpeg +3 -0
- assets/examples/door.jpeg +3 -0
- assets/examples/door2.jpeg +3 -0
- assets/examples/grasslands-national-park.jpeg +3 -0
- assets/examples/house.jpeg +3 -0
- assets/examples/house2.jpeg +3 -0
- assets/examples/ian.jpeg +3 -0
- assets/examples/park.webp +3 -0
- assets/examples/ran.webp +3 -0
- assets/hulk.jpeg +0 -0
- assets/ironman.webp +0 -0
- assets/lava.jpg +0 -0
- assets/ski.jpg +0 -0
- assets/truck.png +0 -0
- assets/truck2.jpeg +0 -0
- cldm/appearance_networks.py +75 -0
- cldm/cldm.py +115 -118
- cldm/controlnet.py +306 -0
- cldm/ddim_hacked.py +2 -3
- cldm/logger.py +10 -10
- configs/{sap_fixed_hintnet_v15.yaml → pair_diff.yaml} +20 -7
- ldm/ldm/util.py +197 -0
- ldm/models/diffusion/ddim.py +15 -4
- ldm/modules/attention.py +61 -15
- ldm/modules/diffusionmodules/openaimodel.py +16 -4
- ldm/modules/diffusionmodules/util.py +2 -1
- ldm/modules/encoders/modules.py +3 -3
- pair_diff_demo.py +516 -0
- requirements.txt +2 -1
.gitattributes
CHANGED
@@ -32,3 +32,37 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
35 |
+
assets/examples/ian.jpeg filter=lfs diff=lfs merge=lfs -text
|
36 |
+
assets/examples/resized_anm_38.jpg filter=lfs diff=lfs merge=lfs -text
|
37 |
+
assets/examples/anm_8.jpg filter=lfs diff=lfs merge=lfs -text
|
38 |
+
assets/examples/house.jpeg filter=lfs diff=lfs merge=lfs -text
|
39 |
+
assets/examples/door2.jpeg filter=lfs diff=lfs merge=lfs -text
|
40 |
+
assets/examples/door.jpeg filter=lfs diff=lfs merge=lfs -text
|
41 |
+
assets/examples/frn_38.jpg filter=lfs diff=lfs merge=lfs -text
|
42 |
+
assets/examples/park.webp filter=lfs diff=lfs merge=lfs -text
|
43 |
+
assets/examples/car1.webp filter=lfs diff=lfs merge=lfs -text
|
44 |
+
assets/examples/car.jpeg filter=lfs diff=lfs merge=lfs -text
|
45 |
+
assets/examples/house2.jpeg filter=lfs diff=lfs merge=lfs -text
|
46 |
+
assets/examples/Lancia.webp filter=lfs diff=lfs merge=lfs -text
|
47 |
+
assets/examples/obj_11.jpg filter=lfs diff=lfs merge=lfs -text
|
48 |
+
assets/examples/resized_anm_8.jpg filter=lfs diff=lfs merge=lfs -text
|
49 |
+
assets/examples/resized_frn_38.jpg filter=lfs diff=lfs merge=lfs -text
|
50 |
+
assets/examples/resized_obj_11.jpg filter=lfs diff=lfs merge=lfs -text
|
51 |
+
assets/examples/dog.jpeg filter=lfs diff=lfs merge=lfs -text
|
52 |
+
assets/examples/grasslands-national-park.jpeg filter=lfs diff=lfs merge=lfs -text
|
53 |
+
assets/examples/resized_obj_38.jpg filter=lfs diff=lfs merge=lfs -text
|
54 |
+
assets/examples/chair1.jpeg filter=lfs diff=lfs merge=lfs -text
|
55 |
+
assets/examples/chair.jpeg filter=lfs diff=lfs merge=lfs -text
|
56 |
+
assets/examples/obj_38.jpg filter=lfs diff=lfs merge=lfs -text
|
57 |
+
assets/examples/ran.webp filter=lfs diff=lfs merge=lfs -text
|
58 |
+
assets/examples/anm_38.jpg filter=lfs diff=lfs merge=lfs -text
|
59 |
+
assets/examples/carpet2.webp filter=lfs diff=lfs merge=lfs -text
|
60 |
+
assets/ironman.webp filter=lfs diff=lfs merge=lfs -text
|
61 |
+
assets/truck2.jpeg filter=lfs diff=lfs merge=lfs -text
|
62 |
+
assets/truck.png filter=lfs diff=lfs merge=lfs -text
|
63 |
+
assets/ski.jpg filter=lfs diff=lfs merge=lfs -text
|
64 |
+
assets/Teaser_Small.png filter=lfs diff=lfs merge=lfs -text
|
65 |
+
assets/examples filter=lfs diff=lfs merge=lfs -text
|
66 |
+
assets/GIF.gif filter=lfs diff=lfs merge=lfs -text
|
67 |
+
assets/hulk.jpeg filter=lfs diff=lfs merge=lfs -text
|
68 |
+
assets/lava.jpg filter=lfs diff=lfs merge=lfs -text
|
app.py
CHANGED
@@ -1,429 +1,484 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
import einops
|
4 |
import gradio as gr
|
5 |
-
|
6 |
-
import torch
|
7 |
-
import random
|
8 |
-
import os
|
9 |
-
import subprocess
|
10 |
-
import shlex
|
11 |
-
|
12 |
-
from huggingface_hub import hf_hub_url, hf_hub_download
|
13 |
-
from share import *
|
14 |
-
|
15 |
-
from pytorch_lightning import seed_everything
|
16 |
-
from annotator.util import resize_image, HWC3
|
17 |
-
from annotator.OneFormer import OneformerSegmenter
|
18 |
-
from cldm.model import create_model, load_state_dict
|
19 |
-
from cldm.ddim_hacked import DDIMSamplerSpaCFG
|
20 |
-
from ldm.models.autoencoder import DiagonalGaussianDistribution
|
21 |
-
|
22 |
-
urls = {
|
23 |
-
'shi-labs/oneformer_coco_swin_large': ['150_16_swin_l_oneformer_coco_100ep.pth'],
|
24 |
-
'PAIR/PAIR-diffusion-sdv15-coco-finetune': ['pair_diffusion_epoch62.ckpt']
|
25 |
-
}
|
26 |
-
|
27 |
-
WTS_DICT = {
|
28 |
-
|
29 |
-
}
|
30 |
-
|
31 |
-
if os.path.exists('checkpoints') == False:
|
32 |
-
os.mkdir('checkpoints')
|
33 |
-
for repo in urls:
|
34 |
-
files = urls[repo]
|
35 |
-
for file in files:
|
36 |
-
url = hf_hub_url(repo, file)
|
37 |
-
name_ckp = url.split('/')[-1]
|
38 |
-
WTS_DICT[repo] = hf_hub_download(repo_id=repo, filename=file, token=os.environ.get("ACCESS_TOKEN"))
|
39 |
-
|
40 |
-
print(WTS_DICT)
|
41 |
-
apply_segmentor = OneformerSegmenter(WTS_DICT['shi-labs/oneformer_coco_swin_large'])
|
42 |
-
|
43 |
-
model = create_model('./configs/sap_fixed_hintnet_v15.yaml').cpu()
|
44 |
-
model.load_state_dict(load_state_dict(WTS_DICT['PAIR/PAIR-diffusion-sdv15-coco-finetune'], location='cuda'))
|
45 |
-
model = model.cuda()
|
46 |
-
ddim_sampler = DDIMSamplerSpaCFG(model)
|
47 |
-
_COLORS = []
|
48 |
-
save_memory = False
|
49 |
-
|
50 |
-
def gen_color():
|
51 |
-
color = tuple(np.round(np.random.choice(range(256), size=3), 3))
|
52 |
-
if color not in _COLORS and np.mean(color) != 0.0:
|
53 |
-
_COLORS.append(color)
|
54 |
-
else:
|
55 |
-
gen_color()
|
56 |
-
|
57 |
-
|
58 |
-
for _ in range(300):
|
59 |
-
gen_color()
|
60 |
-
|
61 |
|
62 |
-
|
63 |
-
def __init__(self, edit_operation):
|
64 |
-
self.input_img = None
|
65 |
-
self.input_pmask = None
|
66 |
-
self.input_segmask = None
|
67 |
|
68 |
-
|
69 |
-
|
70 |
-
self.ref_segmask = None
|
71 |
-
|
72 |
-
self.H = None
|
73 |
-
self.W = None
|
74 |
-
self.baseoutput = None
|
75 |
-
self.kernel = np.ones((5, 5), np.uint8)
|
76 |
-
self.edit_operation = edit_operation
|
77 |
-
|
78 |
-
def init_input_canvas(self, img):
|
79 |
-
img = HWC3(img)
|
80 |
-
img = resize_image(img, 512)
|
81 |
-
detected_mask = apply_segmentor(img, 'panoptic')[0]
|
82 |
-
detected_seg = apply_segmentor(img, 'semantic')
|
83 |
|
84 |
-
|
85 |
-
|
86 |
-
self.input_segmask = detected_seg
|
87 |
-
self.H = img.shape[0]
|
88 |
-
self.W = img.shape[1]
|
89 |
-
|
90 |
-
detected_mask = detected_mask.cpu().numpy()
|
91 |
|
92 |
-
|
93 |
-
|
94 |
-
for i in uni:
|
95 |
-
color_mask[detected_mask == i] = _COLORS[i]
|
96 |
|
97 |
-
|
98 |
-
|
99 |
-
return self.baseoutput
|
100 |
-
|
101 |
-
def init_ref_canvas(self, img):
|
102 |
-
img = HWC3(img)
|
103 |
-
img = resize_image(img, 512)
|
104 |
-
detected_mask = apply_segmentor(img, 'panoptic')[0]
|
105 |
-
detected_seg = apply_segmentor(img, 'semantic')
|
106 |
-
|
107 |
-
self.ref_img = img
|
108 |
-
self.ref_pmask = detected_mask
|
109 |
-
self.ref_segmask = detected_seg
|
110 |
-
|
111 |
-
detected_mask = detected_mask.cpu().numpy()
|
112 |
-
|
113 |
-
uni = np.unique(detected_mask)
|
114 |
-
color_mask = np.zeros((detected_mask.shape[0], detected_mask.shape[1], 3))
|
115 |
-
for i in uni:
|
116 |
-
color_mask[detected_mask == i] = _COLORS[i]
|
117 |
-
|
118 |
-
output = color_mask*0.8 + img * 0.2
|
119 |
-
self.baseoutput = output.astype(np.uint8)
|
120 |
-
return self.baseoutput
|
121 |
-
|
122 |
-
def _process_mask(self, mask, panoptic_mask, segmask):
|
123 |
-
panoptic_mask_ = panoptic_mask + 1
|
124 |
-
mask_ = resize_image(mask['mask'][:, :, 0], min(panoptic_mask.shape))
|
125 |
-
mask_ = torch.tensor(mask_)
|
126 |
-
maski = torch.zeros_like(mask_).cuda()
|
127 |
-
maski[mask_ > 127] = 1
|
128 |
-
mask = maski * panoptic_mask_
|
129 |
-
unique_ids, counts = torch.unique(mask, return_counts=True)
|
130 |
-
mask_id = unique_ids[torch.argmax(counts[1:]) + 1]
|
131 |
-
final_mask = torch.zeros(mask.shape).cuda()
|
132 |
-
final_mask[panoptic_mask_ == mask_id] = 1
|
133 |
-
|
134 |
-
obj_class = maski * (segmask + 1)
|
135 |
-
unique_ids, counts = torch.unique(obj_class, return_counts=True)
|
136 |
-
obj_class = unique_ids[torch.argmax(counts[1:]) + 1] - 1
|
137 |
-
return final_mask, obj_class
|
138 |
-
|
139 |
-
|
140 |
-
def _edit_app(self, input_mask, ref_mask, whole_ref):
|
141 |
-
input_pmask = self.input_pmask
|
142 |
-
input_segmask = self.input_segmask
|
143 |
|
144 |
-
|
145 |
-
|
146 |
-
else:
|
147 |
-
reference_mask, _ = self._process_mask(ref_mask, self.ref_pmask, self.ref_segmask)
|
148 |
|
149 |
-
|
150 |
-
|
151 |
-
input_pmask[edit_mask == 1] = ma + 1
|
152 |
-
return reference_mask, input_pmask, input_segmask, edit_mask, ma
|
153 |
|
154 |
-
|
155 |
-
|
156 |
-
input_img = (self.input_img/127.5 - 1)
|
157 |
-
input_img = torch.from_numpy(input_img.astype(np.float32)).cuda().unsqueeze(0).permute(0,3,1,2)
|
158 |
|
159 |
-
|
160 |
-
|
161 |
|
162 |
-
|
|
|
163 |
|
164 |
-
|
165 |
-
|
|
|
|
|
|
|
166 |
|
167 |
-
|
168 |
-
|
|
|
|
|
|
|
|
|
169 |
|
170 |
-
if mean_feat_ref.shape[1] > 1:
|
171 |
-
mean_feat_inpt[:, ma + 1] = (1 - inter) * mean_feat_inpt[:, ma + 1] + inter*mean_feat_ref[:, 1]
|
172 |
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
178 |
|
|
|
179 |
|
180 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
181 |
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
whole_ref=whole_ref, inter=inter)
|
187 |
-
|
188 |
-
null_structure = torch.zeros(structure.shape).cuda() - 1
|
189 |
-
null_appearance = torch.zeros(appearance.shape).cuda()
|
190 |
-
|
191 |
-
null_control = torch.cat([null_structure, null_appearance], dim=1)
|
192 |
-
structure_control = torch.cat([structure, null_appearance], dim=1)
|
193 |
-
full_control = torch.cat([structure, appearance], dim=1)
|
194 |
-
|
195 |
-
null_control = torch.cat([null_control for _ in range(num_samples)], dim=0)
|
196 |
-
structure_control = torch.cat([structure_control for _ in range(num_samples)], dim=0)
|
197 |
-
full_control = torch.cat([full_control for _ in range(num_samples)], dim=0)
|
198 |
-
|
199 |
-
#Masking for local edit
|
200 |
-
if not masking:
|
201 |
-
mask, x0 = None, None
|
202 |
-
else:
|
203 |
-
x0 = model.encode_first_stage(img)
|
204 |
-
x0 = x0.sample() if isinstance(x0, DiagonalGaussianDistribution) else x0 # todo: check if we can set random number
|
205 |
-
x0 = x0 * model.scale_factor
|
206 |
-
mask = 1 - torch.tensor(mask).unsqueeze(0).unsqueeze(1).cuda()
|
207 |
-
mask = torch.nn.functional.interpolate(mask, x0.shape[2:]).float()
|
208 |
-
|
209 |
-
if seed == -1:
|
210 |
-
seed = random.randint(0, 65535)
|
211 |
-
seed_everything(seed)
|
212 |
|
213 |
-
|
214 |
-
print(scale)
|
215 |
-
if save_memory:
|
216 |
-
model.low_vram_shift(is_diffusing=False)
|
217 |
-
# uc_cross = model.get_unconditional_conditioning(num_samples)
|
218 |
-
uc_cross = model.get_learned_conditioning([n_prompt] * num_samples)
|
219 |
-
cond = {"c_concat": [full_control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
|
220 |
-
un_cond = {"c_concat": None if guess_mode else [null_control], "c_crossattn": [uc_cross]}
|
221 |
-
un_cond_struct = {"c_concat": None if guess_mode else [structure_control], "c_crossattn": [uc_cross]}
|
222 |
-
un_cond_struct_app = {"c_concat": None if guess_mode else [full_control], "c_crossattn": [uc_cross]}
|
223 |
|
224 |
-
|
|
|
225 |
|
226 |
-
|
227 |
-
|
228 |
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
|
235 |
-
|
236 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
237 |
|
238 |
-
|
239 |
-
x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c')).cpu().numpy().clip(0, 255).astype(np.uint8)
|
240 |
|
241 |
-
|
242 |
-
|
|
|
|
|
|
|
|
|
243 |
|
244 |
|
245 |
-
def
|
246 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
247 |
|
248 |
-
|
249 |
-
|
|
|
|
|
|
|
|
|
|
|
250 |
|
251 |
-
|
252 |
-
|
|
|
253 |
|
|
|
|
|
|
|
254 |
|
|
|
255 |
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
.
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
.
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
.
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
|
282 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
283 |
|
|
|
284 |
with gr.Row():
|
285 |
-
gr.Markdown("##
|
286 |
with gr.Row():
|
287 |
gr.HTML(
|
288 |
"""
|
289 |
-
<div
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
<div class="image">
|
301 |
-
<img src="file/assets/GIF.gif" width="400"">
|
302 |
-
</div>
|
303 |
-
</div>
|
304 |
-
""")
|
305 |
with gr.Column():
|
306 |
with gr.Row():
|
307 |
img_edit = gr.State(ImageComp('edit_app'))
|
308 |
with gr.Column():
|
309 |
-
btn1 = gr.Button("Input Image")
|
310 |
input_image = gr.Image(source='upload', label='Input Image', type="numpy",)
|
311 |
with gr.Column():
|
312 |
-
|
313 |
-
input_mask = gr.Image(source="upload", label='Select Object in Input Image', type="numpy", tool="sketch")
|
314 |
-
input_image.change(fn=init_input_canvas_wrapper, inputs=[img_edit, input_image], outputs=[input_mask], queue=False)
|
315 |
-
|
316 |
-
# with gr.Row():
|
317 |
-
with gr.Column():
|
318 |
-
btn3 = gr.Button("Reference Image")
|
319 |
-
ref_img = gr.Image(source='upload', label='Reference Image', type="numpy")
|
320 |
-
with gr.Column():
|
321 |
-
btn4 = gr.Button("Select Reference Object")
|
322 |
-
reference_mask = gr.Image(source="upload", label='Select Object in Refernce Image', type="numpy", tool="sketch")
|
323 |
|
324 |
-
ref_img.change(fn=init_ref_canvas_wrapper, inputs=[img_edit, ref_img], outputs=[reference_mask], queue=False)
|
325 |
-
|
326 |
with gr.Row():
|
327 |
-
prompt = gr.Textbox(label="Prompt", value='
|
328 |
-
|
329 |
-
|
330 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
331 |
with gr.Row():
|
332 |
run_button = gr.Button(label="Run")
|
|
|
333 |
|
334 |
with gr.Row():
|
335 |
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=4, height='auto')
|
336 |
|
337 |
with gr.Accordion("Advanced options", open=False):
|
338 |
-
num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=
|
|
|
339 |
strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
|
340 |
guess_mode = gr.Checkbox(label='Guess Mode', value=False)
|
341 |
ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
|
342 |
-
|
343 |
-
|
344 |
-
|
|
|
345 |
seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
|
346 |
eta = gr.Number(label="eta (DDIM)", value=0.0)
|
347 |
masking = gr.Checkbox(label='Only edit the local region', value=True)
|
348 |
a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed')
|
349 |
n_prompt = gr.Textbox(label="Negative Prompt",
|
350 |
value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
|
351 |
-
|
|
|
|
|
|
|
352 |
with gr.Column():
|
353 |
gr.Examples(
|
354 |
-
examples=[['
|
355 |
-
['
|
356 |
-
['
|
357 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
358 |
outputs=None,
|
359 |
fn=None,
|
360 |
cache_examples=False,
|
361 |
)
|
362 |
-
ips = [input_mask,
|
363 |
-
scale_s, scale_f, scale_t, seed, eta,
|
|
|
|
|
364 |
run_button.click(fn=process_wrapper, inputs=[img_edit, *ips], outputs=[result_gallery])
|
|
|
365 |
|
366 |
|
367 |
-
|
368 |
-
def create_struct_demo():
|
369 |
with gr.Row():
|
370 |
-
gr.Markdown("##
|
371 |
-
|
372 |
-
def create_both_demo():
|
373 |
with gr.Row():
|
374 |
-
gr.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
375 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
376 |
|
377 |
|
378 |
-
block = gr.Blocks(css=css).queue()
|
379 |
with block:
|
380 |
gr.HTML(
|
381 |
"""
|
382 |
<div style="text-align: center; max-width: 1200px; margin: 20px auto;">
|
383 |
<h1 style="font-weight: 900; font-size: 3rem; margin: 0rem">
|
384 |
-
PAIR Diffusion
|
385 |
</h1>
|
386 |
-
<
|
387 |
-
<a href="https://vidit98.github.io/" style="color:blue;">Vidit Goel</a><sup>1*</sup>,
|
388 |
-
<a href="https://helia95.github.io/" style="color:blue;">Elia Peruzzo</a><sup>1,2*</sup>,
|
389 |
-
<a href="https://yifanjiang19.github.io/" style="color:blue;">Yifan Jiang</a><sup>3</sup>,
|
390 |
-
<a href="https://ir1d.github.io/" style="color:blue;">Dejia Xu</a><sup>3</sup>,
|
391 |
-
<a href="http://disi.unitn.it/~sebe/" style="color:blue;">Nicu Sebe</a><sup>2</sup>, <br>
|
392 |
-
<a href=" https://people.eecs.berkeley.edu/~trevor/" style="color:blue;">Trevor Darrell</a><sup>4</sup>,
|
393 |
-
<a href="https://vita-group.github.io/" style="color:blue;">Zhangyang Wang</a><sup>1,3</sup>
|
394 |
-
and <a href="https://www.humphreyshi.com/home" style="color:blue;">Humphrey Shi</a> <sup>1,5,6</sup> <br>
|
395 |
-
[<a href="https://arxiv.org/abs/2303.17546" style="color:red;">arXiv</a>]
|
396 |
-
[<a href="https://github.com/Picsart-AI-Research/PAIR-Diffusion" style="color:red;">GitHub</a>]
|
397 |
-
</h2>
|
398 |
-
<h3 style="font-weight: 450; font-size: 1rem; margin: 0rem">
|
399 |
-
<sup>1</sup>Picsart AI Resarch (PAIR), <sup>2</sup>UTrenton, <sup>3</sup>UT Austin, <sup>4</sup>UC Berkeley, <sup>5</sup>UOregon, <sup>6</sup>UIUC
|
400 |
-
</h3>
|
401 |
<h2 style="font-weight: 450; font-size: 1rem; margin-top: 0.8rem; margin-bottom: 0.8rem">
|
402 |
-
|
403 |
-
|
404 |
-
|
405 |
-
|
406 |
-
|
407 |
-
where we need to have consistent appearance across time in case of video or across various viewing positions in case of 3D.
|
408 |
</h2>
|
409 |
-
|
410 |
</div>
|
411 |
""")
|
412 |
|
413 |
-
gr.HTML("""
|
414 |
-
<p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings.
|
415 |
-
<br/>
|
416 |
-
<a href="https://huggingface.co/spaces/PAIR/PAIR-Diffusion?duplicate=true">
|
417 |
-
<img style="margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
|
418 |
-
</p>""")
|
419 |
-
|
420 |
with gr.Tab('Edit Appearance'):
|
421 |
create_app_demo()
|
422 |
-
with gr.Tab('
|
423 |
-
|
424 |
-
with gr.Tab('
|
425 |
-
|
426 |
-
|
|
|
427 |
|
428 |
block.queue(max_size=20)
|
429 |
-
block.launch(
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
+
from pair_diff_demo import ImageComp
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
|
4 |
+
# torch.cuda.set_per_process_memory_fraction(0.6)
|
|
|
|
|
|
|
|
|
5 |
|
6 |
+
def init_input_canvas_wrapper(obj, *args):
|
7 |
+
return obj.init_input_canvas(*args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
+
def init_ref_canvas_wrapper(obj, *args):
|
10 |
+
return obj.init_ref_canvas(*args)
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
+
def select_input_object_wrapper(obj, evt: gr.SelectData):
|
13 |
+
return obj.select_input_object(evt)
|
|
|
|
|
14 |
|
15 |
+
def select_ref_object_wrapper(obj, evt: gr.SelectData):
|
16 |
+
return obj.select_ref_object(evt)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
+
def process_wrapper(obj, *args):
|
19 |
+
return obj.process(*args)
|
|
|
|
|
20 |
|
21 |
+
def set_multi_modal_wrapper(obj, *args):
|
22 |
+
return obj.set_multi_modal(*args)
|
|
|
|
|
23 |
|
24 |
+
def save_result_wrapper(obj, *args):
|
25 |
+
return obj.save_result(*args)
|
|
|
|
|
26 |
|
27 |
+
def return_input_img_wrapper(obj):
|
28 |
+
return obj.return_input_img()
|
29 |
|
30 |
+
def get_caption_wrapper(obj, *args):
|
31 |
+
return obj.get_caption(*args)
|
32 |
|
33 |
+
def multimodal_params(b):
|
34 |
+
if b:
|
35 |
+
return 10, 3, 6
|
36 |
+
else:
|
37 |
+
return 6, 8, 9
|
38 |
|
39 |
+
theme = gr.themes.Soft(
|
40 |
+
primary_hue="purple",
|
41 |
+
font_mono=[gr.themes.GoogleFont("IBM Plex Mono"), "ui-monospace", "Consolas", 'monospace'],
|
42 |
+
).set(
|
43 |
+
block_label_background_fill_dark='*neutral_800'
|
44 |
+
)
|
45 |
|
|
|
|
|
46 |
|
47 |
+
css = """
|
48 |
+
#customized_imbox {
|
49 |
+
min-height: 450px;
|
50 |
+
}
|
51 |
+
#customized_imbox>div[data-testid="image"] {
|
52 |
+
min-height: 450px;
|
53 |
+
}
|
54 |
+
#customized_imbox>div[data-testid="image"]>div {
|
55 |
+
min-height: 450px;
|
56 |
+
}
|
57 |
+
#customized_imbox>div[data-testid="image"]>iframe {
|
58 |
+
min-height: 450px;
|
59 |
+
}
|
60 |
+
#customized_imbox>div.unpadded_box {
|
61 |
+
min-height: 450px;
|
62 |
+
}
|
63 |
+
#myinst {
|
64 |
+
font-size: 0.8rem;
|
65 |
+
margin: 0rem;
|
66 |
+
color: #6B7280;
|
67 |
+
}
|
68 |
+
#maskinst {
|
69 |
+
text-align: justify;
|
70 |
+
min-width: 1200px;
|
71 |
+
}
|
72 |
+
#maskinst>img {
|
73 |
+
min-width:399px;
|
74 |
+
max-width:450px;
|
75 |
+
vertical-align: top;
|
76 |
+
display: inline-block;
|
77 |
+
}
|
78 |
+
#maskinst:after {
|
79 |
+
content: "";
|
80 |
+
width: 100%;
|
81 |
+
display: inline-block;
|
82 |
+
}
|
83 |
+
"""
|
84 |
|
85 |
+
def create_app_demo():
|
86 |
|
87 |
+
with gr.Row():
|
88 |
+
gr.Markdown("## Object Level Appearance Editing")
|
89 |
+
with gr.Row():
|
90 |
+
gr.HTML(
|
91 |
+
"""
|
92 |
+
<div style="text-align: left; max-width: 1200px;">
|
93 |
+
<h3 style="font-weight: 450; font-size: 1rem; margin-top: 0.8rem; margin-bottom: 0.8rem">
|
94 |
+
Instructions </h3>
|
95 |
+
<ol>
|
96 |
+
<li>Upload an Input Image.</li>
|
97 |
+
<li>Mark one of segmented objects in the <i>Select Object to Edit</i> tab.</li>
|
98 |
+
<li>Upload an Reference Image.</li>
|
99 |
+
<li>Mark one of segmented objects in the <i>Select Reference Object</i> tab, whose appearance needs to used in the selected input object.</li>
|
100 |
+
<li>Enter a prompt and press <i>Run</i> button. (A very simple would also work) </li>
|
101 |
+
</ol>
|
102 |
+
</ol>
|
103 |
+
</div>""")
|
104 |
+
with gr.Column():
|
105 |
+
with gr.Row():
|
106 |
+
img_edit = gr.State(ImageComp('edit_app'))
|
107 |
+
with gr.Column():
|
108 |
+
input_image = gr.Image(source='upload', label='Input Image', type="numpy",)
|
109 |
+
with gr.Column():
|
110 |
+
input_mask = gr.Image(source="upload", label='Select Object in Input Image', type="numpy",)
|
111 |
+
|
112 |
+
with gr.Column():
|
113 |
+
ref_img = gr.Image(source='upload', label='Reference Image', type="numpy")
|
114 |
+
with gr.Column():
|
115 |
+
reference_mask = gr.Image(source="upload", label='Select Object in Refernce Image', type="numpy")
|
116 |
|
117 |
+
with gr.Row():
|
118 |
+
with gr.Column():
|
119 |
+
prompt = gr.Textbox(label="Prompt", value='A picture of truck')
|
120 |
+
mulitmod = gr.Checkbox(label='Multi-Modal', value=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
|
122 |
+
mulitmod.change(fn=set_multi_modal_wrapper, inputs=[img_edit, mulitmod])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
|
124 |
+
input_image.change(fn=init_input_canvas_wrapper, inputs=[img_edit, input_image], outputs=[input_image], show_progress=True)
|
125 |
+
input_image.select(fn=select_input_object_wrapper, inputs=[img_edit], outputs=[input_mask, prompt])
|
126 |
|
127 |
+
ref_img.change(fn=init_ref_canvas_wrapper, inputs=[img_edit, ref_img], outputs=[ref_img], show_progress=True)
|
128 |
+
ref_img.select(fn=select_ref_object_wrapper, inputs=[img_edit], outputs=[reference_mask])
|
129 |
|
130 |
+
with gr.Column():
|
131 |
+
interpolation = gr.Slider(label="Mixing ratio of appearance from reference object", minimum=0.1, maximum=1, value=1.0, step=0.1)
|
132 |
+
whole_ref = gr.Checkbox(label='Use whole reference Image for appearance (Only useful for style transfers)', visible=False)
|
133 |
+
|
134 |
+
# clear_button.click(fn=img_edit.clear_points, inputs=[], outputs=[input_mask, reference_mask])
|
135 |
|
136 |
+
with gr.Row():
|
137 |
+
run_button = gr.Button(label="Run")
|
138 |
+
save_button = gr.Button("Save")
|
139 |
+
|
140 |
+
with gr.Row():
|
141 |
+
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=4, height='auto')
|
142 |
+
|
143 |
+
with gr.Accordion("Advanced options", open=False):
|
144 |
+
num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=4, step=1)
|
145 |
+
image_resolution = gr.Slider(label="Image Resolution", minimum=512, maximum=512, value=512, step=64)
|
146 |
+
strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
|
147 |
+
guess_mode = gr.Checkbox(label='Guess Mode', value=False)
|
148 |
+
ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
|
149 |
+
scale_t = gr.Slider(label="Guidance Scale Text", minimum=0., maximum=30.0, value=6.0, step=0.1)
|
150 |
+
scale_f = gr.Slider(label="Guidance Scale Appearance", minimum=0., maximum=30.0, value=8.0, step=0.1)
|
151 |
+
scale_s = gr.Slider(label="Guidance Scale Structure", minimum=0., maximum=30.0, value=9.0, step=0.1)
|
152 |
+
seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
|
153 |
+
eta = gr.Number(label="eta (DDIM)", value=0.0)
|
154 |
+
masking = gr.Checkbox(label='Only edit the local region', value=True)
|
155 |
+
a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed')
|
156 |
+
n_prompt = gr.Textbox(label="Negative Prompt",
|
157 |
+
value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
|
158 |
+
dil = gr.Slider(label="Merging region around Edge", minimum=0, maximum=0, value=0, step=0)
|
159 |
+
|
160 |
+
with gr.Column():
|
161 |
+
gr.Examples(
|
162 |
+
examples=[['assets/examples/car.jpeg','assets/examples/ian.jpeg', '', 709736989, 6, 8, 9],
|
163 |
+
['assets/examples/ian.jpeg','assets/examples/car.jpeg', '', 709736989, 6, 8, 9],
|
164 |
+
['assets/examples/car.jpeg','assets/examples/ran.webp', '', 709736989, 6, 8, 9],
|
165 |
+
['assets/examples/car.jpeg','assets/examples/car1.webp', '', 709736989, 6, 8, 9],
|
166 |
+
['assets/examples/car1.webp','assets/examples/car.jpeg', '', 709736989, 6, 8, 9],
|
167 |
+
['assets/examples/chair.jpeg','assets/examples/chair1.jpeg', '', 1106204668, 6, 8, 9],
|
168 |
+
['assets/examples/house.jpeg','assets/examples/house2.jpeg', '', 1106204668, 6, 8, 9],
|
169 |
+
['assets/examples/house2.jpeg','assets/examples/house.jpeg', '', 1106204668, 6, 8, 9],
|
170 |
+
['assets/examples/park.webp','assets/examples/grasslands-national-park.jpeg', '', 1106204668, 6, 8, 9],
|
171 |
+
['assets/examples/door.jpeg','assets/examples/door2.jpeg', '', 709736989, 6, 8, 9]],
|
172 |
+
inputs=[input_image, ref_img, prompt, seed, scale_t, scale_f, scale_s],
|
173 |
+
cache_examples=False,
|
174 |
+
)
|
175 |
|
176 |
+
mulitmod.change(fn=multimodal_params, inputs=[mulitmod], outputs=[scale_t, scale_f, scale_s])
|
|
|
177 |
|
178 |
+
ips = [input_mask, reference_mask, prompt, a_prompt, n_prompt, num_samples, ddim_steps, guess_mode, strength,
|
179 |
+
scale_s, scale_f, scale_t, seed, eta, dil, masking, whole_ref, interpolation]
|
180 |
+
ips_save = [input_mask, prompt, a_prompt, n_prompt, ddim_steps,
|
181 |
+
scale_s, scale_f, scale_t, seed, dil, interpolation]
|
182 |
+
run_button.click(fn=process_wrapper, inputs=[img_edit, *ips], outputs=[result_gallery])
|
183 |
+
save_button.click(fn=save_result_wrapper, inputs=[img_edit, *ips_save])
|
184 |
|
185 |
|
186 |
+
def create_add_obj_demo():
|
187 |
+
with gr.Row():
|
188 |
+
gr.Markdown("## Add Objects to Image")
|
189 |
+
with gr.Row():
|
190 |
+
gr.HTML(
|
191 |
+
"""
|
192 |
+
<div style="text-align: left; max-width: 1200px;">
|
193 |
+
<h3 style="font-weight: 450; font-size: 1rem; margin-top: 0.8rem; margin-bottom: 0.8rem">
|
194 |
+
Instructions </h3>
|
195 |
+
<ol>
|
196 |
+
<li> Upload an Input Image.</li>
|
197 |
+
<li>Draw the precise shape of object in the image where you want to add object in <i>Draw Object</i> tab.</li>
|
198 |
+
<li>Upload an Reference Image.</li>
|
199 |
+
<li>Click on the object in the Reference Image tab that you want to add in the Input Image.</li>
|
200 |
+
<li>Enter a prompt and press <i>Run</i> button. (A very simple would also work) </li>
|
201 |
+
</ol>
|
202 |
+
</ol>
|
203 |
+
</div>""")
|
204 |
+
with gr.Column():
|
205 |
+
with gr.Row():
|
206 |
+
img_edit = gr.State(ImageComp('add_obj'))
|
207 |
+
with gr.Column():
|
208 |
+
input_image = gr.Image(source='upload', label='Input Image', type="numpy",)
|
209 |
+
with gr.Column():
|
210 |
+
input_mask = gr.Image(source="upload", label='Draw the desired Object', type="numpy", tool="sketch")
|
211 |
|
212 |
+
input_image.change(fn=init_input_canvas_wrapper, inputs=[img_edit, input_image], outputs=[input_image])
|
213 |
+
input_image.change(fn=return_input_img_wrapper, inputs=[img_edit], outputs=[input_mask], queue=False)
|
214 |
+
|
215 |
+
with gr.Column():
|
216 |
+
ref_img = gr.Image(source='upload', label='Reference Image', type="numpy")
|
217 |
+
with gr.Column():
|
218 |
+
reference_mask = gr.Image(source="upload", label='Selected Object in Refernce Image', type="numpy")
|
219 |
|
220 |
+
ref_img.change(fn=init_ref_canvas_wrapper, inputs=[img_edit, ref_img], outputs=[ref_img], queue=False)
|
221 |
+
# ref_img.upload(fn=img_edit.init_ref_canvas, inputs=[ref_img], outputs=[ref_img])
|
222 |
+
ref_img.select(fn=select_ref_object_wrapper, inputs=[img_edit], outputs=[reference_mask])
|
223 |
|
224 |
+
with gr.Row():
|
225 |
+
prompt = gr.Textbox(label="Prompt", value='A picture of truck')
|
226 |
+
mulitmod = gr.Checkbox(label='Multi-Modal', value=False, visible=False)
|
227 |
|
228 |
+
mulitmod.change(fn=set_multi_modal_wrapper, inputs=[img_edit, mulitmod])
|
229 |
|
230 |
+
with gr.Row():
|
231 |
+
run_button = gr.Button(label="Run")
|
232 |
+
save_button = gr.Button("Save")
|
233 |
+
|
234 |
+
with gr.Row():
|
235 |
+
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=4, height='auto')
|
236 |
+
|
237 |
+
with gr.Accordion("Advanced options", open=False):
|
238 |
+
num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=4, step=1)
|
239 |
+
# image_resolution = gr.Slider(label="Image Resolution", minimum=512, maximum=512, value=512, step=64)
|
240 |
+
strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
|
241 |
+
guess_mode = gr.Checkbox(label='Guess Mode', value=False)
|
242 |
+
ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
|
243 |
+
dil = gr.Slider(label="Merging region around Edge", minimum=0, maximum=5, value=2, step=1)
|
244 |
+
scale_t = gr.Slider(label="Guidance Scale Text", minimum=0., maximum=30.0, value=6.0, step=0.1)
|
245 |
+
scale_f = gr.Slider(label="Guidance Scale Appearance", minimum=0., maximum=30.0, value=8.0, step=0.1)
|
246 |
+
scale_s = gr.Slider(label="Guidance Scale Structure", minimum=0., maximum=30.0, value=9.0, step=0.1)
|
247 |
+
seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
|
248 |
+
eta = gr.Number(label="eta (DDIM)", value=0.0)
|
249 |
+
masking = gr.Checkbox(label='Only edit the local region', value=True)
|
250 |
+
a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed')
|
251 |
+
n_prompt = gr.Textbox(label="Negative Prompt",
|
252 |
+
value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
|
253 |
+
|
254 |
+
mulitmod.change(fn=multimodal_params, inputs=[mulitmod], outputs=[scale_t, scale_f, scale_s])
|
255 |
|
256 |
+
with gr.Column():
|
257 |
+
gr.Examples(
|
258 |
+
examples=[['assets/examples/chair.jpeg','assets/examples/carpet2.webp', 'A picture of living room with carpet', 892905419, 6, 8, 9],
|
259 |
+
['assets/examples/chair.jpeg','assets/examples/chair1.jpeg', 'A picture of living room with a orange and white sofa', 892905419, 6, 8, 9],
|
260 |
+
['assets/examples/park.webp','assets/examples/dog.jpeg', 'A picture of dog in the park', 892905419, 6, 8, 9]],
|
261 |
+
inputs=[input_image, ref_img, prompt, seed, scale_t, scale_f, scale_s],
|
262 |
+
outputs=None,
|
263 |
+
fn=None,
|
264 |
+
cache_examples=False,
|
265 |
+
)
|
266 |
+
ips = [input_mask, reference_mask, prompt, a_prompt, n_prompt, num_samples, ddim_steps, guess_mode, strength,
|
267 |
+
scale_s, scale_f, scale_t, seed, eta, dil, masking]
|
268 |
+
ips_save = [input_mask, prompt, a_prompt, n_prompt, ddim_steps,
|
269 |
+
scale_s, scale_f, scale_t, seed, dil]
|
270 |
+
run_button.click(fn=process_wrapper, inputs=[img_edit, *ips], outputs=[result_gallery])
|
271 |
+
save_button.click(fn=save_result_wrapper, inputs=[img_edit, *ips_save])
|
272 |
|
273 |
+
def create_obj_variation_demo():
|
274 |
with gr.Row():
|
275 |
+
gr.Markdown("## Objects Variation")
|
276 |
with gr.Row():
|
277 |
gr.HTML(
|
278 |
"""
|
279 |
+
<div style="text-align: left; max-width: 1200px;">
|
280 |
+
<h3 style="font-weight: 450; font-size: 1rem; margin-top: 0.8rem; margin-bottom: 0.8rem">
|
281 |
+
Instructions </h3>
|
282 |
+
<ol>
|
283 |
+
<li> Upload an Input Image.</li>
|
284 |
+
<li>Click on object to have variations</li>
|
285 |
+
<li>Press <i>Run</i> button</li>
|
286 |
+
</ol>
|
287 |
+
</ol>
|
288 |
+
</div>""")
|
289 |
+
|
|
|
|
|
|
|
|
|
|
|
290 |
with gr.Column():
|
291 |
with gr.Row():
|
292 |
img_edit = gr.State(ImageComp('edit_app'))
|
293 |
with gr.Column():
|
|
|
294 |
input_image = gr.Image(source='upload', label='Input Image', type="numpy",)
|
295 |
with gr.Column():
|
296 |
+
input_mask = gr.Image(source="upload", label='Select Object in Input Image', type="numpy",)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
297 |
|
|
|
|
|
298 |
with gr.Row():
|
299 |
+
prompt = gr.Textbox(label="Prompt", value='')
|
300 |
+
mulitmod = gr.Checkbox(label='Multi-Modal', value=False)
|
301 |
+
|
302 |
+
|
303 |
+
mulitmod.change(fn=set_multi_modal_wrapper, inputs=[img_edit, mulitmod])
|
304 |
+
|
305 |
+
input_image.change(fn=init_input_canvas_wrapper, inputs=[img_edit, input_image], outputs=[input_image])
|
306 |
+
input_image.select(fn=select_input_object_wrapper, inputs=[img_edit], outputs=[input_mask, prompt])
|
307 |
+
input_image.change(fn=init_ref_canvas_wrapper, inputs=[img_edit, input_image], outputs=[], queue=False)
|
308 |
+
input_image.select(fn=select_ref_object_wrapper, inputs=[img_edit], outputs=[])
|
309 |
+
|
310 |
with gr.Row():
|
311 |
run_button = gr.Button(label="Run")
|
312 |
+
save_button = gr.Button("Save")
|
313 |
|
314 |
with gr.Row():
|
315 |
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=4, height='auto')
|
316 |
|
317 |
with gr.Accordion("Advanced options", open=False):
|
318 |
+
num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=4, step=2)
|
319 |
+
# image_resolution = gr.Slider(label="Image Resolution", minimum=512, maximum=512, value=512, step=64)
|
320 |
strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
|
321 |
guess_mode = gr.Checkbox(label='Guess Mode', value=False)
|
322 |
ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
|
323 |
+
dil = gr.Slider(label="Merging region around Edge", minimum=0, maximum=5, value=2, step=1)
|
324 |
+
scale_t = gr.Slider(label="Guidance Scale Text", minimum=0.0, maximum=30.0, value=6.0, step=0.1)
|
325 |
+
scale_f = gr.Slider(label="Guidance Scale Appearance", minimum=0.0, maximum=30.0, value=8.0, step=0.1)
|
326 |
+
scale_s = gr.Slider(label="Guidance Scale Structure", minimum=0.0, maximum=30.0, value=9.0, step=0.1)
|
327 |
seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
|
328 |
eta = gr.Number(label="eta (DDIM)", value=0.0)
|
329 |
masking = gr.Checkbox(label='Only edit the local region', value=True)
|
330 |
a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed')
|
331 |
n_prompt = gr.Textbox(label="Negative Prompt",
|
332 |
value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
|
333 |
+
|
334 |
+
|
335 |
+
mulitmod.change(fn=multimodal_params, inputs=[mulitmod], outputs=[scale_t, scale_f, scale_s])
|
336 |
+
|
337 |
with gr.Column():
|
338 |
gr.Examples(
|
339 |
+
examples=[['assets/examples/chair.jpeg' , 892905419, 6, 8, 9],
|
340 |
+
['assets/examples/chair1.jpeg', 892905419, 6, 8, 9],
|
341 |
+
['assets/examples/park.webp', 892905419, 6, 8, 9],
|
342 |
+
['assets/examples/car.jpeg', 709736989, 6, 8, 9],
|
343 |
+
['assets/examples/ian.jpeg', 709736989, 6, 8, 9],
|
344 |
+
['assets/examples/chair.jpeg', 1106204668, 6, 8, 9],
|
345 |
+
['assets/examples/door.jpeg', 709736989, 6, 8, 9],
|
346 |
+
['assets/examples/carpet2.webp', 892905419, 6, 8, 9],
|
347 |
+
['assets/examples/house.jpeg', 709736989, 6, 8, 9],
|
348 |
+
['assets/examples/house2.jpeg', 709736989, 6, 8, 9],],
|
349 |
+
inputs=[input_image, seed, scale_t, scale_f, scale_s],
|
350 |
outputs=None,
|
351 |
fn=None,
|
352 |
cache_examples=False,
|
353 |
)
|
354 |
+
ips = [input_mask, input_mask, prompt, a_prompt, n_prompt, num_samples, ddim_steps, guess_mode, strength,
|
355 |
+
scale_s, scale_f, scale_t, seed, eta, dil, masking]
|
356 |
+
ips_save = [input_mask, prompt, a_prompt, n_prompt, ddim_steps,
|
357 |
+
scale_s, scale_f, scale_t, seed, dil]
|
358 |
run_button.click(fn=process_wrapper, inputs=[img_edit, *ips], outputs=[result_gallery])
|
359 |
+
save_button.click(fn=save_result_wrapper, inputs=[img_edit, *ips_save])
|
360 |
|
361 |
|
362 |
+
def create_free_form_obj_variation_demo():
|
|
|
363 |
with gr.Row():
|
364 |
+
gr.Markdown("## Objects Variation")
|
|
|
|
|
365 |
with gr.Row():
|
366 |
+
gr.HTML(
|
367 |
+
"""
|
368 |
+
<div style="text-align: left; max-width: 1200px;">
|
369 |
+
<h3 style="font-weight: 450; font-size: 1rem; margin-top: 0.8rem; margin-bottom: 0.8rem">
|
370 |
+
Instructions </h3>
|
371 |
+
<ol>
|
372 |
+
<li> Upload an Input Image.</li>
|
373 |
+
<li>Mask the region that you want to have variation</li>
|
374 |
+
<li>Press <i>Run</i> button</li>
|
375 |
+
</ol>
|
376 |
+
</ol>
|
377 |
+
</div>""")
|
378 |
|
379 |
+
with gr.Column():
|
380 |
+
with gr.Row():
|
381 |
+
img_edit = gr.State(ImageComp('edit_app'))
|
382 |
+
with gr.Column():
|
383 |
+
input_image = gr.Image(source='upload', label='Input Image', type="numpy", )
|
384 |
+
with gr.Column():
|
385 |
+
input_mask = gr.Image(source="upload", label='Select Object in Input Image', type="numpy", tool="sketch")
|
386 |
+
|
387 |
+
with gr.Row():
|
388 |
+
prompt = gr.Textbox(label="Prompt", value='')
|
389 |
+
ignore_structure = gr.Checkbox(label='Ignore Structure (Please provide a good caption)', visible=False)
|
390 |
+
mulitmod = gr.Checkbox(label='Multi-Modal', value=False)
|
391 |
+
|
392 |
+
mulitmod.change(fn=set_multi_modal_wrapper, inputs=[img_edit, mulitmod])
|
393 |
+
|
394 |
+
input_image.change(fn=init_input_canvas_wrapper, inputs=[img_edit, input_image], outputs=[input_mask])
|
395 |
+
input_mask.edit(fn=get_caption_wrapper, inputs=[img_edit, input_mask], outputs=[prompt])
|
396 |
+
input_image.change(fn=init_ref_canvas_wrapper, inputs=[img_edit, input_image], outputs=[], queue=False)
|
397 |
+
# input_image.select(fn=select_ref_object_wrapper, inputs=[img_edit], outputs=[])
|
398 |
+
|
399 |
+
# input_image.edit(fn=img_edit.vis_mask, inputs=[input_image], outputs=[input_mask])
|
400 |
+
|
401 |
+
with gr.Row():
|
402 |
+
run_button = gr.Button(label="Run")
|
403 |
+
save_button = gr.Button("Save")
|
404 |
+
|
405 |
+
with gr.Row():
|
406 |
+
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=4, height='auto')
|
407 |
+
|
408 |
+
with gr.Accordion("Advanced options", open=False):
|
409 |
+
num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=4, step=2)
|
410 |
+
# image_resolution = gr.Slider(label="Image Resolution", minimum=512, maximum=512, value=512, step=64)
|
411 |
+
strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
|
412 |
+
guess_mode = gr.Checkbox(label='Guess Mode', value=False)
|
413 |
+
ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
|
414 |
+
dil = gr.Slider(label="Merging region around Edge", minimum=0, maximum=5, value=2, step=1)
|
415 |
+
scale_t = gr.Slider(label="Guidance Scale Text", minimum=0.0, maximum=30.0, value=6.0, step=0.1)
|
416 |
+
scale_f = gr.Slider(label="Guidance Scale Appearance", minimum=0.0, maximum=30.0, value=8.0, step=0.1)
|
417 |
+
scale_s = gr.Slider(label="Guidance Scale Structure", minimum=0.0, maximum=30.0, value=9.0, step=0.1)
|
418 |
+
seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
|
419 |
+
eta = gr.Number(label="eta (DDIM)", value=0.0)
|
420 |
+
masking = gr.Checkbox(label='Only edit the local region', value=True)
|
421 |
+
free_form_obj_var = gr.Checkbox(label='', value=True)
|
422 |
+
a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed')
|
423 |
+
n_prompt = gr.Textbox(label="Negative Prompt",
|
424 |
+
value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
|
425 |
+
interpolation = gr.Slider(label="Mixing ratio of appearance from reference object", minimum=0.0, maximum=0.1, step=0.1)
|
426 |
+
|
427 |
+
mulitmod.change(fn=multimodal_params, inputs=[mulitmod], outputs=[scale_t, scale_f, scale_s])
|
428 |
+
|
429 |
+
with gr.Column():
|
430 |
+
gr.Examples(
|
431 |
+
examples=[['assets/examples/chair.jpeg' , 892905419, 6, 8, 9],
|
432 |
+
['assets/examples/chair1.jpeg', 892905419, 6, 8, 9],
|
433 |
+
['assets/examples/park.webp', 892905419, 6, 8, 9],
|
434 |
+
['assets/examples/car.jpeg', 709736989, 6, 8, 9],
|
435 |
+
['assets/examples/ian.jpeg', 709736989, 6, 8, 9],
|
436 |
+
['assets/examples/chair.jpeg', 1106204668, 6, 8, 9],
|
437 |
+
['assets/examples/door.jpeg', 709736989, 6, 8, 9],
|
438 |
+
['assets/examples/carpet2.webp', 892905419, 6, 8, 9],
|
439 |
+
['assets/examples/house.jpeg', 709736989, 6, 8, 9],
|
440 |
+
['assets/examples/house2.jpeg', 709736989, 6, 8, 9],],
|
441 |
+
inputs=[input_image, seed, scale_t, scale_f, scale_s],
|
442 |
+
outputs=None,
|
443 |
+
fn=None,
|
444 |
+
cache_examples=False,
|
445 |
+
)
|
446 |
+
ips = [input_mask, input_mask, prompt, a_prompt, n_prompt, num_samples, ddim_steps, guess_mode, strength,
|
447 |
+
scale_s, scale_f, scale_t, seed, eta, dil, masking, free_form_obj_var, dil, free_form_obj_var, ignore_structure]
|
448 |
+
ips_save = [input_mask, prompt, a_prompt, n_prompt, ddim_steps,
|
449 |
+
scale_s, scale_f, scale_t, seed, dil, interpolation, free_form_obj_var]
|
450 |
+
run_button.click(fn=process_wrapper, inputs=[img_edit, *ips], outputs=[result_gallery])
|
451 |
+
save_button.click(fn=save_result_wrapper, inputs=[img_edit, *ips_save])
|
452 |
|
453 |
|
454 |
+
block = gr.Blocks(css=css, theme=theme).queue()
|
455 |
with block:
|
456 |
gr.HTML(
|
457 |
"""
|
458 |
<div style="text-align: center; max-width: 1200px; margin: 20px auto;">
|
459 |
<h1 style="font-weight: 900; font-size: 3rem; margin: 0rem">
|
460 |
+
PAIR Diffusion: A Comprehensive Multimodal Object-Level Image Editor
|
461 |
</h1>
|
462 |
+
<h3 style="margin-top: 0.6rem; margin-bottom: 1rem">Picsart AI Research</h3>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
463 |
<h2 style="font-weight: 450; font-size: 1rem; margin-top: 0.8rem; margin-bottom: 0.8rem">
|
464 |
+
PAIR diffusion provides comprehensive multi-modal editing capabilities to edit real images without the need of inverting them. The current suite contains
|
465 |
+
<span style="color: #01feee;">Object Variation</span>, <span style="color: #4f82d9;">Edit Appearance of any object using a reference image and text</span>,
|
466 |
+
<span style="color: #d402bf;">Add any object from a reference image in the input image</span>. This operations can be mixed with each other to
|
467 |
+
develop new editing operations in future.
|
468 |
+
</ul>
|
|
|
469 |
</h2>
|
|
|
470 |
</div>
|
471 |
""")
|
472 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
473 |
with gr.Tab('Edit Appearance'):
|
474 |
create_app_demo()
|
475 |
+
with gr.Tab('Object Variation Free Form Mask'):
|
476 |
+
create_free_form_obj_variation_demo()
|
477 |
+
with gr.Tab('Object Variation'):
|
478 |
+
create_obj_variation_demo()
|
479 |
+
with gr.Tab('Add Objects'):
|
480 |
+
create_add_obj_demo()
|
481 |
|
482 |
block.queue(max_size=20)
|
483 |
+
block.launch(share=True)
|
484 |
+
|
assets/GIF.gif
CHANGED
Git LFS Details
|
assets/Teaser_Small.png
ADDED
Git LFS Details
|
assets/examples/Lancia.webp
ADDED
Git LFS Details
|
assets/examples/car.jpeg
ADDED
Git LFS Details
|
assets/examples/car1.webp
ADDED
Git LFS Details
|
assets/examples/carpet2.webp
ADDED
Git LFS Details
|
assets/examples/chair.jpeg
ADDED
Git LFS Details
|
assets/examples/chair1.jpeg
ADDED
Git LFS Details
|
assets/examples/dog.jpeg
ADDED
Git LFS Details
|
assets/examples/door.jpeg
ADDED
Git LFS Details
|
assets/examples/door2.jpeg
ADDED
Git LFS Details
|
assets/examples/grasslands-national-park.jpeg
ADDED
Git LFS Details
|
assets/examples/house.jpeg
ADDED
Git LFS Details
|
assets/examples/house2.jpeg
ADDED
Git LFS Details
|
assets/examples/ian.jpeg
ADDED
Git LFS Details
|
assets/examples/park.webp
ADDED
Git LFS Details
|
assets/examples/ran.webp
ADDED
Git LFS Details
|
assets/hulk.jpeg
CHANGED
Git LFS Details
|
assets/ironman.webp
CHANGED
Git LFS Details
|
assets/lava.jpg
CHANGED
Git LFS Details
|
assets/ski.jpg
CHANGED
Git LFS Details
|
assets/truck.png
CHANGED
Git LFS Details
|
assets/truck2.jpeg
CHANGED
Git LFS Details
|
cldm/appearance_networks.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Neighborhood Attention Transformer.
|
3 |
+
https://arxiv.org/abs/2204.07143
|
4 |
+
|
5 |
+
This source code is licensed under the license found in the
|
6 |
+
LICENSE file in the root directory of this source tree.
|
7 |
+
"""
|
8 |
+
import torch
|
9 |
+
import torchvision
|
10 |
+
import torch.nn as nn
|
11 |
+
from timm.models.layers import trunc_normal_, DropPath
|
12 |
+
from timm.models.registry import register_model
|
13 |
+
|
14 |
+
|
15 |
+
IMAGENET_MEAN = [0.485, 0.456, 0.406]
|
16 |
+
IMAGENET_STD = [0.229, 0.224, 0.225]
|
17 |
+
|
18 |
+
class VGGPerceptualLoss(torch.nn.Module):
|
19 |
+
def __init__(self, resize=True):
|
20 |
+
super(VGGPerceptualLoss, self).__init__()
|
21 |
+
blocks = []
|
22 |
+
blocks.append(torchvision.models.vgg16(pretrained=True).features[:4].eval())
|
23 |
+
blocks.append(torchvision.models.vgg16(pretrained=True).features[4:9].eval())
|
24 |
+
blocks.append(torchvision.models.vgg16(pretrained=True).features[9:16].eval())
|
25 |
+
blocks.append(torchvision.models.vgg16(pretrained=True).features[16:23].eval())
|
26 |
+
for bl in blocks:
|
27 |
+
for p in bl.parameters():
|
28 |
+
p.requires_grad = False
|
29 |
+
self.blocks = torch.nn.ModuleList(blocks)
|
30 |
+
self.transform = torch.nn.functional.interpolate
|
31 |
+
self.resize = resize
|
32 |
+
self.register_buffer("mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
|
33 |
+
self.register_buffer("std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
|
34 |
+
|
35 |
+
def forward(self, input, appearance_layers=[0,1,2,3]):
|
36 |
+
if input.shape[1] != 3:
|
37 |
+
input = input.repeat(1, 3, 1, 1)
|
38 |
+
target = target.repeat(1, 3, 1, 1)
|
39 |
+
input = (input-self.mean) / self.std
|
40 |
+
if self.resize:
|
41 |
+
input = self.transform(input, mode='bilinear', size=(224, 224), align_corners=False)
|
42 |
+
x = input
|
43 |
+
feats = []
|
44 |
+
for i, block in enumerate(self.blocks):
|
45 |
+
x = block(x)
|
46 |
+
if i in appearance_layers:
|
47 |
+
feats.append(x)
|
48 |
+
|
49 |
+
return feats
|
50 |
+
|
51 |
+
|
52 |
+
class DINOv2(torch.nn.Module):
|
53 |
+
def __init__(self, resize=True, size=224, model_type='dinov2_vitl14'):
|
54 |
+
super(DINOv2, self).__init__()
|
55 |
+
self.size=size
|
56 |
+
self.resize = resize
|
57 |
+
self.transform = torch.nn.functional.interpolate
|
58 |
+
self.model = torch.hub.load('facebookresearch/dinov2', model_type)
|
59 |
+
self.register_buffer("mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
|
60 |
+
self.register_buffer("std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
|
61 |
+
|
62 |
+
def forward(self, input, appearance_layers=[1,2]):
|
63 |
+
if input.shape[1] != 3:
|
64 |
+
input = input.repeat(1, 3, 1, 1)
|
65 |
+
target = target.repeat(1, 3, 1, 1)
|
66 |
+
|
67 |
+
if self.resize:
|
68 |
+
input = self.transform(input, mode='bicubic', size=(self.size, self.size), align_corners=False)
|
69 |
+
# mean = torch.tensor(IMAGENET_MEAN).view(1, 3, 1, 1).to(input.device)
|
70 |
+
# std = torch.tensor(IMAGENET_STD).view(1, 3, 1, 1).to(input.device)
|
71 |
+
input = (input-self.mean) / self.std
|
72 |
+
feats = self.model.get_intermediate_layers(input, self.model.n_blocks, reshape=True)
|
73 |
+
feats = [f.detach() for f in feats]
|
74 |
+
|
75 |
+
return feats
|
cldm/cldm.py
CHANGED
@@ -10,7 +10,6 @@ from ldm.modules.diffusionmodules.util import (
|
|
10 |
zero_module,
|
11 |
timestep_embedding,
|
12 |
)
|
13 |
-
import torchvision
|
14 |
from einops import rearrange, repeat
|
15 |
from torchvision.utils import make_grid
|
16 |
from ldm.modules.attention import SpatialTransformer
|
@@ -18,46 +17,9 @@ from ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSeq
|
|
18 |
from ldm.models.diffusion.ddpm import LatentDiffusion
|
19 |
from ldm.util import log_txt_as_img, exists, instantiate_from_config
|
20 |
from ldm.models.diffusion.ddim import DDIMSampler
|
|
|
21 |
|
22 |
|
23 |
-
class VGGPerceptualLoss(torch.nn.Module):
|
24 |
-
def __init__(self, resize=True):
|
25 |
-
super(VGGPerceptualLoss, self).__init__()
|
26 |
-
blocks = []
|
27 |
-
vgg_model = torchvision.models.vgg16(pretrained=True)
|
28 |
-
print('Loaded VGG weights')
|
29 |
-
blocks.append(vgg_model.features[:4].eval())
|
30 |
-
blocks.append(vgg_model.features[4:9].eval())
|
31 |
-
blocks.append(vgg_model.features[9:16].eval())
|
32 |
-
blocks.append(vgg_model.features[16:23].eval())
|
33 |
-
|
34 |
-
for bl in blocks:
|
35 |
-
for p in bl.parameters():
|
36 |
-
p.requires_grad = False
|
37 |
-
self.blocks = torch.nn.ModuleList(blocks)
|
38 |
-
self.transform = torch.nn.functional.interpolate
|
39 |
-
self.resize = resize
|
40 |
-
self.register_buffer("mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
|
41 |
-
self.register_buffer("std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
|
42 |
-
print('Initialized VGG model')
|
43 |
-
|
44 |
-
def forward(self, input, feature_layers=[0, 1, 2, 3], style_layers=[1,]):
|
45 |
-
if input.shape[1] != 3:
|
46 |
-
input = input.repeat(1, 3, 1, 1)
|
47 |
-
target = target.repeat(1, 3, 1, 1)
|
48 |
-
input = (input-self.mean) / self.std
|
49 |
-
if self.resize:
|
50 |
-
input = self.transform(input, mode='bilinear', size=(224, 224), align_corners=False)
|
51 |
-
x = input
|
52 |
-
gram_matrices_all = []
|
53 |
-
feats = []
|
54 |
-
for i, block in enumerate(self.blocks):
|
55 |
-
x = block(x)
|
56 |
-
if i in style_layers:
|
57 |
-
feats.append(x)
|
58 |
-
|
59 |
-
return feats
|
60 |
-
|
61 |
|
62 |
|
63 |
class ControlledUnetModel(UNetModel):
|
@@ -325,6 +287,7 @@ class ControlNet(nn.Module):
|
|
325 |
def forward(self, x, hint, timesteps, context, **kwargs):
|
326 |
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
|
327 |
emb = self.time_embed(t_emb)
|
|
|
328 |
guided_hint = self.input_hint_block(hint, emb, context, x.shape)
|
329 |
|
330 |
outs = []
|
@@ -343,57 +306,6 @@ class ControlNet(nn.Module):
|
|
343 |
outs.append(self.middle_block_out(h, emb, context))
|
344 |
|
345 |
return outs
|
346 |
-
|
347 |
-
class Interpolate(nn.Module):
|
348 |
-
def __init__(self, size, mode):
|
349 |
-
super(Interpolate, self).__init__()
|
350 |
-
self.interp = torch.nn.functional.interpolate
|
351 |
-
self.size = size
|
352 |
-
self.mode = mode
|
353 |
-
self.factor = 8
|
354 |
-
|
355 |
-
def forward(self, x):
|
356 |
-
h,w = x.shape[2]//self.factor, x.shape[3]//self.factor
|
357 |
-
x = self.interp(x, size=(h,w), mode=self.mode)
|
358 |
-
return x
|
359 |
-
|
360 |
-
class ControlNetSAP(ControlNet):
|
361 |
-
def __init__(
|
362 |
-
self,
|
363 |
-
hint_channels,
|
364 |
-
model_channels,
|
365 |
-
input_hint_block='fixed',
|
366 |
-
size = 64,
|
367 |
-
mode='nearest',
|
368 |
-
*args,
|
369 |
-
**kwargs
|
370 |
-
):
|
371 |
-
super().__init__( hint_channels=hint_channels, model_channels=model_channels, *args, **kwargs)
|
372 |
-
#hint channels are atleast 128 dims
|
373 |
-
|
374 |
-
if input_hint_block == 'learnable':
|
375 |
-
ch = 2 ** (int(math.log2(hint_channels)))
|
376 |
-
self.input_hint_block = TimestepEmbedSequential(
|
377 |
-
conv_nd(self.dims, hint_channels, hint_channels, 3, padding=1),
|
378 |
-
nn.SiLU(),
|
379 |
-
conv_nd(self.dims, hint_channels, 2*ch, 3, padding=1, stride=2),
|
380 |
-
nn.SiLU(),
|
381 |
-
conv_nd(self.dims, 2*ch, 2*ch, 3, padding=1),
|
382 |
-
nn.SiLU(),
|
383 |
-
conv_nd(self.dims, 2*ch, 2*ch, 3, padding=1, stride=2),
|
384 |
-
nn.SiLU(),
|
385 |
-
conv_nd(self.dims, 2*ch, 2*ch, 3, padding=1),
|
386 |
-
nn.SiLU(),
|
387 |
-
conv_nd(self.dims, 2*ch, model_channels, 3, padding=1, stride=2),
|
388 |
-
nn.SiLU(),
|
389 |
-
zero_module(conv_nd(self.dims, model_channels, model_channels, 3, padding=1))
|
390 |
-
)
|
391 |
-
else:
|
392 |
-
print("Only interpolation")
|
393 |
-
self.input_hint_block = TimestepEmbedSequential(
|
394 |
-
Interpolate(size, mode),
|
395 |
-
zero_module(conv_nd(self.dims, hint_channels, model_channels, 3, padding=1)))
|
396 |
-
|
397 |
|
398 |
class ControlLDM(LatentDiffusion):
|
399 |
|
@@ -420,11 +332,11 @@ class ControlLDM(LatentDiffusion):
|
|
420 |
diffusion_model = self.model.diffusion_model
|
421 |
|
422 |
cond_txt = torch.cat(cond['c_crossattn'], 1)
|
423 |
-
|
424 |
if cond['c_concat'] is None:
|
425 |
eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=None, only_mid_control=self.only_mid_control)
|
426 |
else:
|
427 |
-
control = self.control_model(x=x_noisy, hint=torch.cat(cond['c_concat'], 1), timesteps=t, context=cond_txt)
|
|
|
428 |
control = [c * scale for c, scale in zip(control, self.control_scales)]
|
429 |
eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=control, only_mid_control=self.only_mid_control)
|
430 |
|
@@ -443,7 +355,7 @@ class ControlLDM(LatentDiffusion):
|
|
443 |
use_ddim = ddim_steps is not None
|
444 |
|
445 |
log = dict()
|
446 |
-
z, c = self.get_input(batch, self.first_stage_key, bs=N)
|
447 |
c_cat, c = c["c_concat"][0][:N], c["c_crossattn"][0][:N]
|
448 |
N = min(z.shape[0], N)
|
449 |
n_row = min(z.shape[0], n_row)
|
@@ -498,8 +410,9 @@ class ControlLDM(LatentDiffusion):
|
|
498 |
@torch.no_grad()
|
499 |
def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
|
500 |
ddim_sampler = DDIMSampler(self)
|
501 |
-
b, c, h, w = cond["c_concat"][0].shape
|
502 |
-
shape = (self.channels, h // 8, w // 8)
|
|
|
503 |
samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size, shape, cond, verbose=False, **kwargs)
|
504 |
return samples, intermediates
|
505 |
|
@@ -525,24 +438,54 @@ class ControlLDM(LatentDiffusion):
|
|
525 |
self.cond_stage_model = self.cond_stage_model.cuda()
|
526 |
|
527 |
|
528 |
-
|
|
|
529 |
@torch.no_grad()
|
530 |
-
def __init__(self,control_stage_config, control_key, only_mid_control,
|
|
|
531 |
super().__init__(control_stage_config=control_stage_config,
|
532 |
control_key=control_key,
|
533 |
only_mid_control=only_mid_control,
|
534 |
*args, **kwargs)
|
535 |
-
self.appearance_net = VGGPerceptualLoss().to(self.device)
|
536 |
-
print("Loaded VGG model")
|
537 |
|
538 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
539 |
img = (img + 1) * 0.5
|
540 |
-
feat =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
541 |
empty_mask_flag = torch.sum(mask, dim=(1,2,3)) == 0
|
542 |
|
543 |
|
544 |
empty_appearance = torch.zeros(feat.shape).to(self.device)
|
545 |
-
mask = torch.nn.functional.interpolate(mask.float(), (feat.shape[2
|
546 |
one_hot = torch.nn.functional.one_hot(mask[:,0]).permute(0,3,1,2).float()
|
547 |
|
548 |
feat = torch.einsum('nchw, nmhw->nmchw', feat, one_hot)
|
@@ -552,32 +495,68 @@ class SAP(ControlLDM):
|
|
552 |
mean_feat[:, 0] = torch.zeros(mean_feat[:,0].shape).to(self.device) #set edges in panopitc mask to empty appearance feature
|
553 |
|
554 |
splatted_feat = torch.einsum('nmc, nmhw->nchw', mean_feat, one_hot)
|
555 |
-
splatted_feat[empty_mask_flag] = empty_appearance[empty_mask_flag]
|
556 |
splatted_feat = torch.nn.functional.normalize(splatted_feat) #l2 normalize on c dim
|
557 |
|
558 |
if return_all:
|
559 |
return splatted_feat, mean_feat, one_hot, empty_mask_flag
|
560 |
-
|
561 |
return splatted_feat
|
562 |
-
|
|
|
563 |
def get_input(self, batch, k, bs=None, *args, **kwargs):
|
564 |
z, c, x_orig, x_recon = super(ControlLDM, self).get_input(batch, self.first_stage_key, return_first_stage_outputs=True , *args, **kwargs)
|
565 |
structure = batch['seg'].unsqueeze(1)
|
566 |
mask = batch['mask'].unsqueeze(1).to(self.device)
|
567 |
-
|
|
|
|
|
|
|
568 |
if bs is not None:
|
569 |
structure = structure[:bs]
|
570 |
-
appearance = appearance[:bs]
|
571 |
-
|
572 |
structure = structure.to(self.device)
|
573 |
-
appearance = appearance.to(self.device)
|
574 |
structure = structure.to(memory_format=torch.contiguous_format).float()
|
575 |
-
|
576 |
-
|
577 |
-
|
578 |
-
|
579 |
-
|
580 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
581 |
@torch.no_grad()
|
582 |
def log_images(self, batch, N=4, n_row=2, sample=False, ddim_steps=50, ddim_eta=0.0, return_keys=None,
|
583 |
quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=False,
|
@@ -588,11 +567,14 @@ class SAP(ControlLDM):
|
|
588 |
|
589 |
log = dict()
|
590 |
z, c = self.get_input(batch, self.first_stage_key, bs=N)
|
591 |
-
c_cat, c = c["c_concat"][0]
|
592 |
N = min(z.shape[0], N)
|
593 |
n_row = min(z.shape[0], n_row)
|
594 |
log["reconstruction"] = self.decode_first_stage(z)
|
595 |
-
log["control"] =
|
|
|
|
|
|
|
596 |
log["conditioning"] = log_txt_as_img((512, 512), batch[self.cond_stage_key], size=16)
|
597 |
|
598 |
if plot_diffusion_rows:
|
@@ -634,7 +616,7 @@ class SAP(ControlLDM):
|
|
634 |
|
635 |
if unconditional_guidance_scale > 1.0:
|
636 |
uc_cross = self.get_unconditional_conditioning(N)
|
637 |
-
uc_cat = c_cat # torch.zeros_like(c_cat)
|
638 |
uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]}
|
639 |
samples_cfg, _ = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
|
640 |
batch_size=N, ddim=use_ddim,
|
@@ -646,3 +628,18 @@ class SAP(ControlLDM):
|
|
646 |
log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
|
647 |
|
648 |
return log
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
zero_module,
|
11 |
timestep_embedding,
|
12 |
)
|
|
|
13 |
from einops import rearrange, repeat
|
14 |
from torchvision.utils import make_grid
|
15 |
from ldm.modules.attention import SpatialTransformer
|
|
|
17 |
from ldm.models.diffusion.ddpm import LatentDiffusion
|
18 |
from ldm.util import log_txt_as_img, exists, instantiate_from_config
|
19 |
from ldm.models.diffusion.ddim import DDIMSampler
|
20 |
+
from cldm.appearance_networks import VGGPerceptualLoss, DINOv2
|
21 |
|
22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
|
25 |
class ControlledUnetModel(UNetModel):
|
|
|
287 |
def forward(self, x, hint, timesteps, context, **kwargs):
|
288 |
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
|
289 |
emb = self.time_embed(t_emb)
|
290 |
+
# hint = hint[:,:-1]
|
291 |
guided_hint = self.input_hint_block(hint, emb, context, x.shape)
|
292 |
|
293 |
outs = []
|
|
|
306 |
outs.append(self.middle_block_out(h, emb, context))
|
307 |
|
308 |
return outs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
309 |
|
310 |
class ControlLDM(LatentDiffusion):
|
311 |
|
|
|
332 |
diffusion_model = self.model.diffusion_model
|
333 |
|
334 |
cond_txt = torch.cat(cond['c_crossattn'], 1)
|
|
|
335 |
if cond['c_concat'] is None:
|
336 |
eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=None, only_mid_control=self.only_mid_control)
|
337 |
else:
|
338 |
+
# control = self.control_model(x=x_noisy, hint=torch.cat(cond['c_concat'], 1), timesteps=t, context=cond_txt)
|
339 |
+
control = self.control_model(x=x_noisy, hint=cond['c_concat'][0], timesteps=t, context=cond_txt)
|
340 |
control = [c * scale for c, scale in zip(control, self.control_scales)]
|
341 |
eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=control, only_mid_control=self.only_mid_control)
|
342 |
|
|
|
355 |
use_ddim = ddim_steps is not None
|
356 |
|
357 |
log = dict()
|
358 |
+
z, c = self.get_input(batch, self.first_stage_key, bs=N, logging=True)
|
359 |
c_cat, c = c["c_concat"][0][:N], c["c_crossattn"][0][:N]
|
360 |
N = min(z.shape[0], N)
|
361 |
n_row = min(z.shape[0], n_row)
|
|
|
410 |
@torch.no_grad()
|
411 |
def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
|
412 |
ddim_sampler = DDIMSampler(self)
|
413 |
+
b, c, h, w = cond["c_concat"][0][0].shape if isinstance(cond["c_concat"][0], list) else cond["c_concat"][0].shape
|
414 |
+
# shape = (self.channels, h // 8, w // 8)
|
415 |
+
shape = (self.channels, h, w)
|
416 |
samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size, shape, cond, verbose=False, **kwargs)
|
417 |
return samples, intermediates
|
418 |
|
|
|
438 |
self.cond_stage_model = self.cond_stage_model.cuda()
|
439 |
|
440 |
|
441 |
+
|
442 |
+
class PAIRDiffusion(ControlLDM):
|
443 |
@torch.no_grad()
|
444 |
+
def __init__(self,control_stage_config, control_key, only_mid_control, app_net='vgg', app_layer_conc=(1,), app_layer_ca=(6,6,18,18),
|
445 |
+
appearance_net_locked=True, concat_multi_app=False, train_structure_variation_only=False, instruct=False, *args, **kwargs):
|
446 |
super().__init__(control_stage_config=control_stage_config,
|
447 |
control_key=control_key,
|
448 |
only_mid_control=only_mid_control,
|
449 |
*args, **kwargs)
|
|
|
|
|
450 |
|
451 |
+
self.appearance_net_conc = VGGPerceptualLoss().to(self.device)
|
452 |
+
self.appearance_net_ca = DINOv2().to(self.device)
|
453 |
+
self.appearance_net = VGGPerceptualLoss().to(self.device) #need to be removed no use
|
454 |
+
self.app_layer_conc = app_layer_conc
|
455 |
+
self.app_layer_ca = app_layer_ca
|
456 |
+
|
457 |
+
|
458 |
+
def get_appearance(self, net, layer, img, mask, return_all=False):
|
459 |
img = (img + 1) * 0.5
|
460 |
+
feat = net(img)
|
461 |
+
splatted_feat = []
|
462 |
+
mean_feat = []
|
463 |
+
for fe_i in layer:
|
464 |
+
v = self.get_appearance_single(feat[fe_i], mask, return_all=return_all)
|
465 |
+
if return_all:
|
466 |
+
spl, me_f, one_hot, empty_mask = v
|
467 |
+
splatted_feat.append(spl)
|
468 |
+
mean_feat.append(me_f)
|
469 |
+
else:
|
470 |
+
splatted_feat.append(v)
|
471 |
+
|
472 |
+
if len(layer) == 1:
|
473 |
+
splatted_feat = splatted_feat[0]
|
474 |
+
# mean_feat = mean_feat[0]
|
475 |
+
|
476 |
+
del feat
|
477 |
+
|
478 |
+
if return_all:
|
479 |
+
return splatted_feat, mean_feat, one_hot, empty_mask
|
480 |
+
|
481 |
+
return splatted_feat
|
482 |
+
|
483 |
+
def get_appearance_single(self, feat, mask, return_all):
|
484 |
empty_mask_flag = torch.sum(mask, dim=(1,2,3)) == 0
|
485 |
|
486 |
|
487 |
empty_appearance = torch.zeros(feat.shape).to(self.device)
|
488 |
+
mask = torch.nn.functional.interpolate(mask.float(), size=(feat.shape[2], feat.shape[3])).long()
|
489 |
one_hot = torch.nn.functional.one_hot(mask[:,0]).permute(0,3,1,2).float()
|
490 |
|
491 |
feat = torch.einsum('nchw, nmhw->nmchw', feat, one_hot)
|
|
|
495 |
mean_feat[:, 0] = torch.zeros(mean_feat[:,0].shape).to(self.device) #set edges in panopitc mask to empty appearance feature
|
496 |
|
497 |
splatted_feat = torch.einsum('nmc, nmhw->nchw', mean_feat, one_hot)
|
498 |
+
splatted_feat[empty_mask_flag] = empty_appearance[empty_mask_flag]
|
499 |
splatted_feat = torch.nn.functional.normalize(splatted_feat) #l2 normalize on c dim
|
500 |
|
501 |
if return_all:
|
502 |
return splatted_feat, mean_feat, one_hot, empty_mask_flag
|
|
|
503 |
return splatted_feat
|
504 |
+
|
505 |
+
|
506 |
def get_input(self, batch, k, bs=None, *args, **kwargs):
|
507 |
z, c, x_orig, x_recon = super(ControlLDM, self).get_input(batch, self.first_stage_key, return_first_stage_outputs=True , *args, **kwargs)
|
508 |
structure = batch['seg'].unsqueeze(1)
|
509 |
mask = batch['mask'].unsqueeze(1).to(self.device)
|
510 |
+
|
511 |
+
appearance_conc = self.get_appearance(self.appearance_net_conc, self.app_layer_conc, x_orig, mask)
|
512 |
+
appearance_ca = self.get_appearance(self.appearance_net_ca, self.app_layer_ca, x_orig, mask)
|
513 |
+
|
514 |
if bs is not None:
|
515 |
structure = structure[:bs]
|
|
|
|
|
516 |
structure = structure.to(self.device)
|
|
|
517 |
structure = structure.to(memory_format=torch.contiguous_format).float()
|
518 |
+
structure = torch.nn.functional.interpolate(structure, z.shape[2:])
|
519 |
+
|
520 |
+
mask = torch.nn.functional.interpolate(mask.float(), z.shape[2:])
|
521 |
+
|
522 |
+
def format_appearance(appearance):
|
523 |
+
if isinstance(appearance, list):
|
524 |
+
if bs is not None:
|
525 |
+
appearance = [ap[:bs] for ap in appearance]
|
526 |
+
appearance = [ap.to(self.device) for ap in appearance]
|
527 |
+
appearance = [ap.to(memory_format=torch.contiguous_format).float() for ap in appearance]
|
528 |
+
appearance = [torch.nn.functional.interpolate(ap, z.shape[2:]) for ap in appearance]
|
529 |
+
|
530 |
+
else:
|
531 |
+
if bs is not None:
|
532 |
+
appearance = appearance[:bs]
|
533 |
+
appearance = appearance.to(self.device)
|
534 |
+
appearance = appearance.to(memory_format=torch.contiguous_format).float()
|
535 |
+
appearance = torch.nn.functional.interpolate(appearance, z.shape[2:])
|
536 |
+
|
537 |
+
return appearance
|
538 |
+
|
539 |
+
appearance_conc = format_appearance(appearance_conc)
|
540 |
+
appearance_ca = format_appearance(appearance_ca)
|
541 |
+
|
542 |
+
if isinstance(appearance_conc, list):
|
543 |
+
concat_control = torch.cat(appearance_conc, dim=1)
|
544 |
+
concat_control = torch.cat([structure, concat_control, mask], dim=1)
|
545 |
+
else:
|
546 |
+
concat_control = torch.cat([structure, appearance_conc, mask], dim=1)
|
547 |
+
|
548 |
+
|
549 |
+
if isinstance(appearance_ca, list):
|
550 |
+
control = []
|
551 |
+
for ap in appearance_ca:
|
552 |
+
control.append(torch.cat([structure, ap, mask], dim=1))
|
553 |
+
control.append(concat_control)
|
554 |
+
return z, dict(c_crossattn=[c], c_concat=[control])
|
555 |
+
else:
|
556 |
+
control = torch.cat([structure, appearance_ca, mask], dim=1)
|
557 |
+
control.append(concat_control)
|
558 |
+
return z, dict(c_crossattn=[c], c_concat=[control])
|
559 |
+
|
560 |
@torch.no_grad()
|
561 |
def log_images(self, batch, N=4, n_row=2, sample=False, ddim_steps=50, ddim_eta=0.0, return_keys=None,
|
562 |
quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=False,
|
|
|
567 |
|
568 |
log = dict()
|
569 |
z, c = self.get_input(batch, self.first_stage_key, bs=N)
|
570 |
+
c_cat, c = c["c_concat"][0], c["c_crossattn"][0]
|
571 |
N = min(z.shape[0], N)
|
572 |
n_row = min(z.shape[0], n_row)
|
573 |
log["reconstruction"] = self.decode_first_stage(z)
|
574 |
+
log["control"] = batch['mask'].unsqueeze(1)
|
575 |
+
if 'aug_mask' in batch:
|
576 |
+
log['aug_mask'] = batch['aug_mask'].unsqueeze(1)
|
577 |
+
|
578 |
log["conditioning"] = log_txt_as_img((512, 512), batch[self.cond_stage_key], size=16)
|
579 |
|
580 |
if plot_diffusion_rows:
|
|
|
616 |
|
617 |
if unconditional_guidance_scale > 1.0:
|
618 |
uc_cross = self.get_unconditional_conditioning(N)
|
619 |
+
uc_cat = list(c_cat) # torch.zeros_like(c_cat)
|
620 |
uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]}
|
621 |
samples_cfg, _ = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
|
622 |
batch_size=N, ddim=use_ddim,
|
|
|
628 |
log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
|
629 |
|
630 |
return log
|
631 |
+
|
632 |
+
|
633 |
+
def configure_optimizers(self):
|
634 |
+
lr = self.learning_rate
|
635 |
+
|
636 |
+
params = list(self.control_model.parameters())
|
637 |
+
if not self.sd_locked:
|
638 |
+
params += list(self.model.diffusion_model.output_blocks.parameters())
|
639 |
+
params += list(self.model.diffusion_model.out.parameters())
|
640 |
+
|
641 |
+
opt = torch.optim.AdamW(params, lr=lr)
|
642 |
+
return opt
|
643 |
+
|
644 |
+
|
645 |
+
|
cldm/controlnet.py
ADDED
@@ -0,0 +1,306 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch as th
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
from ldm.modules.diffusionmodules.util import (
|
6 |
+
conv_nd,
|
7 |
+
linear,
|
8 |
+
zero_module,
|
9 |
+
timestep_embedding,
|
10 |
+
)
|
11 |
+
|
12 |
+
from ldm.modules.attention import SpatialTransformer
|
13 |
+
from ldm.modules.diffusionmodules.openaimodel import TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock
|
14 |
+
from ldm.util import exists
|
15 |
+
|
16 |
+
torch.autograd.set_detect_anomaly(True)
|
17 |
+
|
18 |
+
class Interpolate(nn.Module):
|
19 |
+
def __init__(self, mode):
|
20 |
+
super(Interpolate, self).__init__()
|
21 |
+
self.interp = torch.nn.functional.interpolate
|
22 |
+
self.mode = mode
|
23 |
+
self.factor = 8
|
24 |
+
|
25 |
+
def forward(self, x):
|
26 |
+
return x
|
27 |
+
|
28 |
+
class ControlNetPAIR(nn.Module):
|
29 |
+
def __init__(
|
30 |
+
self,
|
31 |
+
image_size,
|
32 |
+
in_channels,
|
33 |
+
model_channels,
|
34 |
+
hint_channels,
|
35 |
+
concat_indices,
|
36 |
+
num_res_blocks,
|
37 |
+
attention_resolutions,
|
38 |
+
concat_channels=130,
|
39 |
+
dropout=0,
|
40 |
+
channel_mult=(1, 2, 4, 8),
|
41 |
+
mode='nearest',
|
42 |
+
conv_resample=True,
|
43 |
+
dims=2,
|
44 |
+
use_checkpoint=False,
|
45 |
+
use_fp16=False,
|
46 |
+
num_heads=-1,
|
47 |
+
num_head_channels=-1,
|
48 |
+
num_heads_upsample=-1,
|
49 |
+
use_scale_shift_norm=False,
|
50 |
+
resblock_updown=False,
|
51 |
+
use_new_attention_order=False,
|
52 |
+
use_spatial_transformer=False, # custom transformer support
|
53 |
+
transformer_depth=1, # custom transformer support
|
54 |
+
context_dim=None, # custom transformer support
|
55 |
+
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
|
56 |
+
legacy=True,
|
57 |
+
disable_self_attentions=None,
|
58 |
+
num_attention_blocks=None,
|
59 |
+
disable_middle_self_attn=False,
|
60 |
+
use_linear_in_transformer=False,
|
61 |
+
attn_class=['softmax', 'softmax', 'softmax', 'softmax'],
|
62 |
+
):
|
63 |
+
super().__init__()
|
64 |
+
if use_spatial_transformer:
|
65 |
+
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
|
66 |
+
|
67 |
+
if context_dim is not None:
|
68 |
+
assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
|
69 |
+
from omegaconf.listconfig import ListConfig
|
70 |
+
if type(context_dim) == ListConfig:
|
71 |
+
context_dim = list(context_dim)
|
72 |
+
|
73 |
+
if num_heads_upsample == -1:
|
74 |
+
num_heads_upsample = num_heads
|
75 |
+
|
76 |
+
if num_heads == -1:
|
77 |
+
assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
|
78 |
+
|
79 |
+
if num_head_channels == -1:
|
80 |
+
assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
|
81 |
+
|
82 |
+
self.dims = dims
|
83 |
+
self.image_size = image_size
|
84 |
+
self.in_channels = in_channels
|
85 |
+
self.model_channels = model_channels
|
86 |
+
if isinstance(num_res_blocks, int):
|
87 |
+
self.num_res_blocks = len(channel_mult) * [num_res_blocks]
|
88 |
+
else:
|
89 |
+
if len(num_res_blocks) != len(channel_mult):
|
90 |
+
raise ValueError("provide num_res_blocks either as an int (globally constant) or "
|
91 |
+
"as a list/tuple (per-level) with the same length as channel_mult")
|
92 |
+
self.num_res_blocks = num_res_blocks
|
93 |
+
if disable_self_attentions is not None:
|
94 |
+
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
|
95 |
+
assert len(disable_self_attentions) == len(channel_mult)
|
96 |
+
if num_attention_blocks is not None:
|
97 |
+
assert len(num_attention_blocks) == len(self.num_res_blocks)
|
98 |
+
assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
|
99 |
+
print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
|
100 |
+
f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
|
101 |
+
f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
|
102 |
+
f"attention will still not be set.")
|
103 |
+
|
104 |
+
self.attention_resolutions = attention_resolutions
|
105 |
+
self.dropout = dropout
|
106 |
+
self.channel_mult = channel_mult
|
107 |
+
self.conv_resample = conv_resample
|
108 |
+
self.use_checkpoint = use_checkpoint
|
109 |
+
self.dtype = th.float16 if use_fp16 else th.float32
|
110 |
+
self.num_heads = num_heads
|
111 |
+
self.num_head_channels = num_head_channels
|
112 |
+
self.num_heads_upsample = num_heads_upsample
|
113 |
+
self.predict_codebook_ids = n_embed is not None
|
114 |
+
|
115 |
+
time_embed_dim = model_channels * 4
|
116 |
+
self.time_embed = nn.Sequential(
|
117 |
+
linear(model_channels, time_embed_dim),
|
118 |
+
nn.SiLU(),
|
119 |
+
linear(time_embed_dim, time_embed_dim),
|
120 |
+
)
|
121 |
+
|
122 |
+
self.input_blocks = nn.ModuleList(
|
123 |
+
[
|
124 |
+
TimestepEmbedSequential(
|
125 |
+
conv_nd(dims, in_channels, model_channels, 3, padding=1)
|
126 |
+
)
|
127 |
+
]
|
128 |
+
)
|
129 |
+
self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels)])
|
130 |
+
self.concat_indices = concat_indices
|
131 |
+
self.hint_channels = hint_channels
|
132 |
+
h_ch = sum([hint_channels[i] for i in concat_indices ])
|
133 |
+
|
134 |
+
self.input_hint_block = TimestepEmbedSequential(
|
135 |
+
Interpolate('nearest'),
|
136 |
+
conv_nd(self.dims, concat_channels, self.model_channels, 3, padding=1),
|
137 |
+
nn.SiLU(),
|
138 |
+
zero_module(conv_nd(self.dims, self.model_channels, self.model_channels, 3, padding=1)))
|
139 |
+
|
140 |
+
self._feature_size = model_channels
|
141 |
+
input_block_chans = [model_channels]
|
142 |
+
ch = model_channels
|
143 |
+
ds = 1
|
144 |
+
for level, mult in enumerate(channel_mult):
|
145 |
+
for nr in range(self.num_res_blocks[level]):
|
146 |
+
layers = [
|
147 |
+
ResBlock(
|
148 |
+
ch,
|
149 |
+
time_embed_dim,
|
150 |
+
dropout,
|
151 |
+
out_channels=mult * model_channels,
|
152 |
+
dims=dims,
|
153 |
+
use_checkpoint=use_checkpoint,
|
154 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
155 |
+
)
|
156 |
+
]
|
157 |
+
ch = mult * model_channels
|
158 |
+
if ds in attention_resolutions:
|
159 |
+
if num_head_channels == -1:
|
160 |
+
dim_head = ch // num_heads
|
161 |
+
else:
|
162 |
+
num_heads = ch // num_head_channels
|
163 |
+
dim_head = num_head_channels
|
164 |
+
if legacy:
|
165 |
+
# num_heads = 1
|
166 |
+
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
167 |
+
if exists(disable_self_attentions):
|
168 |
+
disabled_sa = disable_self_attentions[level]
|
169 |
+
else:
|
170 |
+
disabled_sa = False
|
171 |
+
|
172 |
+
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
|
173 |
+
layers.append(
|
174 |
+
AttentionBlock(
|
175 |
+
ch,
|
176 |
+
use_checkpoint=use_checkpoint,
|
177 |
+
num_heads=num_heads,
|
178 |
+
num_head_channels=dim_head,
|
179 |
+
use_new_attention_order=use_new_attention_order,
|
180 |
+
) if not use_spatial_transformer else SpatialTransformer(
|
181 |
+
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
|
182 |
+
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
183 |
+
use_checkpoint=use_checkpoint, attn1_mode=attn_class[level], obj_feat_dim=hint_channels[level]
|
184 |
+
)
|
185 |
+
)
|
186 |
+
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
187 |
+
self.zero_convs.append(self.make_zero_conv(ch))
|
188 |
+
self._feature_size += ch
|
189 |
+
input_block_chans.append(ch)
|
190 |
+
if level != len(channel_mult) - 1:
|
191 |
+
out_ch = ch
|
192 |
+
self.input_blocks.append(
|
193 |
+
TimestepEmbedSequential(
|
194 |
+
ResBlock(
|
195 |
+
ch,
|
196 |
+
time_embed_dim,
|
197 |
+
dropout,
|
198 |
+
out_channels=out_ch,
|
199 |
+
dims=dims,
|
200 |
+
use_checkpoint=use_checkpoint,
|
201 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
202 |
+
down=True,
|
203 |
+
)
|
204 |
+
if resblock_updown
|
205 |
+
else Downsample(
|
206 |
+
ch, conv_resample, dims=dims, out_channels=out_ch
|
207 |
+
)
|
208 |
+
)
|
209 |
+
)
|
210 |
+
ch = out_ch
|
211 |
+
input_block_chans.append(ch)
|
212 |
+
self.zero_convs.append(self.make_zero_conv(ch))
|
213 |
+
ds *= 2
|
214 |
+
self._feature_size += ch
|
215 |
+
|
216 |
+
if num_head_channels == -1:
|
217 |
+
dim_head = ch // num_heads
|
218 |
+
else:
|
219 |
+
num_heads = ch // num_head_channels
|
220 |
+
dim_head = num_head_channels
|
221 |
+
if legacy:
|
222 |
+
# num_heads = 1
|
223 |
+
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
224 |
+
self.middle_block = TimestepEmbedSequential(
|
225 |
+
ResBlock(
|
226 |
+
ch,
|
227 |
+
time_embed_dim,
|
228 |
+
# hint_channels[-1],
|
229 |
+
dropout,
|
230 |
+
dims=dims,
|
231 |
+
use_checkpoint=use_checkpoint,
|
232 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
233 |
+
),
|
234 |
+
AttentionBlock(
|
235 |
+
ch,
|
236 |
+
use_checkpoint=use_checkpoint,
|
237 |
+
num_heads=num_heads,
|
238 |
+
num_head_channels=dim_head,
|
239 |
+
use_new_attention_order=use_new_attention_order,
|
240 |
+
) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
|
241 |
+
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
|
242 |
+
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
|
243 |
+
use_checkpoint=use_checkpoint
|
244 |
+
),
|
245 |
+
ResBlock(
|
246 |
+
ch,
|
247 |
+
time_embed_dim,
|
248 |
+
# hint_channels[-1],
|
249 |
+
dropout,
|
250 |
+
dims=dims,
|
251 |
+
use_checkpoint=use_checkpoint,
|
252 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
253 |
+
),
|
254 |
+
)
|
255 |
+
self.middle_block_out = self.make_zero_conv(ch)
|
256 |
+
self._feature_size += ch
|
257 |
+
|
258 |
+
def make_zero_conv(self, channels):
|
259 |
+
return TimestepEmbedSequential(zero_module(conv_nd(self.dims, channels, channels, 1, padding=0)))
|
260 |
+
|
261 |
+
def forward(self, x, hint, timesteps, context, **kwargs):
|
262 |
+
hint_list = []
|
263 |
+
concat_hint = hint[-1]
|
264 |
+
hint_c = hint[:-1]
|
265 |
+
|
266 |
+
if not isinstance(hint_c, list):
|
267 |
+
for _ in range(len(self.channel_mult)):
|
268 |
+
hint_list.append(hint_c)
|
269 |
+
else:
|
270 |
+
hint_list = hint_c
|
271 |
+
while len(hint_list) < 4:
|
272 |
+
hint_list.append(hint_c[-1])
|
273 |
+
|
274 |
+
mask = hint_c[0][:,-1].unsqueeze(1) #panoptic
|
275 |
+
|
276 |
+
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
|
277 |
+
emb = self.time_embed(t_emb)
|
278 |
+
|
279 |
+
guided_hint = self.input_hint_block(concat_hint, emb, context, x.shape)
|
280 |
+
outs = []
|
281 |
+
|
282 |
+
h = x.type(self.dtype)
|
283 |
+
|
284 |
+
cnt = self.num_res_blocks[0] + 1
|
285 |
+
i = 0
|
286 |
+
for module, zero_conv in zip(self.input_blocks, self.zero_convs):
|
287 |
+
if guided_hint is not None:
|
288 |
+
h = module(h, emb, context, hint_list[i], mask)
|
289 |
+
h += guided_hint
|
290 |
+
guided_hint = None
|
291 |
+
else:
|
292 |
+
h = module(h, emb, context, hint_list[i], mask)
|
293 |
+
outs.append(zero_conv(h, emb, context))
|
294 |
+
|
295 |
+
cnt -= 1
|
296 |
+
if cnt == 0:
|
297 |
+
if i<len(self.num_res_blocks):
|
298 |
+
cnt = self.num_res_blocks[i] + 1
|
299 |
+
else:
|
300 |
+
if (i+1)<len(self.num_res_blocks):
|
301 |
+
i += 1
|
302 |
+
|
303 |
+
h = self.middle_block(h, emb, context, hint_list[-1], mask)
|
304 |
+
outs.append(self.middle_block_out(h, emb, context))
|
305 |
+
|
306 |
+
return outs
|
cldm/ddim_hacked.py
CHANGED
@@ -316,7 +316,6 @@ class DDIMSampler(object):
|
|
316 |
return x_dec
|
317 |
|
318 |
|
319 |
-
|
320 |
class DDIMSamplerSpaCFG(DDIMSampler):
|
321 |
@torch.no_grad()
|
322 |
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
@@ -332,8 +331,8 @@ class DDIMSamplerSpaCFG(DDIMSampler):
|
|
332 |
model_uncond = self.model.apply_model(x, t, unconditional_conditioning[0])
|
333 |
model_struct = self.model.apply_model(x, t, unconditional_conditioning[1])
|
334 |
model_struct_app = self.model.apply_model(x, t, unconditional_conditioning[2])
|
335 |
-
|
336 |
-
model_output = model_uncond + sS * (model_struct - model_uncond) + sF * (model_struct_app - model_struct) + sT * (model_t -
|
337 |
|
338 |
if self.model.parameterization == "v":
|
339 |
e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
|
|
|
316 |
return x_dec
|
317 |
|
318 |
|
|
|
319 |
class DDIMSamplerSpaCFG(DDIMSampler):
|
320 |
@torch.no_grad()
|
321 |
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
|
|
331 |
model_uncond = self.model.apply_model(x, t, unconditional_conditioning[0])
|
332 |
model_struct = self.model.apply_model(x, t, unconditional_conditioning[1])
|
333 |
model_struct_app = self.model.apply_model(x, t, unconditional_conditioning[2])
|
334 |
+
sS, sF, sT = unconditional_guidance_scale
|
335 |
+
model_output = model_uncond + sS * (model_struct - model_uncond) + sF * (model_struct_app - model_struct) + sT * (model_t - model_uncond)
|
336 |
|
337 |
if self.model.parameterization == "v":
|
338 |
e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
|
cldm/logger.py
CHANGED
@@ -114,16 +114,16 @@ class SetupCallback(Callback):
|
|
114 |
OmegaConf.save(OmegaConf.create({"lightning": self.lightning_config}),
|
115 |
os.path.join(self.cfgdir, "{}-lightning.yaml".format(self.now)))
|
116 |
|
117 |
-
else:
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
|
128 |
|
129 |
class ImageLogger(Callback):
|
|
|
114 |
OmegaConf.save(OmegaConf.create({"lightning": self.lightning_config}),
|
115 |
os.path.join(self.cfgdir, "{}-lightning.yaml".format(self.now)))
|
116 |
|
117 |
+
# else:
|
118 |
+
# # ModelCheckpoint callback created log directory --- remove it
|
119 |
+
# if not self.resume and os.path.exists(self.logdir):
|
120 |
+
# dst, name = os.path.split(self.logdir)
|
121 |
+
# dst = os.path.join(dst, "child_runs", name)
|
122 |
+
# os.makedirs(os.path.split(dst)[0], exist_ok=True)
|
123 |
+
# try:
|
124 |
+
# os.rename(self.logdir, dst)
|
125 |
+
# except FileNotFoundError:
|
126 |
+
# pass
|
127 |
|
128 |
|
129 |
class ImageLogger(Callback):
|
configs/{sap_fixed_hintnet_v15.yaml → pair_diff.yaml}
RENAMED
@@ -1,9 +1,9 @@
|
|
1 |
model:
|
2 |
-
target: cldm.cldm.
|
3 |
learning_rate: 1.5e-05
|
4 |
sd_locked: True
|
5 |
only_mid_control: False
|
6 |
-
init_ckpt: './models/
|
7 |
params:
|
8 |
linear_start: 0.00085
|
9 |
linear_end: 0.0120
|
@@ -21,14 +21,17 @@ model:
|
|
21 |
scale_factor: 0.18215
|
22 |
use_ema: False
|
23 |
only_mid_control: False
|
|
|
|
|
24 |
|
25 |
control_stage_config:
|
26 |
-
target: cldm.
|
27 |
params:
|
28 |
-
input_hint_block: 'fixed'
|
29 |
image_size: 32 # unused
|
30 |
in_channels: 4
|
31 |
-
|
|
|
|
|
32 |
model_channels: 320
|
33 |
attention_resolutions: [ 4, 2, 1 ]
|
34 |
num_res_blocks: 2
|
@@ -39,6 +42,7 @@ model:
|
|
39 |
context_dim: 768
|
40 |
use_checkpoint: True
|
41 |
legacy: False
|
|
|
42 |
|
43 |
unet_config:
|
44 |
target: cldm.cldm.ControlledUnetModel
|
@@ -87,16 +91,25 @@ model:
|
|
87 |
data:
|
88 |
target: cldm.data.DataModuleFromConfig
|
89 |
params:
|
90 |
-
batch_size:
|
91 |
wrap: True
|
|
|
92 |
train:
|
93 |
target: dataset.txtseg.COCOTrain
|
94 |
params:
|
|
|
|
|
|
|
|
|
95 |
size: 512
|
96 |
validation:
|
97 |
target: dataset.txtseg.COCOValidation
|
98 |
params:
|
99 |
size: 512
|
|
|
|
|
|
|
|
|
100 |
|
101 |
|
102 |
lightning:
|
@@ -111,4 +124,4 @@ lightning:
|
|
111 |
|
112 |
trainer:
|
113 |
benchmark: True
|
114 |
-
accumulate_grad_batches:
|
|
|
1 |
model:
|
2 |
+
target: cldm.cldm.PAIRDiffusion
|
3 |
learning_rate: 1.5e-05
|
4 |
sd_locked: True
|
5 |
only_mid_control: False
|
6 |
+
init_ckpt: './models/pair_diff_init.ckpt'
|
7 |
params:
|
8 |
linear_start: 0.00085
|
9 |
linear_end: 0.0120
|
|
|
21 |
scale_factor: 0.18215
|
22 |
use_ema: False
|
23 |
only_mid_control: False
|
24 |
+
appearance_net_locked: True
|
25 |
+
app_net: 'DINO'
|
26 |
|
27 |
control_stage_config:
|
28 |
+
target: cldm.controlnet.ControlNetPAIR
|
29 |
params:
|
|
|
30 |
image_size: 32 # unused
|
31 |
in_channels: 4
|
32 |
+
concat_indices: [0,1]
|
33 |
+
concat_channels: 130
|
34 |
+
hint_channels: [1026, 1026, -1, -1] #(1024 + 2)
|
35 |
model_channels: 320
|
36 |
attention_resolutions: [ 4, 2, 1 ]
|
37 |
num_res_blocks: 2
|
|
|
42 |
context_dim: 768
|
43 |
use_checkpoint: True
|
44 |
legacy: False
|
45 |
+
attn_class: ['maskguided', 'maskguided', 'softmax', 'softmax']
|
46 |
|
47 |
unet_config:
|
48 |
target: cldm.cldm.ControlledUnetModel
|
|
|
91 |
data:
|
92 |
target: cldm.data.DataModuleFromConfig
|
93 |
params:
|
94 |
+
batch_size: 2
|
95 |
wrap: True
|
96 |
+
num_workers: 4
|
97 |
train:
|
98 |
target: dataset.txtseg.COCOTrain
|
99 |
params:
|
100 |
+
image_dir:
|
101 |
+
caption_file:
|
102 |
+
panoptic_mask_dir:
|
103 |
+
seg_dir:
|
104 |
size: 512
|
105 |
validation:
|
106 |
target: dataset.txtseg.COCOValidation
|
107 |
params:
|
108 |
size: 512
|
109 |
+
image_dir:
|
110 |
+
caption_file:
|
111 |
+
panoptic_mask_dir:
|
112 |
+
seg_dir:
|
113 |
|
114 |
|
115 |
lightning:
|
|
|
124 |
|
125 |
trainer:
|
126 |
benchmark: True
|
127 |
+
accumulate_grad_batches: 2
|
ldm/ldm/util.py
ADDED
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import optim
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
from inspect import isfunction
|
8 |
+
from PIL import Image, ImageDraw, ImageFont
|
9 |
+
|
10 |
+
|
11 |
+
def log_txt_as_img(wh, xc, size=10):
|
12 |
+
# wh a tuple of (width, height)
|
13 |
+
# xc a list of captions to plot
|
14 |
+
b = len(xc)
|
15 |
+
txts = list()
|
16 |
+
for bi in range(b):
|
17 |
+
txt = Image.new("RGB", wh, color="white")
|
18 |
+
draw = ImageDraw.Draw(txt)
|
19 |
+
font = ImageFont.truetype('font/DejaVuSans.ttf', size=size)
|
20 |
+
nc = int(40 * (wh[0] / 256))
|
21 |
+
lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
|
22 |
+
|
23 |
+
try:
|
24 |
+
draw.text((0, 0), lines, fill="black", font=font)
|
25 |
+
except UnicodeEncodeError:
|
26 |
+
print("Cant encode string for logging. Skipping.")
|
27 |
+
|
28 |
+
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
|
29 |
+
txts.append(txt)
|
30 |
+
txts = np.stack(txts)
|
31 |
+
txts = torch.tensor(txts)
|
32 |
+
return txts
|
33 |
+
|
34 |
+
|
35 |
+
def ismap(x):
|
36 |
+
if not isinstance(x, torch.Tensor):
|
37 |
+
return False
|
38 |
+
return (len(x.shape) == 4) and (x.shape[1] > 3)
|
39 |
+
|
40 |
+
|
41 |
+
def isimage(x):
|
42 |
+
if not isinstance(x,torch.Tensor):
|
43 |
+
return False
|
44 |
+
return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
|
45 |
+
|
46 |
+
|
47 |
+
def exists(x):
|
48 |
+
return x is not None
|
49 |
+
|
50 |
+
|
51 |
+
def default(val, d):
|
52 |
+
if exists(val):
|
53 |
+
return val
|
54 |
+
return d() if isfunction(d) else d
|
55 |
+
|
56 |
+
|
57 |
+
def mean_flat(tensor):
|
58 |
+
"""
|
59 |
+
https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
|
60 |
+
Take the mean over all non-batch dimensions.
|
61 |
+
"""
|
62 |
+
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
63 |
+
|
64 |
+
|
65 |
+
def count_params(model, verbose=False):
|
66 |
+
total_params = sum(p.numel() for p in model.parameters())
|
67 |
+
if verbose:
|
68 |
+
print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
|
69 |
+
return total_params
|
70 |
+
|
71 |
+
|
72 |
+
def instantiate_from_config(config):
|
73 |
+
if not "target" in config:
|
74 |
+
if config == '__is_first_stage__':
|
75 |
+
return None
|
76 |
+
elif config == "__is_unconditional__":
|
77 |
+
return None
|
78 |
+
raise KeyError("Expected key `target` to instantiate.")
|
79 |
+
return get_obj_from_str(config["target"])(**config.get("params", dict()))
|
80 |
+
|
81 |
+
|
82 |
+
def get_obj_from_str(string, reload=False):
|
83 |
+
module, cls = string.rsplit(".", 1)
|
84 |
+
if reload:
|
85 |
+
module_imp = importlib.import_module(module)
|
86 |
+
importlib.reload(module_imp)
|
87 |
+
return getattr(importlib.import_module(module, package=None), cls)
|
88 |
+
|
89 |
+
|
90 |
+
class AdamWwithEMAandWings(optim.Optimizer):
|
91 |
+
# credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298
|
92 |
+
def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, # TODO: check hyperparameters before using
|
93 |
+
weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999, # ema decay to match previous code
|
94 |
+
ema_power=1., param_names=()):
|
95 |
+
"""AdamW that saves EMA versions of the parameters."""
|
96 |
+
if not 0.0 <= lr:
|
97 |
+
raise ValueError("Invalid learning rate: {}".format(lr))
|
98 |
+
if not 0.0 <= eps:
|
99 |
+
raise ValueError("Invalid epsilon value: {}".format(eps))
|
100 |
+
if not 0.0 <= betas[0] < 1.0:
|
101 |
+
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
|
102 |
+
if not 0.0 <= betas[1] < 1.0:
|
103 |
+
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
|
104 |
+
if not 0.0 <= weight_decay:
|
105 |
+
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
|
106 |
+
if not 0.0 <= ema_decay <= 1.0:
|
107 |
+
raise ValueError("Invalid ema_decay value: {}".format(ema_decay))
|
108 |
+
defaults = dict(lr=lr, betas=betas, eps=eps,
|
109 |
+
weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay,
|
110 |
+
ema_power=ema_power, param_names=param_names)
|
111 |
+
super().__init__(params, defaults)
|
112 |
+
|
113 |
+
def __setstate__(self, state):
|
114 |
+
super().__setstate__(state)
|
115 |
+
for group in self.param_groups:
|
116 |
+
group.setdefault('amsgrad', False)
|
117 |
+
|
118 |
+
@torch.no_grad()
|
119 |
+
def step(self, closure=None):
|
120 |
+
"""Performs a single optimization step.
|
121 |
+
Args:
|
122 |
+
closure (callable, optional): A closure that reevaluates the model
|
123 |
+
and returns the loss.
|
124 |
+
"""
|
125 |
+
loss = None
|
126 |
+
if closure is not None:
|
127 |
+
with torch.enable_grad():
|
128 |
+
loss = closure()
|
129 |
+
|
130 |
+
for group in self.param_groups:
|
131 |
+
params_with_grad = []
|
132 |
+
grads = []
|
133 |
+
exp_avgs = []
|
134 |
+
exp_avg_sqs = []
|
135 |
+
ema_params_with_grad = []
|
136 |
+
state_sums = []
|
137 |
+
max_exp_avg_sqs = []
|
138 |
+
state_steps = []
|
139 |
+
amsgrad = group['amsgrad']
|
140 |
+
beta1, beta2 = group['betas']
|
141 |
+
ema_decay = group['ema_decay']
|
142 |
+
ema_power = group['ema_power']
|
143 |
+
|
144 |
+
for p in group['params']:
|
145 |
+
if p.grad is None:
|
146 |
+
continue
|
147 |
+
params_with_grad.append(p)
|
148 |
+
if p.grad.is_sparse:
|
149 |
+
raise RuntimeError('AdamW does not support sparse gradients')
|
150 |
+
grads.append(p.grad)
|
151 |
+
|
152 |
+
state = self.state[p]
|
153 |
+
|
154 |
+
# State initialization
|
155 |
+
if len(state) == 0:
|
156 |
+
state['step'] = 0
|
157 |
+
# Exponential moving average of gradient values
|
158 |
+
state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
159 |
+
# Exponential moving average of squared gradient values
|
160 |
+
state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
161 |
+
if amsgrad:
|
162 |
+
# Maintains max of all exp. moving avg. of sq. grad. values
|
163 |
+
state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
164 |
+
# Exponential moving average of parameter values
|
165 |
+
state['param_exp_avg'] = p.detach().float().clone()
|
166 |
+
|
167 |
+
exp_avgs.append(state['exp_avg'])
|
168 |
+
exp_avg_sqs.append(state['exp_avg_sq'])
|
169 |
+
ema_params_with_grad.append(state['param_exp_avg'])
|
170 |
+
|
171 |
+
if amsgrad:
|
172 |
+
max_exp_avg_sqs.append(state['max_exp_avg_sq'])
|
173 |
+
|
174 |
+
# update the steps for each param group update
|
175 |
+
state['step'] += 1
|
176 |
+
# record the step after step update
|
177 |
+
state_steps.append(state['step'])
|
178 |
+
|
179 |
+
optim._functional.adamw(params_with_grad,
|
180 |
+
grads,
|
181 |
+
exp_avgs,
|
182 |
+
exp_avg_sqs,
|
183 |
+
max_exp_avg_sqs,
|
184 |
+
state_steps,
|
185 |
+
amsgrad=amsgrad,
|
186 |
+
beta1=beta1,
|
187 |
+
beta2=beta2,
|
188 |
+
lr=group['lr'],
|
189 |
+
weight_decay=group['weight_decay'],
|
190 |
+
eps=group['eps'],
|
191 |
+
maximize=False)
|
192 |
+
|
193 |
+
cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power)
|
194 |
+
for param, ema_param in zip(params_with_grad, ema_params_with_grad):
|
195 |
+
ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay)
|
196 |
+
|
197 |
+
return loss
|
ldm/models/diffusion/ddim.py
CHANGED
@@ -194,9 +194,19 @@ class DDIMSampler(object):
|
|
194 |
c_in = dict()
|
195 |
for k in c:
|
196 |
if isinstance(c[k], list):
|
197 |
-
c_in[k] = [
|
198 |
-
|
199 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
200 |
else:
|
201 |
c_in[k] = torch.cat([
|
202 |
unconditional_conditioning[k],
|
@@ -333,4 +343,5 @@ class DDIMSampler(object):
|
|
333 |
unconditional_guidance_scale=unconditional_guidance_scale,
|
334 |
unconditional_conditioning=unconditional_conditioning)
|
335 |
if callback: callback(i)
|
336 |
-
return x_dec
|
|
|
|
194 |
c_in = dict()
|
195 |
for k in c:
|
196 |
if isinstance(c[k], list):
|
197 |
+
c_in[k] = []
|
198 |
+
if isinstance(c[k][0], list):
|
199 |
+
for i in range(len(c[k])):
|
200 |
+
c_ = []
|
201 |
+
for j in range(len(c[k][i])):
|
202 |
+
c_.append(torch.cat([
|
203 |
+
unconditional_conditioning[k][i][j],
|
204 |
+
c[k][i][j]]) )
|
205 |
+
c_in[k].append(c_)
|
206 |
+
else:
|
207 |
+
c_in[k] = [torch.cat([
|
208 |
+
unconditional_conditioning[k][i],
|
209 |
+
c[k][i]]) for i in range(len(c[k]))]
|
210 |
else:
|
211 |
c_in[k] = torch.cat([
|
212 |
unconditional_conditioning[k],
|
|
|
343 |
unconditional_guidance_scale=unconditional_guidance_scale,
|
344 |
unconditional_conditioning=unconditional_conditioning)
|
345 |
if callback: callback(i)
|
346 |
+
return x_dec
|
347 |
+
|
ldm/modules/attention.py
CHANGED
@@ -42,7 +42,7 @@ def init_(tensor):
|
|
42 |
dim = tensor.shape[-1]
|
43 |
std = 1 / math.sqrt(dim)
|
44 |
tensor.uniform_(-std, std)
|
45 |
-
return tensor
|
46 |
|
47 |
|
48 |
# feedforward
|
@@ -143,7 +143,7 @@ class SpatialSelfAttention(nn.Module):
|
|
143 |
|
144 |
|
145 |
class CrossAttention(nn.Module):
|
146 |
-
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0
|
147 |
super().__init__()
|
148 |
inner_dim = dim_head * heads
|
149 |
context_dim = default(context_dim, query_dim)
|
@@ -160,7 +160,7 @@ class CrossAttention(nn.Module):
|
|
160 |
nn.Dropout(dropout)
|
161 |
)
|
162 |
|
163 |
-
def forward(self, x, context=None, mask=None):
|
164 |
h = self.heads
|
165 |
|
166 |
q = self.to_q(x)
|
@@ -194,6 +194,34 @@ class CrossAttention(nn.Module):
|
|
194 |
return self.to_out(out)
|
195 |
|
196 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
197 |
class MemoryEfficientCrossAttention(nn.Module):
|
198 |
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
|
199 |
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
|
@@ -246,17 +274,19 @@ class MemoryEfficientCrossAttention(nn.Module):
|
|
246 |
class BasicTransformerBlock(nn.Module):
|
247 |
ATTENTION_MODES = {
|
248 |
"softmax": CrossAttention, # vanilla attention
|
249 |
-
"softmax-xformers": MemoryEfficientCrossAttention
|
|
|
250 |
}
|
251 |
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
|
252 |
-
disable_self_attn=False):
|
253 |
super().__init__()
|
254 |
attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax"
|
255 |
assert attn_mode in self.ATTENTION_MODES
|
256 |
attn_cls = self.ATTENTION_MODES[attn_mode]
|
|
|
257 |
self.disable_self_attn = disable_self_attn
|
258 |
-
self.attn1 =
|
259 |
-
context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn
|
260 |
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
261 |
self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim,
|
262 |
heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
|
@@ -265,11 +295,17 @@ class BasicTransformerBlock(nn.Module):
|
|
265 |
self.norm3 = nn.LayerNorm(dim)
|
266 |
self.checkpoint = checkpoint
|
267 |
|
268 |
-
|
269 |
-
|
|
|
|
|
|
|
|
|
|
|
270 |
|
271 |
-
def _forward(self, x, context=None):
|
272 |
-
x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None
|
|
|
273 |
x = self.attn2(self.norm2(x), context=context) + x
|
274 |
x = self.ff(self.norm3(x)) + x
|
275 |
return x
|
@@ -287,7 +323,7 @@ class SpatialTransformer(nn.Module):
|
|
287 |
def __init__(self, in_channels, n_heads, d_head,
|
288 |
depth=1, dropout=0., context_dim=None,
|
289 |
disable_self_attn=False, use_linear=False,
|
290 |
-
use_checkpoint=True):
|
291 |
super().__init__()
|
292 |
if exists(context_dim) and not isinstance(context_dim, list):
|
293 |
context_dim = [context_dim]
|
@@ -305,7 +341,8 @@ class SpatialTransformer(nn.Module):
|
|
305 |
|
306 |
self.transformer_blocks = nn.ModuleList(
|
307 |
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
|
308 |
-
disable_self_attn=disable_self_attn, checkpoint=use_checkpoint
|
|
|
309 |
for d in range(depth)]
|
310 |
)
|
311 |
if not use_linear:
|
@@ -318,11 +355,20 @@ class SpatialTransformer(nn.Module):
|
|
318 |
self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
|
319 |
self.use_linear = use_linear
|
320 |
|
321 |
-
def forward(self, x, context=None):
|
322 |
# note: if no context is given, cross-attention defaults to self-attention
|
323 |
if not isinstance(context, list):
|
324 |
context = [context]
|
|
|
|
|
|
|
|
|
|
|
325 |
b, c, h, w = x.shape
|
|
|
|
|
|
|
|
|
326 |
x_in = x
|
327 |
x = self.norm(x)
|
328 |
if not self.use_linear:
|
@@ -331,7 +377,7 @@ class SpatialTransformer(nn.Module):
|
|
331 |
if self.use_linear:
|
332 |
x = self.proj_in(x)
|
333 |
for i, block in enumerate(self.transformer_blocks):
|
334 |
-
x = block(x, context=context[i])
|
335 |
if self.use_linear:
|
336 |
x = self.proj_out(x)
|
337 |
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
|
|
|
42 |
dim = tensor.shape[-1]
|
43 |
std = 1 / math.sqrt(dim)
|
44 |
tensor.uniform_(-std, std)
|
45 |
+
return tensor
|
46 |
|
47 |
|
48 |
# feedforward
|
|
|
143 |
|
144 |
|
145 |
class CrossAttention(nn.Module):
|
146 |
+
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., **kargs):
|
147 |
super().__init__()
|
148 |
inner_dim = dim_head * heads
|
149 |
context_dim = default(context_dim, query_dim)
|
|
|
160 |
nn.Dropout(dropout)
|
161 |
)
|
162 |
|
163 |
+
def forward(self, x, context=None, mask=None, **kargs):
|
164 |
h = self.heads
|
165 |
|
166 |
q = self.to_q(x)
|
|
|
194 |
return self.to_out(out)
|
195 |
|
196 |
|
197 |
+
class MaskGuidedSelfAttention(nn.Module):
|
198 |
+
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., obj_feat_dim=1024):
|
199 |
+
super().__init__()
|
200 |
+
#here context dim is for object features coming from image encoder
|
201 |
+
inner_dim = dim_head * heads
|
202 |
+
self.heads = heads
|
203 |
+
|
204 |
+
self.obj_feats_map = nn.Linear(obj_feat_dim, inner_dim)
|
205 |
+
self.to_v = nn.Linear(inner_dim, inner_dim, bias=False)
|
206 |
+
|
207 |
+
self.to_out = nn.Sequential(
|
208 |
+
nn.Linear(inner_dim, query_dim),
|
209 |
+
nn.Dropout(dropout)
|
210 |
+
)
|
211 |
+
|
212 |
+
self.scale = dim_head ** -0.5
|
213 |
+
|
214 |
+
def forward(self, x, context=None, mask=None, obj_mask=None, obj_feat=None):
|
215 |
+
_, _, ht, wd = obj_feat.shape
|
216 |
+
obj_feat = rearrange(obj_feat, 'b c h w -> b (h w) c').contiguous()
|
217 |
+
obj_feat = self.obj_feats_map(obj_feat)
|
218 |
+
v = self.to_v(obj_feat)
|
219 |
+
return self.to_out(v)
|
220 |
+
|
221 |
+
|
222 |
+
|
223 |
+
|
224 |
+
|
225 |
class MemoryEfficientCrossAttention(nn.Module):
|
226 |
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
|
227 |
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
|
|
|
274 |
class BasicTransformerBlock(nn.Module):
|
275 |
ATTENTION_MODES = {
|
276 |
"softmax": CrossAttention, # vanilla attention
|
277 |
+
"softmax-xformers": MemoryEfficientCrossAttention,
|
278 |
+
"maskguided": MaskGuidedSelfAttention
|
279 |
}
|
280 |
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
|
281 |
+
disable_self_attn=False, attn1_mode="softmax", obj_feat_dim=1024):
|
282 |
super().__init__()
|
283 |
attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax"
|
284 |
assert attn_mode in self.ATTENTION_MODES
|
285 |
attn_cls = self.ATTENTION_MODES[attn_mode]
|
286 |
+
attn1_cls = self.ATTENTION_MODES[attn1_mode]
|
287 |
self.disable_self_attn = disable_self_attn
|
288 |
+
self.attn1 = attn1_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
|
289 |
+
context_dim=context_dim if self.disable_self_attn else None, obj_feat_dim=obj_feat_dim) # is a self-attention if not self.disable_self_attn
|
290 |
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
291 |
self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim,
|
292 |
heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
|
|
|
295 |
self.norm3 = nn.LayerNorm(dim)
|
296 |
self.checkpoint = checkpoint
|
297 |
|
298 |
+
# self.ff_text_obj_feat = FeedForward(context_dim, dim_out=dim, mult=1, dropout=dropout, glu=gated_ff)
|
299 |
+
|
300 |
+
def forward(self, x, context=None, obj_mask=None, obj_feat=None):
|
301 |
+
if obj_mask is None:
|
302 |
+
# return self._forward(x, context, obj_mask, obj_feat)
|
303 |
+
return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
|
304 |
+
return checkpoint(self._forward, (x, context, obj_mask, obj_feat), self.parameters(), self.checkpoint)
|
305 |
|
306 |
+
def _forward(self, x, context=None, obj_mask=None, obj_feat=None):
|
307 |
+
x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None,
|
308 |
+
obj_mask=obj_mask, obj_feat=obj_feat) + x
|
309 |
x = self.attn2(self.norm2(x), context=context) + x
|
310 |
x = self.ff(self.norm3(x)) + x
|
311 |
return x
|
|
|
323 |
def __init__(self, in_channels, n_heads, d_head,
|
324 |
depth=1, dropout=0., context_dim=None,
|
325 |
disable_self_attn=False, use_linear=False,
|
326 |
+
use_checkpoint=True,attn1_mode='softmax',obj_feat_dim=None):
|
327 |
super().__init__()
|
328 |
if exists(context_dim) and not isinstance(context_dim, list):
|
329 |
context_dim = [context_dim]
|
|
|
341 |
|
342 |
self.transformer_blocks = nn.ModuleList(
|
343 |
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
|
344 |
+
disable_self_attn=disable_self_attn, checkpoint=use_checkpoint, attn1_mode=attn1_mode,
|
345 |
+
obj_feat_dim=obj_feat_dim)
|
346 |
for d in range(depth)]
|
347 |
)
|
348 |
if not use_linear:
|
|
|
355 |
self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
|
356 |
self.use_linear = use_linear
|
357 |
|
358 |
+
def forward(self, x, context=None, obj_masks=None, obj_feats=None):
|
359 |
# note: if no context is given, cross-attention defaults to self-attention
|
360 |
if not isinstance(context, list):
|
361 |
context = [context]
|
362 |
+
if not isinstance(obj_masks, list):
|
363 |
+
obj_masks = [obj_masks]
|
364 |
+
if not isinstance(obj_feats, list):
|
365 |
+
obj_feats = [obj_feats]
|
366 |
+
|
367 |
b, c, h, w = x.shape
|
368 |
+
if obj_feats[0] is not None:
|
369 |
+
obj_feats = [torch.nn.functional.interpolate(ofe, [h,w]) for ofe in obj_feats]
|
370 |
+
obj_masks = [torch.nn.functional.interpolate(om, [h,w]) for om in obj_masks]
|
371 |
+
|
372 |
x_in = x
|
373 |
x = self.norm(x)
|
374 |
if not self.use_linear:
|
|
|
377 |
if self.use_linear:
|
378 |
x = self.proj_in(x)
|
379 |
for i, block in enumerate(self.transformer_blocks):
|
380 |
+
x = block(x, context=context[i], obj_mask=obj_masks[i], obj_feat=obj_feats[i])
|
381 |
if self.use_linear:
|
382 |
x = self.proj_out(x)
|
383 |
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
|
ldm/modules/diffusionmodules/openaimodel.py
CHANGED
@@ -69,19 +69,31 @@ class TimestepBlock(nn.Module):
|
|
69 |
Apply the module to `x` given `emb` timestep embeddings.
|
70 |
"""
|
71 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
|
73 |
-
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
74 |
"""
|
75 |
A sequential module that passes timestep embeddings to the children that
|
76 |
support it as an extra input.
|
77 |
"""
|
78 |
|
79 |
-
def forward(self, x, emb, context=None, *args):
|
80 |
for layer in self:
|
81 |
if isinstance(layer, TimestepBlock):
|
82 |
x = layer(x, emb)
|
83 |
elif isinstance(layer, SpatialTransformer):
|
84 |
-
x = layer(x, context)
|
|
|
|
|
85 |
else:
|
86 |
x = layer(x)
|
87 |
return x
|
@@ -783,4 +795,4 @@ class UNetModel(nn.Module):
|
|
783 |
if self.predict_codebook_ids:
|
784 |
return self.id_predictor(h)
|
785 |
else:
|
786 |
-
return self.out(h)
|
|
|
69 |
Apply the module to `x` given `emb` timestep embeddings.
|
70 |
"""
|
71 |
|
72 |
+
class TimestepBlockSpa(nn.Module):
|
73 |
+
"""
|
74 |
+
Any module where forward() takes timestep embeddings as a second argument.
|
75 |
+
"""
|
76 |
+
|
77 |
+
@abstractmethod
|
78 |
+
def forward(self, x, emb, obj_feat):
|
79 |
+
"""
|
80 |
+
Apply the module to `x` given `emb` timestep embeddings.
|
81 |
+
"""
|
82 |
|
83 |
+
class TimestepEmbedSequential(nn.Sequential, TimestepBlock, TimestepBlockSpa):
|
84 |
"""
|
85 |
A sequential module that passes timestep embeddings to the children that
|
86 |
support it as an extra input.
|
87 |
"""
|
88 |
|
89 |
+
def forward(self, x, emb, context=None, obj_feat=None,obj_masks=None, *args):
|
90 |
for layer in self:
|
91 |
if isinstance(layer, TimestepBlock):
|
92 |
x = layer(x, emb)
|
93 |
elif isinstance(layer, SpatialTransformer):
|
94 |
+
x = layer(x, context, obj_masks=obj_masks, obj_feats=obj_feat)
|
95 |
+
elif isinstance(layer, TimestepBlockSpa):
|
96 |
+
x = layer(x, emb, obj_feat)
|
97 |
else:
|
98 |
x = layer(x)
|
99 |
return x
|
|
|
795 |
if self.predict_codebook_ids:
|
796 |
return self.id_predictor(h)
|
797 |
else:
|
798 |
+
return self.out(h)
|
ldm/modules/diffusionmodules/util.py
CHANGED
@@ -215,9 +215,10 @@ class SiLU(nn.Module):
|
|
215 |
|
216 |
|
217 |
class GroupNorm32(nn.GroupNorm):
|
218 |
-
def forward(self, x):
|
219 |
return super().forward(x.float()).type(x.dtype)
|
220 |
|
|
|
221 |
def conv_nd(dims, *args, **kwargs):
|
222 |
"""
|
223 |
Create a 1D, 2D, or 3D convolution module.
|
|
|
215 |
|
216 |
|
217 |
class GroupNorm32(nn.GroupNorm):
|
218 |
+
def forward(self, x, *args):
|
219 |
return super().forward(x.float()).type(x.dtype)
|
220 |
|
221 |
+
|
222 |
def conv_nd(dims, *args, **kwargs):
|
223 |
"""
|
224 |
Create a 1D, 2D, or 3D convolution module.
|
ldm/modules/encoders/modules.py
CHANGED
@@ -114,14 +114,14 @@ class FrozenCLIPEmbedder(AbstractEncoder):
|
|
114 |
for param in self.parameters():
|
115 |
param.requires_grad = False
|
116 |
|
117 |
-
def forward(self, text):
|
118 |
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
|
119 |
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
|
120 |
tokens = batch_encoding["input_ids"].to(self.device)
|
121 |
outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden")
|
122 |
-
if
|
123 |
z = outputs.last_hidden_state
|
124 |
-
elif
|
125 |
z = outputs.pooler_output[:, None, :]
|
126 |
else:
|
127 |
z = outputs.hidden_states[self.layer_idx]
|
|
|
114 |
for param in self.parameters():
|
115 |
param.requires_grad = False
|
116 |
|
117 |
+
def forward(self, text, layer='last'):
|
118 |
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
|
119 |
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
|
120 |
tokens = batch_encoding["input_ids"].to(self.device)
|
121 |
outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden")
|
122 |
+
if layer == "last":
|
123 |
z = outputs.last_hidden_state
|
124 |
+
elif layer == "pooled":
|
125 |
z = outputs.pooler_output[:, None, :]
|
126 |
else:
|
127 |
z = outputs.hidden_states[self.layer_idx]
|
pair_diff_demo.py
ADDED
@@ -0,0 +1,516 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import einops
|
3 |
+
import gradio as gr
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import random
|
7 |
+
import os
|
8 |
+
import json
|
9 |
+
import datetime
|
10 |
+
from huggingface_hub import hf_hub_url, hf_hub_download
|
11 |
+
|
12 |
+
from pytorch_lightning import seed_everything
|
13 |
+
from annotator.util import resize_image, HWC3
|
14 |
+
from annotator.OneFormer import OneformerSegmenter
|
15 |
+
from cldm.model import create_model, load_state_dict
|
16 |
+
from cldm.ddim_hacked import DDIMSamplerSpaCFG
|
17 |
+
from ldm.models.autoencoder import DiagonalGaussianDistribution
|
18 |
+
|
19 |
+
|
20 |
+
SEGMENT_MODEL_DICT = {
|
21 |
+
'Oneformer': OneformerSegmenter,
|
22 |
+
}
|
23 |
+
|
24 |
+
MASK_MODEL_DICT = {
|
25 |
+
'Oneformer': OneformerSegmenter,
|
26 |
+
}
|
27 |
+
|
28 |
+
urls = {
|
29 |
+
'shi-labs/oneformer_coco_swin_large': ['150_16_swin_l_oneformer_coco_100ep.pth'],
|
30 |
+
'PAIR/PAIR-diffusion-sdv15-coco-finetune': ['model_e91.ckpt']
|
31 |
+
}
|
32 |
+
|
33 |
+
WTS_DICT = {
|
34 |
+
|
35 |
+
}
|
36 |
+
|
37 |
+
if os.path.exists('checkpoints') == False:
|
38 |
+
os.mkdir('checkpoints')
|
39 |
+
for repo in urls:
|
40 |
+
files = urls[repo]
|
41 |
+
for file in files:
|
42 |
+
url = hf_hub_url(repo, file)
|
43 |
+
name_ckp = url.split('/')[-1]
|
44 |
+
|
45 |
+
WTS_DICT[repo] = hf_hub_download(repo_id=repo, filename=file)
|
46 |
+
|
47 |
+
|
48 |
+
#main model
|
49 |
+
model = create_model('configs/pair_diff.yaml').cpu()
|
50 |
+
model.load_state_dict(load_state_dict(WTS_DICT['PAIR/PAIR-diffusion-sdv15-coco-finetune'], location='cuda'))
|
51 |
+
|
52 |
+
save_dir = 'results/'
|
53 |
+
|
54 |
+
model = model.cuda()
|
55 |
+
ddim_sampler = DDIMSamplerSpaCFG(model)
|
56 |
+
save_memory = False
|
57 |
+
|
58 |
+
|
59 |
+
class ImageComp:
|
60 |
+
def __init__(self, edit_operation):
|
61 |
+
self.input_img = None
|
62 |
+
self.input_pmask = None
|
63 |
+
self.input_segmask = None
|
64 |
+
self.input_mask = None
|
65 |
+
self.input_points = []
|
66 |
+
self.input_scale = 1
|
67 |
+
|
68 |
+
self.ref_img = None
|
69 |
+
self.ref_pmask = None
|
70 |
+
self.ref_segmask = None
|
71 |
+
self.ref_mask = None
|
72 |
+
self.ref_points = []
|
73 |
+
self.ref_scale = 1
|
74 |
+
|
75 |
+
self.multi_modal = False
|
76 |
+
|
77 |
+
self.H = None
|
78 |
+
self.W = None
|
79 |
+
self.kernel = np.ones((5, 5), np.uint8)
|
80 |
+
self.edit_operation = edit_operation
|
81 |
+
self.init_segmentation_model()
|
82 |
+
os.makedirs(save_dir, exist_ok=True)
|
83 |
+
|
84 |
+
self.base_prompt = 'A picture of {}'
|
85 |
+
|
86 |
+
def init_segmentation_model(self, mask_model='Oneformer', segment_model='Oneformer'):
|
87 |
+
self.segment_model_name = segment_model
|
88 |
+
self.mask_model_name = mask_model
|
89 |
+
|
90 |
+
self.segment_model = SEGMENT_MODEL_DICT[segment_model](WTS_DICT['shi-labs/oneformer_coco_swin_large'])
|
91 |
+
|
92 |
+
if mask_model == 'Oneformer' and segment_model == 'Oneformer':
|
93 |
+
self.mask_model_inp = self.segment_model
|
94 |
+
self.mask_model_ref = self.segment_model
|
95 |
+
else:
|
96 |
+
self.mask_model_inp = MASK_MODEL_DICT[mask_model]()
|
97 |
+
self.mask_model_ref = MASK_MODEL_DICT[mask_model]()
|
98 |
+
|
99 |
+
print(f"Segmentation Models initialized with {mask_model} as mask and {segment_model} as segment")
|
100 |
+
|
101 |
+
def init_input_canvas(self, img):
|
102 |
+
|
103 |
+
img = HWC3(img)
|
104 |
+
img = resize_image(img, 512)
|
105 |
+
if self.segment_model_name == 'Oneformer':
|
106 |
+
detected_seg = self.segment_model(img, 'semantic')
|
107 |
+
elif self.segment_model_name == 'SAM':
|
108 |
+
raise NotImplementedError
|
109 |
+
|
110 |
+
if self.mask_model_name == 'Oneformer':
|
111 |
+
detected_mask = self.mask_model_inp(img, 'panoptic')[0]
|
112 |
+
elif self.mask_model_name == 'SAM':
|
113 |
+
detected_mask = self.mask_model_inp(img)
|
114 |
+
|
115 |
+
self.input_points = []
|
116 |
+
self.input_img = img
|
117 |
+
self.input_pmask = detected_mask
|
118 |
+
self.input_segmask = detected_seg
|
119 |
+
self.H = img.shape[0]
|
120 |
+
self.W = img.shape[1]
|
121 |
+
|
122 |
+
return img
|
123 |
+
|
124 |
+
def init_ref_canvas(self, img):
|
125 |
+
|
126 |
+
img = HWC3(img)
|
127 |
+
img = resize_image(img, 512)
|
128 |
+
if self.segment_model_name == 'Oneformer':
|
129 |
+
detected_seg = self.segment_model(img, 'semantic')
|
130 |
+
elif self.segment_model_name == 'SAM':
|
131 |
+
raise NotImplementedError
|
132 |
+
|
133 |
+
if self.mask_model_name == 'Oneformer':
|
134 |
+
detected_mask = self.mask_model_ref(img, 'panoptic')[0]
|
135 |
+
elif self.mask_model_name == 'SAM':
|
136 |
+
detected_mask = self.mask_model_ref(img)
|
137 |
+
|
138 |
+
self.ref_points = []
|
139 |
+
print("Initialized ref", img.shape)
|
140 |
+
self.ref_img = img
|
141 |
+
self.ref_pmask = detected_mask
|
142 |
+
self.ref_segmask = detected_seg
|
143 |
+
|
144 |
+
return img
|
145 |
+
|
146 |
+
def select_input_object(self, evt: gr.SelectData):
|
147 |
+
idx = list(np.array(evt.index) * self.input_scale)
|
148 |
+
self.input_points.append(idx)
|
149 |
+
if self.mask_model_name == 'Oneformer':
|
150 |
+
mask = self._get_mask_from_panoptic(np.array(self.input_points), self.input_pmask)
|
151 |
+
else:
|
152 |
+
mask = self.mask_model_inp(self.input_img, self.input_points)
|
153 |
+
|
154 |
+
c_ids = self.input_segmask[np.array(self.input_points)[:,1], np.array(self.input_points)[:,0]]
|
155 |
+
unique_ids, counts = torch.unique(c_ids, return_counts=True)
|
156 |
+
c_id = int(unique_ids[torch.argmax(counts)].cpu().detach().numpy())
|
157 |
+
category = self.segment_model.metadata.stuff_classes[c_id]
|
158 |
+
# print(self.segment_model.metadata.stuff_classes)
|
159 |
+
|
160 |
+
self.input_mask = mask
|
161 |
+
mask = mask.cpu().numpy()
|
162 |
+
output = mask[:,:,None] * self.input_img + (1 - mask[:,:,None]) * self.input_img * 0.2
|
163 |
+
return output.astype(np.uint8), self.base_prompt.format(category)
|
164 |
+
|
165 |
+
def select_ref_object(self, evt: gr.SelectData):
|
166 |
+
idx = list(np.array(evt.index) * self.ref_scale)
|
167 |
+
self.ref_points.append(idx)
|
168 |
+
if self.mask_model_name == 'Oneformer':
|
169 |
+
mask = self._get_mask_from_panoptic(np.array(self.ref_points), self.ref_pmask)
|
170 |
+
else:
|
171 |
+
mask = self.mask_model_ref(self.ref_img, self.ref_points)
|
172 |
+
c_ids = self.ref_segmask[np.array(self.ref_points)[:,1], np.array(self.ref_points)[:,0]]
|
173 |
+
unique_ids, counts = torch.unique(c_ids, return_counts=True)
|
174 |
+
c_id = int(unique_ids[torch.argmax(counts)].cpu().detach().numpy())
|
175 |
+
category = self.segment_model.metadata.stuff_classes[c_id]
|
176 |
+
print("Category of reference object is:", category)
|
177 |
+
|
178 |
+
self.ref_mask = mask
|
179 |
+
mask = mask.cpu().numpy()
|
180 |
+
output = mask[:,:,None] * self.ref_img + (1 - mask[:,:,None]) * self.ref_img * 0.2
|
181 |
+
return output.astype(np.uint8)
|
182 |
+
|
183 |
+
def clear_points(self):
|
184 |
+
self.input_points = []
|
185 |
+
self.ref_points = []
|
186 |
+
zeros_inp = np.zeros(self.input_img.shape)
|
187 |
+
zeros_ref = np.zeros(self.ref_img.shape)
|
188 |
+
return zeros_inp, zeros_ref
|
189 |
+
|
190 |
+
def return_input_img(self):
|
191 |
+
return self.input_img
|
192 |
+
|
193 |
+
|
194 |
+
def _get_mask_from_panoptic(self, points, panoptic_mask):
|
195 |
+
panoptic_mask_ = panoptic_mask + 1
|
196 |
+
ids = panoptic_mask_[points[:,1], points[:,0]]
|
197 |
+
unique_ids, counts = torch.unique(ids, return_counts=True)
|
198 |
+
mask_id = unique_ids[torch.argmax(counts)]
|
199 |
+
final_mask = torch.zeros(panoptic_mask.shape).cuda()
|
200 |
+
final_mask[panoptic_mask_ == mask_id] = 1
|
201 |
+
|
202 |
+
return final_mask
|
203 |
+
|
204 |
+
|
205 |
+
def _process_mask(self, mask, panoptic_mask, segmask):
|
206 |
+
obj_class = mask * (segmask + 1)
|
207 |
+
unique_ids, counts = torch.unique(obj_class, return_counts=True)
|
208 |
+
obj_class = unique_ids[torch.argmax(counts[1:]) + 1] - 1
|
209 |
+
return mask, obj_class
|
210 |
+
|
211 |
+
|
212 |
+
def _edit_app(self, whole_ref):
|
213 |
+
"""
|
214 |
+
Manipulates the panoptic mask of input image to change appearance
|
215 |
+
"""
|
216 |
+
input_pmask = self.input_pmask
|
217 |
+
input_segmask = self.input_segmask
|
218 |
+
|
219 |
+
if whole_ref:
|
220 |
+
reference_mask = torch.ones(self.ref_pmask.shape).cuda()
|
221 |
+
else:
|
222 |
+
reference_mask, _ = self._process_mask(self.ref_mask, self.ref_pmask, self.ref_segmask)
|
223 |
+
|
224 |
+
edit_mask, _ = self._process_mask(self.input_mask, self.input_pmask, self.input_segmask)
|
225 |
+
# tmp = cv2.dilate(edit_mask.squeeze().cpu().numpy(), self.kernel, iterations = 2)
|
226 |
+
# region_mask = torch.tensor(tmp).cuda()
|
227 |
+
region_mask = edit_mask
|
228 |
+
ma = torch.max(input_pmask)
|
229 |
+
|
230 |
+
input_pmask[edit_mask == 1] = ma + 1
|
231 |
+
return reference_mask, input_pmask, input_segmask, region_mask, ma
|
232 |
+
|
233 |
+
def _add_object(self, input_mask, dilation_fac):
|
234 |
+
"""
|
235 |
+
Manipulates the panooptic mask of input image for adding objects
|
236 |
+
Args:
|
237 |
+
input_mask (numpy array): Region where new objects needs to be added
|
238 |
+
dilation factor (float): Controls edge merging region for adding objects
|
239 |
+
|
240 |
+
"""
|
241 |
+
input_pmask = self.input_pmask
|
242 |
+
input_segmask = self.input_segmask
|
243 |
+
reference_mask, obj_class = self._process_mask(self.ref_mask, self.ref_pmask, self.ref_segmask)
|
244 |
+
|
245 |
+
tmp = cv2.dilate(input_mask['mask'][:, :, 0], self.kernel, iterations = int(dilation_fac))
|
246 |
+
region = torch.tensor(tmp)
|
247 |
+
region_mask = torch.zeros_like(region).cuda()
|
248 |
+
region_mask[region > 127] = 1
|
249 |
+
|
250 |
+
mask_ = torch.tensor(input_mask['mask'][:, :, 0])
|
251 |
+
edit_mask = torch.zeros_like(mask_).cuda()
|
252 |
+
edit_mask[mask_ > 127] = 1
|
253 |
+
ma = torch.max(input_pmask)
|
254 |
+
input_pmask[edit_mask == 1] = ma + 1
|
255 |
+
print(obj_class)
|
256 |
+
input_segmask[edit_mask == 1] = obj_class.long()
|
257 |
+
|
258 |
+
return reference_mask, input_pmask, input_segmask, region_mask, ma
|
259 |
+
|
260 |
+
def _edit(self, input_mask, ref_mask, dilation_fac=1, whole_ref=False, inter=1):
|
261 |
+
"""
|
262 |
+
Entry point for all the appearance editing and add objects operations. The function manipulates the
|
263 |
+
appearance vectors and structure based on user input
|
264 |
+
Args:
|
265 |
+
input mask (numpy array): Region in input image which needs to be edited
|
266 |
+
dilation factor (float): Controls edge merging region for adding objects
|
267 |
+
whole_ref (bool): Flag for specifying if complete reference image should be used
|
268 |
+
inter (float): Interpolation of appearance between the reference appearance and the input appearance.
|
269 |
+
"""
|
270 |
+
input_img = (self.input_img/127.5 - 1)
|
271 |
+
input_img = torch.from_numpy(input_img.astype(np.float32)).cuda().unsqueeze(0).permute(0,3,1,2)
|
272 |
+
|
273 |
+
reference_img = (self.ref_img/127.5 - 1)
|
274 |
+
reference_img = torch.from_numpy(reference_img.astype(np.float32)).cuda().unsqueeze(0).permute(0,3,1,2)
|
275 |
+
|
276 |
+
if self.edit_operation == 'add_obj':
|
277 |
+
reference_mask, input_pmask, input_segmask, region_mask, ma = self._add_object(input_mask, dilation_fac)
|
278 |
+
elif self.edit_operation == 'edit_app':
|
279 |
+
reference_mask, input_pmask, input_segmask, region_mask, ma = self._edit_app(whole_ref)
|
280 |
+
|
281 |
+
#concat featurees
|
282 |
+
input_pmask = input_pmask.float().cuda().unsqueeze(0).unsqueeze(1)
|
283 |
+
_, mean_feat_inpt_conc, one_hot_inpt_conc, _ = model.get_appearance(model.appearance_net_conc, model.app_layer_conc, input_img, input_pmask, return_all=True)
|
284 |
+
|
285 |
+
reference_mask = reference_mask.float().cuda().unsqueeze(0).unsqueeze(1)
|
286 |
+
_, mean_feat_ref_conc, _, _ = model.get_appearance(model.appearance_net_conc, model.app_layer_conc, reference_img, reference_mask, return_all=True)
|
287 |
+
|
288 |
+
# if mean_feat_ref.shape[1] > 1:
|
289 |
+
if isinstance(mean_feat_inpt_conc, list):
|
290 |
+
appearance_conc = []
|
291 |
+
for i in range(len(mean_feat_inpt_conc)):
|
292 |
+
mean_feat_inpt_conc[i][:, ma + 1] = (1 - inter) * mean_feat_inpt_conc[i][:, ma + 1] + inter*mean_feat_ref_conc[i][:, 1]
|
293 |
+
splatted_feat_conc = torch.einsum('nmc, nmhw->nchw', mean_feat_inpt_conc[i], one_hot_inpt_conc)
|
294 |
+
splatted_feat_conc = torch.nn.functional.normalize(splatted_feat_conc)
|
295 |
+
splatted_feat_conc = torch.nn.functional.interpolate(splatted_feat_conc, (self.H//8, self.W//8))
|
296 |
+
appearance_conc.append(splatted_feat_conc)
|
297 |
+
appearance_conc = torch.cat(appearance_conc, dim=1)
|
298 |
+
else:
|
299 |
+
print("manipulating")
|
300 |
+
mean_feat_inpt_conc[:, ma + 1] = (1 - inter) * mean_feat_inpt_conc[:, ma + 1] + inter*mean_feat_ref_conc[:, 1]
|
301 |
+
|
302 |
+
splatted_feat_conc = torch.einsum('nmc, nmhw->nchw', mean_feat_inpt_conc, one_hot_inpt_conc)
|
303 |
+
appearance_conc = torch.nn.functional.normalize(splatted_feat_conc) #l2 normaliz
|
304 |
+
appearance_conc = torch.nn.functional.interpolate(appearance_conc, (self.H//8, self.W//8))
|
305 |
+
|
306 |
+
#cross attention features
|
307 |
+
_, mean_feat_inpt_ca, one_hot_inpt_ca, _ = model.get_appearance(model.appearance_net_ca, model.app_layer_ca, input_img, input_pmask, return_all=True)
|
308 |
+
|
309 |
+
_, mean_feat_ref_ca, _, _ = model.get_appearance(model.appearance_net_ca, model.app_layer_ca, reference_img, reference_mask, return_all=True)
|
310 |
+
|
311 |
+
# if mean_feat_ref.shape[1] > 1:
|
312 |
+
if isinstance(mean_feat_inpt_ca, list):
|
313 |
+
appearance_ca = []
|
314 |
+
for i in range(len(mean_feat_inpt_ca)):
|
315 |
+
mean_feat_inpt_ca[i][:, ma + 1] = (1 - inter) * mean_feat_inpt_ca[i][:, ma + 1] + inter*mean_feat_ref_ca[i][:, 1]
|
316 |
+
splatted_feat_ca = torch.einsum('nmc, nmhw->nchw', mean_feat_inpt_ca[i], one_hot_inpt_ca)
|
317 |
+
splatted_feat_ca = torch.nn.functional.normalize(splatted_feat_ca)
|
318 |
+
splatted_feat_ca = torch.nn.functional.interpolate(splatted_feat_ca, (self.H//8, self.W//8))
|
319 |
+
appearance_ca.append(splatted_feat_ca)
|
320 |
+
else:
|
321 |
+
print("manipulating")
|
322 |
+
mean_feat_inpt_ca[:, ma + 1] = (1 - inter) * mean_feat_inpt_ca[:, ma + 1] + inter*mean_feat_ref_ca[:, 1]
|
323 |
+
|
324 |
+
splatted_feat_ca = torch.einsum('nmc, nmhw->nchw', mean_feat_inpt_ca, one_hot_inpt_ca)
|
325 |
+
appearance_ca = torch.nn.functional.normalize(splatted_feat_ca) #l2 normaliz
|
326 |
+
appearance_ca = torch.nn.functional.interpolate(appearance_ca, (self.H//8, self.W//8))
|
327 |
+
|
328 |
+
|
329 |
+
|
330 |
+
input_segmask = ((input_segmask+1)/ 127.5 - 1.0).cuda().unsqueeze(0).unsqueeze(1)
|
331 |
+
structure = torch.nn.functional.interpolate(input_segmask, (self.H//8, self.W//8))
|
332 |
+
|
333 |
+
|
334 |
+
return structure, appearance_conc, appearance_ca, region_mask, input_img
|
335 |
+
|
336 |
+
def _edit_obj_var(self, input_mask, ignore_structure):
|
337 |
+
input_img = (self.input_img/127.5 - 1)
|
338 |
+
input_img = torch.from_numpy(input_img.astype(np.float32)).cuda().unsqueeze(0).permute(0,3,1,2)
|
339 |
+
|
340 |
+
|
341 |
+
input_pmask = self.input_pmask
|
342 |
+
input_segmask = self.input_segmask
|
343 |
+
|
344 |
+
ma = torch.max(input_pmask)
|
345 |
+
mask_ = torch.tensor(input_mask['mask'][:, :, 0])
|
346 |
+
edit_mask = torch.zeros_like(mask_).cuda()
|
347 |
+
edit_mask[mask_ > 127] = 1
|
348 |
+
tmp = edit_mask * (input_pmask + ma + 1)
|
349 |
+
if ignore_structure:
|
350 |
+
tmp = edit_mask
|
351 |
+
|
352 |
+
input_pmask = tmp * edit_mask + (1 - edit_mask) * input_pmask
|
353 |
+
|
354 |
+
input_pmask = input_pmask.float().cuda().unsqueeze(0).unsqueeze(1)
|
355 |
+
|
356 |
+
mask_ca_feat = self.input_pmask.float().cuda().unsqueeze(0).unsqueeze(1) if ignore_structure else input_pmask
|
357 |
+
print(torch.unique(mask_ca_feat))
|
358 |
+
|
359 |
+
appearance_conc,_,_,_ = model.get_appearance(model.appearance_net_conc, model.app_layer_conc, input_img, input_pmask, return_all=True)
|
360 |
+
appearance_ca = model.get_appearance(model.appearance_net_ca, model.app_layer_ca, input_img, mask_ca_feat)
|
361 |
+
|
362 |
+
appearance_conc = torch.nn.functional.interpolate(appearance_conc, (self.H//8, self.W//8))
|
363 |
+
appearance_ca = [torch.nn.functional.interpolate(ap, (self.H//8, self.W//8)) for ap in appearance_ca]
|
364 |
+
|
365 |
+
input_segmask = ((input_segmask+1)/ 127.5 - 1.0).cuda().unsqueeze(0).unsqueeze(1)
|
366 |
+
structure = torch.nn.functional.interpolate(input_segmask, (self.H//8, self.W//8))
|
367 |
+
|
368 |
+
|
369 |
+
tmp = input_mask['mask'][:, :, 0]
|
370 |
+
region = torch.tensor(tmp)
|
371 |
+
mask = torch.zeros_like(region).cuda()
|
372 |
+
mask[region > 127] = 1
|
373 |
+
|
374 |
+
return structure, appearance_conc, appearance_ca, mask, input_img
|
375 |
+
|
376 |
+
def get_caption(self, mask):
|
377 |
+
"""
|
378 |
+
Generates the captions based on a set template
|
379 |
+
Args:
|
380 |
+
mask (numpy array): Region of image based on which caption needs to be generated
|
381 |
+
"""
|
382 |
+
mask = mask['mask'][:, :, 0]
|
383 |
+
region = torch.tensor(mask).cuda()
|
384 |
+
mask = torch.zeros_like(region)
|
385 |
+
mask[region > 127] = 1
|
386 |
+
|
387 |
+
if torch.sum(mask) == 0:
|
388 |
+
return ""
|
389 |
+
|
390 |
+
c_ids = self.input_segmask * mask
|
391 |
+
unique_ids, counts = torch.unique(c_ids, return_counts=True)
|
392 |
+
c_id = int(unique_ids[torch.argmax(counts[1:]) + 1].cpu().detach().numpy())
|
393 |
+
category = self.segment_model.metadata.stuff_classes[c_id]
|
394 |
+
|
395 |
+
return self.base_prompt.format(category)
|
396 |
+
|
397 |
+
def save_result(self, input_mask, prompt, a_prompt, n_prompt,
|
398 |
+
ddim_steps, scale_s, scale_f, scale_t, seed, dilation_fac=1,inter=1,
|
399 |
+
free_form_obj_var=False, ignore_structure=False):
|
400 |
+
"""
|
401 |
+
Saves the current results with all the meta data
|
402 |
+
"""
|
403 |
+
|
404 |
+
meta_data = {}
|
405 |
+
meta_data['prompt'] = prompt
|
406 |
+
meta_data['a_prompt'] = a_prompt
|
407 |
+
meta_data['n_prompt'] = n_prompt
|
408 |
+
meta_data['seed'] = seed
|
409 |
+
meta_data['ddim_steps'] = ddim_steps
|
410 |
+
meta_data['scale_s'] = scale_s
|
411 |
+
meta_data['scale_f'] = scale_f
|
412 |
+
meta_data['scale_t'] = scale_t
|
413 |
+
meta_data['inter'] = inter
|
414 |
+
meta_data['dilation_fac'] = dilation_fac
|
415 |
+
meta_data['edit_operation'] = self.edit_operation
|
416 |
+
|
417 |
+
uuid = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
|
418 |
+
os.makedirs(f'{save_dir}/{uuid}')
|
419 |
+
|
420 |
+
with open(f'{save_dir}/{uuid}/meta.json', "w") as outfile:
|
421 |
+
json.dump(meta_data, outfile)
|
422 |
+
cv2.imwrite(f'{save_dir}/{uuid}/input.png', self.input_img[:,:,::-1])
|
423 |
+
cv2.imwrite(f'{save_dir}/{uuid}/ref.png', self.ref_img[:,:,::-1])
|
424 |
+
if self.ref_mask is not None:
|
425 |
+
cv2.imwrite(f'{save_dir}/{uuid}/ref_mask.png', self.ref_mask.cpu().squeeze().numpy() * 200)
|
426 |
+
for i in range(len(self.results)):
|
427 |
+
cv2.imwrite(f'{save_dir}/{uuid}/edit{i}.png', self.results[i][:,:,::-1])
|
428 |
+
|
429 |
+
if self.edit_operation == 'add_obj' or free_form_obj_var:
|
430 |
+
cv2.imwrite(f'{save_dir}/{uuid}/input_mask.png', input_mask['mask'] * 200)
|
431 |
+
else:
|
432 |
+
cv2.imwrite(f'{save_dir}/{uuid}/input_mask.png', self.input_mask.cpu().squeeze().numpy() * 200)
|
433 |
+
|
434 |
+
print("Saved results at", f'{save_dir}/{uuid}')
|
435 |
+
|
436 |
+
|
437 |
+
def process(self, input_mask, ref_mask, prompt, a_prompt, n_prompt,
|
438 |
+
num_samples, ddim_steps, guess_mode, strength,
|
439 |
+
scale_s, scale_f, scale_t, seed, eta, dilation_fac=1,masking=True,whole_ref=False,inter=1,
|
440 |
+
free_form_obj_var=False, ignore_structure=False):
|
441 |
+
|
442 |
+
print(prompt)
|
443 |
+
if free_form_obj_var:
|
444 |
+
print("Free form")
|
445 |
+
structure, appearance_conc, appearance_ca, mask, img = self._edit_obj_var(input_mask, ignore_structure)
|
446 |
+
else:
|
447 |
+
structure, appearance_conc, appearance_ca, mask, img = self._edit(input_mask, ref_mask, dilation_fac=dilation_fac,
|
448 |
+
whole_ref=whole_ref, inter=inter)
|
449 |
+
|
450 |
+
input_pmask = torch.nn.functional.interpolate(self.input_pmask.cuda().unsqueeze(0).unsqueeze(1).float(), (self.H//8, self.W//8))
|
451 |
+
input_pmask = input_pmask.to(memory_format=torch.contiguous_format)
|
452 |
+
|
453 |
+
|
454 |
+
if isinstance(appearance_ca, list):
|
455 |
+
null_appearance_ca = [torch.zeros(a.shape).cuda() for a in appearance_ca]
|
456 |
+
null_appearance_conc = torch.zeros(appearance_conc.shape).cuda()
|
457 |
+
null_structure = torch.zeros(structure.shape).cuda() - 1
|
458 |
+
|
459 |
+
null_control = [torch.cat([null_structure, napp, input_pmask * 0], dim=1) for napp in null_appearance_ca]
|
460 |
+
structure_control = [torch.cat([structure, napp, input_pmask], dim=1) for napp in null_appearance_ca]
|
461 |
+
full_control = [torch.cat([structure, napp, input_pmask], dim=1) for napp in appearance_ca]
|
462 |
+
|
463 |
+
null_control.append(torch.cat([null_structure, null_appearance_conc, null_structure * 0], dim=1))
|
464 |
+
structure_control.append(torch.cat([structure, null_appearance_conc, null_structure], dim=1))
|
465 |
+
full_control.append(torch.cat([structure, appearance_conc, input_pmask], dim=1))
|
466 |
+
|
467 |
+
null_control = [torch.cat([nc for _ in range(num_samples)], dim=0) for nc in null_control]
|
468 |
+
structure_control = [torch.cat([sc for _ in range(num_samples)], dim=0) for sc in structure_control]
|
469 |
+
full_control = [torch.cat([fc for _ in range(num_samples)], dim=0) for fc in full_control]
|
470 |
+
|
471 |
+
#Masking for local edit
|
472 |
+
if not masking:
|
473 |
+
mask, x0 = None, None
|
474 |
+
else:
|
475 |
+
x0 = model.encode_first_stage(img)
|
476 |
+
x0 = x0.sample() if isinstance(x0, DiagonalGaussianDistribution) else x0 # todo: check if we can set random number
|
477 |
+
x0 = x0 * model.scale_factor
|
478 |
+
mask = 1 - torch.tensor(mask).unsqueeze(0).unsqueeze(1).cuda()
|
479 |
+
mask = torch.nn.functional.interpolate(mask.float(), x0.shape[2:]).float()
|
480 |
+
|
481 |
+
if seed == -1:
|
482 |
+
seed = random.randint(0, 65535)
|
483 |
+
seed_everything(seed)
|
484 |
+
|
485 |
+
scale = [scale_s, scale_f, scale_t]
|
486 |
+
print(scale)
|
487 |
+
if save_memory:
|
488 |
+
model.low_vram_shift(is_diffusing=False)
|
489 |
+
|
490 |
+
uc_cross = model.get_learned_conditioning([n_prompt] * num_samples)
|
491 |
+
c_cross = model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)
|
492 |
+
cond = {"c_concat": [null_control], "c_crossattn": [c_cross]}
|
493 |
+
un_cond = {"c_concat": None if guess_mode else [null_control], "c_crossattn": [uc_cross]}
|
494 |
+
un_cond_struct = {"c_concat": None if guess_mode else [structure_control], "c_crossattn": [uc_cross]}
|
495 |
+
un_cond_struct_app = {"c_concat": None if guess_mode else [full_control], "c_crossattn": [uc_cross]}
|
496 |
+
|
497 |
+
shape = (4, self.H // 8, self.W // 8)
|
498 |
+
|
499 |
+
if save_memory:
|
500 |
+
model.low_vram_shift(is_diffusing=True)
|
501 |
+
|
502 |
+
model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01
|
503 |
+
samples, _ = ddim_sampler.sample(ddim_steps, num_samples,
|
504 |
+
shape, cond, verbose=False, eta=eta,
|
505 |
+
unconditional_guidance_scale=scale, mask=mask, x0=x0,
|
506 |
+
unconditional_conditioning=[un_cond, un_cond_struct, un_cond_struct_app ])
|
507 |
+
|
508 |
+
if save_memory:
|
509 |
+
model.low_vram_shift(is_diffusing=False)
|
510 |
+
|
511 |
+
x_samples = (model.decode_first_stage(samples) + 1) * 127.5
|
512 |
+
x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c')).cpu().numpy().clip(0, 255).astype(np.uint8)
|
513 |
+
|
514 |
+
results = [x_samples[i] for i in range(num_samples)]
|
515 |
+
self.results = results
|
516 |
+
return [] + results
|
requirements.txt
CHANGED
@@ -9,6 +9,7 @@ omegaconf==2.3.0
|
|
9 |
open-clip-torch==2.0.2
|
10 |
opencv-contrib-python==4.3.0.36
|
11 |
opencv-python-headless==4.7.0.72
|
|
|
12 |
prettytable==3.6.0
|
13 |
pytorch-lightning==1.5.0
|
14 |
safetensors==0.2.7
|
@@ -44,4 +45,4 @@ diffdist
|
|
44 |
gdown
|
45 |
huggingface_hub
|
46 |
tqdm
|
47 |
-
wget
|
|
|
9 |
open-clip-torch==2.0.2
|
10 |
opencv-contrib-python==4.3.0.36
|
11 |
opencv-python-headless==4.7.0.72
|
12 |
+
pillow==9.4.0
|
13 |
prettytable==3.6.0
|
14 |
pytorch-lightning==1.5.0
|
15 |
safetensors==0.2.7
|
|
|
45 |
gdown
|
46 |
huggingface_hub
|
47 |
tqdm
|
48 |
+
wget
|