silentchen commited on
Commit
09ce3f4
·
1 Parent(s): 3ab28ab

Upload 13 files

Browse files
app.py CHANGED
@@ -1,8 +1,10 @@
1
  import gradio as gr
2
  import torch
3
  from omegaconf import OmegaConf
4
- # from gligen.task_grounded_generation import grounded_generation_box, load_ckpt, load_common_ckpt
5
-
 
 
6
  import json
7
  import numpy as np
8
  from PIL import Image, ImageDraw, ImageFont
@@ -10,7 +12,7 @@ from functools import partial
10
  from collections import Counter
11
  import math
12
  import gc
13
-
14
  from gradio import processing_utils
15
  from typing import Optional
16
 
@@ -184,77 +186,6 @@ class Blocks(gr.Blocks):
184
  '''
185
  inference model
186
  '''
187
-
188
-
189
- @torch.no_grad()
190
- def inference(task, language_instruction, grounding_instruction, inpainting_boxes_nodrop, image,
191
- alpha_sample, guidance_scale, batch_size,
192
- fix_seed, rand_seed, actual_mask, style_image,
193
- *args, **kwargs):
194
- grounding_instruction = json.loads(grounding_instruction)
195
- phrase_list, location_list = [], []
196
- for k, v in grounding_instruction.items():
197
- phrase_list.append(k)
198
- location_list.append(v)
199
-
200
- placeholder_image = Image.open('images/teddy.jpg').convert("RGB")
201
- image_list = [placeholder_image] * len(phrase_list) # placeholder input for visual prompt, which is disabled
202
-
203
- batch_size = int(batch_size)
204
- if not 1 <= batch_size <= 4:
205
- batch_size = 2
206
-
207
- if style_image == None:
208
- has_text_mask = 1
209
- has_image_mask = 0 # then we hack above 'image_list'
210
- else:
211
- valid_phrase_len = len(phrase_list)
212
-
213
- phrase_list += ['placeholder']
214
- has_text_mask = [1] * valid_phrase_len + [0]
215
-
216
- image_list = [placeholder_image] * valid_phrase_len + [style_image]
217
- has_image_mask = [0] * valid_phrase_len + [1]
218
-
219
- location_list += [[0.0, 0.0, 1, 0.01]] # style image grounding location
220
-
221
- if task == 'Grounded Inpainting':
222
- alpha_sample = 1.0
223
-
224
- instruction = dict(
225
- prompt=language_instruction,
226
- phrases=phrase_list,
227
- images=image_list,
228
- locations=location_list,
229
- alpha_type=[alpha_sample, 0, 1.0 - alpha_sample],
230
- has_text_mask=has_text_mask,
231
- has_image_mask=has_image_mask,
232
- save_folder_name=language_instruction,
233
- guidance_scale=guidance_scale,
234
- batch_size=batch_size,
235
- fix_seed=bool(fix_seed),
236
- rand_seed=int(rand_seed),
237
- actual_mask=actual_mask,
238
- inpainting_boxes_nodrop=inpainting_boxes_nodrop,
239
- )
240
-
241
- get_model = partial(instance.get_model,
242
- batch_size=batch_size,
243
- instruction=language_instruction,
244
- phrase_list=phrase_list)
245
-
246
- with torch.autocast(device_type='cuda', dtype=torch.float16):
247
- if task == 'Grounded Generation':
248
- if style_image == None:
249
- return grounded_generation_box(get_model('base'), instruction, *args, **kwargs)
250
- else:
251
- return grounded_generation_box(get_model('style'), instruction, *args, **kwargs)
252
- elif task == 'Grounded Inpainting':
253
- assert image is not None
254
- instruction['input_image'] = image.convert("RGB")
255
- return grounded_generation_box(get_model('inpaint'), instruction, *args, **kwargs)
256
-
257
-
258
  def draw_box(boxes=[], texts=[], img=None):
259
  if len(boxes) == 0 and img is None:
260
  return None
@@ -275,6 +206,106 @@ def draw_box(boxes=[], texts=[], img=None):
275
  fill=(255, 255, 255))
276
  return img
277
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
278
 
279
  def get_concat(ims):
280
  if len(ims) == 1:
@@ -297,13 +328,11 @@ def auto_append_grounding(language_instruction, grounding_texts):
297
  return language_instruction
298
 
299
 
300
- def generate(task, language_instruction, grounding_texts, sketch_pad,
301
- alpha_sample, guidance_scale, batch_size,
302
- fix_seed, rand_seed, use_actual_mask, append_grounding, style_cond_image,
303
  state):
304
  if 'boxes' not in state:
305
  state['boxes'] = []
306
-
307
  boxes = state['boxes']
308
  grounding_texts = [x.strip() for x in grounding_texts.split(';')]
309
  # assert len(boxes) == len(grounding_texts)
@@ -315,44 +344,19 @@ Please draw boxes accordingly on the sketch pad.""".format(len(boxes), len(groun
315
  grounding_texts = grounding_texts + [""] * (len(boxes) - len(grounding_texts))
316
 
317
  boxes = (np.asarray(boxes) / 512).tolist()
 
318
  grounding_instruction = json.dumps({obj: box for obj, box in zip(grounding_texts, boxes)})
319
-
320
- image = None
321
- actual_mask = None
322
- if task == 'Grounded Inpainting':
323
- image = state.get('original_image', sketch_pad['image']).copy()
324
- image = center_crop(image)
325
- image = Image.fromarray(image)
326
-
327
- if use_actual_mask:
328
- actual_mask = sketch_pad['mask'].copy()
329
- if actual_mask.ndim == 3:
330
- actual_mask = actual_mask[..., 0]
331
- actual_mask = center_crop(actual_mask, tgt_size=(64, 64))
332
- actual_mask = torch.from_numpy(actual_mask == 0).float()
333
-
334
- if state.get('inpaint_hw', None):
335
- boxes = np.asarray(boxes) * 0.9 + 0.05
336
- boxes = boxes.tolist()
337
- grounding_instruction = json.dumps({obj: box for obj, box in zip(grounding_texts, boxes) if obj != 'auto'})
338
-
339
- if append_grounding:
340
- language_instruction = auto_append_grounding(language_instruction, grounding_texts)
341
-
342
- gen_images, gen_overlays = inference(
343
- task, language_instruction, grounding_instruction, boxes, image,
344
- alpha_sample, guidance_scale, batch_size,
345
- fix_seed, rand_seed, actual_mask, style_cond_image, clip_model=clip_model,
346
- )
347
-
348
- for idx, gen_image in enumerate(gen_images):
349
-
350
- if task == 'Grounded Inpainting' and state.get('inpaint_hw', None):
351
- hw = min(*state['original_image'].shape[:2])
352
- gen_image = sized_center_fill(state['original_image'].copy(), np.array(gen_image.resize((hw, hw))), hw, hw)
353
- gen_image = Image.fromarray(gen_image)
354
-
355
- gen_images[idx] = gen_image
356
 
357
  blank_samples = batch_size % 2 if batch_size > 1 else 0
358
  gen_images = [gr.Image.update(value=x, visible=True) for i, x in enumerate(gen_images)] \
@@ -401,35 +405,18 @@ def center_crop(img, HW=None, tgt_size=(512, 512)):
401
  return np.array(img)
402
 
403
 
404
- def draw(task, input, grounding_texts, new_image_trigger, state):
 
405
  if type(input) == dict:
406
  image = input['image']
407
  mask = input['mask']
408
  else:
409
  mask = input
410
-
411
  if mask.ndim == 3:
412
  mask = mask[..., 0]
413
 
414
  image_scale = 1.0
415
 
416
- # resize trigger
417
- if task == "Grounded Inpainting":
418
- mask_cond = mask.sum() == 0
419
- # size_cond = mask.shape != (512, 512)
420
- if mask_cond and 'original_image' not in state:
421
- image = Image.fromarray(image)
422
- width, height = image.size
423
- scale = 600 / min(width, height)
424
- image = image.resize((int(width * scale), int(height * scale)))
425
- state['original_image'] = np.array(image).copy()
426
- image_scale = float(height / width)
427
- return [None, new_image_trigger + 1, image_scale, state]
428
- else:
429
- original_image = state['original_image']
430
- H, W = original_image.shape[:2]
431
- image_scale = float(H / W)
432
-
433
  mask = binarize(mask)
434
  if mask.shape != (512, 512):
435
  # assert False, "should not receive any non- 512x512 masks."
@@ -444,13 +431,10 @@ def draw(task, input, grounding_texts, new_image_trigger, state):
444
  if type(mask) != np.ndarray:
445
  mask = np.array(mask)
446
 
447
- if mask.sum() == 0 and task != "Grounded Inpainting":
448
  state = {}
449
 
450
- if task != 'Grounded Inpainting':
451
- image = None
452
- else:
453
- image = Image.fromarray(image)
454
 
455
  if 'boxes' not in state:
456
  state['boxes'] = []
@@ -488,7 +472,6 @@ def draw(task, input, grounding_texts, new_image_trigger, state):
488
  box_image_resize = np.array(box_image.resize((inpaint_hw, inpaint_hw)))
489
  original_image = state['original_image'].copy()
490
  box_image = sized_center_fill(original_image, box_image_resize, inpaint_hw, inpaint_hw)
491
- print(box_image, new_image_trigger, image_scale, state)
492
  return [box_image, new_image_trigger, image_scale, state]
493
 
494
 
@@ -518,6 +501,37 @@ css = """
518
  cursor: pointer;
519
  text-decoration: none;
520
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
521
  """
522
 
523
  rescale_js = """
