MrAlex commited on
Commit
63652fd
1 Parent(s): 27f9e25

prepare_latents modification

Browse files
Files changed (1) hide show
  1. pipeline.py +59 -16
pipeline.py CHANGED
@@ -625,47 +625,90 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi
625
 
626
  return timesteps, num_inference_steps - t_start
627
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
628
  def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
629
  if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
630
  raise ValueError(
631
  f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
632
  )
633
-
634
- image = image.to(device=device, dtype=dtype)
635
-
 
 
 
 
 
 
 
 
 
636
  batch_size = batch_size * num_images_per_prompt
637
  if isinstance(generator, list) and len(generator) != batch_size:
638
  raise ValueError(
639
  f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
640
  f" size of {batch_size}. Make sure the batch size matches the length of the generators."
641
  )
642
-
643
  if isinstance(generator, list):
644
  init_latents = [
645
- self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
646
  ]
647
  init_latents = torch.cat(init_latents, dim=0)
648
  else:
649
  init_latents = self.vae.encode(image).latent_dist.sample(generator)
650
-
651
  init_latents = self.vae.config.scaling_factor * init_latents
652
-
653
- if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
654
- raise ValueError(
655
- f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
656
- )
657
- else:
658
- init_latents = torch.cat([init_latents], dim=0)
659
-
660
  shape = init_latents.shape
661
  noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
662
-
663
  # get latents
664
  init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
665
  latents = init_latents
666
-
667
  return latents
668
 
 
669
  def _default_height_width(self, height, width, image):
670
  if isinstance(image, list):
671
  image = image[0]
 
625
 
626
  return timesteps, num_inference_steps - t_start
627
 
628
+ # def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
629
+ # if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
630
+ # raise ValueError(
631
+ # f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
632
+ # )
633
+
634
+ # image = image.to(device=device, dtype=dtype)
635
+
636
+ # batch_size = batch_size * num_images_per_prompt
637
+ # if isinstance(generator, list) and len(generator) != batch_size:
638
+ # raise ValueError(
639
+ # f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
640
+ # f" size of {batch_size}. Make sure the batch size matches the length of the generators."
641
+ # )
642
+
643
+ # if isinstance(generator, list):
644
+ # init_latents = [
645
+ # self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
646
+ # ]
647
+ # init_latents = torch.cat(init_latents, dim=0)
648
+ # else:
649
+ # init_latents = self.vae.encode(image).latent_dist.sample(generator)
650
+
651
+ # init_latents = self.vae.config.scaling_factor * init_latents
652
+
653
+ # if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
654
+ # raise ValueError(
655
+ # f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
656
+ # )
657
+ # else:
658
+ # init_latents = torch.cat([init_latents], dim=0)
659
+
660
+ # shape = init_latents.shape
661
+ # noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
662
+
663
+ # # get latents
664
+ # init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
665
+ # latents = init_latents
666
+
667
+ # return latents
668
+
669
  def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
670
  if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
671
  raise ValueError(
672
  f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
673
  )
674
+
675
+ if isinstance(image, list):
676
+ image_tensors = []
677
+ for img in image:
678
+ img_tensor = prepare_image(img)
679
+ img_tensor = img_tensor.to(device=device, dtype=dtype)
680
+ image_tensors.append(img_tensor)
681
+ image = torch.stack(image_tensors, dim=0)
682
+ else:
683
+ image = prepare_image(image)
684
+ image = image.to(device=device, dtype=dtype)
685
+
686
  batch_size = batch_size * num_images_per_prompt
687
  if isinstance(generator, list) and len(generator) != batch_size:
688
  raise ValueError(
689
  f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
690
  f" size of {batch_size}. Make sure the batch size matches the length of the generators."
691
  )
692
+
693
  if isinstance(generator, list):
694
  init_latents = [
695
+ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(image.shape[0])
696
  ]
697
  init_latents = torch.cat(init_latents, dim=0)
698
  else:
699
  init_latents = self.vae.encode(image).latent_dist.sample(generator)
 
700
  init_latents = self.vae.config.scaling_factor * init_latents
701
+
 
 
 
 
 
 
 
702
  shape = init_latents.shape
703
  noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
704
+
705
  # get latents
706
  init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
707
  latents = init_latents
708
+
709
  return latents
710
 
711
+
712
  def _default_height_width(self, height, width, image):
713
  if isinstance(image, list):
714
  image = image[0]