@@ -536,71 +550,84 @@ function(x) {
536
  with Blocks(
537
  css=css,
538
  analytics_enabled=False,
539
- title="GLIGen demo",
540
  ) as main:
541
  description = """<p style="text-align: center; font-weight: bold;">
542
  <span style="font-size: 28px">Layout Guidance</span>
543
  <br>
544
  <span style="font-size: 18px" id="paper-info">
545
- [<a href="https://gligen.github.io" target="_blank">Project Page</a>]
546
- [<a href="https://arxiv.org/abs/2301.07093" target="_blank">Paper</a>]
547
- [<a href="https://github.com/gligen/GLIGEN" target="_blank">GitHub</a>]
548
- [<a href="https://huggingface.co/spaces/gligen/demo_legacy" target="_blank">Mirror</a>]
549
  </span>
550
  </p>
551
  """
552
  gr.HTML(description)
553
-
554
- with gr.Row():
555
- with gr.Column(scale=4):
556
- sketch_pad_trigger = gr.Number(value=0, visible=False)
557
- sketch_pad_resize_trigger = gr.Number(value=0, visible=False)
558
- init_white_trigger = gr.Number(value=0, visible=False)
559
- image_scale = gr.Number(value=0, elem_id="image_scale", visible=False)
560
- new_image_trigger = gr.Number(value=0, visible=False)
561
-
562
- # task = gr.Radio(
563
- # choices=["Grounded Generation", 'Grounded Inpainting'],
564
- # type="value",
565
- # value="Grounded Generation",
566
- # label="Task",
567
- # )
568
- language_instruction = gr.Textbox(
569
- label="Text Caption",
570
- )
571
- grounding_instruction = gr.Textbox(
572
- label="Grounding instruction (Separated by semicolon)",
573
- )
574
- with gr.Row():
575
- sketch_pad = ImageMask(label="Sketch Pad", elem_id="img2img_image")
576
- out_imagebox = gr.Image(type="pil", label="Parsed Sketch Pad")
577
- with gr.Row():
578
- clear_btn = gr.Button(value='Clear')
579
- gen_btn = gr.Button(value='Generate')
580
- with gr.Accordion("Advanced Options", open=False):
581
- with gr.Column():
582
- Loss_scale = gr.Slider(minimum=0, maximum=500, step=5, value=30,
583
- label="Loss Scale Factor")
584
- guidance_scale = gr.Slider(minimum=0, maximum=50, step=0.5, value=7.5, label="Guidance Scale")
585
- batch_size = gr.Slider(minimum=1, maximum=4, step=1, value=2, label="Number of Samples")
586
- max_iter = gr.Slider(minimum=0, maximum=10, step=1, value=5, label="Max Iteration per Step")
587
- loss_threshold = gr.Slider(minimum=0, maximum=1, step=0.1, value=0.2, label="Loss Threshold")
588
- max_step = gr.Slider(minimum=0, maximum=50, step=1, value=10, label="Max Step of Backward Guidance")
589
-
590
- # append_grounding = gr.Checkbox(value=True, label="Append grounding instructions to the caption")
591
- # use_actual_mask = gr.Checkbox(value=False, label="Use actual mask for inpainting", visible=False)
592
- with gr.Row():
593
- fix_seed = gr.Checkbox(value=True, label="Fixed seed")
594
- rand_seed = gr.Slider(minimum=0, maximum=1000, step=1, value=0, label="Seed")
595
-
596
- with gr.Column(scale=4):
597
- gr.HTML('<span style="font-size: 20px; font-weight: bold">Generated Images</span>')
598
- with gr.Row():
599
- out_gen_1 = gr.Image(type="pil", visible=True, show_label=False, label="Generated Image")
600
- out_gen_2 = gr.Image(type="pil", visible=True, show_label=False)
601
- with gr.Row():
602
- out_gen_3 = gr.Image(type="pil", visible=False, show_label=False)
603
- out_gen_4 = gr.Image(type="pil", visible=False, show_label=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
604
 
605
  state = gr.State({})
606
 
@@ -658,28 +685,22 @@ with Blocks(
658
  queue=False)
659
  sketch_pad.edit(
660
  draw,
661
- inputs=[sketch_pad, sketch_pad, grounding_instruction, sketch_pad_resize_trigger, state],
662
  outputs=[out_imagebox, sketch_pad_resize_trigger, image_scale, state],
663
  queue=False,
664
  )
665
  grounding_instruction.change(
666
  draw,
667
- inputs=[sketch_pad, sketch_pad, grounding_instruction, sketch_pad_resize_trigger, state],
668
  outputs=[out_imagebox, sketch_pad_resize_trigger, image_scale, state],
669
  queue=False,
670
  )
671
  clear_btn.click(
672
  clear,
673
  inputs=[sketch_pad_trigger, sketch_pad_trigger, batch_size, state],
674
- outputs=[sketch_pad, sketch_pad_trigger, out_imagebox, image_scale, out_gen_1, out_gen_2, out_gen_3,
675
- out_gen_4, state],
676
  queue=False)
677
- # task.change(
678
- # partial(clear, switch_task=True),
679
- # inputs=[task, sketch_pad_trigger, batch_size, state],
680
- # outputs=[sketch_pad, sketch_pad_trigger, out_imagebox, image_scale, out_gen_1, out_gen_2, out_gen_3,
681
- # out_gen_4, state],
682
- # queue=False)
683
  sketch_pad_trigger.change(
684
  controller.init_white,
685
  inputs=[init_white_trigger],
@@ -690,29 +711,28 @@ with Blocks(
690
  inputs=[state],
691
  outputs=[sketch_pad, state],
692
  queue=False)
693
- batch_size.change(
694
- controller.change_n_samples,
695
- inputs=[batch_size],
696
- outputs=[out_gen_1, out_gen_2, out_gen_3, out_gen_4],
697
- queue=False)
698
 
699
- batch_size.change(
700
- controller.change_n_samples,
701
- inputs=[batch_size],
702
- outputs=[out_gen_1, out_gen_2, out_gen_3, out_gen_4],
703
- queue=False)
704
 
705
  gen_btn.click(
706
  generate,
707
  inputs=[
708
- language_instruction, language_instruction, grounding_instruction, sketch_pad,
709
- loss_threshold, guidance_scale, batch_size,
710
- fix_seed, rand_seed,
711
  max_step,
712
  Loss_scale, max_iter,
713
  state,
714
  ],
715
- outputs=[out_gen_1, out_gen_2, out_gen_3, out_gen_4, state],
716
  queue=True
717
  )
718
  sketch_pad_resize_trigger.change(
@@ -732,13 +752,13 @@ with Blocks(
732
  gr.Examples(
733
  examples=[
734
  [
735
- "images/input.png",
736
  "A hello kitty toy is playing with a purple ball.",
737
  "hello kitty;ball",
738
  "images/hello_kitty_results.png"
739
  ],
740
  ],
741
- inputs=[sketch_pad, language_instruction, grounding_instruction, out_gen_1],
742
  outputs=None,
743
  fn=None,
744
  cache_examples=False,
@@ -746,3 +766,4 @@ with Blocks(
746
 
747
  main.queue(concurrency_count=1, api_open=False)
748
  main.launch(share=False, show_api=False, show_error=True)
 
 
1
  import gradio as gr
2
  import torch
3
  from omegaconf import OmegaConf
4
+ # from layout_guidance.inference import inference
5
+ from transformers import CLIPTextModel, CLIPTokenizer
6
+ from diffusers import AutoencoderKL, LMSDiscreteScheduler
7
+ from my_model import unet_2d_condition
8
  import json
9
  import numpy as np
10
  from PIL import Image, ImageDraw, ImageFont
 
12
  from collections import Counter
13
  import math
14
  import gc
15
+ from utils import compute_ca_loss
16
  from gradio import processing_utils
17
  from typing import Optional
18
 
 
186
  '''
187
  inference model
188
  '''
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
  def draw_box(boxes=[], texts=[], img=None):
190
  if len(boxes) == 0 and img is None:
191
  return None
 
206
  fill=(255, 255, 255))
207
  return img
208
 
209
+ with open('./conf/unet/config.json') as f:
210
+ unet_config = json.load(f)
211
+
212
+ unet = unet_2d_condition.UNet2DConditionModel(**unet_config).from_pretrained('runwayml/stable-diffusion-v1-5', subfolder="unet")
213
+ tokenizer = CLIPTokenizer.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="tokenizer")
214
+ text_encoder = CLIPTextModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="text_encoder")
215
+ vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae")
216
+ attn_map = None
217
+ cfg = OmegaConf.load('./conf/net_conf.yaml')
218
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
219
+ unet.to(device)
220
+ text_encoder.to(device)
221
+ vae.to(device)
222
+ def inference(device, unet, vae, tokenizer, text_encoder, prompt, cfg,attn_map, bboxes, object_positions, batch_size, loss_scale, loss_threshold, max_iter, max_index_step, rand_seed, guidance_scale):
223
+ uncond_input = tokenizer(
224
+ [""] * 1, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt"
225
+ )
226
+ uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0]
227
+
228
+ input_ids = tokenizer(
229
+ prompt,
230
+ padding="max_length",
231
+ truncation=True,
232
+ max_length=tokenizer.model_max_length,
233
+ return_tensors="pt",
234
+ ).input_ids[0].unsqueeze(0).to(device)
235
+ # text_embeddings = text_encoder(input_ids)[0]
236
+ text_embeddings = torch.cat([uncond_embeddings, text_encoder(input_ids)[0]])
237
+ # text_embeddings[1, 1, :] = text_embeddings[1, 2, :]
238
+ generator = torch.manual_seed(rand_seed) # Seed generator to create the inital latent noise
239
+
240
+ latents = torch.randn(
241
+ (batch_size, 4, 64, 64),
242
+ generator=generator,
243
+ ).to(device)
244
+
245
+ noise_scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
246
+
247
+ # generator = torch.Generator("cuda").manual_seed(1024)
248
+ noise_scheduler.set_timesteps(51)
249
+
250
+ latents = latents * noise_scheduler.init_noise_sigma
251
+
252
+ loss = torch.tensor(10000)
253
+
254
+ for index, t in enumerate(noise_scheduler.timesteps):
255
+ iteration = 0
256
+
257
+ while loss.item() / loss_scale > loss_threshold and iteration < max_iter and index < max_index_step:
258
+ latents = latents.requires_grad_(True)
259
+
260
+ # latent_model_input = torch.cat([latents] * 2)
261
+ latent_model_input = latents
262
+
263
+ latent_model_input = noise_scheduler.scale_model_input(latent_model_input, t)
264
+ noise_pred, attn_map_integrated_up, attn_map_integrated_mid, attn_map_integrated_down, _, _, _ = \
265
+ unet(latent_model_input, t, index, encoder_hidden_states=text_encoder(input_ids)[0], attn_map=attn_map,
266
+ cfg=cfg)
267
+
268
+ # update latents with guidence from gaussian blob
269
+
270
+ loss = compute_ca_loss(attn_map_integrated_mid, attn_map_integrated_up, bboxes=bboxes,
271
+ object_positions=object_positions) * loss_scale
272
+
273
+ print(loss.item() / loss_scale)
274
+
275
+ grad_cond = torch.autograd.grad(loss.requires_grad_(True), [latents])[0]
276
+
277
+ latents = latents - grad_cond * noise_scheduler.sigmas[index] ** 2
278
+ iteration += 1
279
+ torch.cuda.empty_cache()
280
+ torch.cuda.empty_cache()
281
+
282
+
283
+ with torch.no_grad():
284
+
285
+ latent_model_input = torch.cat([latents] * 2)
286
+
287
+ latent_model_input = noise_scheduler.scale_model_input(latent_model_input, t)
288
+ noise_pred, attn_map_integrated_up, attn_map_integrated_mid, attn_map_integrated_down, _, _, _ = \
289
+ unet(latent_model_input, t, index, encoder_hidden_states=text_embeddings, attn_map=attn_map, cfg=cfg)
290
+
291
+ noise_pred = noise_pred.sample
292
+
293
+ # perform guidance
294
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
295
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
296
+
297
+ latents = noise_scheduler.step(noise_pred, t, latents).prev_sample
298
+ torch.cuda.empty_cache()
299
+
300
+ with torch.no_grad():
301
+ print("decode image")
302
+ latents = 1 / 0.18215 * latents
303
+ image = vae.decode(latents).sample
304
+ image = (image / 2 + 0.5).clamp(0, 1)
305
+ image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
306
+ images = (image * 255).round().astype("uint8")
307
+ pil_images = [Image.fromarray(image) for image in images]
308
+ return pil_images
309
 
310
  def get_concat(ims):
311
  if len(ims) == 1:
 
328
  return language_instruction
329
 
330
 
331
+ def generate(language_instruction, grounding_texts, sketch_pad,
332
+ loss_threshold, guidance_scale, batch_size, rand_seed, max_step, loss_scale, max_iter,
 
333
  state):
334
  if 'boxes' not in state:
335
  state['boxes'] = []
 
336
  boxes = state['boxes']
337
  grounding_texts = [x.strip() for x in grounding_texts.split(';')]
338
  # assert len(boxes) == len(grounding_texts)
 
344
  grounding_texts = grounding_texts + [""] * (len(boxes) - len(grounding_texts))
345
 
346
  boxes = (np.asarray(boxes) / 512).tolist()
347
+ boxes = [[box] for box in boxes]
348
  grounding_instruction = json.dumps({obj: box for obj, box in zip(grounding_texts, boxes)})
349
+ language_instruction_list = language_instruction.strip('.').split(' ')
350
+ object_positions = []
351
+ for obj in grounding_texts:
352
+ obj_position = []
353
+ for word in obj.split(' '):
354
+ obj_first_index = language_instruction_list.index(word) + 1
355
+ obj_position.append(obj_first_index)
356
+ object_positions.append(obj_position)
357
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
358
+
359
+ gen_images = inference(device, unet, vae, tokenizer, text_encoder, language_instruction, cfg, attn_map, boxes, object_positions, batch_size, loss_scale, loss_threshold, max_iter, max_step, rand_seed, guidance_scale)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
360
 
361
  blank_samples = batch_size % 2 if batch_size > 1 else 0
362
  gen_images = [gr.Image.update(value=x, visible=True) for i, x in enumerate(gen_images)] \
 
405
  return np.array(img)
406
 
407
 
408
+ def draw(input, grounding_texts, new_image_trigger, state):
409
+
410
  if type(input) == dict:
411
  image = input['image']
412
  mask = input['mask']
413
  else:
414
  mask = input
 
415
  if mask.ndim == 3:
416
  mask = mask[..., 0]
417
 
418
  image_scale = 1.0
419
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
420
  mask = binarize(mask)
421
  if mask.shape != (512, 512):
422
  # assert False, "should not receive any non- 512x512 masks."
 
431
  if type(mask) != np.ndarray:
432
  mask = np.array(mask)
433
 
434
+ if mask.sum() == 0:
435
  state = {}
436
 
437
+ image = None
 
 
 
438
 
439
  if 'boxes' not in state:
440
  state['boxes'] = []
 
472
  box_image_resize = np.array(box_image.resize((inpaint_hw, inpaint_hw)))
473
  original_image = state['original_image'].copy()
474
  box_image = sized_center_fill(original_image, box_image_resize, inpaint_hw, inpaint_hw)
 
475
  return [box_image, new_image_trigger, image_scale, state]
476
 
477
 
 
501
  cursor: pointer;
502
  text-decoration: none;
503
  }
504
+
505
+ .tooltip {
506
+ color: #555;
507
+ position: relative;
508
+ display: inline-block;
509
+ cursor: pointer;
510
+ }
511
+
512
+ .tooltip .tooltiptext {
513
+ visibility: hidden;
514
+ width: 400px;
515
+ background-color: #555;
516
+ color: #fff;
517
+ text-align: center;
518
+ padding: 5px;
519
+ border-radius: 5px;
520
+ position: absolute;
521
+ z-index: 1; /* Set z-index to 1 */
522
+ left: 10px;
523
+ top: 100%;
524
+ opacity: 0;
525
+ transition: opacity 0.3s;
526
+ }
527
+
528
+ .tooltip:hover .tooltiptext {
529
+ visibility: visible;
530
+ opacity: 1;
531
+ z-index: 9999; /* Set a high z-index value when hovering */
532
+ }
533
+
534
+
535
  """
536
 
537
  rescale_js = """
 
550
  with Blocks(
551
  css=css,
552
  analytics_enabled=False,
553
+ title="Layout-Guidance demo",
554
  ) as main:
555
  description = """<p style="text-align: center; font-weight: bold;">
556
  <span style="font-size: 28px">Layout Guidance</span>
557
  <br>
558
  <span style="font-size: 18px" id="paper-info">
559
+ [<a href=" " target="_blank">Project Page</a>]
560
+ [<a href=" " target="_blank">Paper</a>]
561
+ [<a href=" " target="_blank">GitHub</a>]
 
562
  </span>
563
  </p>
564
  """
565
  gr.HTML(description)
566
+ with gr.Column():
567
+ language_instruction = gr.Textbox(
568
+ label="Text Prompt",
569
+ )
570
+ grounding_instruction = gr.Textbox(
571
+ label="Grounding instruction (Separated by semicolon)",
572
+ )
573
+ sketch_pad_trigger = gr.Number(value=0, visible=False)
574
+ sketch_pad_resize_trigger = gr.Number(value=0, visible=False)
575
+ init_white_trigger = gr.Number(value=0, visible=False)
576
+ image_scale = gr.Number(value=0, elem_id="image_scale", visible=False)
577
+ new_image_trigger = gr.Number(value=0, visible=False)
578
+
579
+
580
+
581
+ with gr.Row():
582
+ sketch_pad = ImageMask(label="Sketch Pad", elem_id="img2img_image")
583
+ out_imagebox = gr.Image(type="pil", label="Parsed Sketch Pad")
584
+ out_gen_1 = gr.Image(type="pil", visible=True, label="Generated Image")
585
+ # out_gen_2 = gr.Image(type="pil", visible=True, label="Generated Image")
586
+ # out_gen_3 = gr.Image(type="pil", visible=True, show_label=False)
587
+ # out_gen_4 = gr.Image(type="pil", visible=True, show_label=False)
588
+
589
+ with gr.Row():
590
+ clear_btn = gr.Button(value='Clear')
591
+ gen_btn = gr.Button(value='Generate')
592
+ # clear_btn = gr.Button(value='Clear')
593
+ # clear_btn = gr.Button(value='Clear')
594
+
595
+ with gr.Accordion("Advanced Options", open=False):
596
+ with gr.Column():
597
+ description = """<div class="tooltip">Loss Scale Factor &#9432
598
+ <span class="tooltiptext">The scale factor of the backward guidance loss. The larger it is, the better control we get while it sometimes losses fidelity. </span>
599
+ </div>
600
+ <div class="tooltip">Guidance Scale &#9432
601
+ <span class="tooltiptext">The scale factor of classifier-free guidance. </span>
602
+ </div>
603
+ <div class="tooltip" >Max Iteration per Step &#9432
604
+ <span class="tooltiptext">The max iterations of backward guidance in each diffusion inference process.</span>
605
+ </div>
606
+ <div class="tooltip" >Loss Threshold &#9432
607
+ <span class="tooltiptext">The threshold of loss. If the loss computed by cross-attention map is smaller then the threshold, the backward guidance is stopped. </span>
608
+ </div>
609
+ <div class="tooltip" >Max Step of Backward Guidance &#9432
610
+ <span class="tooltiptext">The max steps of backward guidance in diffusion inference process.</span>
611
+ </div>
612
+ """
613
+ gr.HTML(description)
614
+ Loss_scale = gr.Slider(minimum=0, maximum=500, step=5, value=30,label="Loss Scale Factor")
615
+ guidance_scale = gr.Slider(minimum=0, maximum=50, step=0.5, value=7.5, label="Guidance Scale")
616
+ batch_size = gr.Slider(minimum=1, maximum=4, step=1, value=1, label="Number of Samples", visible=False)
617
+ max_iter = gr.Slider(minimum=0, maximum=10, step=1, value=5, label="Max Iteration per Step")
618
+ loss_threshold = gr.Slider(minimum=0, maximum=1, step=0.1, value=0.2, label="Loss Threshold")
619
+ max_step = gr.Slider(minimum=0, maximum=50, step=1, value=10, label="Max Step of Backward Guidance")
620
+ # fix_seed = gr.Checkbox(value=True, label="Fixed seed")
621
+ rand_seed = gr.Slider(minimum=0, maximum=1000, step=1, value=445, label="Random Seed")
622
+
623
+ # with gr.Column(scale=4):
624
+ # gr.HTML('<span style="font-size: 20px; font-weight: bold">Generated Images</span>')
625
+ # with gr.Row():
626
+ # out_gen_1 = gr.Image(type="pil", visible=True, show_label=False, label="Generated Image")
627
+ # out_gen_2 = gr.Image(type="pil", visible=True, show_label=False)
628
+ # with gr.Row():
629
+ # out_gen_3 = gr.Image(type="pil", visible=False, show_label=False)
630
+ # out_gen_4 = gr.Image(type="pil", visible=False, show_label=False)
631
 
632
  state = gr.State({})
633
 
 
685
  queue=False)
686
  sketch_pad.edit(
687
  draw,
688
+ inputs=[sketch_pad, grounding_instruction, sketch_pad_resize_trigger, state],
689
  outputs=[out_imagebox, sketch_pad_resize_trigger, image_scale, state],
690
  queue=False,
691
  )
692
  grounding_instruction.change(
693
  draw,
694
+ inputs=[sketch_pad, grounding_instruction, sketch_pad_resize_trigger, state],
695
  outputs=[out_imagebox, sketch_pad_resize_trigger, image_scale, state],
696
  queue=False,
697
  )
698
  clear_btn.click(
699
  clear,
700
  inputs=[sketch_pad_trigger, sketch_pad_trigger, batch_size, state],
701
+ outputs=[sketch_pad, sketch_pad_trigger, out_imagebox, image_scale, out_gen_1, state],
 
702
  queue=False)
703
+
 
 
 
 
 
704
  sketch_pad_trigger.change(
705
  controller.init_white,
706
  inputs=[init_white_trigger],
 
711
  inputs=[state],
712
  outputs=[sketch_pad, state],
713
  queue=False)
714
+ # batch_size.change(
715
+ # controller.change_n_samples,
716
+ # inputs=[batch_size],
717
+ # outputs=[out_gen_1, out_gen_2],
718
+ # queue=False)
719
 
720
+ # batch_size.change(
721
+ # controller.change_n_samples,
722
+ # inputs=[batch_size],
723
+ # outputs=[out_gen_1, out_gen_2],
724
+ # queue=False)
725
 
726
  gen_btn.click(
727
  generate,
728
  inputs=[
729
+ language_instruction, grounding_instruction, sketch_pad,
730
+ loss_threshold, guidance_scale, batch_size, rand_seed,
 
731
  max_step,
732
  Loss_scale, max_iter,
733
  state,
734
  ],
735
+ outputs=[out_gen_1, state],
736
  queue=True
737
  )
738
  sketch_pad_resize_trigger.change(
 
752
  gr.Examples(
753
  examples=[
754
  [
755
+ # "images/input.png",
756
  "A hello kitty toy is playing with a purple ball.",
757
  "hello kitty;ball",
758
  "images/hello_kitty_results.png"
759
  ],
760
  ],
761
+ inputs=[language_instruction, grounding_instruction, out_gen_1],
762
  outputs=None,
763
  fn=None,
764
  cache_examples=False,
 
766
 
767
  main.queue(concurrency_count=1, api_open=False)
768
  main.launch(share=False, show_api=False, show_error=True)
769
+
conf/net_conf.yaml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ training:
2
+ use_ema: True
3
+ batch_size: 8
4
+ adam_beta1: 0.9
5
+ adam_beta2: 0.999
6
+ adam_weight_decay: 1e-2
7
+ adam_epsilon: 1e-08
8
+ lr_scheduler: constant
9
+ lr_warmup_steps: 0
10
+ max_train_steps: 50000
11
+ text_finetune_step: 50
12
+ unet_finetune_step: 50
13
+ alpha: 0.1
14
+ min_lr: 1e-6
15
+ warmup_epochs: 0
16
+ num_train_epochs: 300
17
+ accumulate_step: 1
18
+ lr: 1e-6
19
+ resume: ' '
20
+ down_attn_shift: -1
21
+ down_attn_map: -1
22
+ mid_attn_shift: -1
23
+ mid_attn_map: -1
24
+ up_attn_shift: -1
25
+ up_attn_map: -1
26
+
27
+ inference:
28
+ loss_scale: 30
29
+ batch_size: 1
30
+ loss_threshold: 0.2
31
+ max_iter: 5
32
+ index_step: 10
33
+ start_pair: 800
34
+ iteration_interval: 400
35
+ infer_iter: 0
conf/unet/config.json ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "UNet2DConditionModel",
3
+ "_diffusers_version": "0.6.0",
4
+ "act_fn": "silu",
5
+ "attention_head_dim": 8,
6
+ "block_out_channels": [
7
+ 320,
8
+ 640,
9
+ 1280,
10
+ 1280
11
+ ],
12
+ "center_input_sample": false,
13
+ "cross_attention_dim": 768,
14
+ "down_block_types": [
15
+ "CrossAttnDownBlock2D",
16
+ "CrossAttnDownBlock2D",
17
+ "CrossAttnDownBlock2D",
18
+ "DownBlock2D"
19
+ ],
20
+ "downsample_padding": 1,
21
+ "flip_sin_to_cos": true,
22
+ "freq_shift": 0,
23
+ "in_channels": 4,
24
+ "layers_per_block": 2,
25
+ "mid_block_scale_factor": 1,
26
+ "norm_eps": 1e-05,
27
+ "norm_num_groups": 32,
28
+ "out_channels": 4,
29
+ "sample_size": 64,
30
+ "up_block_types": [
31
+ "UpBlock2D",
32
+ "CrossAttnUpBlock2D",
33
+ "CrossAttnUpBlock2D",
34
+ "CrossAttnUpBlock2D"
35
+ ]
36
+ }
images/.DS_Store ADDED
Binary file (6.15 kB). View file
 
layout_guidance/__init__.py ADDED
File without changes
layout_guidance/inference.py ADDED
@@ -0,0 +1,488 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # !pip install diffusers["torch"] transformers
2
+ import hydra
3
+ import torch
4
+ import yaml
5
+ from diffusers import StableDiffusionPipeline
6
+ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
7
+ from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel
8
+ import torch.nn.functional as F
9
+ from PIL import Image, ImageDraw, ImageFont
10
+ import matplotlib.pyplot as plt
11
+ import torch.nn as nn
12
+ import time
13
+ from accelerate import Accelerator
14
+ import torchvision.transforms as transforms
15
+ from torch.utils.tensorboard import SummaryWriter
16
+ from omegaconf import DictConfig, OmegaConf
17
+ from datetime import datetime
18
+ import logging
19
+ import itertools
20
+ from torch.utils.data import DataLoader
21
+ from tqdm import tqdm
22
+ from diffusers import LMSDiscreteScheduler
23
+ from diffusers.optimization import get_scheduler
24
+ from torch import autocast
25
+ from torch.cuda.amp import GradScaler
26
+ import pdb
27
+ import math
28
+ from my_model import unet_2d_condition
29
+ from typing import Iterable, Optional
30
+ import os
31
+ import json
32
+ import numpy as np
33
+ import scipy
34
+
35
+ def freeze_params(params):
36
+ for param in params:
37
+ param.requires_grad = False
38
+ def unfreeze_params(params):
39
+ for param in params:
40
+ param.requires_grad = True
41
+
42
+
43
+ class EMAModel:
44
+ """
45
+ Exponential Moving Average of models weights
46
+ """
47
+
48
+ def __init__(self, parameters: Iterable[torch.nn.Parameter], decay=0.9999):
49
+ parameters = list(parameters)
50
+ print("list parameters")
51
+ self.shadow_params = [p.clone().detach() for p in parameters]
52
+ print("finish clone parameters")
53
+
54
+ self.decay = decay
55
+ self.optimization_step = 0
56
+
57
+ def get_decay(self, optimization_step):
58
+ """
59
+ Compute the decay factor for the exponential moving average.
60
+ """
61
+ value = (1 + optimization_step) / (10 + optimization_step)
62
+ return 1 - min(self.decay, value)
63
+
64
+ @torch.no_grad()
65
+ def step(self, parameters):
66
+ parameters = list(parameters)
67
+
68
+ self.optimization_step += 1
69
+ self.decay = self.get_decay(self.optimization_step)
70
+
71
+ for s_param, param in zip(self.shadow_params, parameters):
72
+ if param.requires_grad:
73
+ tmp = self.decay * (s_param - param)
74
+ s_param.sub_(tmp)
75
+ else:
76
+ s_param.copy_(param)
77
+
78
+ torch.cuda.empty_cache()
79
+
80
+ def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None:
81
+ """
82
+ Copy current averaged parameters into given collection of parameters.
83
+ Args:
84
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
85
+ updated with the stored moving averages. If `None`, the
86
+ parameters with which this `ExponentialMovingAverage` was
87
+ initialized will be used.
88
+ """
89
+ parameters = list(parameters)
90
+ for s_param, param in zip(self.shadow_params, parameters):
91
+ param.data.copy_(s_param.data)
92
+
93
+ def to(self, device=None, dtype=None) -> None:
94
+ r"""c"""
95
+ # .to() on the tensors handles None correctly
96
+ self.shadow_params = [
97
+ p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device)
98
+ for p in self.shadow_params
99
+ ]
100
+
101
+ def compute_visor_loss(attn_maps_mid, attn_maps_up, obj_a_positions, obj_b_positions, relationship):
102
+ loss = 0
103
+ for attn_map_integrated in attn_maps_mid:
104
+ attn_map = attn_map_integrated.chunk(2)[1]
105
+
106
+ #
107
+ b, i, j = attn_map.shape
108
+ H = W = int(math.sqrt(i))
109
+ weight_matrix_x = torch.zeros(size=(H, W)).cuda()
110
+ weight_matrix_y = torch.zeros(size=(H, W)).cuda()
111
+ for x_indx in range(W):
112
+ weight_matrix_x[:, x_indx] = x_indx
113
+ for y_indx in range(H):
114
+ weight_matrix_y[y_indx, :] = y_indx
115
+
116
+ # for obj_idx in range(object_number):
117
+ #
118
+ # bbox = bboxes[obj_idx]
119
+ obj_a_avg_x_total = 0
120
+ obj_a_avg_y_total = 0
121
+ for obj_a_position in obj_a_positions:
122
+ ca_map_obj = attn_map[:, :, obj_a_position].reshape(b, H, W)
123
+ # pdb.set_trace()
124
+
125
+ obj_a_avg_x = (ca_map_obj * weight_matrix_x.unsqueeze(0)).reshape(b, -1).sum(-1)/ca_map_obj.reshape(b,-1).sum(-1)
126
+ obj_a_avg_y = (ca_map_obj * weight_matrix_y.unsqueeze(0)).reshape(b, -1).sum(-1)/ca_map_obj.reshape(b,-1).sum(-1)
127
+ obj_a_avg_x_total += obj_a_avg_x
128
+ obj_a_avg_y_total += obj_a_avg_y
129
+ obj_a_avg_x_total = (obj_a_avg_x_total/len(obj_a_positions)).mean() / W
130
+ obj_a_avg_y_total = (obj_a_avg_y_total/len(obj_a_positions)).mean() / H
131
+ print('mid: obj_a_avg_x_total', obj_a_avg_x_total)
132
+
133
+ obj_b_avg_x_total = 0
134
+ obj_b_avg_y_total = 0
135
+ for obj_b_position in obj_b_positions:
136
+ ca_map_obj = attn_map[:, :, obj_b_position].reshape(b, H, W)
137
+ obj_b_avg_x = (ca_map_obj * weight_matrix_x.unsqueeze(0)).reshape(b, -1).sum(-1)/ca_map_obj.reshape(b,-1).sum(-1)
138
+ obj_b_avg_y = (ca_map_obj * weight_matrix_y.unsqueeze(0)).reshape(b, -1).sum(-1)/ca_map_obj.reshape(b,-1).sum(-1)
139
+ obj_b_avg_x_total += obj_b_avg_x
140
+ obj_b_avg_y_total += obj_b_avg_y
141
+
142
+ obj_b_avg_x_total = (obj_b_avg_x_total/len(obj_b_positions)).mean() / W
143
+ obj_b_avg_y_total = (obj_b_avg_y_total/len(obj_b_positions)).mean() / H
144
+ print('mid: obj_b_avg_x_total', obj_b_avg_x_total)
145
+
146
+ if relationship == 0:
147
+ loss += (obj_b_avg_x_total - obj_a_avg_x_total)
148
+ elif relationship == 1:
149
+ loss += (obj_a_avg_x_total - obj_b_avg_x_total)
150
+ elif relationship == 2:
151
+ loss += (obj_b_avg_y_total - obj_a_avg_y_total)
152
+ elif relationship == 3:
153
+ loss += (obj_a_avg_y_total - obj_b_avg_y_total)
154
+
155
+
156
+ for attn_map_integrated in attn_maps_up[0]:
157
+ attn_map = attn_map_integrated.chunk(2)[1]
158
+
159
+ b, i, j = attn_map.shape
160
+ H = W = int(math.sqrt(i))
161
+ weight_matrix_x = torch.zeros(size=(H, W)).cuda()
162
+ weight_matrix_y = torch.zeros(size=(H, W)).cuda()
163
+ for x_indx in range(W):
164
+ weight_matrix_x[:, x_indx] = x_indx
165
+ for y_indx in range(H):
166
+ weight_matrix_y[y_indx, :] = y_indx
167
+
168
+ # for obj_idx in range(object_number):
169
+ #
170
+ # bbox = bboxes[obj_idx]
171
+ obj_a_avg_x_total = 0
172
+ obj_a_avg_y_total = 0
173
+ for obj_a_position in obj_a_positions:
174
+ ca_map_obj = attn_map[:, :, obj_a_position].reshape(b, H, W)
175
+ obj_a_avg_x = (ca_map_obj * weight_matrix_x.unsqueeze(0)).reshape(b, -1).sum(-1) / ca_map_obj.reshape(b, -1).sum(-1)
176
+ obj_a_avg_y = (ca_map_obj * weight_matrix_y.unsqueeze(0)).reshape(b, -1).sum(-1) / ca_map_obj.reshape(b, -1).sum(-1)
177
+ obj_a_avg_x_total += obj_a_avg_x
178
+ obj_a_avg_y_total += obj_a_avg_y
179
+ obj_a_avg_x_total = (obj_a_avg_x_total / len(obj_a_positions)).mean() / W
180
+ obj_a_avg_y_total = (obj_a_avg_y_total / len(obj_a_positions)).mean() / H
181
+ print('up: obj_a_avg_x_total', obj_a_avg_x_total)
182
+
183
+ obj_b_avg_x_total = 0
184
+ obj_b_avg_y_total = 0
185
+
186
+ for obj_b_position in obj_b_positions:
187
+ ca_map_obj = attn_map[:, :, obj_b_position].reshape(b, H, W)
188
+ obj_b_avg_x = (ca_map_obj * weight_matrix_x.unsqueeze(0)).reshape(b, -1).sum(-1) / ca_map_obj.reshape(b, -1).sum(-1)
189
+ obj_b_avg_y = (ca_map_obj * weight_matrix_y.unsqueeze(0)).reshape(b, -1).sum(-1) / ca_map_obj.reshape(b, -1).sum(-1)
190
+ obj_b_avg_x_total += obj_b_avg_x
191
+ obj_b_avg_y_total += obj_b_avg_y
192
+
193
+ obj_b_avg_x_total = (obj_b_avg_x_total / len(obj_b_positions)).mean() / W
194
+ obj_b_avg_y_total = (obj_b_avg_y_total / len(obj_b_positions)).mean() / H
195
+ print('up: obj_b_avg_x_total', obj_b_avg_x_total)
196
+
197
+ if relationship == 0:
198
+ loss += (obj_a_avg_x_total - obj_b_avg_x_total)
199
+ elif relationship == 1:
200
+ loss += (obj_b_avg_x_total - obj_a_avg_x_total)
201
+ elif relationship == 2:
202
+ loss += (obj_a_avg_y_total - obj_b_avg_y_total)
203
+ elif relationship == 3:
204
+ loss += (obj_b_avg_y_total - obj_a_avg_y_total)
205
+
206
+
207
+ loss = loss / (len(attn_maps_up[0]) + len(attn_maps_mid))
208
+ return loss
209
+
210
+
211
+ @hydra.main(version_base=None, config_path="conf", config_name="config_visor_box")
212
+ def train(cfg: DictConfig):
213
+ # fix the randomness of torch
214
+
215
+ print(cfg)
216
+ with open('./conf/unet/origin_config.json') as f:
217
+ unet_config = json.load(f)
218
+ unet = unet_2d_condition.UNet2DConditionModel(**unet_config)
219
+ # ckp = torch.load('/Users/shil5883/Downloads/diffusion_pytorch_model.bin', map_location='cpu')
220
+ # prev_attn_map = torch.load('./attn_map.ckp', map_location='cpu')
221
+
222
+ ckp = torch.load('/work/minghao/chess_gen/diffusion_pytorch_model.bin', map_location='cpu')
223
+ prev_attn_map = torch.load('/work/minghao/chess_gen/visual_attn/2023-02-02/15-05-51/epoch_100_sche_constant_lr_1e-06_ac_1/attn_map.ckp', map_location='cpu')
224
+
225
+ # prev_attn_map = torch.load('/work/minghao/chess_gen/visual_attn/2023-01-16/18-58-12/epoch_100_sche_constant_lr_1e-06_ac_1/attn_map.ckp', map_location='cpu')
226
+ unet.load_state_dict(ckp)
227
+ unet_original = UNet2DConditionModel(**unet_config)
228
+ unet_original.load_state_dict(ckp)
229
+ date_now, time_now = datetime.now().strftime("%Y-%m-%d,%H-%M-%S").split(',')
230
+
231
+ # cfg.general.save_path = os.path.join(cfg.general.save_path, date_now, time_now)
232
+ # if not os.path.exists(cfg.general.save_path ):
233
+ # os.makedirs(cfg.general.save_path)
234
+ # cfg.general.save_path
235
+ mixed_precision = 'fp16' if torch.cuda.is_available() else 'no'
236
+ accelerator = Accelerator(
237
+ gradient_accumulation_steps=cfg.training.accumulate_step,
238
+ mixed_precision=mixed_precision,
239
+ log_with="tensorboard",
240
+ logging_dir='./',
241
+ )
242
+ # initialize dataset and dataloader
243
+ if accelerator.is_main_process:
244
+ print("Loading the dataset!!!!!")
245
+ tokenizer = CLIPTokenizer.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="tokenizer")
246
+ # train_dataset = ICLEVERDataset(cfg.data.data_path, tokenizer, cfg, prefix='train')
247
+ # val_dataset = ICLEVERDataset(cfg.data.data_path, tokenizer, cfg, prefix='val')
248
+ # train_loader = DataLoader(train_dataset, batch_size=cfg.training.batch_size, shuffle=True, num_workers=2, pin_memory=False)
249
+ # val_loader = DataLoader(val_dataset, batch_size=cfg.training.batch_size * 2, shuffle=True, num_workers=2, pin_memory=False)
250
+
251
+ if accelerator.is_main_process:
252
+ print("Complete loading the dataset!!!!!")
253
+
254
+ if accelerator.is_main_process:
255
+ print("Complete load the noise scheduler!!!!!")
256
+ with open("config.yaml", "w") as f:
257
+ OmegaConf.save(cfg, f)
258
+ if not os.path.exists(cfg.general.save_path) and accelerator.is_main_process:
259
+ os.makedirs(cfg.general.save_path)
260
+ if accelerator.is_main_process:
261
+ print("saved load the noise scheduler!!!!!")
262
+
263
+ # Move unet to device
264
+ device = "cuda" if torch.cuda.is_available() else "cpu"
265
+ # load pretrained models and schedular
266
+ text_encoder = CLIPTextModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="text_encoder")
267
+ vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae")
268
+
269
+ # boards_embedder.to(device)
270
+ if accelerator.is_main_process:
271
+ print("move the model to device!!!!!")
272
+ # Keep vae and unet in eval model as we don't train these
273
+
274
+ # Initialize the optimizer
275
+ cfg.training.lr = (
276
+ cfg.training.lr * cfg.training.accumulate_step * cfg.training.batch_size * accelerator.num_processes
277
+ )
278
+ # Move vae and unet to device
279
+ vae.to(device)
280
+ unet.to(device)
281
+ text_encoder.to(device)
282
+ # prev_attn_map.to(device)
283
+ unet_original.to(device)
284
+ vae.eval()
285
+ unet.eval()
286
+ text_encoder.eval()
287
+ unet_original.eval()
288
+ # tokenizer.to(device)
289
+ # if accelerator.is_main_process:
290
+ print("prepare the accelerator module at process: {}!!!!!".format(accelerator.process_index))
291
+ # unet = accelerator.prepare(unet)
292
+
293
+ print("done the accelerator module at process: {}!!!!!".format(accelerator.process_index))
294
+
295
+ # Create EMA for the unet.
296
+ # if cfg.training.use_ema:
297
+ # ema_unet = EMAModel(unet.parameters())
298
+ # ema_encoder = EMAModel(boards_embedder.parameters())
299
+ ema_unet = None
300
+ # print(start_ema)
301
+ if cfg.training.use_ema:
302
+ if accelerator.is_main_process:
303
+ print("Using the EMA model!!!!!")
304
+ print("start EMA at process: {}!!!!!".format(accelerator.process_index))
305
+
306
+ ema_unet = EMAModel(unet.parameters())
307
+ # ema_encoder = EMAModel(boards_embedder.parameters())
308
+
309
+ # prompt = 'A traffic light below a sink'
310
+ templates = ['{} to the left of {}', '{} to the right of {}', '{} above {}', '{} below {}']
311
+ bboxes_template = [[0.0, 0.0, 0.5, 1.0], [0.0, 0.0, 1.0, 0.5], [0.5, 0.0, 1.0, 1.0], [0.0, 0.5, 1.0, 1.0]]
312
+ bboxes_template_list = [[0, 2], [2, 0], [1, 3], [3, 1]]
313
+ iteration_start = cfg.inference.start_pair
314
+ iteration_now = iteration_start
315
+ iteration_interval = cfg.inference.iteration_interval
316
+ with open('./coco_paris.txt', 'r') as f:
317
+ image_pairs = f.readlines()
318
+ for image_pair in tqdm(image_pairs[iteration_start: iteration_start + iteration_interval]):
319
+ obj_a, obj_b = image_pair.strip().split(',')[0], image_pair.strip().split(',')[1]
320
+ obj_a = 'A {}'.format(obj_a) if obj_a[0] not in ['a', 'e', 'i', 'o', 'u'] else 'An {}'.format(obj_a)
321
+ obj_b = 'a {}'.format(obj_b) if obj_b[0] not in ['a', 'e', 'i', 'o', 'u'] else 'an {}'.format(obj_b)
322
+ for idx, template in enumerate(templates):
323
+ prompt = template.format(obj_a, obj_b)
324
+ obj_a_len = len(obj_a.split(' ')) - 1
325
+ obj_a_position = [2] if obj_a_len == 1 else [2, 3]
326
+ obj_b_position = [obj_a_len + 1 + len(template.split(' ')) + i for i in range(len(obj_b.split(' '))-1)]
327
+ obj_positions = [obj_a_position, obj_b_position]
328
+ obj_a_boxes = [bboxes_template[bboxes_template_list[idx][0]].copy() for _ in range(len(obj_a.split(' ')) - 1)]
329
+ obj_b_boxes = [bboxes_template[bboxes_template_list[idx][1]].copy() for _ in range(len(obj_b.split(' ')) - 1)]
330
+ obj_boxes = [obj_a_boxes, obj_b_boxes]
331
+ print(prompt, obj_positions, obj_boxes)
332
+ # for infer_iter in range(1):
333
+ inference(device, unet, unet_original, vae, tokenizer, text_encoder, prompt, cfg, prev_attn_map, bboxes=obj_boxes, object_positions=obj_positions, infer_iter=cfg.inference.infer_iter, pair_id=iteration_now)
334
+
335
+
336
+ obj_b, obj_a = image_pair.strip().split(',')[0], image_pair.strip().split(',')[1]
337
+ obj_a = 'A {}'.format(obj_a) if obj_a[0] not in ['a', 'e', 'i', 'o', 'u'] else 'An {}'.format(obj_a)
338
+ obj_b = 'a {}'.format(obj_b) if obj_b[0] not in ['a', 'e', 'i', 'o', 'u'] else 'an {}'.format(obj_b)
339
+ for idx, template in enumerate(templates):
340
+ prompt = template.format(obj_a, obj_b)
341
+ obj_a_len = len(obj_a.split(' ')) - 1
342
+ obj_a_position = [2] if obj_a_len == 1 else [2, 3]
343
+ obj_b_position = [obj_a_len + 1 + len(template.split(' ')) + i for i in range(len(obj_b.split(' '))-1)]
344
+ obj_positions = [obj_a_position, obj_b_position]
345
+ obj_a_boxes = [bboxes_template[bboxes_template_list[idx][0]].copy() for _ in range(len(obj_a.split(' ')) - 1)]
346
+ obj_b_boxes = [bboxes_template[bboxes_template_list[idx][1]].copy() for _ in range(len(obj_b.split(' ')) - 1)]
347
+ obj_boxes = [obj_a_boxes, obj_b_boxes]
348
+ print(prompt, obj_positions, obj_boxes)
349
+ inference(device, unet, unet_original, vae, tokenizer, text_encoder, prompt, cfg, prev_attn_map, bboxes=obj_boxes, object_positions=obj_positions, infer_iter=cfg.inference.infer_iter, pair_id=iteration_now)
350
+ iteration_now += 1
351
+ def compute_ca_loss(attn_maps_mid, attn_maps_up, bboxes, object_positions):
352
+ loss = 0
353
+ object_number = len(bboxes)
354
+ if object_number == 0:
355
+ return torch.tensor(0).float().cuda()
356
+ for attn_map_integrated in attn_maps_mid:
357
+ attn_map = attn_map_integrated.chunk(2)[1]
358
+
359
+ #
360
+ b, i, j = attn_map.shape
361
+ H = W = int(math.sqrt(i))
362
+ # pdb.set_trace()
363
+ for obj_idx in range(object_number):
364
+ obj_loss = 0
365
+ mask = torch.zeros(size=(H, W)).cuda()
366
+ for obj_box in bboxes[obj_idx]:
367
+
368
+ x_min, y_min, x_max, y_max = int(obj_box[0] * W), \
369
+ int(obj_box[1] * H), int(obj_box[2] * W), int(obj_box[3] * H)
370
+ mask[y_min: y_max, x_min: x_max] = 1
371
+
372
+ for obj_position in object_positions[obj_idx]:
373
+ ca_map_obj = attn_map[:, :, obj_position].reshape(b, H, W)
374
+ # ca_map_obj = attn_map[:, :, object_positions[obj_position]].reshape(b, H, W)
375
+
376
+ activation_value = (ca_map_obj * mask).reshape(b, -1).sum(dim=-1)/ca_map_obj.reshape(b, -1).sum(dim=-1)
377
+
378
+ obj_loss += torch.mean((1 - activation_value) ** 2)
379
+ loss += (obj_loss/len(object_positions[obj_idx]))
380
+ # print("??", obj_idx, obj_loss/len(object_positions[obj_idx]))
381
+
382
+ # compute loss on padding tokens
383
+ # activation_value = torch.zeros(size=(b, )).cuda()
384
+ # for obj_idx in range(object_number):
385
+ # bbox = bboxes[obj_idx]
386
+ # ca_map_obj = attn_map[:, :, padding_start:].reshape(b, H, W, -1)
387
+ # activation_value += ca_map_obj[:, int(bbox[0] * H): int(bbox[1] * H),
388
+ # int(bbox[2] * W): int(bbox[3] * W), :].reshape(b, -1).sum(dim=-1) / ca_map_obj.reshape(b, -1).sum(dim=-1)
389
+ #
390
+ # loss += torch.mean((1 - activation_value) ** 2)
391
+
392
+
393
+ for attn_map_integrated in attn_maps_up[0]:
394
+ attn_map = attn_map_integrated.chunk(2)[1]
395
+ #
396
+ b, i, j = attn_map.shape
397
+ H = W = int(math.sqrt(i))
398
+
399
+ for obj_idx in range(object_number):
400
+ obj_loss = 0
401
+ mask = torch.zeros(size=(H, W)).cuda()
402
+ for obj_box in bboxes[obj_idx]:
403
+ x_min, y_min, x_max, y_max = int(obj_box[0] * W), \
404
+ int(obj_box[1] * H), int(obj_box[2] * W), int(obj_box[3] * H)
405
+ mask[y_min: y_max, x_min: x_max] = 1
406
+
407
+ for obj_position in object_positions[obj_idx]:
408
+ ca_map_obj = attn_map[:, :, obj_position].reshape(b, H, W)
409
+ # ca_map_obj = attn_map[:, :, object_positions[obj_position]].reshape(b, H, W)
410
+
411
+ activation_value = (ca_map_obj * mask).reshape(b, -1).sum(dim=-1) / ca_map_obj.reshape(b, -1).sum(
412
+ dim=-1)
413
+
414
+ obj_loss += torch.mean((1 - activation_value) ** 2)
415
+ loss += (obj_loss / len(object_positions[obj_idx]))
416
+
417
+ # compute loss on padding tokens
418
+ # activation_value = torch.zeros(size=(b, )).cuda()
419
+ # for obj_idx in range(object_number):
420
+ # bbox = bboxes[obj_idx]
421
+ # ca_map_obj = attn_map[:, :,padding_start:].reshape(b, H, W, -1)
422
+ # activation_value += ca_map_obj[:, int(bbox[0] * H): int(bbox[1] * H),
423
+ # int(bbox[2] * W): int(bbox[3] * W), :].reshape(b, -1).sum(dim=-1) / ca_map_obj.reshape(b, -1).sum(dim=-1)
424
+ #
425
+ # loss += torch.mean((1 - activation_value) ** 2)
426
+ loss = loss / (object_number * (len(attn_maps_up[0]) + len(attn_maps_mid)))
427
+ return loss
428
+ def plt_all_attn_map_in_one(attn_map_integrated_list_down, attn_map_integrated_list_mid, attn_map_integrated_list_up, image, prompt, cfg, t, prefix='all'):
429
+
430
+ prompt_split = prompt.split(' ')
431
+ prompt_len = len(prompt_split) + 4
432
+
433
+ total_layers = len(attn_map_integrated_list_down) + len(attn_map_integrated_list_mid) + len(attn_map_integrated_list_up)
434
+ fig, axs = plt.subplots(nrows=total_layers+1, ncols=prompt_len, figsize=(4 * prompt_len, 4 * total_layers))
435
+ fig.suptitle(prompt, fontsize=32)
436
+ fig.tight_layout()
437
+ cnt = 1
438
+ ax = axs[0][0]
439
+ ax.imshow(image)
440
+ for prompt_idx in range(prompt_len):
441
+ ax = axs[0][prompt_idx]
442
+ ax.set_axis_off()
443
+ for layer, attn_map_integrated in enumerate(attn_map_integrated_list_down):
444
+ attn_map_uncond, attn_map = attn_map_integrated.chunk(2)
445
+ grid_size = int(math.sqrt(attn_map.shape[1]))
446
+ for prompt_idx in range(prompt_len):
447
+ ax = axs[cnt][prompt_idx]
448
+ if prompt_idx == 0:
449
+ ax.set_ylabel('down {}'.format(layer), rotation=0, size='large')
450
+ mask = attn_map.mean(dim=0)[:, prompt_idx].reshape(grid_size, grid_size).detach().cpu().numpy()
451
+ im = ax.imshow(mask, cmap='YlGn')
452
+ ax.set_axis_off()
453
+ cnt += 1
454
+
455
+ for layer, attn_map_integrated in enumerate(attn_map_integrated_list_mid):
456
+ attn_map_uncond, attn_map = attn_map_integrated.chunk(2)
457
+ grid_size = int(math.sqrt(attn_map.shape[1]))
458
+ for prompt_idx in range(prompt_len):
459
+ ax = axs[cnt][prompt_idx]
460
+ if prompt_idx ==0:
461
+ ax.set_ylabel('mid {}'.format(layer), rotation=0, size='large')
462
+ mask = attn_map.mean(dim=0)[:, prompt_idx].reshape(grid_size, grid_size).detach().cpu().numpy()
463
+ im = ax.imshow(mask, cmap='YlGn')
464
+ ax.set_axis_off()
465
+ cnt += 1
466
+
467
+ for layer, attn_map_integrated in enumerate(attn_map_integrated_list_up):
468
+ attn_map_uncond, attn_map = attn_map_integrated.chunk(2)
469
+ grid_size = int(math.sqrt(attn_map.shape[1]))
470
+ for prompt_idx in range(prompt_len):
471
+ ax = axs[cnt][prompt_idx]
472
+ if prompt_idx ==0:
473
+ ax.set_ylabel('up {}'.format(layer), rotation=0, size='large')
474
+ mask = attn_map.mean(dim=0)[:, prompt_idx].reshape(grid_size, grid_size).detach().cpu().numpy()
475
+ im = ax.imshow(mask, cmap='YlGn')
476
+ ax.set_axis_off()
477
+ cnt += 1
478
+
479
+ if not os.path.exists(cfg.general.save_path + "/{}".format(prefix)):
480
+ os.makedirs(cfg.general.save_path + "/{}".format(prefix))
481
+ plt.savefig(cfg.general.save_path + "/{}/step_{}.png".format(prefix, str(int(t)).zfill(4)))
482
+ # generate_video()
483
+ plt.close()
484
+
485
+ if __name__=="__main__":
486
+ train()
487
+
488
+
my_model/__init__.py ADDED
File without changes
my_model/attention.py ADDED
@@ -0,0 +1,929 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import math
15
+ from dataclasses import dataclass
16
+ from typing import Optional
17
+ import numpy as np
18
+ import torch
19
+ import torch.nn.functional as F
20
+ from torch import nn
21
+
22
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
23
+ from diffusers.modeling_utils import ModelMixin
24
+ from diffusers.models.embeddings import ImagePositionalEmbeddings
25
+ from diffusers.utils import BaseOutput
26
+ from diffusers.utils.import_utils import is_xformers_available
27
+ import scipy
28
+
29
+ @dataclass
30
+ class Transformer2DModelOutput(BaseOutput):
31
+ """
32
+ Args:
33
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
34
+ Hidden states conditioned on `encoder_hidden_states` input. If discrete, returns probability distributions
35
+ for the unnoised latent pixels.
36
+ """
37
+
38
+ sample: torch.FloatTensor
39
+
40
+
41
+ if is_xformers_available():
42
+ import xformers
43
+ import xformers.ops
44
+ else:
45
+ xformers = None
46
+
47
+
48
+ class Transformer2DModel(ModelMixin, ConfigMixin):
49
+ """
50
+ Transformer model for image-like data. Takes either discrete (classes of vector embeddings) or continuous (actual
51
+ embeddings) inputs_coarse.
52
+
53
+ When input is continuous: First, project the input (aka embedding) and reshape to b, t, d. Then apply standard
54
+ transformer action. Finally, reshape to image.
55
+
56
+ When input is discrete: First, input (classes of latent pixels) is converted to embeddings and has positional
57
+ embeddings applied, see `ImagePositionalEmbeddings`. Then apply standard transformer action. Finally, predict
58
+ classes of unnoised image.
59
+
60
+ Note that it is assumed one of the input classes is the masked latent pixel. The predicted classes of the unnoised
61
+ image do not contain a prediction for the masked pixel as the unnoised image cannot be masked.
62
+
63
+ Parameters:
64
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
65
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
66
+ in_channels (`int`, *optional*):
67
+ Pass if the input is continuous. The number of channels in the input and output.
68
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
69
+ dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use.
70
+ cross_attention_dim (`int`, *optional*): The number of context dimensions to use.
71
+ sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
72
+ Note that this is fixed at training time as it is used for learning a number of position embeddings. See
73
+ `ImagePositionalEmbeddings`.
74
+ num_vector_embeds (`int`, *optional*):
75
+ Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
76
+ Includes the class for the masked latent pixel.
77
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
78
+ num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
79
+ The number of diffusion steps used during training. Note that this is fixed at training time as it is used
80
+ to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
81
+ up to but not more than steps than `num_embeds_ada_norm`.
82
+ attention_bias (`bool`, *optional*):
83
+ Configure if the TransformerBlocks' attention should contain a bias parameter.
84
+ """
85
+
86
+ @register_to_config
87
+ def __init__(
88
+ self,
89
+ num_attention_heads: int = 16,
90
+ attention_head_dim: int = 88,
91
+ in_channels: Optional[int] = None,
92
+ num_layers: int = 1,
93
+ dropout: float = 0.0,
94
+ norm_num_groups: int = 32,
95
+ cross_attention_dim: Optional[int] = None,
96
+ attention_bias: bool = False,
97
+ sample_size: Optional[int] = None,
98
+ num_vector_embeds: Optional[int] = None,
99
+ activation_fn: str = "geglu",
100
+ num_embeds_ada_norm: Optional[int] = None,
101
+ ):
102
+ super().__init__()
103
+ self.num_attention_heads = num_attention_heads
104
+ self.attention_head_dim = attention_head_dim
105
+ inner_dim = num_attention_heads * attention_head_dim
106
+
107
+ # 1. Transformer2DModel can process both standard continous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
108
+ # Define whether input is continuous or discrete depending on configuration
109
+ self.is_input_continuous = in_channels is not None
110
+ self.is_input_vectorized = num_vector_embeds is not None
111
+
112
+ if self.is_input_continuous and self.is_input_vectorized:
113
+ raise ValueError(
114
+ f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
115
+ " sure that either `in_channels` or `num_vector_embeds` is None."
116
+ )
117
+ elif not self.is_input_continuous and not self.is_input_vectorized:
118
+ raise ValueError(
119
+ f"Has to define either `in_channels`: {in_channels} or `num_vector_embeds`: {num_vector_embeds}. Make"
120
+ " sure that either `in_channels` or `num_vector_embeds` is not None."
121
+ )
122
+
123
+ # 2. Define input layers
124
+ if self.is_input_continuous:
125
+ self.in_channels = in_channels
126
+
127
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
128
+ self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
129
+ elif self.is_input_vectorized:
130
+ assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
131
+ assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
132
+
133
+ self.height = sample_size
134
+ self.width = sample_size
135
+ self.num_vector_embeds = num_vector_embeds
136
+ self.num_latent_pixels = self.height * self.width
137
+
138
+ self.latent_image_embedding = ImagePositionalEmbeddings(
139
+ num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
140
+ )
141
+
142
+ # 3. Define transformers blocks
143
+ self.transformer_blocks = nn.ModuleList(
144
+ [
145
+ BasicTransformerBlock(
146
+ inner_dim,
147
+ num_attention_heads,
148
+ attention_head_dim,
149
+ dropout=dropout,
150
+ cross_attention_dim=cross_attention_dim,
151
+ activation_fn=activation_fn,
152
+ num_embeds_ada_norm=num_embeds_ada_norm,
153
+ attention_bias=attention_bias,
154
+ )
155
+ for d in range(num_layers)
156
+ ]
157
+ )
158
+
159
+ # 4. Define output layers
160
+ if self.is_input_continuous:
161
+ self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
162
+ elif self.is_input_vectorized:
163
+ self.norm_out = nn.LayerNorm(inner_dim)
164
+ self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
165
+
166
+ def _set_attention_slice(self, slice_size):
167
+ for block in self.transformer_blocks:
168
+ block._set_attention_slice(slice_size)
169
+
170
+ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attn_map=None, attn_shift=False, obj_ids=None, relationship=None, return_dict: bool = True):
171
+ """
172
+ Args:
173
+ hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
174
+ When continous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
175
+ hidden_states
176
+ encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, context dim)`, *optional*):
177
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
178
+ self-attention.
179
+ timestep ( `torch.long`, *optional*):
180
+ Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
181
+ return_dict (`bool`, *optional*, defaults to `True`):
182
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
183
+
184
+ Returns:
185
+ [`~models.attention.Transformer2DModelOutput`] or `tuple`: [`~models.attention.Transformer2DModelOutput`]
186
+ if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample
187
+ tensor.
188
+ """
189
+ # 1. Input
190
+ if self.is_input_continuous:
191
+ batch, channel, height, weight = hidden_states.shape
192
+ residual = hidden_states
193
+ hidden_states = self.norm(hidden_states)
194
+ hidden_states = self.proj_in(hidden_states)
195
+ inner_dim = hidden_states.shape[1]
196
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
197
+ elif self.is_input_vectorized:
198
+ hidden_states = self.latent_image_embedding(hidden_states)
199
+
200
+ # 2. Blocks
201
+ for block in self.transformer_blocks:
202
+ hidden_states, cross_attn_prob, save_key = block(hidden_states, context=encoder_hidden_states, timestep=timestep, attn_map=attn_map, attn_shift=attn_shift, obj_ids=obj_ids, relationship=relationship)
203
+
204
+ # 3. Output
205
+ if self.is_input_continuous:
206
+ hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2)
207
+ hidden_states = self.proj_out(hidden_states)
208
+ output = hidden_states + residual
209
+ elif self.is_input_vectorized:
210
+ hidden_states = self.norm_out(hidden_states)
211
+ logits = self.out(hidden_states)
212
+ # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
213
+ logits = logits.permute(0, 2, 1)
214
+
215
+ # log(p(x_0))
216
+ output = F.log_softmax(logits.double(), dim=1).float()
217
+
218
+ if not return_dict:
219
+ return (output,)
220
+
221
+ return Transformer2DModelOutput(sample=output), cross_attn_prob, save_key
222
+
223
+ def _set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
224
+ for block in self.transformer_blocks:
225
+ block._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
226
+
227
+
228
+ class AttentionBlock(nn.Module):
229
+ """
230
+ An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
231
+ to the N-d case.
232
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
233
+ Uses three q, k, v linear layers to compute attention.
234
+
235
+ Parameters:
236
+ channels (`int`): The number of channels in the input and output.
237
+ num_head_channels (`int`, *optional*):
238
+ The number of channels in each head. If None, then `num_heads` = 1.
239
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for group norm.
240
+ rescale_output_factor (`float`, *optional*, defaults to 1.0): The factor to rescale the output by.
241
+ eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm.
242
+ """
243
+
244
+ def __init__(
245
+ self,
246
+ channels: int,
247
+ num_head_channels: Optional[int] = None,
248
+ norm_num_groups: int = 32,
249
+ rescale_output_factor: float = 1.0,
250
+ eps: float = 1e-5,
251
+ ):
252
+ super().__init__()
253
+ self.channels = channels
254
+
255
+ self.num_heads = channels // num_head_channels if num_head_channels is not None else 1
256
+ self.num_head_size = num_head_channels
257
+ self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=norm_num_groups, eps=eps, affine=True)
258
+
259
+ # define q,k,v as linear layers
260
+ self.query = nn.Linear(channels, channels)
261
+ self.key = nn.Linear(channels, channels)
262
+ self.value = nn.Linear(channels, channels)
263
+
264
+ self.rescale_output_factor = rescale_output_factor
265
+ self.proj_attn = nn.Linear(channels, channels, 1)
266
+
267
+ def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor:
268
+ new_projection_shape = projection.size()[:-1] + (self.num_heads, -1)
269
+ # move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D)
270
+ new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3)
271
+ return new_projection
272
+
273
+ def forward(self, hidden_states):
274
+ residual = hidden_states
275
+ batch, channel, height, width = hidden_states.shape
276
+
277
+ # norm
278
+ hidden_states = self.group_norm(hidden_states)
279
+
280
+ hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)
281
+
282
+ # proj to q, k, v
283
+ query_proj = self.query(hidden_states)
284
+ key_proj = self.key(hidden_states)
285
+ value_proj = self.value(hidden_states)
286
+
287
+ # transpose
288
+ query_states = self.transpose_for_scores(query_proj)
289
+ key_states = self.transpose_for_scores(key_proj)
290
+ value_states = self.transpose_for_scores(value_proj)
291
+
292
+ # get scores
293
+ scale = 1 / math.sqrt(math.sqrt(self.channels / self.num_heads))
294
+ attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale) # TODO: use baddmm
295
+ attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype)
296
+
297
+ # compute attention output
298
+ hidden_states = torch.matmul(attention_probs, value_states)
299
+
300
+ hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous()
301
+ new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,)
302
+ hidden_states = hidden_states.view(new_hidden_states_shape)
303
+
304
+ # compute next hidden_states
305
+ hidden_states = self.proj_attn(hidden_states)
306
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
307
+
308
+ # res connect and rescale
309
+ hidden_states = (hidden_states + residual) / self.rescale_output_factor
310
+ return hidden_states
311
+
312
+
313
+ class BasicTransformerBlock(nn.Module):
314
+ r"""
315
+ A basic Transformer block.
316
+
317
+ Parameters:
318
+ dim (`int`): The number of channels in the input and output.
319
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
320
+ attention_head_dim (`int`): The number of channels in each head.
321
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
322
+ cross_attention_dim (`int`, *optional*): The size of the context vector for cross attention.
323
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
324
+ num_embeds_ada_norm (:
325
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
326
+ attention_bias (:
327
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
328
+ """
329
+
330
+ def __init__(
331
+ self,
332
+ dim: int,
333
+ num_attention_heads: int,
334
+ attention_head_dim: int,
335
+ dropout=0.0,
336
+ cross_attention_dim: Optional[int] = None,
337
+ activation_fn: str = "geglu",
338
+ num_embeds_ada_norm: Optional[int] = None,
339
+ attention_bias: bool = False,
340
+ ):
341
+ super().__init__()
342
+ self.attn1 = CrossAttention(
343
+ query_dim=dim,
344
+ heads=num_attention_heads,
345
+ dim_head=attention_head_dim,
346
+ dropout=dropout,
347
+ bias=attention_bias,
348
+ ) # is a self-attention
349
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
350
+ self.attn2 = CrossAttention(
351
+ query_dim=dim,
352
+ cross_attention_dim=cross_attention_dim,
353
+ heads=num_attention_heads,
354
+ dim_head=attention_head_dim,
355
+ dropout=dropout,
356
+ bias=attention_bias,
357
+ ) # is self-attn if context is none
358
+
359
+ # layer norms
360
+ self.use_ada_layer_norm = num_embeds_ada_norm is not None
361
+ if self.use_ada_layer_norm:
362
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
363
+ self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm)
364
+ else:
365
+ self.norm1 = nn.LayerNorm(dim)
366
+ self.norm2 = nn.LayerNorm(dim)
367
+ self.norm3 = nn.LayerNorm(dim)
368
+
369
+ def _set_attention_slice(self, slice_size):
370
+ self.attn1._slice_size = slice_size
371
+ self.attn2._slice_size = slice_size
372
+
373
+ def _set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
374
+ if not is_xformers_available():
375
+ print("Here is how to install it")
376
+ raise ModuleNotFoundError(
377
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
378
+ " xformers",
379
+ name="xformers",
380
+ )
381
+ elif not torch.cuda.is_available():
382
+ raise ValueError(
383
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
384
+ " available for GPU "
385
+ )
386
+ else:
387
+ try:
388
+ # Make sure we can run the memory efficient attention
389
+ _ = xformers.ops.memory_efficient_attention(
390
+ torch.randn((1, 2, 40), device="cuda"),
391
+ torch.randn((1, 2, 40), device="cuda"),
392
+ torch.randn((1, 2, 40), device="cuda"),
393
+ )
394
+ except Exception as e:
395
+ raise e
396
+ self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
397
+ self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
398
+
399
+ def forward(self, hidden_states, context=None, timestep=None, attn_map=None, attn_shift=False, obj_ids=None, relationship=None):
400
+ # 1. Self-Attention
401
+ norm_hidden_states = (
402
+ self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
403
+ )
404
+ tmp_hidden_states, cross_attn_prob, save_key = self.attn1(norm_hidden_states)
405
+ hidden_states = tmp_hidden_states + hidden_states
406
+
407
+ # 2. Cross-Attention
408
+ norm_hidden_states = (
409
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
410
+ )
411
+ tmp_hidden_states, cross_attn_prob, save_key = self.attn2(norm_hidden_states, context=context, attn_map=attn_map, attn_shift=attn_shift, obj_ids=obj_ids, relationship=relationship)
412
+ hidden_states = tmp_hidden_states + hidden_states
413
+
414
+ # 3. Feed-forward
415
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
416
+
417
+ return hidden_states, cross_attn_prob, save_key
418
+
419
+
420
+ class CrossAttention(nn.Module):
421
+ r"""
422
+ A cross attention layer.
423
+
424
+ Parameters:
425
+ query_dim (`int`): The number of channels in the query.
426
+ cross_attention_dim (`int`, *optional*):
427
+ The number of channels in the context. If not given, defaults to `query_dim`.
428
+ heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
429
+ dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
430
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
431
+ bias (`bool`, *optional*, defaults to False):
432
+ Set to `True` for the query, key, and value linear layers to contain a bias parameter.
433
+ """
434
+
435
+ def __init__(
436
+ self,
437
+ query_dim: int,
438
+ cross_attention_dim: Optional[int] = None,
439
+ heads: int = 8,
440
+ dim_head: int = 64,
441
+ dropout: float = 0.0,
442
+ bias=False,
443
+ ):
444
+ super().__init__()
445
+ inner_dim = dim_head * heads
446
+ cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
447
+
448
+ self.scale = dim_head**-0.5
449
+ self.heads = heads
450
+ # for slice_size > 0 the attention score computation
451
+ # is split across the batch axis to save memory
452
+ # You can set slice_size with `set_attention_slice`
453
+ self._slice_size = None
454
+ self._use_memory_efficient_attention_xformers = False
455
+
456
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
457
+ self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
458
+ self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
459
+
460
+ self.to_out = nn.ModuleList([])
461
+ self.to_out.append(nn.Linear(inner_dim, query_dim))
462
+ self.to_out.append(nn.Dropout(dropout))
463
+
464
+ def reshape_heads_to_batch_dim(self, tensor):
465
+ batch_size, seq_len, dim = tensor.shape
466
+ head_size = self.heads
467
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
468
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
469
+ return tensor
470
+
471
+ def reshape_batch_dim_to_heads(self, tensor):
472
+ batch_size, seq_len, dim = tensor.shape
473
+ head_size = self.heads
474
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
475
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
476
+ return tensor
477
+
478
+ def forward(self, hidden_states, context=None, attn_map=None, use_prev_key=False, prev_key=None, mask=None, attn_shift=False, obj_ids=None, relationship=None):
479
+ batch_size, sequence_length, _ = hidden_states.shape
480
+
481
+ query = self.to_q(hidden_states)
482
+ context = context if context is not None else hidden_states
483
+ key = self.to_k(context)
484
+ value = self.to_v(context)
485
+
486
+ dim = query.shape[-1]
487
+
488
+ query = self.reshape_heads_to_batch_dim(query)
489
+ key = self.reshape_heads_to_batch_dim(key)
490
+ if use_prev_key:
491
+ key = prev_key
492
+ value = self.reshape_heads_to_batch_dim(value)
493
+
494
+ # TODO(PVP) - mask is currently never used. Remember to re-implement when used
495
+
496
+ # attention, what we cannot get enough of
497
+ if self._use_memory_efficient_attention_xformers:
498
+ hidden_states = self._memory_efficient_attention_xformers(query, key, value)
499
+ else:
500
+ if self._slice_size is None or query.shape[0] // self._slice_size == 1:
501
+ hidden_states, attention_probs = self._attention(query, key, value, attn_map=attn_map, attn_shift=attn_shift, obj_ids=obj_ids, relationship=relationship)
502
+ else:
503
+ hidden_states = self._sliced_attention(query, key, value, sequence_length, dim)
504
+
505
+ # linear proj
506
+ hidden_states = self.to_out[0](hidden_states)
507
+ # dropout
508
+ hidden_states = self.to_out[1](hidden_states)
509
+ return hidden_states, attention_probs, key
510
+
511
+ def _attention(self, query, key, value, attn_map=None, attn_shift=False, attn_mask=None, obj_ids=None, relationship=None):
512
+ # TODO: use baddbmm for better performance
513
+ if query.device.type == "mps":
514
+ # Better performance on mps (~20-25%)
515
+ attention_scores = torch.einsum("b i d, b j d -> b i j", query, key) * self.scale
516
+ else:
517
+ attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale
518
+ attention_probs = attention_scores.softmax(dim=-1)
519
+ # compute attention output
520
+
521
+ if query.device.type == "mps":
522
+ hidden_states = torch.einsum("b i j, b j d -> b i d", attention_probs, value)
523
+ else:
524
+ per_image_size = attention_probs.shape[0] // 2
525
+
526
+ if attn_map is not None:
527
+
528
+ print(attn_map.shape, attention_probs.shape)
529
+ # # hidden_states = torch.matmul(attention_probs, value)
530
+ # # print(attention_probs.shape, attn_map.shape)
531
+ # #
532
+ # b, i, j = attention_probs.shape
533
+ # H = W = int(math.sqrt(i))
534
+ # # # random_start = torch.randn(size=(b, j, i))
535
+ # # # random_start = (random_start/random_start.sum(-1).unsqueeze(-1)).permute(0, 2, 1).cuda()
536
+ # # # attention_probs[per_image_size:, :, 7:] = random_start[per_image_size:, :, 0].unsqueeze(-1)
537
+ # n = np.zeros((H, W))
538
+ # n[H//2, 1*W//4] = 1
539
+ # # n[3*H//4, 1*W//4] = 1
540
+ # # n[3*H//4, 3*W//4] = 1
541
+ #
542
+ # attention_weight_cat = torch.from_numpy(scipy.ndimage.gaussian_filter(n, sigma=H/12)).cuda().reshape(-1)
543
+ # # print(attention_probs[per_image_size:, :, :].shape, attention_weight_cat.shape )
544
+ # # attention_probs[per_image_size:, :, 4:] = attention_probs[per_image_size:, :, 4:] * attention_weight_cat.unsqueeze(-1)
545
+ # # attention_probs[per_image_size:, :, :] = attention_probs[per_image_size:, :, :]/attention_probs[per_image_size:, :, :].sum(dim=1).unsqueeze(1)
546
+ # #
547
+ # n = np.zeros((H, W))
548
+ # n[H//2, 3*W//4] = 1
549
+ # attention_weight_dog = torch.from_numpy(scipy.ndimage.gaussian_filter(n, sigma=H/12)).cuda().reshape(-1)
550
+ # #
551
+ # attention_weight_all = attention_weight_dog + attention_weight_cat
552
+ # # attention_weight_all = torch.ones(size=attention_weight_motorbike.shape)
553
+ # attention_weight_all_normalized = attention_weight_all/attention_weight_all.sum()
554
+ # # attention_probs[per_image_size:, :, 1:] = attention_weight_all_normalized.unsqueeze(-1)
555
+ # #
556
+ # attention_weight_bg = attention_weight_dog + attention_weight_cat
557
+ # # attention_weight_bg = torch.ones(size=attention_weight_motorbike.shape)
558
+ # attention_weight_all_normalized_bg = attention_weight_bg/attention_weight_bg.sum()
559
+ # attention_weight_all_normalized_bg_reverse = attention_weight_all_normalized_bg.max() - attention_weight_all_normalized_bg
560
+ # # attention_weight_all_normalized_bg_reverse = torch.ones(size=attention_weight_motorbike.shape)
561
+ # attention_weight_all_normalized_bg_reverse = attention_weight_all_normalized_bg_reverse/attention_weight_all_normalized_bg_reverse.sum()
562
+ # # attention_probs[per_image_size:, :, 0] = attention_weight_all_normalized_bg_reverse
563
+ # #
564
+ # # per_image_size = attention_probs.shape[0] // 2
565
+ #
566
+ # # attention_probs[per_image_size:] = attn_map if attn_map.shape[0] == per_image_size else attn_map[per_image_size:]
567
+ # # attention_probs_new = attention_probs.clone()
568
+ # # attention_probs[per_image_size:, :, 1] = attention_probs_new[per_image_size:, :, 3]
569
+ # # attention_probs[per_image_size:, :, 3] = attention_probs_new[per_image_size:, :, 1]
570
+
571
+ if attn_shift:
572
+ # print("???")
573
+
574
+
575
+ b, i, j = attention_probs.shape
576
+ H = W = int(math.sqrt(i))
577
+ strength = relationship['strength']
578
+ spatial_relationship = relationship['spatial_relationship']
579
+
580
+ # print(obj_ids, relationship)
581
+
582
+
583
+ ##### padding token one
584
+ if relationship['padding_token']:
585
+ # print("forward with padding_token")
586
+ n = np.zeros((H, W))
587
+ padding_token_start = relationship['padding_start']
588
+ # print(relationship)
589
+ if spatial_relationship == 0:
590
+ n[H // 2, 1 * W // 4] = 1/2
591
+ n[H // 2, 3 * W // 4] = 1/2
592
+ elif spatial_relationship == 1:
593
+ n[H // 2, 1 * W // 4] = 1/2
594
+ n[H // 2, 3 * W // 4] = 1/2
595
+ elif spatial_relationship == 2:
596
+ n[1 * H // 4, W // 2] = 1/2
597
+ n[3 * H // 4, W // 2] = 1/2
598
+ elif spatial_relationship == 3:
599
+ n[1 * H // 4, W // 2] = 1/2
600
+ n[3 * H // 4, W // 2] = 1/2
601
+ attention_weight_obj_a = torch.from_numpy(
602
+ scipy.ndimage.gaussian_filter(n, sigma=H / 8)).cuda().reshape(-1)
603
+ # print((attention_weight_obj_a / attention_weight_obj_a.sum()).shape)
604
+ attention_weight_obj_a_normalized = torch.tile(
605
+ (attention_weight_obj_a / attention_weight_obj_a.sum()).unsqueeze(0).unsqueeze(-1), (b // 2, 1, j-padding_token_start))
606
+ # print(attention_weight_obj_a_normalized.shape)
607
+
608
+ word_sum = torch.tile(attention_probs[per_image_size:, :, padding_token_start:].sum(dim=-2).unsqueeze(-2), (1, i, 1))
609
+ attention_probs[per_image_size:, :, padding_token_start:] = (1-strength)*attention_probs[per_image_size:, :, padding_token_start:] + strength * attention_weight_obj_a_normalized * word_sum
610
+
611
+ ### start token
612
+ n = np.zeros((H, W))
613
+ # print("use start token", relationship)
614
+ if spatial_relationship == 0:
615
+ n[H // 2, 1 * W // 4] = 1/2
616
+ n[H // 2, 3 * W // 4] = 1/2
617
+ elif spatial_relationship == 1:
618
+ n[H // 2, 1 * W // 4] = 1/2
619
+ n[H // 2, 3 * W // 4] = 1/2
620
+ elif spatial_relationship == 2:
621
+ n[1 * H // 4, W // 2] = 1/2
622
+ n[3 * H // 4, W // 2] = 1/2
623
+ elif spatial_relationship == 3:
624
+ n[1 * H // 4, W // 2] = 1/2
625
+ n[3 * H // 4, W // 2] = 1/2
626
+
627
+
628
+ attention_weight_obj_a = torch.from_numpy(
629
+ scipy.ndimage.gaussian_filter(n, sigma=H / 8)).cuda().reshape(-1)
630
+ attention_weight_obj_a = 1 - attention_weight_obj_a
631
+ # print((attention_weight_obj_a / attention_weight_obj_a.sum()).shape)
632
+ attention_weight_obj_a_normalized = torch.tile(
633
+ (attention_weight_obj_a / attention_weight_obj_a.sum()).unsqueeze(0),
634
+ (b // 2, 1))
635
+ # print(attention_weight_obj_a_normalized.shape)
636
+
637
+ word_sum = attention_probs[per_image_size:, :, 0].sum(dim=-1)
638
+ # print("before the adding", attention_probs[per_image_size:, :, 0].sum(dim=-1)[0])
639
+ # print("adding noise" , (attention_weight_obj_a_normalized * word_sum.unsqueeze(-1)).sum(dim=-1)[0])
640
+ # print("before the adding" ,attention_probs[per_image_size:, :, 0].sum(dim=-1)[0], )
641
+
642
+ attention_probs[per_image_size:, :, 0] = (1 - strength) * attention_probs[per_image_size:, :, 0] + strength * attention_weight_obj_a_normalized * word_sum.unsqueeze(-1)
643
+ # print("after the adding", attention_probs[per_image_size:, :, 0].sum(dim=-1)[0])
644
+ ### end
645
+
646
+
647
+ ### one token
648
+ #
649
+ # n = np.zeros((H, W))
650
+ # n[3 * H // 4, 1 * W // 4] = 1
651
+ # obj_a_ids = 5
652
+ # # obj_b_ids = obj_ids[1]
653
+ # attention_weight_obj_a = torch.from_numpy(
654
+ # scipy.ndimage.gaussian_filter(n, sigma=H / 8)).cuda().reshape(-1)
655
+ # # print((attention_weight_obj_a / attention_weight_obj_a.sum()).shape)
656
+ # attention_weight_obj_a_normalized = torch.tile(
657
+ # (attention_weight_obj_a / attention_weight_obj_a.sum()).unsqueeze(0),
658
+ # (b // 2, 1))
659
+ # # print(attention_weight_obj_a_normalized.shape)
660
+ # word_sum = attention_probs[per_image_size:, :, obj_a_ids].sum(dim=-1)
661
+ # print(word_sum.shape, attention_weight_obj_a_normalized.shape)
662
+ #
663
+ # attention_probs[per_image_size:, :, obj_a_ids] = (1 - strength) * attention_probs[per_image_size:, :,
664
+ # obj_a_ids] + strength * attention_weight_obj_a_normalized * word_sum.unsqueeze(-1)
665
+
666
+
667
+ ###### Normal one
668
+ obj_a_ids = obj_ids[0]
669
+ obj_b_ids = obj_ids[1]
670
+ # obj_a_ids = [2]
671
+ # obj_b_ids = [8]
672
+ strength = relationship['strength']
673
+ spatial_relationship = relationship['spatial_relationship']
674
+ # print("use_normal_one")
675
+ for obj_a_id in obj_a_ids:
676
+ n = np.zeros((H, W))
677
+ if spatial_relationship == 0:
678
+ n[H // 2, 1 * W // 4] = 1
679
+ elif spatial_relationship == 1:
680
+ n[H // 2, 3 * W // 4] = 1
681
+ elif spatial_relationship == 2:
682
+ n[1 * H // 4, W // 2] = 1
683
+ elif spatial_relationship == 3:
684
+ n[3 * H // 4, W // 2] = 1
685
+
686
+ attention_weight_obj_a = torch.from_numpy(
687
+ scipy.ndimage.gaussian_filter(n, sigma=H / 8)).cuda().reshape(-1)
688
+ attention_weight_obj_a_normalized = torch.tile(
689
+ (attention_weight_obj_a / attention_weight_obj_a.sum()).unsqueeze(0), (b // 2, 1))
690
+
691
+ word_sum = attention_probs[per_image_size:, :, obj_a_id].sum(dim=-1)
692
+ attention_probs[per_image_size:, :, obj_a_id] = (1-strength)*attention_probs[per_image_size:, :, obj_a_id] + strength * attention_weight_obj_a_normalized * word_sum.unsqueeze(-1)
693
+
694
+ for obj_id in obj_b_ids:
695
+ n = np.zeros((H, W))
696
+ if spatial_relationship == 0:
697
+ n[H // 2, 3 * W // 4] = 1
698
+ elif spatial_relationship == 1:
699
+ n[H // 2, 1 * W // 4] = 1
700
+ elif spatial_relationship == 2:
701
+ n[3 * H // 4, W // 2] = 1
702
+ elif spatial_relationship == 3:
703
+ n[1 * H // 4, W // 2] = 1
704
+
705
+ attention_weight_obj = torch.from_numpy(
706
+ scipy.ndimage.gaussian_filter(n, sigma=H / 8)).cuda().reshape(-1)
707
+ attention_weight_obj_normalized = torch.tile(
708
+ (attention_weight_obj / attention_weight_obj.sum()).unsqueeze(0), (b // 2, 1))
709
+
710
+ word_sum = attention_probs[per_image_size:, :, obj_id].sum(dim=-1)
711
+ attention_probs[per_image_size:, :, obj_id] = (1-strength) * attention_probs[per_image_size:, :,obj_id] + strength * attention_weight_obj_normalized * word_sum.unsqueeze(-1)
712
+
713
+
714
+
715
+
716
+ # n = np.zeros((H, W))
717
+
718
+
719
+ # if relationship =
720
+ # n[H//2, 1*W//4] = 1
721
+
722
+ # attention_weight_dog = torch.from_numpy(scipy.ndimage.gaussian_filter(n, sigma=H/8)).cuda().reshape(-1)
723
+ # attention_weight_dog_normalized = torch.tile((attention_weight_dog/attention_weight_dog.sum()).unsqueeze(0),(b//2, 1))
724
+ # word_sum = attention_probs[per_image_size:, :, 8].sum(dim=-1)
725
+ # attention_probs[per_image_size:, :, 8] = 0 * attention_probs[per_image_size:, :, 1] + 1 * attention_weight_dog_normalized * word_sum.unsqueeze(-1)
726
+
727
+
728
+ # attention_weight_motorbike = torch.from_numpy(scipy.ndimage.gaussian_filter(n, sigma=H/12)).cuda().reshape(-1)
729
+ # attention_weight_motorbike_normalized = torch.tile(attention_weight_motorbike/attention_weight_motorbike.sum().unsqueeze(0), (b//2, 1))
730
+ #
731
+ #
732
+ # # print('attention_probs', attention_probs[per_image_size:, :, 3].sum(dim=-1))
733
+ # print(attention_weight_motorbike_normalized.shape, attention_probs[per_image_size:, :, 3].sum(dim=-1))
734
+
735
+ # attention_probs[per_image_size:, :, 3] = 0.9 * attention_weight_motorbike_normalized * attention_probs[per_image_size:, :, 3].sum(dim=-1).unsqueeze(-1) + 0.1 * attention_probs[per_image_size:, :, 3]
736
+
737
+ # attention_weight_all = attention_weight_motorbike + attention_weight_cat
738
+ # attention_weight_all_normalized = attention_weight_all/attention_weight_all.sum()
739
+ # attention_probs[per_image_size:, :, 4:] = attention_weight_all_normalized.unsqueeze(-1)
740
+
741
+
742
+
743
+ # b, i, j = attention_probs.shape
744
+ #
745
+ # H = W = int(math.sqrt(i))
746
+ # attention_probs_reshape = attention_probs.permute(0, 2, 1).reshape(b, j, H, W)
747
+ # if attn_mask is None:
748
+ # attn_mask = torch.zeros(size=attention_probs_reshape.shape).cuda()
749
+ # attn_mask[:, :, H//2:, W//2:] = 1
750
+ # # print(attention_probs_reshape.is_cuda, attention_probs_reshape.get_device())
751
+ # # attn_mask.cuda()
752
+ # attention_probs_reshape = attention_probs_reshape * attn_mask
753
+ # else:
754
+ # attn_mask.cuda()
755
+ # attention_probs_reshape = attention_probs_reshape * attn_mask
756
+ # attention_probs_reshape = attention_probs_reshape.reshape(b, j, i)
757
+ # attention_probs_reshape = attention_probs_reshape/(attention_probs_reshape.sum(dim=-1).unsqueeze(-1))
758
+ # attention_probs[per_image_size:] = attention_probs_reshape.permute(0, 2, 1)[per_image_size:]
759
+ # if attn_shift:
760
+ # b, i, j = attention_probs.shape
761
+ # H = W = int(math.sqrt(i))
762
+ # attention_map_hw = attention_probs.permute(0, 2, 1).reshape(b, j, H, W)
763
+ # # print("attention_map_hw", attention_map_hw.shape)
764
+ # attention_map_hw_pad = F.pad(attention_map_hw, (W//2, W//2), "constant", 0)
765
+ # # print("attention_map_hw_pad", attention_map_hw_pad.shape)
766
+ # attention_map_hw_pad = torch.roll(attention_map_hw_pad, W//4, -1)
767
+ # attention_map_hw_pad_crop = attention_map_hw_pad[:, :, :, W//2:W//2 + W].reshape(b, j, i)
768
+ # attention_map_flatten_pad_crop_sum = attention_map_hw_pad_crop.sum(dim=-1)
769
+ # attention_map_hw_pad_crop = (attention_map_hw_pad_crop/attention_map_flatten_pad_crop_sum.unsqueeze(-1)).permute(0, 2, 1)
770
+ # # attention_map_hw_pad_crop = attention_map_hw_pad_crop.reshape(b, j, i).permute(0, 2, 1)
771
+ # # attention_map_hw_pad_crop_sum = attention_map_hw_pad_crop.sum(dim=-2)
772
+ # # print(attention_map_hw_pad_crop.min())
773
+ # # print("attention_map_hw_pad_crop", attention_map_hw_pad_crop.shape)
774
+ # # attention_probs[per_image_size:, :, (2, 6)] = attention_map_hw_pad_crop.softmax(dim=-1)[per_image_size:, :, (2, 6)]
775
+ # # attention_probs[per_image_size:, :, (2, 6)] = attention_map_hw_pad_crop[per_image_size:, :, (2, 6)]
776
+ # attention_probs[per_image_size:] = attention_map_hw_pad_crop[per_image_size:]
777
+
778
+
779
+ # if attn_blob:
780
+ # n = np.zeros((21, 21))
781
+ # n[10, 10] = 1
782
+ # k = scipy.ndimage.gaussian_filter(n, sigma=3)
783
+ # else:
784
+ # # print(attention_probs.shape)
785
+ # hidden_states = torch.matmul(attention_probs, value)
786
+ hidden_states = torch.matmul(attention_probs, value)
787
+
788
+ # reshape hidden_states
789
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
790
+ return hidden_states, attention_probs
791
+
792
+ def _sliced_attention(self, query, key, value, sequence_length, dim):
793
+ batch_size_attention = query.shape[0]
794
+ hidden_states = torch.zeros(
795
+ (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
796
+ )
797
+ slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
798
+ for i in range(hidden_states.shape[0] // slice_size):
799
+ start_idx = i * slice_size
800
+ end_idx = (i + 1) * slice_size
801
+ if query.device.type == "mps":
802
+ # Better performance on mps (~20-25%)
803
+ attn_slice = (
804
+ torch.einsum("b i d, b j d -> b i j", query[start_idx:end_idx], key[start_idx:end_idx])
805
+ * self.scale
806
+ )
807
+ else:
808
+ attn_slice = (
809
+ torch.matmul(query[start_idx:end_idx], key[start_idx:end_idx].transpose(1, 2)) * self.scale
810
+ ) # TODO: use baddbmm for better performance
811
+ attn_slice = attn_slice.softmax(dim=-1)
812
+ if query.device.type == "mps":
813
+ attn_slice = torch.einsum("b i j, b j d -> b i d", attn_slice, value[start_idx:end_idx])
814
+ else:
815
+ attn_slice = torch.matmul(attn_slice, value[start_idx:end_idx])
816
+
817
+ hidden_states[start_idx:end_idx] = attn_slice
818
+
819
+ # reshape hidden_states
820
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
821
+ return hidden_states
822
+
823
+ def _memory_efficient_attention_xformers(self, query, key, value):
824
+ hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=None)
825
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
826
+ return hidden_states
827
+
828
+
829
+ class FeedForward(nn.Module):
830
+ r"""
831
+ A feed-forward layer.
832
+
833
+ Parameters:
834
+ dim (`int`): The number of channels in the input.
835
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
836
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
837
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
838
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
839
+ """
840
+
841
+ def __init__(
842
+ self,
843
+ dim: int,
844
+ dim_out: Optional[int] = None,
845
+ mult: int = 4,
846
+ dropout: float = 0.0,
847
+ activation_fn: str = "geglu",
848
+ ):
849
+ super().__init__()
850
+ inner_dim = int(dim * mult)
851
+ dim_out = dim_out if dim_out is not None else dim
852
+
853
+ if activation_fn == "geglu":
854
+ geglu = GEGLU(dim, inner_dim)
855
+ elif activation_fn == "geglu-approximate":
856
+ geglu = ApproximateGELU(dim, inner_dim)
857
+
858
+ self.net = nn.ModuleList([])
859
+ # project in
860
+ self.net.append(geglu)
861
+ # project dropout
862
+ self.net.append(nn.Dropout(dropout))
863
+ # project out
864
+ self.net.append(nn.Linear(inner_dim, dim_out))
865
+
866
+ def forward(self, hidden_states):
867
+ for module in self.net:
868
+ hidden_states = module(hidden_states)
869
+ return hidden_states
870
+
871
+
872
+ # feedforward
873
+ class GEGLU(nn.Module):
874
+ r"""
875
+ A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
876
+
877
+ Parameters:
878
+ dim_in (`int`): The number of channels in the input.
879
+ dim_out (`int`): The number of channels in the output.
880
+ """
881
+
882
+ def __init__(self, dim_in: int, dim_out: int):
883
+ super().__init__()
884
+ self.proj = nn.Linear(dim_in, dim_out * 2)
885
+
886
+ def gelu(self, gate):
887
+ if gate.device.type != "mps":
888
+ return F.gelu(gate)
889
+ # mps: gelu is not implemented for float16
890
+ return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
891
+
892
+ def forward(self, hidden_states):
893
+ hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
894
+ return hidden_states * self.gelu(gate)
895
+
896
+
897
+ class ApproximateGELU(nn.Module):
898
+ """
899
+ The approximate form of Gaussian Error Linear Unit (GELU)
900
+
901
+ For more details, see section 2: https://arxiv.org/abs/1606.08415
902
+ """
903
+
904
+ def __init__(self, dim_in: int, dim_out: int):
905
+ super().__init__()
906
+ self.proj = nn.Linear(dim_in, dim_out)
907
+
908
+ def forward(self, x):
909
+ x = self.proj(x)
910
+ return x * torch.sigmoid(1.702 * x)
911
+
912
+
913
+ class AdaLayerNorm(nn.Module):
914
+ """
915
+ Norm layer modified to incorporate timestep embeddings.
916
+ """
917
+
918
+ def __init__(self, embedding_dim, num_embeddings):
919
+ super().__init__()
920
+ self.emb = nn.Embedding(num_embeddings, embedding_dim)
921
+ self.silu = nn.SiLU()
922
+ self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
923
+ self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False)
924
+
925
+ def forward(self, x, timestep):
926
+ emb = self.linear(self.silu(self.emb(timestep)))
927
+ scale, shift = torch.chunk(emb, 2)
928
+ x = self.norm(x) * (1 + scale) + shift
929
+ return x
my_model/unet_2d_blocks.py ADDED
@@ -0,0 +1,1612 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import numpy as np
15
+ import torch
16
+ from torch import nn
17
+
18
+ from .attention import AttentionBlock, Transformer2DModel
19
+ from diffusers.models.resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock2D, Upsample2D
20
+
21
+
22
+ def get_down_block(
23
+ down_block_type,
24
+ num_layers,
25
+ in_channels,
26
+ out_channels,
27
+ temb_channels,
28
+ add_downsample,
29
+ resnet_eps,
30
+ resnet_act_fn,
31
+ attn_num_head_channels,
32
+ resnet_groups=None,
33
+ cross_attention_dim=None,
34
+ downsample_padding=None,
35
+ ):
36
+ down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
37
+ if down_block_type == "DownBlock2D":
38
+ return DownBlock2D(
39
+ num_layers=num_layers,
40
+ in_channels=in_channels,
41
+ out_channels=out_channels,
42
+ temb_channels=temb_channels,
43
+ add_downsample=add_downsample,
44
+ resnet_eps=resnet_eps,
45
+ resnet_act_fn=resnet_act_fn,
46
+ resnet_groups=resnet_groups,
47
+ downsample_padding=downsample_padding,
48
+ )
49
+ elif down_block_type == "AttnDownBlock2D":
50
+ return AttnDownBlock2D(
51
+ num_layers=num_layers,
52
+ in_channels=in_channels,
53
+ out_channels=out_channels,
54
+ temb_channels=temb_channels,
55
+ add_downsample=add_downsample,
56
+ resnet_eps=resnet_eps,
57
+ resnet_act_fn=resnet_act_fn,
58
+ resnet_groups=resnet_groups,
59
+ downsample_padding=downsample_padding,
60
+ attn_num_head_channels=attn_num_head_channels,
61
+ )
62
+ elif down_block_type == "CrossAttnDownBlock2D":
63
+ if cross_attention_dim is None:
64
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D")
65
+ return CrossAttnDownBlock2D(
66
+ num_layers=num_layers,
67
+ in_channels=in_channels,
68
+ out_channels=out_channels,
69
+ temb_channels=temb_channels,
70
+ add_downsample=add_downsample,
71
+ resnet_eps=resnet_eps,
72
+ resnet_act_fn=resnet_act_fn,
73
+ resnet_groups=resnet_groups,
74
+ downsample_padding=downsample_padding,
75
+ cross_attention_dim=cross_attention_dim,
76
+ attn_num_head_channels=attn_num_head_channels,
77
+ )
78
+ elif down_block_type == "SkipDownBlock2D":
79
+ return SkipDownBlock2D(
80
+ num_layers=num_layers,
81
+ in_channels=in_channels,
82
+ out_channels=out_channels,
83
+ temb_channels=temb_channels,
84
+ add_downsample=add_downsample,
85
+ resnet_eps=resnet_eps,
86
+ resnet_act_fn=resnet_act_fn,
87
+ downsample_padding=downsample_padding,
88
+ )
89
+ elif down_block_type == "AttnSkipDownBlock2D":
90
+ return AttnSkipDownBlock2D(
91
+ num_layers=num_layers,
92
+ in_channels=in_channels,
93
+ out_channels=out_channels,
94
+ temb_channels=temb_channels,
95
+ add_downsample=add_downsample,
96
+ resnet_eps=resnet_eps,
97
+ resnet_act_fn=resnet_act_fn,
98
+ downsample_padding=downsample_padding,
99
+ attn_num_head_channels=attn_num_head_channels,
100
+ )
101
+ elif down_block_type == "DownEncoderBlock2D":
102
+ return DownEncoderBlock2D(
103
+ num_layers=num_layers,
104
+ in_channels=in_channels,
105
+ out_channels=out_channels,
106
+ add_downsample=add_downsample,
107
+ resnet_eps=resnet_eps,
108
+ resnet_act_fn=resnet_act_fn,
109
+ resnet_groups=resnet_groups,
110
+ downsample_padding=downsample_padding,
111
+ )
112
+ elif down_block_type == "AttnDownEncoderBlock2D":
113
+ return AttnDownEncoderBlock2D(
114
+ num_layers=num_layers,
115
+ in_channels=in_channels,
116
+ out_channels=out_channels,
117
+ add_downsample=add_downsample,
118
+ resnet_eps=resnet_eps,
119
+ resnet_act_fn=resnet_act_fn,
120
+ resnet_groups=resnet_groups,
121
+ downsample_padding=downsample_padding,
122
+ attn_num_head_channels=attn_num_head_channels,
123
+ )
124
+ raise ValueError(f"{down_block_type} does not exist.")
125
+
126
+
127
+ def get_up_block(
128
+ up_block_type,
129
+ num_layers,
130
+ in_channels,
131
+ out_channels,
132
+ prev_output_channel,
133
+ temb_channels,
134
+ add_upsample,
135
+ resnet_eps,
136
+ resnet_act_fn,
137
+ attn_num_head_channels,
138
+ resnet_groups=None,
139
+ cross_attention_dim=None,
140
+ ):
141
+ up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
142
+ if up_block_type == "UpBlock2D":
143
+ return UpBlock2D(
144
+ num_layers=num_layers,
145
+ in_channels=in_channels,
146
+ out_channels=out_channels,
147
+ prev_output_channel=prev_output_channel,
148
+ temb_channels=temb_channels,
149
+ add_upsample=add_upsample,
150
+ resnet_eps=resnet_eps,
151
+ resnet_act_fn=resnet_act_fn,
152
+ resnet_groups=resnet_groups,
153
+ )
154
+ elif up_block_type == "CrossAttnUpBlock2D":
155
+ if cross_attention_dim is None:
156
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D")
157
+ return CrossAttnUpBlock2D(
158
+ num_layers=num_layers,
159
+ in_channels=in_channels,
160
+ out_channels=out_channels,
161
+ prev_output_channel=prev_output_channel,
162
+ temb_channels=temb_channels,
163
+ add_upsample=add_upsample,
164
+ resnet_eps=resnet_eps,
165
+ resnet_act_fn=resnet_act_fn,
166
+ resnet_groups=resnet_groups,
167
+ cross_attention_dim=cross_attention_dim,
168
+ attn_num_head_channels=attn_num_head_channels,
169
+ )
170
+ elif up_block_type == "AttnUpBlock2D":
171
+ return AttnUpBlock2D(
172
+ num_layers=num_layers,
173
+ in_channels=in_channels,
174
+ out_channels=out_channels,
175
+ prev_output_channel=prev_output_channel,
176
+ temb_channels=temb_channels,
177
+ add_upsample=add_upsample,
178
+ resnet_eps=resnet_eps,
179
+ resnet_act_fn=resnet_act_fn,
180
+ resnet_groups=resnet_groups,
181
+ attn_num_head_channels=attn_num_head_channels,
182
+ )
183
+ elif up_block_type == "SkipUpBlock2D":
184
+ return SkipUpBlock2D(
185
+ num_layers=num_layers,
186
+ in_channels=in_channels,
187
+ out_channels=out_channels,
188
+ prev_output_channel=prev_output_channel,
189
+ temb_channels=temb_channels,
190
+ add_upsample=add_upsample,
191
+ resnet_eps=resnet_eps,
192
+ resnet_act_fn=resnet_act_fn,
193
+ )
194
+ elif up_block_type == "AttnSkipUpBlock2D":
195
+ return AttnSkipUpBlock2D(
196
+ num_layers=num_layers,
197
+ in_channels=in_channels,
198
+ out_channels=out_channels,
199
+ prev_output_channel=prev_output_channel,
200
+ temb_channels=temb_channels,
201
+ add_upsample=add_upsample,
202
+ resnet_eps=resnet_eps,
203
+ resnet_act_fn=resnet_act_fn,
204
+ attn_num_head_channels=attn_num_head_channels,
205
+ )
206
+ elif up_block_type == "UpDecoderBlock2D":
207
+ return UpDecoderBlock2D(
208
+ num_layers=num_layers,
209
+ in_channels=in_channels,
210
+ out_channels=out_channels,
211
+ add_upsample=add_upsample,
212
+ resnet_eps=resnet_eps,
213
+ resnet_act_fn=resnet_act_fn,
214
+ resnet_groups=resnet_groups,
215
+ )
216
+ elif up_block_type == "AttnUpDecoderBlock2D":
217
+ return AttnUpDecoderBlock2D(
218
+ num_layers=num_layers,
219
+ in_channels=in_channels,
220
+ out_channels=out_channels,
221
+ add_upsample=add_upsample,
222
+ resnet_eps=resnet_eps,
223
+ resnet_act_fn=resnet_act_fn,
224
+ resnet_groups=resnet_groups,
225
+ attn_num_head_channels=attn_num_head_channels,
226
+ )
227
+ raise ValueError(f"{up_block_type} does not exist.")
228
+
229
+
230
+ class UNetMidBlock2D(nn.Module):
231
+ def __init__(
232
+ self,
233
+ in_channels: int,
234
+ temb_channels: int,
235
+ dropout: float = 0.0,
236
+ num_layers: int = 1,
237
+ resnet_eps: float = 1e-6,
238
+ resnet_time_scale_shift: str = "default",
239
+ resnet_act_fn: str = "swish",
240
+ resnet_groups: int = 32,
241
+ resnet_pre_norm: bool = True,
242
+ attn_num_head_channels=1,
243
+ attention_type="default",
244
+ output_scale_factor=1.0,
245
+ **kwargs,
246
+ ):
247
+ super().__init__()
248
+
249
+ self.attention_type = attention_type
250
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
251
+
252
+ # there is always at least one resnet
253
+ resnets = [
254
+ ResnetBlock2D(
255
+ in_channels=in_channels,
256
+ out_channels=in_channels,
257
+ temb_channels=temb_channels,
258
+ eps=resnet_eps,
259
+ groups=resnet_groups,
260
+ dropout=dropout,
261
+ time_embedding_norm=resnet_time_scale_shift,
262
+ non_linearity=resnet_act_fn,
263
+ output_scale_factor=output_scale_factor,
264
+ pre_norm=resnet_pre_norm,
265
+ )
266
+ ]
267
+ attentions = []
268
+
269
+ for _ in range(num_layers):
270
+ attentions.append(
271
+ AttentionBlock(
272
+ in_channels,
273
+ num_head_channels=attn_num_head_channels,
274
+ rescale_output_factor=output_scale_factor,
275
+ eps=resnet_eps,
276
+ norm_num_groups=resnet_groups,
277
+ )
278
+ )
279
+ resnets.append(
280
+ ResnetBlock2D(
281
+ in_channels=in_channels,
282
+ out_channels=in_channels,
283
+ temb_channels=temb_channels,
284
+ eps=resnet_eps,
285
+ groups=resnet_groups,
286
+ dropout=dropout,
287
+ time_embedding_norm=resnet_time_scale_shift,
288
+ non_linearity=resnet_act_fn,
289
+ output_scale_factor=output_scale_factor,
290
+ pre_norm=resnet_pre_norm,
291
+ )
292
+ )
293
+
294
+ self.attentions = nn.ModuleList(attentions)
295
+ self.resnets = nn.ModuleList(resnets)
296
+
297
+ def forward(self, hidden_states, temb=None, encoder_states=None):
298
+ hidden_states = self.resnets[0](hidden_states, temb)
299
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
300
+ if self.attention_type == "default":
301
+ hidden_states = attn(hidden_states)
302
+ else:
303
+ hidden_states = attn(hidden_states, encoder_states)
304
+ hidden_states = resnet(hidden_states, temb)
305
+
306
+ return hidden_states
307
+
308
+
309
+ class UNetMidBlock2DCrossAttn(nn.Module):
310
+ def __init__(
311
+ self,
312
+ in_channels: int,
313
+ temb_channels: int,
314
+ dropout: float = 0.0,
315
+ num_layers: int = 1,
316
+ resnet_eps: float = 1e-6,
317
+ resnet_time_scale_shift: str = "default",
318
+ resnet_act_fn: str = "swish",
319
+ resnet_groups: int = 32,
320
+ resnet_pre_norm: bool = True,
321
+ attn_num_head_channels=1,
322
+ attention_type="default",
323
+ output_scale_factor=1.0,
324
+ cross_attention_dim=1280,
325
+ **kwargs,
326
+ ):
327
+ super().__init__()
328
+
329
+ self.attention_type = attention_type
330
+ self.attn_num_head_channels = attn_num_head_channels
331
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
332
+
333
+ # there is always at least one resnet
334
+ resnets = [
335
+ ResnetBlock2D(
336
+ in_channels=in_channels,
337
+ out_channels=in_channels,
338
+ temb_channels=temb_channels,
339
+ eps=resnet_eps,
340
+ groups=resnet_groups,
341
+ dropout=dropout,
342
+ time_embedding_norm=resnet_time_scale_shift,
343
+ non_linearity=resnet_act_fn,
344
+ output_scale_factor=output_scale_factor,
345
+ pre_norm=resnet_pre_norm,
346
+ )
347
+ ]
348
+ attentions = []
349
+
350
+ for _ in range(num_layers):
351
+ attentions.append(
352
+ Transformer2DModel(
353
+ attn_num_head_channels,
354
+ in_channels // attn_num_head_channels,
355
+ in_channels=in_channels,
356
+ num_layers=1,
357
+ cross_attention_dim=cross_attention_dim,
358
+ norm_num_groups=resnet_groups,
359
+ )
360
+ )
361
+ resnets.append(
362
+ ResnetBlock2D(
363
+ in_channels=in_channels,
364
+ out_channels=in_channels,
365
+ temb_channels=temb_channels,
366
+ eps=resnet_eps,
367
+ groups=resnet_groups,
368
+ dropout=dropout,
369
+ time_embedding_norm=resnet_time_scale_shift,
370
+ non_linearity=resnet_act_fn,
371
+ output_scale_factor=output_scale_factor,
372
+ pre_norm=resnet_pre_norm,
373
+ )
374
+ )
375
+
376
+ self.attentions = nn.ModuleList(attentions)
377
+ self.resnets = nn.ModuleList(resnets)
378
+
379
+ def set_attention_slice(self, slice_size):
380
+ if slice_size is not None and self.attn_num_head_channels % slice_size != 0:
381
+ raise ValueError(
382
+ f"Make sure slice_size {slice_size} is a divisor of "
383
+ f"the number of heads used in cross_attention {self.attn_num_head_channels}"
384
+ )
385
+ if slice_size is not None and slice_size > self.attn_num_head_channels:
386
+ raise ValueError(
387
+ f"Chunk_size {slice_size} has to be smaller or equal to "
388
+ f"the number of heads used in cross_attention {self.attn_num_head_channels}"
389
+ )
390
+
391
+ for attn in self.attentions:
392
+ attn._set_attention_slice(slice_size)
393
+
394
+ def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
395
+ for attn in self.attentions:
396
+ attn._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
397
+
398
+ def forward(self, hidden_states, index, temb=None, encoder_hidden_states=None, attn_map=None, attn_shift=False, attn_map_step=20, obj_ids=None, relationship=None):
399
+ device = hidden_states.get_device() if hidden_states.is_cuda else 'cpu'
400
+ hidden_states = self.resnets[0](hidden_states, temb)
401
+ mid_attn = []
402
+ mid_value = []
403
+ for layer_idx, (attn, resnet) in enumerate(zip(self.attentions, self.resnets[1:])):
404
+ hidden_states, cross_attn_prob, save_value = attn(hidden_states, encoder_hidden_states, attn_map=attn_map[layer_idx].chunk(2)[1].to(device) if index < attn_map_step else None, attn_shift=attn_shift, obj_ids=obj_ids, relationship=relationship)
405
+ hidden_states = hidden_states.sample
406
+ hidden_states = resnet(hidden_states, temb)
407
+ mid_attn.append(cross_attn_prob)
408
+ mid_value.append(save_value)
409
+ return hidden_states, mid_attn, mid_value
410
+
411
+
412
+ class AttnDownBlock2D(nn.Module):
413
+ def __init__(
414
+ self,
415
+ in_channels: int,
416
+ out_channels: int,
417
+ temb_channels: int,
418
+ dropout: float = 0.0,
419
+ num_layers: int = 1,
420
+ resnet_eps: float = 1e-6,
421
+ resnet_time_scale_shift: str = "default",
422
+ resnet_act_fn: str = "swish",
423
+ resnet_groups: int = 32,
424
+ resnet_pre_norm: bool = True,
425
+ attn_num_head_channels=1,
426
+ attention_type="default",
427
+ output_scale_factor=1.0,
428
+ downsample_padding=1,
429
+ add_downsample=True,
430
+ ):
431
+ super().__init__()
432
+ resnets = []
433
+ attentions = []
434
+
435
+ self.attention_type = attention_type
436
+
437
+ for i in range(num_layers):
438
+ in_channels = in_channels if i == 0 else out_channels
439
+ resnets.append(
440
+ ResnetBlock2D(
441
+ in_channels=in_channels,
442
+ out_channels=out_channels,
443
+ temb_channels=temb_channels,
444
+ eps=resnet_eps,
445
+ groups=resnet_groups,
446
+ dropout=dropout,
447
+ time_embedding_norm=resnet_time_scale_shift,
448
+ non_linearity=resnet_act_fn,
449
+ output_scale_factor=output_scale_factor,
450
+ pre_norm=resnet_pre_norm,
451
+ )
452
+ )
453
+ attentions.append(
454
+ AttentionBlock(
455
+ out_channels,
456
+ num_head_channels=attn_num_head_channels,
457
+ rescale_output_factor=output_scale_factor,
458
+ eps=resnet_eps,
459
+ norm_num_groups=resnet_groups,
460
+ )
461
+ )
462
+
463
+ self.attentions = nn.ModuleList(attentions)
464
+ self.resnets = nn.ModuleList(resnets)
465
+
466
+ if add_downsample:
467
+ self.downsamplers = nn.ModuleList(
468
+ [
469
+ Downsample2D(
470
+ in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
471
+ )
472
+ ]
473
+ )
474
+ else:
475
+ self.downsamplers = None
476
+
477
+ def forward(self, hidden_states, temb=None):
478
+ output_states = ()
479
+
480
+ for resnet, attn in zip(self.resnets, self.attentions):
481
+ hidden_states = resnet(hidden_states, temb)
482
+ hidden_states = attn(hidden_states)
483
+ output_states += (hidden_states,)
484
+
485
+ if self.downsamplers is not None:
486
+ for downsampler in self.downsamplers:
487
+ hidden_states = downsampler(hidden_states)
488
+
489
+ output_states += (hidden_states,)
490
+
491
+ return hidden_states, output_states
492
+
493
+
494
+ class CrossAttnDownBlock2D(nn.Module):
495
+ def __init__(
496
+ self,
497
+ in_channels: int,
498
+ out_channels: int,
499
+ temb_channels: int,
500
+ dropout: float = 0.0,
501
+ num_layers: int = 1,
502
+ resnet_eps: float = 1e-6,
503
+ resnet_time_scale_shift: str = "default",
504
+ resnet_act_fn: str = "swish",
505
+ resnet_groups: int = 32,
506
+ resnet_pre_norm: bool = True,
507
+ attn_num_head_channels=1,
508
+ cross_attention_dim=1280,
509
+ attention_type="default",
510
+ output_scale_factor=1.0,
511
+ downsample_padding=1,
512
+ add_downsample=True,
513
+ ):
514
+ super().__init__()
515
+ resnets = []
516
+ attentions = []
517
+
518
+ self.attention_type = attention_type
519
+ self.attn_num_head_channels = attn_num_head_channels
520
+
521
+ for i in range(num_layers):
522
+ in_channels = in_channels if i == 0 else out_channels
523
+ resnets.append(
524
+ ResnetBlock2D(
525
+ in_channels=in_channels,
526
+ out_channels=out_channels,
527
+ temb_channels=temb_channels,
528
+ eps=resnet_eps,
529
+ groups=resnet_groups,
530
+ dropout=dropout,
531
+ time_embedding_norm=resnet_time_scale_shift,
532
+ non_linearity=resnet_act_fn,
533
+ output_scale_factor=output_scale_factor,
534
+ pre_norm=resnet_pre_norm,
535
+ )
536
+ )
537
+ attentions.append(
538
+ Transformer2DModel(
539
+ attn_num_head_channels,
540
+ out_channels // attn_num_head_channels,
541
+ in_channels=out_channels,
542
+ num_layers=1,
543
+ cross_attention_dim=cross_attention_dim,
544
+ norm_num_groups=resnet_groups,
545
+ )
546
+ )
547
+ self.attentions = nn.ModuleList(attentions)
548
+ self.resnets = nn.ModuleList(resnets)
549
+
550
+ if add_downsample:
551
+ self.downsamplers = nn.ModuleList(
552
+ [
553
+ Downsample2D(
554
+ in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
555
+ )
556
+ ]
557
+ )
558
+ else:
559
+ self.downsamplers = None
560
+
561
+ self.gradient_checkpointing = False
562
+
563
+ def set_attention_slice(self, slice_size):
564
+ if slice_size is not None and self.attn_num_head_channels % slice_size != 0:
565
+ raise ValueError(
566
+ f"Make sure slice_size {slice_size} is a divisor of "
567
+ f"the number of heads used in cross_attention {self.attn_num_head_channels}"
568
+ )
569
+ if slice_size is not None and slice_size > self.attn_num_head_channels:
570
+ raise ValueError(
571
+ f"Chunk_size {slice_size} has to be smaller or equal to "
572
+ f"the number of heads used in cross_attention {self.attn_num_head_channels}"
573
+ )
574
+
575
+ for attn in self.attentions:
576
+ attn._set_attention_slice(slice_size)
577
+
578
+ def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
579
+ for attn in self.attentions:
580
+ attn._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
581
+
582
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attn_map=None, attn_shift=False, obj_ids=None, relationship=None):
583
+ output_states = ()
584
+ cross_attn_prob_list = []
585
+ for layer_idx, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)):
586
+ if self.training and self.gradient_checkpointing:
587
+
588
+ def create_custom_forward(module, return_dict=None):
589
+ def custom_forward(*inputs):
590
+ if return_dict is not None:
591
+ return module(*inputs, return_dict=return_dict)
592
+ else:
593
+ return module(*inputs)
594
+
595
+ return custom_forward
596
+
597
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
598
+ hidden_states = torch.utils.checkpoint.checkpoint(
599
+ create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states
600
+ )[0]
601
+ else:
602
+ hidden_states = resnet(hidden_states, temb)
603
+
604
+ tmp_hidden_states, cross_attn_prob, save_value = attn(hidden_states, encoder_hidden_states=encoder_hidden_states, attn_map=attn_map[layer_idx] if attn_map is not None else None, attn_shift=attn_shift, obj_ids=obj_ids, relationship=relationship)
605
+ hidden_states = tmp_hidden_states.sample
606
+ # hidden_states, cross_attn_prob = attn(hidden_states, encoder_hidden_states=encoder_hidden_states)
607
+ # hidden_states = hidden_states.sample
608
+
609
+ output_states += (hidden_states,)
610
+ cross_attn_prob_list.append(cross_attn_prob)
611
+ if self.downsamplers is not None:
612
+ for downsampler in self.downsamplers:
613
+ hidden_states = downsampler(hidden_states)
614
+
615
+ output_states += (hidden_states,)
616
+
617
+ return hidden_states, output_states, cross_attn_prob_list, save_value
618
+
619
+
620
+ class DownBlock2D(nn.Module):
621
+ def __init__(
622
+ self,
623
+ in_channels: int,
624
+ out_channels: int,
625
+ temb_channels: int,
626
+ dropout: float = 0.0,
627
+ num_layers: int = 1,
628
+ resnet_eps: float = 1e-6,
629
+ resnet_time_scale_shift: str = "default",
630
+ resnet_act_fn: str = "swish",
631
+ resnet_groups: int = 32,
632
+ resnet_pre_norm: bool = True,
633
+ output_scale_factor=1.0,
634
+ add_downsample=True,
635
+ downsample_padding=1,
636
+ ):
637
+ super().__init__()
638
+ resnets = []
639
+
640
+ for i in range(num_layers):
641
+ in_channels = in_channels if i == 0 else out_channels
642
+ resnets.append(
643
+ ResnetBlock2D(
644
+ in_channels=in_channels,
645
+ out_channels=out_channels,
646
+ temb_channels=temb_channels,
647
+ eps=resnet_eps,
648
+ groups=resnet_groups,
649
+ dropout=dropout,
650
+ time_embedding_norm=resnet_time_scale_shift,
651
+ non_linearity=resnet_act_fn,
652
+ output_scale_factor=output_scale_factor,
653
+ pre_norm=resnet_pre_norm,
654
+ )
655
+ )
656
+
657
+ self.resnets = nn.ModuleList(resnets)
658
+
659
+ if add_downsample:
660
+ self.downsamplers = nn.ModuleList(
661
+ [
662
+ Downsample2D(
663
+ in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
664
+ )
665
+ ]
666
+ )
667
+ else:
668
+ self.downsamplers = None
669
+
670
+ self.gradient_checkpointing = False
671
+
672
+ def forward(self, hidden_states, temb=None):
673
+ output_states = ()
674
+
675
+ for resnet in self.resnets:
676
+ if self.training and self.gradient_checkpointing:
677
+
678
+ def create_custom_forward(module):
679
+ def custom_forward(*inputs):
680
+ return module(*inputs)
681
+
682
+ return custom_forward
683
+
684
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
685
+ else:
686
+ hidden_states = resnet(hidden_states, temb)
687
+
688
+ output_states += (hidden_states,)
689
+
690
+ if self.downsamplers is not None:
691
+ for downsampler in self.downsamplers:
692
+ hidden_states = downsampler(hidden_states)
693
+
694
+ output_states += (hidden_states,)
695
+
696
+ return hidden_states, output_states
697
+
698
+
699
+ class DownEncoderBlock2D(nn.Module):
700
+ def __init__(
701
+ self,
702
+ in_channels: int,
703
+ out_channels: int,
704
+ dropout: float = 0.0,
705
+ num_layers: int = 1,
706
+ resnet_eps: float = 1e-6,
707
+ resnet_time_scale_shift: str = "default",
708
+ resnet_act_fn: str = "swish",
709
+ resnet_groups: int = 32,
710
+ resnet_pre_norm: bool = True,
711
+ output_scale_factor=1.0,
712
+ add_downsample=True,
713
+ downsample_padding=1,
714
+ ):
715
+ super().__init__()
716
+ resnets = []
717
+
718
+ for i in range(num_layers):
719
+ in_channels = in_channels if i == 0 else out_channels
720
+ resnets.append(
721
+ ResnetBlock2D(
722
+ in_channels=in_channels,
723
+ out_channels=out_channels,
724
+ temb_channels=None,
725
+ eps=resnet_eps,
726
+ groups=resnet_groups,
727
+ dropout=dropout,
728
+ time_embedding_norm=resnet_time_scale_shift,
729
+ non_linearity=resnet_act_fn,
730
+ output_scale_factor=output_scale_factor,
731
+ pre_norm=resnet_pre_norm,
732
+ )
733
+ )
734
+
735
+ self.resnets = nn.ModuleList(resnets)
736
+
737
+ if add_downsample:
738
+ self.downsamplers = nn.ModuleList(
739
+ [
740
+ Downsample2D(
741
+ in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
742
+ )
743
+ ]
744
+ )
745
+ else:
746
+ self.downsamplers = None
747
+
748
+ def forward(self, hidden_states):
749
+ for resnet in self.resnets:
750
+ hidden_states = resnet(hidden_states, temb=None)
751
+
752
+ if self.downsamplers is not None:
753
+ for downsampler in self.downsamplers:
754
+ hidden_states = downsampler(hidden_states)
755
+
756
+ return hidden_states
757
+
758
+
759
+ class AttnDownEncoderBlock2D(nn.Module):
760
+ def __init__(
761
+ self,
762
+ in_channels: int,
763
+ out_channels: int,
764
+ dropout: float = 0.0,
765
+ num_layers: int = 1,
766
+ resnet_eps: float = 1e-6,
767
+ resnet_time_scale_shift: str = "default",
768
+ resnet_act_fn: str = "swish",
769
+ resnet_groups: int = 32,
770
+ resnet_pre_norm: bool = True,
771
+ attn_num_head_channels=1,
772
+ output_scale_factor=1.0,
773
+ add_downsample=True,
774
+ downsample_padding=1,
775
+ ):
776
+ super().__init__()
777
+ resnets = []
778
+ attentions = []
779
+
780
+ for i in range(num_layers):
781
+ in_channels = in_channels if i == 0 else out_channels
782
+ resnets.append(
783
+ ResnetBlock2D(
784
+ in_channels=in_channels,
785
+ out_channels=out_channels,
786
+ temb_channels=None,
787
+ eps=resnet_eps,
788
+ groups=resnet_groups,
789
+ dropout=dropout,
790
+ time_embedding_norm=resnet_time_scale_shift,
791
+ non_linearity=resnet_act_fn,
792
+ output_scale_factor=output_scale_factor,
793
+ pre_norm=resnet_pre_norm,
794
+ )
795
+ )
796
+ attentions.append(
797
+ AttentionBlock(
798
+ out_channels,
799
+ num_head_channels=attn_num_head_channels,
800
+ rescale_output_factor=output_scale_factor,
801
+ eps=resnet_eps,
802
+ norm_num_groups=resnet_groups,
803
+ )
804
+ )
805
+
806
+ self.attentions = nn.ModuleList(attentions)
807
+ self.resnets = nn.ModuleList(resnets)
808
+
809
+ if add_downsample:
810
+ self.downsamplers = nn.ModuleList(
811
+ [
812
+ Downsample2D(
813
+ in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
814
+ )
815
+ ]
816
+ )
817
+ else:
818
+ self.downsamplers = None
819
+
820
+ def forward(self, hidden_states):
821
+ for resnet, attn in zip(self.resnets, self.attentions):
822
+ hidden_states = resnet(hidden_states, temb=None)
823
+ hidden_states = attn(hidden_states)
824
+
825
+ if self.downsamplers is not None:
826
+ for downsampler in self.downsamplers:
827
+ hidden_states = downsampler(hidden_states)
828
+
829
+ return hidden_states
830
+
831
+
832
+ class AttnSkipDownBlock2D(nn.Module):
833
+ def __init__(
834
+ self,
835
+ in_channels: int,
836
+ out_channels: int,
837
+ temb_channels: int,
838
+ dropout: float = 0.0,
839
+ num_layers: int = 1,
840
+ resnet_eps: float = 1e-6,
841
+ resnet_time_scale_shift: str = "default",
842
+ resnet_act_fn: str = "swish",
843
+ resnet_pre_norm: bool = True,
844
+ attn_num_head_channels=1,
845
+ attention_type="default",
846
+ output_scale_factor=np.sqrt(2.0),
847
+ downsample_padding=1,
848
+ add_downsample=True,
849
+ ):
850
+ super().__init__()
851
+ self.attentions = nn.ModuleList([])
852
+ self.resnets = nn.ModuleList([])
853
+
854
+ self.attention_type = attention_type
855
+
856
+ for i in range(num_layers):
857
+ in_channels = in_channels if i == 0 else out_channels
858
+ self.resnets.append(
859
+ ResnetBlock2D(
860
+ in_channels=in_channels,
861
+ out_channels=out_channels,
862
+ temb_channels=temb_channels,
863
+ eps=resnet_eps,
864
+ groups=min(in_channels // 4, 32),
865
+ groups_out=min(out_channels // 4, 32),
866
+ dropout=dropout,
867
+ time_embedding_norm=resnet_time_scale_shift,
868
+ non_linearity=resnet_act_fn,
869
+ output_scale_factor=output_scale_factor,
870
+ pre_norm=resnet_pre_norm,
871
+ )
872
+ )
873
+ self.attentions.append(
874
+ AttentionBlock(
875
+ out_channels,
876
+ num_head_channels=attn_num_head_channels,
877
+ rescale_output_factor=output_scale_factor,
878
+ eps=resnet_eps,
879
+ )
880
+ )
881
+
882
+ if add_downsample:
883
+ self.resnet_down = ResnetBlock2D(
884
+ in_channels=out_channels,
885
+ out_channels=out_channels,
886
+ temb_channels=temb_channels,
887
+ eps=resnet_eps,
888
+ groups=min(out_channels // 4, 32),
889
+ dropout=dropout,
890
+ time_embedding_norm=resnet_time_scale_shift,
891
+ non_linearity=resnet_act_fn,
892
+ output_scale_factor=output_scale_factor,
893
+ pre_norm=resnet_pre_norm,
894
+ use_in_shortcut=True,
895
+ down=True,
896
+ kernel="fir",
897
+ )
898
+ self.downsamplers = nn.ModuleList([FirDownsample2D(in_channels, out_channels=out_channels)])
899
+ self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1))
900
+ else:
901
+ self.resnet_down = None
902
+ self.downsamplers = None
903
+ self.skip_conv = None
904
+
905
+ def forward(self, hidden_states, temb=None, skip_sample=None):
906
+ output_states = ()
907
+
908
+ for resnet, attn in zip(self.resnets, self.attentions):
909
+ hidden_states = resnet(hidden_states, temb)
910
+ hidden_states = attn(hidden_states)
911
+ output_states += (hidden_states,)
912
+
913
+ if self.downsamplers is not None:
914
+ hidden_states = self.resnet_down(hidden_states, temb)
915
+ for downsampler in self.downsamplers:
916
+ skip_sample = downsampler(skip_sample)
917
+
918
+ hidden_states = self.skip_conv(skip_sample) + hidden_states
919
+
920
+ output_states += (hidden_states,)
921
+
922
+ return hidden_states, output_states, skip_sample
923
+
924
+
925
+ class SkipDownBlock2D(nn.Module):
926
+ def __init__(
927
+ self,
928
+ in_channels: int,
929
+ out_channels: int,
930
+ temb_channels: int,
931
+ dropout: float = 0.0,
932
+ num_layers: int = 1,
933
+ resnet_eps: float = 1e-6,
934
+ resnet_time_scale_shift: str = "default",
935
+ resnet_act_fn: str = "swish",
936
+ resnet_pre_norm: bool = True,
937
+ output_scale_factor=np.sqrt(2.0),
938
+ add_downsample=True,
939
+ downsample_padding=1,
940
+ ):
941
+ super().__init__()
942
+ self.resnets = nn.ModuleList([])
943
+
944
+ for i in range(num_layers):
945
+ in_channels = in_channels if i == 0 else out_channels
946
+ self.resnets.append(
947
+ ResnetBlock2D(
948
+ in_channels=in_channels,
949
+ out_channels=out_channels,
950
+ temb_channels=temb_channels,
951
+ eps=resnet_eps,
952
+ groups=min(in_channels // 4, 32),
953
+ groups_out=min(out_channels // 4, 32),
954
+ dropout=dropout,
955
+ time_embedding_norm=resnet_time_scale_shift,
956
+ non_linearity=resnet_act_fn,
957
+ output_scale_factor=output_scale_factor,
958
+ pre_norm=resnet_pre_norm,
959
+ )
960
+ )
961
+
962
+ if add_downsample:
963
+ self.resnet_down = ResnetBlock2D(
964
+ in_channels=out_channels,
965
+ out_channels=out_channels,
966
+ temb_channels=temb_channels,
967
+ eps=resnet_eps,
968
+ groups=min(out_channels // 4, 32),
969
+ dropout=dropout,
970
+ time_embedding_norm=resnet_time_scale_shift,
971
+ non_linearity=resnet_act_fn,
972
+ output_scale_factor=output_scale_factor,
973
+ pre_norm=resnet_pre_norm,
974
+ use_in_shortcut=True,
975
+ down=True,
976
+ kernel="fir",
977
+ )
978
+ self.downsamplers = nn.ModuleList([FirDownsample2D(in_channels, out_channels=out_channels)])
979
+ self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1))
980
+ else:
981
+ self.resnet_down = None
982
+ self.downsamplers = None
983
+ self.skip_conv = None
984
+
985
+ def forward(self, hidden_states, temb=None, skip_sample=None):
986
+ output_states = ()
987
+
988
+ for resnet in self.resnets:
989
+ hidden_states = resnet(hidden_states, temb)
990
+ output_states += (hidden_states,)
991
+
992
+ if self.downsamplers is not None:
993
+ hidden_states = self.resnet_down(hidden_states, temb)
994
+ for downsampler in self.downsamplers:
995
+ skip_sample = downsampler(skip_sample)
996
+
997
+ hidden_states = self.skip_conv(skip_sample) + hidden_states
998
+
999
+ output_states += (hidden_states,)
1000
+
1001
+ return hidden_states, output_states, skip_sample
1002
+
1003
+
1004
+ class AttnUpBlock2D(nn.Module):
1005
+ def __init__(
1006
+ self,
1007
+ in_channels: int,
1008
+ prev_output_channel: int,
1009
+ out_channels: int,
1010
+ temb_channels: int,
1011
+ dropout: float = 0.0,
1012
+ num_layers: int = 1,
1013
+ resnet_eps: float = 1e-6,
1014
+ resnet_time_scale_shift: str = "default",
1015
+ resnet_act_fn: str = "swish",
1016
+ resnet_groups: int = 32,
1017
+ resnet_pre_norm: bool = True,
1018
+ attention_type="default",
1019
+ attn_num_head_channels=1,
1020
+ output_scale_factor=1.0,
1021
+ add_upsample=True,
1022
+ ):
1023
+ super().__init__()
1024
+ resnets = []
1025
+ attentions = []
1026
+
1027
+ self.attention_type = attention_type
1028
+
1029
+ for i in range(num_layers):
1030
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
1031
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
1032
+
1033
+ resnets.append(
1034
+ ResnetBlock2D(
1035
+ in_channels=resnet_in_channels + res_skip_channels,
1036
+ out_channels=out_channels,
1037
+ temb_channels=temb_channels,
1038
+ eps=resnet_eps,
1039
+ groups=resnet_groups,
1040
+ dropout=dropout,
1041
+ time_embedding_norm=resnet_time_scale_shift,
1042
+ non_linearity=resnet_act_fn,
1043
+ output_scale_factor=output_scale_factor,
1044
+ pre_norm=resnet_pre_norm,
1045
+ )
1046
+ )
1047
+ attentions.append(
1048
+ AttentionBlock(
1049
+ out_channels,
1050
+ num_head_channels=attn_num_head_channels,
1051
+ rescale_output_factor=output_scale_factor,
1052
+ eps=resnet_eps,
1053
+ norm_num_groups=resnet_groups,
1054
+ )
1055
+ )
1056
+
1057
+ self.attentions = nn.ModuleList(attentions)
1058
+ self.resnets = nn.ModuleList(resnets)
1059
+
1060
+ if add_upsample:
1061
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
1062
+ else:
1063
+ self.upsamplers = None
1064
+
1065
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
1066
+ for resnet, attn in zip(self.resnets, self.attentions):
1067
+ # pop res hidden states
1068
+ res_hidden_states = res_hidden_states_tuple[-1]
1069
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1070
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1071
+
1072
+ hidden_states = resnet(hidden_states, temb)
1073
+ hidden_states = attn(hidden_states)
1074
+
1075
+ if self.upsamplers is not None:
1076
+ for upsampler in self.upsamplers:
1077
+ hidden_states = upsampler(hidden_states)
1078
+
1079
+ return hidden_states
1080
+
1081
+
1082
+ class CrossAttnUpBlock2D(nn.Module):
1083
+ def __init__(
1084
+ self,
1085
+ in_channels: int,
1086
+ out_channels: int,
1087
+ prev_output_channel: int,
1088
+ temb_channels: int,
1089
+ dropout: float = 0.0,
1090
+ num_layers: int = 1,
1091
+ resnet_eps: float = 1e-6,
1092
+ resnet_time_scale_shift: str = "default",
1093
+ resnet_act_fn: str = "swish",
1094
+ resnet_groups: int = 32,
1095
+ resnet_pre_norm: bool = True,
1096
+ attn_num_head_channels=1,
1097
+ cross_attention_dim=1280,
1098
+ attention_type="default",
1099
+ output_scale_factor=1.0,
1100
+ add_upsample=True,
1101
+ ):
1102
+ super().__init__()
1103
+ resnets = []
1104
+ attentions = []
1105
+
1106
+ self.attention_type = attention_type
1107
+ self.attn_num_head_channels = attn_num_head_channels
1108
+
1109
+ for i in range(num_layers):
1110
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
1111
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
1112
+
1113
+ resnets.append(
1114
+ ResnetBlock2D(
1115
+ in_channels=resnet_in_channels + res_skip_channels,
1116
+ out_channels=out_channels,
1117
+ temb_channels=temb_channels,
1118
+ eps=resnet_eps,
1119
+ groups=resnet_groups,
1120
+ dropout=dropout,
1121
+ time_embedding_norm=resnet_time_scale_shift,
1122
+ non_linearity=resnet_act_fn,
1123
+ output_scale_factor=output_scale_factor,
1124
+ pre_norm=resnet_pre_norm,
1125
+ )
1126
+ )
1127
+ attentions.append(
1128
+ Transformer2DModel(
1129
+ attn_num_head_channels,
1130
+ out_channels // attn_num_head_channels,
1131
+ in_channels=out_channels,
1132
+ num_layers=1,
1133
+ cross_attention_dim=cross_attention_dim,
1134
+ norm_num_groups=resnet_groups,
1135
+ )
1136
+ )
1137
+ self.attentions = nn.ModuleList(attentions)
1138
+ self.resnets = nn.ModuleList(resnets)
1139
+
1140
+ if add_upsample:
1141
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
1142
+ else:
1143
+ self.upsamplers = None
1144
+
1145
+ self.gradient_checkpointing = False
1146
+
1147
+ def set_attention_slice(self, slice_size):
1148
+ if slice_size is not None and self.attn_num_head_channels % slice_size != 0:
1149
+ raise ValueError(
1150
+ f"Make sure slice_size {slice_size} is a divisor of "
1151
+ f"the number of heads used in cross_attention {self.attn_num_head_channels}"
1152
+ )
1153
+ if slice_size is not None and slice_size > self.attn_num_head_channels:
1154
+ raise ValueError(
1155
+ f"Chunk_size {slice_size} has to be smaller or equal to "
1156
+ f"the number of heads used in cross_attention {self.attn_num_head_channels}"
1157
+ )
1158
+
1159
+ for attn in self.attentions:
1160
+ attn._set_attention_slice(slice_size)
1161
+
1162
+ self.gradient_checkpointing = False
1163
+
1164
+ def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
1165
+ for attn in self.attentions:
1166
+ attn._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
1167
+
1168
+ def forward(
1169
+ self,
1170
+ hidden_states,
1171
+ res_hidden_states_tuple,
1172
+ temb=None,
1173
+ encoder_hidden_states=None,
1174
+ upsample_size=None,
1175
+ attn_map=None,
1176
+ attn_shift=False,
1177
+ obj_ids=None,
1178
+ relationship=None
1179
+ ):
1180
+ cross_attn_prob_list = list()
1181
+ for layer_idx, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)):
1182
+ # pop res hidden states
1183
+ res_hidden_states = res_hidden_states_tuple[-1]
1184
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1185
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1186
+
1187
+ if self.training and self.gradient_checkpointing:
1188
+
1189
+ def create_custom_forward(module, return_dict=None):
1190
+ def custom_forward(*inputs):
1191
+ if return_dict is not None:
1192
+ return module(*inputs, return_dict=return_dict)
1193
+ else:
1194
+ return module(*inputs)
1195
+
1196
+ return custom_forward
1197
+
1198
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
1199
+ hidden_states = torch.utils.checkpoint.checkpoint(
1200
+ create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states
1201
+ )[0]
1202
+ else:
1203
+ hidden_states = resnet(hidden_states, temb)
1204
+ tmp_hidden_states, cross_attn_prob, save_value = attn(hidden_states, encoder_hidden_states=encoder_hidden_states, attn_map=attn_map[layer_idx] if attn_map is not None else None, attn_shift=attn_shift, obj_ids=obj_ids, relationship=relationship)
1205
+ hidden_states = tmp_hidden_states.sample
1206
+ cross_attn_prob_list.append(cross_attn_prob)
1207
+ if self.upsamplers is not None:
1208
+ for upsampler in self.upsamplers:
1209
+ hidden_states = upsampler(hidden_states, upsample_size)
1210
+
1211
+ return hidden_states, cross_attn_prob_list, save_value
1212
+
1213
+
1214
+ class UpBlock2D(nn.Module):
1215
+ def __init__(
1216
+ self,
1217
+ in_channels: int,
1218
+ prev_output_channel: int,
1219
+ out_channels: int,
1220
+ temb_channels: int,
1221
+ dropout: float = 0.0,
1222
+ num_layers: int = 1,
1223
+ resnet_eps: float = 1e-6,
1224
+ resnet_time_scale_shift: str = "default",
1225
+ resnet_act_fn: str = "swish",
1226
+ resnet_groups: int = 32,
1227
+ resnet_pre_norm: bool = True,
1228
+ output_scale_factor=1.0,
1229
+ add_upsample=True,
1230
+ ):
1231
+ super().__init__()
1232
+ resnets = []
1233
+
1234
+ for i in range(num_layers):
1235
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
1236
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
1237
+
1238
+ resnets.append(
1239
+ ResnetBlock2D(
1240
+ in_channels=resnet_in_channels + res_skip_channels,
1241
+ out_channels=out_channels,
1242
+ temb_channels=temb_channels,
1243
+ eps=resnet_eps,
1244
+ groups=resnet_groups,
1245
+ dropout=dropout,
1246
+ time_embedding_norm=resnet_time_scale_shift,
1247
+ non_linearity=resnet_act_fn,
1248
+ output_scale_factor=output_scale_factor,
1249
+ pre_norm=resnet_pre_norm,
1250
+ )
1251
+ )
1252
+
1253
+ self.resnets = nn.ModuleList(resnets)
1254
+
1255
+ if add_upsample:
1256
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
1257
+ else:
1258
+ self.upsamplers = None
1259
+
1260
+ self.gradient_checkpointing = False
1261
+
1262
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
1263
+ for resnet in self.resnets:
1264
+ # pop res hidden states
1265
+ res_hidden_states = res_hidden_states_tuple[-1]
1266
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1267
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1268
+
1269
+ if self.training and self.gradient_checkpointing:
1270
+
1271
+ def create_custom_forward(module):
1272
+ def custom_forward(*inputs):
1273
+ return module(*inputs)
1274
+
1275
+ return custom_forward
1276
+
1277
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
1278
+ else:
1279
+ hidden_states = resnet(hidden_states, temb)
1280
+
1281
+ if self.upsamplers is not None:
1282
+ for upsampler in self.upsamplers:
1283
+ hidden_states = upsampler(hidden_states, upsample_size)
1284
+
1285
+ return hidden_states
1286
+
1287
+
1288
+ class UpDecoderBlock2D(nn.Module):
1289
+ def __init__(
1290
+ self,
1291
+ in_channels: int,
1292
+ out_channels: int,
1293
+ dropout: float = 0.0,
1294
+ num_layers: int = 1,
1295
+ resnet_eps: float = 1e-6,
1296
+ resnet_time_scale_shift: str = "default",
1297
+ resnet_act_fn: str = "swish",
1298
+ resnet_groups: int = 32,
1299
+ resnet_pre_norm: bool = True,
1300
+ output_scale_factor=1.0,
1301
+ add_upsample=True,
1302
+ ):
1303
+ super().__init__()
1304
+ resnets = []
1305
+
1306
+ for i in range(num_layers):
1307
+ input_channels = in_channels if i == 0 else out_channels
1308
+
1309
+ resnets.append(
1310
+ ResnetBlock2D(
1311
+ in_channels=input_channels,
1312
+ out_channels=out_channels,
1313
+ temb_channels=None,
1314
+ eps=resnet_eps,
1315
+ groups=resnet_groups,
1316
+ dropout=dropout,
1317
+ time_embedding_norm=resnet_time_scale_shift,
1318
+ non_linearity=resnet_act_fn,
1319
+ output_scale_factor=output_scale_factor,
1320
+ pre_norm=resnet_pre_norm,
1321
+ )
1322
+ )
1323
+
1324
+ self.resnets = nn.ModuleList(resnets)
1325
+
1326
+ if add_upsample:
1327
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
1328
+ else:
1329
+ self.upsamplers = None
1330
+
1331
+ def forward(self, hidden_states):
1332
+ for resnet in self.resnets:
1333
+ hidden_states = resnet(hidden_states, temb=None)
1334
+
1335
+ if self.upsamplers is not None:
1336
+ for upsampler in self.upsamplers:
1337
+ hidden_states = upsampler(hidden_states)
1338
+
1339
+ return hidden_states
1340
+
1341
+
1342
+ class AttnUpDecoderBlock2D(nn.Module):
1343
+ def __init__(
1344
+ self,
1345
+ in_channels: int,
1346
+ out_channels: int,
1347
+ dropout: float = 0.0,
1348
+ num_layers: int = 1,
1349
+ resnet_eps: float = 1e-6,
1350
+ resnet_time_scale_shift: str = "default",
1351
+ resnet_act_fn: str = "swish",
1352
+ resnet_groups: int = 32,
1353
+ resnet_pre_norm: bool = True,
1354
+ attn_num_head_channels=1,
1355
+ output_scale_factor=1.0,
1356
+ add_upsample=True,
1357
+ ):
1358
+ super().__init__()
1359
+ resnets = []
1360
+ attentions = []
1361
+
1362
+ for i in range(num_layers):
1363
+ input_channels = in_channels if i == 0 else out_channels
1364
+
1365
+ resnets.append(
1366
+ ResnetBlock2D(
1367
+ in_channels=input_channels,
1368
+ out_channels=out_channels,
1369
+ temb_channels=None,
1370
+ eps=resnet_eps,
1371
+ groups=resnet_groups,
1372
+ dropout=dropout,
1373
+ time_embedding_norm=resnet_time_scale_shift,
1374
+ non_linearity=resnet_act_fn,
1375
+ output_scale_factor=output_scale_factor,
1376
+ pre_norm=resnet_pre_norm,
1377
+ )
1378
+ )
1379
+ attentions.append(
1380
+ AttentionBlock(
1381
+ out_channels,
1382
+ num_head_channels=attn_num_head_channels,
1383
+ rescale_output_factor=output_scale_factor,
1384
+ eps=resnet_eps,
1385
+ norm_num_groups=resnet_groups,
1386
+ )
1387
+ )
1388
+
1389
+ self.attentions = nn.ModuleList(attentions)
1390
+ self.resnets = nn.ModuleList(resnets)
1391
+
1392
+ if add_upsample:
1393
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
1394
+ else:
1395
+ self.upsamplers = None
1396
+
1397
+ def forward(self, hidden_states):
1398
+ for resnet, attn in zip(self.resnets, self.attentions):
1399
+ hidden_states = resnet(hidden_states, temb=None)
1400
+ hidden_states = attn(hidden_states)
1401
+
1402
+ if self.upsamplers is not None:
1403
+ for upsampler in self.upsamplers:
1404
+ hidden_states = upsampler(hidden_states)
1405
+
1406
+ return hidden_states
1407
+
1408
+
1409
+ class AttnSkipUpBlock2D(nn.Module):
1410
+ def __init__(
1411
+ self,
1412
+ in_channels: int,
1413
+ prev_output_channel: int,
1414
+ out_channels: int,
1415
+ temb_channels: int,
1416
+ dropout: float = 0.0,
1417
+ num_layers: int = 1,
1418
+ resnet_eps: float = 1e-6,
1419
+ resnet_time_scale_shift: str = "default",
1420
+ resnet_act_fn: str = "swish",
1421
+ resnet_pre_norm: bool = True,
1422
+ attn_num_head_channels=1,
1423
+ attention_type="default",
1424
+ output_scale_factor=np.sqrt(2.0),
1425
+ upsample_padding=1,
1426
+ add_upsample=True,
1427
+ ):
1428
+ super().__init__()
1429
+ self.attentions = nn.ModuleList([])
1430
+ self.resnets = nn.ModuleList([])
1431
+
1432
+ self.attention_type = attention_type
1433
+
1434
+ for i in range(num_layers):
1435
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
1436
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
1437
+
1438
+ self.resnets.append(
1439
+ ResnetBlock2D(
1440
+ in_channels=resnet_in_channels + res_skip_channels,
1441
+ out_channels=out_channels,
1442
+ temb_channels=temb_channels,
1443
+ eps=resnet_eps,
1444
+ groups=min(resnet_in_channels + res_skip_channels // 4, 32),
1445
+ groups_out=min(out_channels // 4, 32),
1446
+ dropout=dropout,
1447
+ time_embedding_norm=resnet_time_scale_shift,
1448
+ non_linearity=resnet_act_fn,
1449
+ output_scale_factor=output_scale_factor,
1450
+ pre_norm=resnet_pre_norm,
1451
+ )
1452
+ )
1453
+
1454
+ self.attentions.append(
1455
+ AttentionBlock(
1456
+ out_channels,
1457
+ num_head_channels=attn_num_head_channels,
1458
+ rescale_output_factor=output_scale_factor,
1459
+ eps=resnet_eps,
1460
+ )
1461
+ )
1462
+
1463
+ self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels)
1464
+ if add_upsample:
1465
+ self.resnet_up = ResnetBlock2D(
1466
+ in_channels=out_channels,
1467
+ out_channels=out_channels,
1468
+ temb_channels=temb_channels,
1469
+ eps=resnet_eps,
1470
+ groups=min(out_channels // 4, 32),
1471
+ groups_out=min(out_channels // 4, 32),
1472
+ dropout=dropout,
1473
+ time_embedding_norm=resnet_time_scale_shift,
1474
+ non_linearity=resnet_act_fn,
1475
+ output_scale_factor=output_scale_factor,
1476
+ pre_norm=resnet_pre_norm,
1477
+ use_in_shortcut=True,
1478
+ up=True,
1479
+ kernel="fir",
1480
+ )
1481
+ self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
1482
+ self.skip_norm = torch.nn.GroupNorm(
1483
+ num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True
1484
+ )
1485
+ self.act = nn.SiLU()
1486
+ else:
1487
+ self.resnet_up = None
1488
+ self.skip_conv = None
1489
+ self.skip_norm = None
1490
+ self.act = None
1491
+
1492
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None):
1493
+ for resnet in self.resnets:
1494
+ # pop res hidden states
1495
+ res_hidden_states = res_hidden_states_tuple[-1]
1496
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1497
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1498
+
1499
+ hidden_states = resnet(hidden_states, temb)
1500
+
1501
+ hidden_states = self.attentions[0](hidden_states)
1502
+
1503
+ if skip_sample is not None:
1504
+ skip_sample = self.upsampler(skip_sample)
1505
+ else:
1506
+ skip_sample = 0
1507
+
1508
+ if self.resnet_up is not None:
1509
+ skip_sample_states = self.skip_norm(hidden_states)
1510
+ skip_sample_states = self.act(skip_sample_states)
1511
+ skip_sample_states = self.skip_conv(skip_sample_states)
1512
+
1513
+ skip_sample = skip_sample + skip_sample_states
1514
+
1515
+ hidden_states = self.resnet_up(hidden_states, temb)
1516
+
1517
+ return hidden_states, skip_sample
1518
+
1519
+
1520
+ class SkipUpBlock2D(nn.Module):
1521
+ def __init__(
1522
+ self,
1523
+ in_channels: int,
1524
+ prev_output_channel: int,
1525
+ out_channels: int,
1526
+ temb_channels: int,
1527
+ dropout: float = 0.0,
1528
+ num_layers: int = 1,
1529
+ resnet_eps: float = 1e-6,
1530
+ resnet_time_scale_shift: str = "default",
1531
+ resnet_act_fn: str = "swish",
1532
+ resnet_pre_norm: bool = True,
1533
+ output_scale_factor=np.sqrt(2.0),
1534
+ add_upsample=True,
1535
+ upsample_padding=1,
1536
+ ):
1537
+ super().__init__()
1538
+ self.resnets = nn.ModuleList([])
1539
+
1540
+ for i in range(num_layers):
1541
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
1542
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
1543
+
1544
+ self.resnets.append(
1545
+ ResnetBlock2D(
1546
+ in_channels=resnet_in_channels + res_skip_channels,
1547
+ out_channels=out_channels,
1548
+ temb_channels=temb_channels,
1549
+ eps=resnet_eps,
1550
+ groups=min((resnet_in_channels + res_skip_channels) // 4, 32),
1551
+ groups_out=min(out_channels // 4, 32),
1552
+ dropout=dropout,
1553
+ time_embedding_norm=resnet_time_scale_shift,
1554
+ non_linearity=resnet_act_fn,
1555
+ output_scale_factor=output_scale_factor,
1556
+ pre_norm=resnet_pre_norm,
1557
+ )
1558
+ )
1559
+
1560
+ self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels)
1561
+ if add_upsample:
1562
+ self.resnet_up = ResnetBlock2D(
1563
+ in_channels=out_channels,
1564
+ out_channels=out_channels,
1565
+ temb_channels=temb_channels,
1566
+ eps=resnet_eps,
1567
+ groups=min(out_channels // 4, 32),
1568
+ groups_out=min(out_channels // 4, 32),
1569
+ dropout=dropout,
1570
+ time_embedding_norm=resnet_time_scale_shift,
1571
+ non_linearity=resnet_act_fn,
1572
+ output_scale_factor=output_scale_factor,
1573
+ pre_norm=resnet_pre_norm,
1574
+ use_in_shortcut=True,
1575
+ up=True,
1576
+ kernel="fir",
1577
+ )
1578
+ self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
1579
+ self.skip_norm = torch.nn.GroupNorm(
1580
+ num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True
1581
+ )
1582
+ self.act = nn.SiLU()
1583
+ else:
1584
+ self.resnet_up = None
1585
+ self.skip_conv = None
1586
+ self.skip_norm = None
1587
+ self.act = None
1588
+
1589
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None):
1590
+ for resnet in self.resnets:
1591
+ # pop res hidden states
1592
+ res_hidden_states = res_hidden_states_tuple[-1]
1593
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1594
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1595
+
1596
+ hidden_states = resnet(hidden_states, temb)
1597
+
1598
+ if skip_sample is not None:
1599
+ skip_sample = self.upsampler(skip_sample)
1600
+ else:
1601
+ skip_sample = 0
1602
+
1603
+ if self.resnet_up is not None:
1604
+ skip_sample_states = self.skip_norm(hidden_states)
1605
+ skip_sample_states = self.act(skip_sample_states)
1606
+ skip_sample_states = self.skip_conv(skip_sample_states)
1607
+
1608
+ skip_sample = skip_sample + skip_sample_states
1609
+
1610
+ hidden_states = self.resnet_up(hidden_states, temb)
1611
+
1612
+ return hidden_states, skip_sample
my_model/unet_2d_condition.py ADDED
@@ -0,0 +1,389 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import pdb
15
+ from dataclasses import dataclass
16
+ from typing import Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.utils.checkpoint
21
+
22
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
23
+ from diffusers.modeling_utils import ModelMixin
24
+ from diffusers.utils import BaseOutput, logging
25
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
26
+ from .unet_2d_blocks import (
27
+ CrossAttnDownBlock2D,
28
+ CrossAttnUpBlock2D,
29
+ DownBlock2D,
30
+ UNetMidBlock2DCrossAttn,
31
+ UpBlock2D,
32
+ get_down_block,
33
+ get_up_block,
34
+ )
35
+
36
+
37
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
38
+
39
+
40
+ @dataclass
41
+ class UNet2DConditionOutput(BaseOutput):
42
+ """
43
+ Args:
44
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
45
+ Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model.
46
+ """
47
+
48
+ sample: torch.FloatTensor
49
+
50
+
51
+ class UNet2DConditionModel(ModelMixin, ConfigMixin):
52
+ r"""
53
+ UNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep
54
+ and returns sample shaped output.
55
+
56
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
57
+ implements for all the models (such as downloading or saving, etc.)
58
+
59
+ Parameters:
60
+ sample_size (`int`, *optional*): The size of the input sample.
61
+ in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
62
+ out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
63
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
64
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
65
+ Whether to flip the sin to cos in the time embedding.
66
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
67
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
68
+ The tuple of downsample blocks to use.
69
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`):
70
+ The tuple of upsample blocks to use.
71
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
72
+ The tuple of output channels for each block.
73
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
74
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
75
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
76
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
77
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
78
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
79
+ cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features.
80
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
81
+ """
82
+
83
+ _supports_gradient_checkpointing = True
84
+
85
+ @register_to_config
86
+ def __init__(
87
+ self,
88
+ sample_size: Optional[int] = None,
89
+ in_channels: int = 4,
90
+ out_channels: int = 4,
91
+ center_input_sample: bool = False,
92
+ flip_sin_to_cos: bool = True,
93
+ freq_shift: int = 0,
94
+ down_block_types: Tuple[str] = (
95
+ "CrossAttnDownBlock2D",
96
+ "CrossAttnDownBlock2D",
97
+ "CrossAttnDownBlock2D",
98
+ "DownBlock2D",
99
+ ),
100
+ up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
101
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
102
+ layers_per_block: int = 2,
103
+ downsample_padding: int = 1,
104
+ mid_block_scale_factor: float = 1,
105
+ act_fn: str = "silu",
106
+ norm_num_groups: int = 32,
107
+ norm_eps: float = 1e-5,
108
+ cross_attention_dim: int = 1280,
109
+ attention_head_dim: int = 8,
110
+ ):
111
+ super().__init__()
112
+
113
+ self.sample_size = sample_size
114
+ time_embed_dim = block_out_channels[0] * 4
115
+
116
+ # input
117
+ self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
118
+
119
+ # time
120
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
121
+ timestep_input_dim = block_out_channels[0]
122
+
123
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
124
+
125
+ self.down_blocks = nn.ModuleList([])
126
+ self.mid_block = None
127
+ self.up_blocks = nn.ModuleList([])
128
+
129
+ # down
130
+ output_channel = block_out_channels[0]
131
+ for i, down_block_type in enumerate(down_block_types):
132
+ input_channel = output_channel
133
+ output_channel = block_out_channels[i]
134
+ is_final_block = i == len(block_out_channels) - 1
135
+
136
+ down_block = get_down_block(
137
+ down_block_type,
138
+ num_layers=layers_per_block,
139
+ in_channels=input_channel,
140
+ out_channels=output_channel,
141
+ temb_channels=time_embed_dim,
142
+ add_downsample=not is_final_block,
143
+ resnet_eps=norm_eps,
144
+ resnet_act_fn=act_fn,
145
+ resnet_groups=norm_num_groups,
146
+ cross_attention_dim=cross_attention_dim,
147
+ attn_num_head_channels=attention_head_dim,
148
+ downsample_padding=downsample_padding,
149
+ )
150
+ self.down_blocks.append(down_block)
151
+
152
+ # mid
153
+ self.mid_block = UNetMidBlock2DCrossAttn(
154
+ in_channels=block_out_channels[-1],
155
+ temb_channels=time_embed_dim,
156
+ resnet_eps=norm_eps,
157
+ resnet_act_fn=act_fn,
158
+ output_scale_factor=mid_block_scale_factor,
159
+ resnet_time_scale_shift="default",
160
+ cross_attention_dim=cross_attention_dim,
161
+ attn_num_head_channels=attention_head_dim,
162
+ resnet_groups=norm_num_groups,
163
+ )
164
+
165
+ # count how many layers upsample the images
166
+ self.num_upsamplers = 0
167
+
168
+ # up
169
+ reversed_block_out_channels = list(reversed(block_out_channels))
170
+ output_channel = reversed_block_out_channels[0]
171
+ for i, up_block_type in enumerate(up_block_types):
172
+ is_final_block = i == len(block_out_channels) - 1
173
+
174
+ prev_output_channel = output_channel
175
+ output_channel = reversed_block_out_channels[i]
176
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
177
+
178
+ # add upsample block for all BUT final layer
179
+ if not is_final_block:
180
+ add_upsample = True
181
+ self.num_upsamplers += 1
182
+ else:
183
+ add_upsample = False
184
+
185
+ up_block = get_up_block(
186
+ up_block_type,
187
+ num_layers=layers_per_block + 1,
188
+ in_channels=input_channel,
189
+ out_channels=output_channel,
190
+ prev_output_channel=prev_output_channel,
191
+ temb_channels=time_embed_dim,
192
+ add_upsample=add_upsample,
193
+ resnet_eps=norm_eps,
194
+ resnet_act_fn=act_fn,
195
+ resnet_groups=norm_num_groups,
196
+ cross_attention_dim=cross_attention_dim,
197
+ attn_num_head_channels=attention_head_dim,
198
+ )
199
+ self.up_blocks.append(up_block)
200
+ prev_output_channel = output_channel
201
+
202
+ # out
203
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
204
+ self.conv_act = nn.SiLU()
205
+ self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
206
+
207
+ def set_attention_slice(self, slice_size):
208
+ if slice_size is not None and self.config.attention_head_dim % slice_size != 0:
209
+ raise ValueError(
210
+ f"Make sure slice_size {slice_size} is a divisor of "
211
+ f"the number of heads used in cross_attention {self.config.attention_head_dim}"
212
+ )
213
+ if slice_size is not None and slice_size > self.config.attention_head_dim:
214
+ raise ValueError(
215
+ f"Chunk_size {slice_size} has to be smaller or equal to "
216
+ f"the number of heads used in cross_attention {self.config.attention_head_dim}"
217
+ )
218
+
219
+ for block in self.down_blocks:
220
+ if hasattr(block, "attentions") and block.attentions is not None:
221
+ block.set_attention_slice(slice_size)
222
+
223
+ self.mid_block.set_attention_slice(slice_size)
224
+
225
+ for block in self.up_blocks:
226
+ if hasattr(block, "attentions") and block.attentions is not None:
227
+ block.set_attention_slice(slice_size)
228
+
229
+ def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
230
+ for block in self.down_blocks:
231
+ if hasattr(block, "attentions") and block.attentions is not None:
232
+ block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
233
+
234
+ self.mid_block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
235
+
236
+ for block in self.up_blocks:
237
+ if hasattr(block, "attentions") and block.attentions is not None:
238
+ block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
239
+
240
+ def _set_gradient_checkpointing(self, module, value=False):
241
+ if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)):
242
+ module.gradient_checkpointing = value
243
+
244
+ def forward(
245
+ self,
246
+ sample: torch.FloatTensor,
247
+ timestep: Union[torch.Tensor, float, int],
248
+ index,
249
+ encoder_hidden_states: torch.Tensor,
250
+ attn_map: Union[torch.Tensor],
251
+ cfg,
252
+ return_dict: bool = True,
253
+ ) -> Union[UNet2DConditionOutput, Tuple]:
254
+ r"""
255
+ Args:
256
+ sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs_coarse tensor
257
+ timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
258
+ encoder_hidden_states (`torch.FloatTensor`): (batch, channel, height, width) encoder hidden states
259
+ return_dict (`bool`, *optional*, defaults to `True`):
260
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
261
+
262
+ Returns:
263
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
264
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
265
+ returning a tuple, the first element is the sample tensor.
266
+ """
267
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
268
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
269
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
270
+ # on the fly if necessary.
271
+ device = sample.get_device() if sample.is_cuda else 'cpu'
272
+ default_overall_up_factor = 2**self.num_upsamplers
273
+
274
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
275
+ forward_upsample_size = False
276
+ upsample_size = None
277
+
278
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
279
+ logger.info("Forward upsample size to force interpolation output size.")
280
+ forward_upsample_size = True
281
+
282
+ # 0. center input if necessary
283
+ if self.config.center_input_sample:
284
+ sample = 2 * sample - 1.0
285
+
286
+ # 1. time
287
+ timesteps = timestep
288
+ if not torch.is_tensor(timesteps):
289
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
290
+ timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
291
+ elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
292
+ timesteps = timesteps[None].to(sample.device)
293
+
294
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
295
+ timesteps = timesteps.expand(sample.shape[0])
296
+
297
+ t_emb = self.time_proj(timesteps)
298
+
299
+ # timesteps does not contain any weights and will always return f32 tensors
300
+ # but time_embedding might actually be running in fp16. so we need to cast here.
301
+ # there might be better ways to encapsulate this.
302
+ t_emb = t_emb.to(dtype=self.dtype)
303
+ emb = self.time_embedding(t_emb)
304
+ # attn_map_uncond, attn_map = attn_map_integrated.chunk(2)
305
+ # 2. pre-process
306
+ sample = self.conv_in(sample)
307
+ # print('index', index)
308
+ # 3. down
309
+ attn_down = []
310
+ value_down = []
311
+ down_block_res_samples = (sample,)
312
+ # print(len(attn_map['attn_down']), len(attn_map['attn_down'][0]), len(attn_map['attn_down'][0][0]), attn_map['attn_down'][0][0][0].shape)
313
+ for block_idx, downsample_block in enumerate(self.down_blocks):
314
+ if hasattr(downsample_block, "attentions") and downsample_block.attentions is not None:
315
+ if block_idx < 5:
316
+ # pdb.set_trace()
317
+ sample, res_samples, cross_atten_prob, save_value = downsample_block(
318
+ hidden_states=sample,
319
+ temb=emb,
320
+ encoder_hidden_states=encoder_hidden_states,
321
+ attn_map=attn_map['attn_down'][index][block_idx] if index < cfg.training.down_attn_map else None,
322
+ attn_shift=True if index < cfg.training.down_attn_shift else False,
323
+ obj_ids=cfg.inference.obj_ids if 'obj_ids' in cfg.inference else None,
324
+ relationship=cfg.inference.relationship if 'relationship' in cfg.inference else None
325
+
326
+ )
327
+ else:
328
+ sample, res_samples, cross_atten_prob, save_value = downsample_block(
329
+ hidden_states=sample,
330
+ temb=emb,
331
+ encoder_hidden_states=encoder_hidden_states,
332
+ attn_map=None
333
+ )
334
+ attn_down.append(cross_atten_prob)
335
+ value_down.append(save_value)
336
+ else:
337
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
338
+
339
+ down_block_res_samples += res_samples
340
+
341
+ # 4. mid
342
+ sample, attn_mid, value_mid = self.mid_block(sample, index, emb, encoder_hidden_states=encoder_hidden_states, attn_map=attn_map['attn_mid'][index] if index < cfg.training.mid_attn_map else None,
343
+ attn_shift=True if index < cfg.training.mid_attn_shift else False, attn_map_step=cfg.training.mid_attn_map,
344
+ obj_ids=cfg.inference.obj_ids if 'obj_ids' in cfg.inference else None,
345
+ relationship=cfg.inference.relationship if 'relationship' in cfg.inference else None
346
+ )
347
+
348
+ # 5. up
349
+ attn_up = []
350
+ value_up = []
351
+ for i, upsample_block in enumerate(self.up_blocks):
352
+ is_final_block = i == len(self.up_blocks) - 1
353
+
354
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
355
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
356
+
357
+ # if we have not reached the final block and need to forward the
358
+ # upsample size, we do it here
359
+ if not is_final_block and forward_upsample_size:
360
+ upsample_size = down_block_res_samples[-1].shape[2:]
361
+
362
+ if hasattr(upsample_block, "attentions") and upsample_block.attentions is not None:
363
+ sample, cross_atten_prob, save_value = upsample_block(
364
+ hidden_states=sample,
365
+ temb=emb,
366
+ res_hidden_states_tuple=res_samples,
367
+ encoder_hidden_states=encoder_hidden_states,
368
+ upsample_size=upsample_size,
369
+ # attn_map=None,
370
+ attn_shift=True if index < cfg.training.up_attn_shift else False,
371
+ attn_map=attn_map['attn_up'][index][i-1] if index < cfg.training.up_attn_map else None,
372
+ obj_ids=cfg.inference.obj_ids if 'obj_ids' in cfg.inference else None,
373
+ relationship=cfg.inference.relationship if 'relationship' in cfg.inference else None
374
+ )
375
+ attn_up.append(cross_atten_prob)
376
+ value_mid.append(save_value)
377
+ else:
378
+ sample = upsample_block(
379
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
380
+ )
381
+ # 6. post-process
382
+ sample = self.conv_norm_out(sample)
383
+ sample = self.conv_act(sample)
384
+ sample = self.conv_out(sample)
385
+
386
+ if not return_dict:
387
+ return (sample,)
388
+
389
+ return UNet2DConditionOutput(sample=sample), attn_up, attn_mid, attn_down, value_up, value_mid, value_down
utils.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+ def compute_ca_loss(attn_maps_mid, attn_maps_up, bboxes, object_positions):
4
+ loss = 0
5
+ object_number = len(bboxes)
6
+ if object_number == 0:
7
+ return torch.tensor(0).float().cuda()
8
+ for attn_map_integrated in attn_maps_mid:
9
+ attn_map = attn_map_integrated.chunk(2)[1]
10
+
11
+ #
12
+ b, i, j = attn_map.shape
13
+ H = W = int(math.sqrt(i))
14
+ for obj_idx in range(object_number):
15
+ obj_loss = 0
16
+ mask = torch.zeros(size=(H, W)).cuda()
17
+ for obj_box in bboxes[obj_idx]:
18
+
19
+ x_min, y_min, x_max, y_max = int(obj_box[0] * W), \
20
+ int(obj_box[1] * H), int(obj_box[2] * W), int(obj_box[3] * H)
21
+ mask[y_min: y_max, x_min: x_max] = 1
22
+
23
+ for obj_position in object_positions[obj_idx]:
24
+ ca_map_obj = attn_map[:, :, obj_position].reshape(b, H, W)
25
+
26
+ activation_value = (ca_map_obj * mask).reshape(b, -1).sum(dim=-1)/ca_map_obj.reshape(b, -1).sum(dim=-1)
27
+
28
+ obj_loss += torch.mean((1 - activation_value) ** 2)
29
+ loss += (obj_loss/len(object_positions[obj_idx]))
30
+
31
+ # compute loss on padding tokens
32
+ # activation_value = torch.zeros(size=(b, )).cuda()
33
+ # for obj_idx in range(object_number):
34
+ # bbox = bboxes[obj_idx]
35
+ # ca_map_obj = attn_map[:, :, padding_start:].reshape(b, H, W, -1)
36
+ # activation_value += ca_map_obj[:, int(bbox[0] * H): int(bbox[1] * H),
37
+ # int(bbox[2] * W): int(bbox[3] * W), :].reshape(b, -1).sum(dim=-1) / ca_map_obj.reshape(b, -1).sum(dim=-1)
38
+ #
39
+ # loss += torch.mean((1 - activation_value) ** 2)
40
+
41
+
42
+ for attn_map_integrated in attn_maps_up[0]:
43
+ attn_map = attn_map_integrated.chunk(2)[1]
44
+ #
45
+ b, i, j = attn_map.shape
46
+ H = W = int(math.sqrt(i))
47
+
48
+ for obj_idx in range(object_number):
49
+ obj_loss = 0
50
+ mask = torch.zeros(size=(H, W)).cuda()
51
+ for obj_box in bboxes[obj_idx]:
52
+ x_min, y_min, x_max, y_max = int(obj_box[0] * W), \
53
+ int(obj_box[1] * H), int(obj_box[2] * W), int(obj_box[3] * H)
54
+ mask[y_min: y_max, x_min: x_max] = 1
55
+
56
+ for obj_position in object_positions[obj_idx]:
57
+ ca_map_obj = attn_map[:, :, obj_position].reshape(b, H, W)
58
+ # ca_map_obj = attn_map[:, :, object_positions[obj_position]].reshape(b, H, W)
59
+
60
+ activation_value = (ca_map_obj * mask).reshape(b, -1).sum(dim=-1) / ca_map_obj.reshape(b, -1).sum(
61
+ dim=-1)
62
+
63
+ obj_loss += torch.mean((1 - activation_value) ** 2)
64
+ loss += (obj_loss / len(object_positions[obj_idx]))
65
+
66
+ # compute loss on padding tokens
67
+ # activation_value = torch.zeros(size=(b, )).cuda()
68
+ # for obj_idx in range(object_number):
69
+ # bbox = bboxes[obj_idx]
70
+ # ca_map_obj = attn_map[:, :,padding_start:].reshape(b, H, W, -1)
71
+ # activation_value += ca_map_obj[:, int(bbox[0] * H): int(bbox[1] * H),
72
+ # int(bbox[2] * W): int(bbox[3] * W), :].reshape(b, -1).sum(dim=-1) / ca_map_obj.reshape(b, -1).sum(dim=-1)
73
+ #
74
+ # loss += torch.mean((1 - activation_value) ** 2)
75
+ loss = loss / (object_number * (len(attn_maps_up[0]) + len(attn_maps_mid)))
76
+ return loss