Stable-X commited on
Commit
9a4072e
1 Parent(s): 19b2060

Fix scheduler and preprocessor bug

Browse files
app.py CHANGED
@@ -28,6 +28,7 @@ import imageio as imageio
28
  import numpy as np
29
  import spaces
30
  import torch as torch
 
31
  from PIL import Image
32
  from gradio_imageslider import ImageSlider
33
  from tqdm import tqdm
@@ -55,7 +56,7 @@ default_image_processing_resolution = 768
55
 
56
  default_video_num_inference_steps = 10
57
  default_video_processing_resolution = 768
58
- default_video_out_max_frames = 450
59
 
60
  def process_image_check(path_input):
61
  if path_input is None:
@@ -99,7 +100,6 @@ def process_image(
99
 
100
  path_output_dir = tempfile.mkdtemp()
101
  path_out_png = os.path.join(path_output_dir, f"{name_base}_normal_colored.png")
102
- yield None
103
  input_image = Image.open(path_input)
104
  input_image = resize_image(input_image, default_image_processing_resolution)
105
 
@@ -132,7 +132,7 @@ def process_video(
132
  pipe,
133
  path_input,
134
  out_max_frames=default_video_out_max_frames,
135
- target_fps=3,
136
  progress=gr.Progress(),
137
  ):
138
  if path_input is None:
@@ -146,6 +146,7 @@ def process_video(
146
  path_output_dir = tempfile.mkdtemp()
147
  path_out_vis = os.path.join(path_output_dir, f"{name_base}_normal_colored.mp4")
148
 
 
149
  reader, writer = None, None
150
  try:
151
  reader = imageio.get_reader(path_input)
@@ -174,8 +175,11 @@ def process_video(
174
  pipe_out = pipe(
175
  frame_pil,
176
  match_input_resolution=False,
 
177
  )
178
 
 
 
179
  processed_frame = pipe.image_processor.visualize_normals( # noqa
180
  pipe_out.prediction
181
  )[0]
@@ -333,7 +337,7 @@ def run_demo_server(pipe):
333
  inputs=[video_input],
334
  outputs=[processed_frames, video_output_files],
335
  directory_name="examples_video",
336
- cache_examples=True,
337
  )
338
 
339
  with gr.Tab("Panorama"):
@@ -407,108 +411,22 @@ def run_demo_server(pipe):
407
  server_port=7860,
408
  )
409
 
410
- from einops import rearrange
411
- class DINOv2_Encoder:
412
- IMAGENET_DEFAULT_MEAN = [0.485, 0.456, 0.406]
413
- IMAGENET_DEFAULT_STD = [0.229, 0.224, 0.225]
414
-
415
- def __init__(
416
- self,
417
- model_name = 'dinov2_vitl14',
418
- freeze = True,
419
- antialias=True,
420
- device="cuda",
421
- size = 448,
422
- ):
423
-
424
- super(DINOv2_Encoder).__init__()
425
-
426
- self.model = torch.hub.load('facebookresearch/dinov2', model_name)
427
- self.model.eval()
428
- self.device = device
429
- self.antialias = antialias
430
- self.dtype = torch.float32
431
-
432
- self.mean = torch.Tensor(self.IMAGENET_DEFAULT_MEAN)
433
- self.std = torch.Tensor(self.IMAGENET_DEFAULT_STD)
434
- self.size = size
435
- if freeze:
436
- self.freeze()
437
-
438
-
439
- def freeze(self):
440
- for param in self.model.parameters():
441
- param.requires_grad = False
442
-
443
- @torch.no_grad()
444
- def encoder(self, x):
445
- '''
446
- x: [b h w c], range from (-1, 1), rbg
447
- '''
448
-
449
- x = self.preprocess(x).to(self.device, self.dtype)
450
-
451
- b, c, h, w = x.shape
452
- patch_h, patch_w = h // 14, w // 14
453
-
454
- embeddings = self.model.forward_features(x)['x_norm_patchtokens']
455
- embeddings = rearrange(embeddings, 'b (h w) c -> b h w c', h = patch_h, w = patch_w)
456
-
457
- return rearrange(embeddings, 'b h w c -> b c h w')
458
-
459
- def preprocess(self, x):
460
- ''' x
461
- '''
462
- # normalize to [0,1],
463
- x = torch.nn.functional.interpolate(
464
- x,
465
- size=(self.size, self.size),
466
- mode='bicubic',
467
- align_corners=True,
468
- antialias=self.antialias,
469
- )
470
-
471
- x = (x + 1.0) / 2.0
472
- # renormalize according to dino
473
- mean = self.mean.view(1, 3, 1, 1).to(x.device)
474
- std = self.std.view(1, 3, 1, 1).to(x.device)
475
- x = (x - mean) / std
476
 
477
- return x
478
-
479
- def to(self, device, dtype=None):
480
- if dtype is not None:
481
- self.dtype = dtype
482
- self.model.to(device, dtype)
483
- self.mean.to(device, dtype)
484
- self.std.to(device, dtype)
485
- else:
486
- self.model.to(device)
487
- self.mean.to(device)
488
- self.std.to(device)
489
- return self
490
-
491
- def __call__(self, x, **kwargs):
492
- return self.encoder(x, **kwargs)
493
-
494
  def main():
495
  os.system("pip freeze")
496
 
497
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
498
 
499
  x_start_pipeline = YOSONormalsPipeline.from_pretrained(
500
- 'Stable-X/yoso-normal-v0-1', trust_remote_code=True,
501
- t_start=300).to(device)
502
- dinov2_prior = DINOv2_Encoder(size=672)
503
- dinov2_prior.to(device)
504
-
505
- pipe = StableNormalPipeline.from_pretrained('Stable-X/stable-normal-v0-1', t_start=300, trust_remote_code=True,
506
  scheduler=HEURI_DDIMScheduler(prediction_type='sample',
507
  beta_start=0.00085, beta_end=0.0120,
508
  beta_schedule = "scaled_linear"))
509
  pipe.x_start_pipeline = x_start_pipeline
510
- pipe.prior = dinov2_prior
511
  pipe.to(device)
 
512
 
513
  try:
514
  import xformers
 
28
  import numpy as np
29
  import spaces
30
  import torch as torch
31
+ torch.backends.cuda.matmul.allow_tf32 = True
32
  from PIL import Image
33
  from gradio_imageslider import ImageSlider
34
  from tqdm import tqdm
 
56
 
57
  default_video_num_inference_steps = 10
58
  default_video_processing_resolution = 768
59
+ default_video_out_max_frames = 60
60
 
61
  def process_image_check(path_input):
62
  if path_input is None:
 
100
 
101
  path_output_dir = tempfile.mkdtemp()
102
  path_out_png = os.path.join(path_output_dir, f"{name_base}_normal_colored.png")
 
103
  input_image = Image.open(path_input)
104
  input_image = resize_image(input_image, default_image_processing_resolution)
105
 
 
132
  pipe,
133
  path_input,
134
  out_max_frames=default_video_out_max_frames,
135
+ target_fps=10,
136
  progress=gr.Progress(),
137
  ):
138
  if path_input is None:
 
146
  path_output_dir = tempfile.mkdtemp()
147
  path_out_vis = os.path.join(path_output_dir, f"{name_base}_normal_colored.mp4")
148
 
149
+ init_latents = None
150
  reader, writer = None, None
151
  try:
152
  reader = imageio.get_reader(path_input)
 
175
  pipe_out = pipe(
176
  frame_pil,
177
  match_input_resolution=False,
178
+ latents=init_latents
179
  )
180
 
181
+ if init_latents is None:
182
+ init_latents = pipe_out.gaus_noise
183
  processed_frame = pipe.image_processor.visualize_normals( # noqa
184
  pipe_out.prediction
185
  )[0]
 
337
  inputs=[video_input],
338
  outputs=[processed_frames, video_output_files],
339
  directory_name="examples_video",
340
+ cache_examples=False,
341
  )
342
 
343
  with gr.Tab("Panorama"):
 
411
  server_port=7860,
412
  )
413
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
414
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
415
  def main():
416
  os.system("pip freeze")
417
 
418
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
419
 
420
  x_start_pipeline = YOSONormalsPipeline.from_pretrained(
421
+ 'weights/yoso-normal-v0-2', trust_remote_code=True, variant="fp16", torch_dtype=torch.float16).to(device)
422
+ pipe = StableNormalPipeline.from_pretrained('weights/stable-normal-v0-1', trust_remote_code=True,
423
+ variant="fp16", torch_dtype=torch.float16,
 
 
 
424
  scheduler=HEURI_DDIMScheduler(prediction_type='sample',
425
  beta_start=0.00085, beta_end=0.0120,
426
  beta_schedule = "scaled_linear"))
427
  pipe.x_start_pipeline = x_start_pipeline
 
428
  pipe.to(device)
429
+ pipe.prior.to(device, torch.float16)
430
 
431
  try:
432
  import xformers
stablenormal/pipeline_stablenormal.py CHANGED
@@ -99,7 +99,90 @@ class StableNormalOutput(BaseOutput):
99
 
100
  prediction: Union[np.ndarray, torch.Tensor]
101
  latent: Union[None, torch.Tensor]
 
102
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
  class StableNormalPipeline(StableDiffusionControlNetPipeline):
105
  """ Pipeline for monocular normals estimation using the Marigold method: https://marigoldmonodepth.github.io.
@@ -163,7 +246,6 @@ class StableNormalPipeline(StableDiffusionControlNetPipeline):
163
  default_processing_resolution: Optional[int] = 768,
164
  prompt="The normal map",
165
  empty_text_embedding=None,
166
- t_start: Optional[int] = 401,
167
  ):
168
  super().__init__(
169
  vae,
@@ -190,8 +272,7 @@ class StableNormalPipeline(StableDiffusionControlNetPipeline):
190
  self.prompt = prompt
191
  self.prompt_embeds = None
192
  self.empty_text_embedding = empty_text_embedding
193
- self.t_start= torch.tensor(t_start) # target_out latents
194
-
195
 
196
  def check_inputs(
197
  self,
@@ -346,7 +427,6 @@ class StableNormalPipeline(StableDiffusionControlNetPipeline):
346
  num_inference_steps: Optional[int] = None,
347
  ensemble_size: int = 1,
348
  processing_resolution: Optional[int] = None,
349
- return_intermediate_result: bool = False,
350
  match_input_resolution: bool = True,
351
  resample_method_input: str = "bilinear",
352
  resample_method_output: str = "bilinear",
@@ -441,10 +521,14 @@ class StableNormalPipeline(StableDiffusionControlNetPipeline):
441
  image, processing_resolution, resample_method_input, device, dtype
442
  ) # [N,3,PPH,PPW]
443
 
 
 
 
 
444
  # 0. X_start latent obtain
445
- predictor = self.x_start_pipeline(image, skip_preprocess=True)
 
446
  x_start_latent = predictor.latent
447
- gauss_latent = predictor.gauss_latent
448
 
449
  # 1. Check inputs.
450
  num_images = self.check_inputs(
@@ -503,28 +587,14 @@ class StableNormalPipeline(StableDiffusionControlNetPipeline):
503
  dino_features = self.dino_controlnet.dino_controlnet_cond_embedding(dino_features)
504
  dino_features = self.match_noisy(dino_features, x_start_latent)
505
 
506
- # 6. Encode input image into latent space. At this step, each of the `N` input images is represented with `E`
507
- # ensemble members. Each ensemble member is an independent diffused prediction, just initialized independently.
508
- # Latents of each such predictions across all input images and all ensemble members are represented in the
509
- # `pred_latent` variable. The variable `image_latent` is of the same shape: it contains each input image encoded
510
- # into latent space and replicated `E` times. The latents can be either generated (see `generator` to ensure
511
- # reproducibility), or passed explicitly via the `latents` argument. The latter can be set outside the pipeline
512
- # code. For example, in the Marigold-LCM video processing demo, the latents initialization of a frame is taken
513
- # as a convex combination of the latents output of the pipeline for the previous frame and a newly-sampled
514
- # noise. This behavior can be achieved by setting the `output_latent` argument to `True`. The latent space
515
- # dimensions are `(h, w)`. Encoding into latent space happens in batches of size `batch_size`.
516
- # Model invocation: self.vae.encoder.
517
- image_latent, pred_latent = self.prepare_latents(
518
- image, latents, generator, ensemble_size, batch_size
519
- ) # [N*E,4,h,w], [N*E,4,h,w]
520
-
521
-
522
  del (
523
  image,
524
  )
525
 
526
  # 7. denoise sampling, using heuritic sampling proposed by Ye.
527
- self.scheduler.set_timesteps(num_inference_steps, device=device)
 
 
528
 
529
  cond_scale =controlnet_conditioning_scale
530
  pred_latent = x_start_latent
@@ -544,50 +614,58 @@ class StableNormalPipeline(StableDiffusionControlNetPipeline):
544
 
545
  pred_latents = []
546
 
547
- down_block_res_samples, mid_block_res_sample = self.controlnet(
548
- image_latent.detach(),
549
- self.t_start,
550
- encoder_hidden_states=self.prompt_embeds,
551
- conditioning_scale=cond_scale,
552
- guess_mode=False,
553
- return_dict=False,
554
- )
555
  last_pred_latent = pred_latent
556
- for i in range(4):
 
557
  _dino_down_block_res_samples = [dino_down_block_res_sample for dino_down_block_res_sample in dino_down_block_res_samples] # copy, avoid repeat quiery
558
-
559
- model_output = self.dino_unet_forward(
 
 
 
 
 
 
 
 
 
 
 
560
  self.unet,
561
  pred_latent,
562
- self.t_start,
563
  encoder_hidden_states=self.prompt_embeds,
564
  down_block_additional_residuals=down_block_res_samples,
565
  mid_block_additional_residual=mid_block_res_sample,
566
  dino_down_block_additional_residuals= _dino_down_block_res_samples,
567
  return_dict=False,
568
  )[0] # [B,4,h,w]
569
- pred_latents.append(model_output)
570
- pred_latent = self.scheduler.add_noise(model_output, gauss_latent, self.t_start)
571
- pred_latent = 0.4 * pred_latent + 0.6 * last_pred_latent
572
- last_pred_latent = pred_latent
573
- pred_latents = torch.cat(pred_latents, dim=0)
 
 
 
 
 
574
  del (
575
  image_latent,
576
  dino_features,
577
  )
578
-
579
 
580
  # decoder
581
- if return_intermediate_result:
582
- prediction = []
583
- for _pred_latent in pred_latents:
584
- _prediction = self.decode_prediction(_pred_latent.unsqueeze(dim=0))
585
- prediction.append(_prediction)
586
- prediction = torch.cat(prediction, dim=0)
587
- else:
588
- prediction = self.decode_prediction(pred_latents[-1].unsqueeze(dim=0))
589
  prediction = self.image_processor.unpad_image(prediction, padding) # [N*E,3,PH,PW]
590
-
 
 
 
 
 
 
591
  if match_input_resolution:
592
  prediction = self.image_processor.resize_antialias(
593
  prediction, original_resolution, resample_method_output, is_aa=False
@@ -604,6 +682,7 @@ class StableNormalPipeline(StableDiffusionControlNetPipeline):
604
  return StableNormalOutput(
605
  prediction=prediction,
606
  latent=pred_latent,
 
607
  )
608
 
609
  # Copied from diffusers.pipelines.marigold.pipeline_marigold_depth.MarigoldDepthPipeline.prepare_latents
 
99
 
100
  prediction: Union[np.ndarray, torch.Tensor]
101
  latent: Union[None, torch.Tensor]
102
+ gaus_noise: Union[None, torch.Tensor]
103
 
104
+ from einops import rearrange
105
+ class DINOv2_Encoder(torch.nn.Module):
106
+ IMAGENET_DEFAULT_MEAN = [0.485, 0.456, 0.406]
107
+ IMAGENET_DEFAULT_STD = [0.229, 0.224, 0.225]
108
+
109
+ def __init__(
110
+ self,
111
+ model_name = 'dinov2_vitl14',
112
+ freeze = True,
113
+ antialias=True,
114
+ device="cuda",
115
+ size = 448,
116
+ ):
117
+ super(DINOv2_Encoder, self).__init__()
118
+
119
+ self.model = torch.hub.load('facebookresearch/dinov2', model_name)
120
+ self.model.eval().to(device)
121
+ self.device = device
122
+ self.antialias = antialias
123
+ self.dtype = torch.float32
124
+
125
+ self.mean = torch.Tensor(self.IMAGENET_DEFAULT_MEAN)
126
+ self.std = torch.Tensor(self.IMAGENET_DEFAULT_STD)
127
+ self.size = size
128
+ if freeze:
129
+ self.freeze()
130
+
131
+
132
+ def freeze(self):
133
+ for param in self.model.parameters():
134
+ param.requires_grad = False
135
+
136
+ @torch.no_grad()
137
+ def encoder(self, x):
138
+ '''
139
+ x: [b h w c], range from (-1, 1), rbg
140
+ '''
141
+
142
+ x = self.preprocess(x).to(self.device, self.dtype)
143
+
144
+ b, c, h, w = x.shape
145
+ patch_h, patch_w = h // 14, w // 14
146
+
147
+ embeddings = self.model.forward_features(x)['x_norm_patchtokens']
148
+ embeddings = rearrange(embeddings, 'b (h w) c -> b h w c', h = patch_h, w = patch_w)
149
+
150
+ return rearrange(embeddings, 'b h w c -> b c h w')
151
+
152
+ def preprocess(self, x):
153
+ ''' x
154
+ '''
155
+ # normalize to [0,1],
156
+ x = torch.nn.functional.interpolate(
157
+ x,
158
+ size=(self.size, self.size),
159
+ mode='bicubic',
160
+ align_corners=True,
161
+ antialias=self.antialias,
162
+ )
163
+
164
+ x = (x + 1.0) / 2.0
165
+ # renormalize according to dino
166
+ mean = self.mean.view(1, 3, 1, 1).to(x.device)
167
+ std = self.std.view(1, 3, 1, 1).to(x.device)
168
+ x = (x - mean) / std
169
+
170
+ return x
171
+
172
+ def to(self, device, dtype=None):
173
+ if dtype is not None:
174
+ self.dtype = dtype
175
+ self.model.to(device, dtype)
176
+ self.mean.to(device, dtype)
177
+ self.std.to(device, dtype)
178
+ else:
179
+ self.model.to(device)
180
+ self.mean.to(device)
181
+ self.std.to(device)
182
+ return self
183
+
184
+ def __call__(self, x, **kwargs):
185
+ return self.encoder(x, **kwargs)
186
 
187
  class StableNormalPipeline(StableDiffusionControlNetPipeline):
188
  """ Pipeline for monocular normals estimation using the Marigold method: https://marigoldmonodepth.github.io.
 
246
  default_processing_resolution: Optional[int] = 768,
247
  prompt="The normal map",
248
  empty_text_embedding=None,
 
249
  ):
250
  super().__init__(
251
  vae,
 
272
  self.prompt = prompt
273
  self.prompt_embeds = None
274
  self.empty_text_embedding = empty_text_embedding
275
+ self.prior = DINOv2_Encoder(size=672)
 
276
 
277
  def check_inputs(
278
  self,
 
427
  num_inference_steps: Optional[int] = None,
428
  ensemble_size: int = 1,
429
  processing_resolution: Optional[int] = None,
 
430
  match_input_resolution: bool = True,
431
  resample_method_input: str = "bilinear",
432
  resample_method_output: str = "bilinear",
 
521
  image, processing_resolution, resample_method_input, device, dtype
522
  ) # [N,3,PPH,PPW]
523
 
524
+ image_latent, gaus_noise = self.prepare_latents(
525
+ image, latents, generator, ensemble_size, batch_size
526
+ ) # [N,4,h,w], [N,4,h,w]
527
+
528
  # 0. X_start latent obtain
529
+ predictor = self.x_start_pipeline(image, latents=gaus_noise,
530
+ processing_resolution=processing_resolution, skip_preprocess=True)
531
  x_start_latent = predictor.latent
 
532
 
533
  # 1. Check inputs.
534
  num_images = self.check_inputs(
 
587
  dino_features = self.dino_controlnet.dino_controlnet_cond_embedding(dino_features)
588
  dino_features = self.match_noisy(dino_features, x_start_latent)
589
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
590
  del (
591
  image,
592
  )
593
 
594
  # 7. denoise sampling, using heuritic sampling proposed by Ye.
595
+
596
+ t_start = self.x_start_pipeline.t_start
597
+ self.scheduler.set_timesteps(num_inference_steps, t_start=t_start,device=device)
598
 
599
  cond_scale =controlnet_conditioning_scale
600
  pred_latent = x_start_latent
 
614
 
615
  pred_latents = []
616
 
 
 
 
 
 
 
 
 
617
  last_pred_latent = pred_latent
618
+ for (t, prev_t) in self.progress_bar(zip(self.scheduler.timesteps,self.scheduler.prev_timesteps), leave=False, desc="Diffusion steps..."):
619
+
620
  _dino_down_block_res_samples = [dino_down_block_res_sample for dino_down_block_res_sample in dino_down_block_res_samples] # copy, avoid repeat quiery
621
+
622
+ # controlnet
623
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
624
+ image_latent.detach(),
625
+ t,
626
+ encoder_hidden_states=self.prompt_embeds,
627
+ conditioning_scale=cond_scale,
628
+ guess_mode=False,
629
+ return_dict=False,
630
+ )
631
+
632
+ # SG-DRN
633
+ noise = self.dino_unet_forward(
634
  self.unet,
635
  pred_latent,
636
+ t,
637
  encoder_hidden_states=self.prompt_embeds,
638
  down_block_additional_residuals=down_block_res_samples,
639
  mid_block_additional_residual=mid_block_res_sample,
640
  dino_down_block_additional_residuals= _dino_down_block_res_samples,
641
  return_dict=False,
642
  )[0] # [B,4,h,w]
643
+
644
+ pred_latents.append(noise)
645
+ # ddim steps
646
+ out = self.scheduler.step(
647
+ noise, t, prev_t, pred_latent, gaus_noise = gaus_noise, generator=generator, cur_step=cur_step+1 # NOTE that cur_step dirs to next_step
648
+ )# [B,4,h,w]
649
+ pred_latent = out.prev_sample
650
+
651
+ cur_step += 1
652
+
653
  del (
654
  image_latent,
655
  dino_features,
656
  )
657
+ pred_latent = pred_latents[-1] # using x0
658
 
659
  # decoder
660
+ prediction = self.decode_prediction(pred_latent)
 
 
 
 
 
 
 
661
  prediction = self.image_processor.unpad_image(prediction, padding) # [N*E,3,PH,PW]
662
+ prediction = self.image_processor.resize_antialias(prediction, original_resolution, resample_method_output, is_aa=False) # [N,3,H,W]
663
+
664
+ if match_input_resolution:
665
+ prediction = self.image_processor.resize_antialias(
666
+ prediction, original_resolution, resample_method_output, is_aa=False
667
+ ) # [N,3,H,W]
668
+
669
  if match_input_resolution:
670
  prediction = self.image_processor.resize_antialias(
671
  prediction, original_resolution, resample_method_output, is_aa=False
 
682
  return StableNormalOutput(
683
  prediction=prediction,
684
  latent=pred_latent,
685
+ gaus_noise=gaus_noise
686
  )
687
 
688
  # Copied from diffusers.pipelines.marigold.pipeline_marigold_depth.MarigoldDepthPipeline.prepare_latents
stablenormal/pipeline_yoso_normal.py CHANGED
@@ -93,7 +93,7 @@ class YosoNormalsOutput(BaseOutput):
93
 
94
  prediction: Union[np.ndarray, torch.Tensor]
95
  latent: Union[None, torch.Tensor]
96
- gauss_latent: Union[None, torch.Tensor]
97
 
98
 
99
  class YOSONormalsPipeline(StableDiffusionControlNetPipeline):
@@ -502,10 +502,11 @@ class YOSONormalsPipeline(StableDiffusionControlNetPipeline):
502
  # noise. This behavior can be achieved by setting the `output_latent` argument to `True`. The latent space
503
  # dimensions are `(h, w)`. Encoding into latent space happens in batches of size `batch_size`.
504
  # Model invocation: self.vae.encoder.
505
- image_latent, gauss_latent = self.prepare_latents(
506
  image, latents, generator, ensemble_size, batch_size
507
  ) # [N*E,4,h,w], [N*E,4,h,w]
508
 
 
509
  del image
510
 
511
 
@@ -523,7 +524,7 @@ class YOSONormalsPipeline(StableDiffusionControlNetPipeline):
523
 
524
  # 7. YOSO sampling
525
  latent_x_t = self.unet(
526
- gauss_latent,
527
  self.t_start,
528
  encoder_hidden_states=self.prompt_embeds,
529
  down_block_additional_residuals=down_block_res_samples,
@@ -533,6 +534,7 @@ class YOSONormalsPipeline(StableDiffusionControlNetPipeline):
533
 
534
 
535
  del (
 
536
  image_latent,
537
  )
538
 
@@ -554,7 +556,7 @@ class YOSONormalsPipeline(StableDiffusionControlNetPipeline):
554
  return YosoNormalsOutput(
555
  prediction=prediction,
556
  latent=latent_x_t,
557
- gauss_latent=gauss_latent,
558
  )
559
 
560
  # Copied from diffusers.pipelines.marigold.pipeline_marigold_depth.MarigoldDepthPipeline.prepare_latents
@@ -585,7 +587,15 @@ class YOSONormalsPipeline(StableDiffusionControlNetPipeline):
585
  ) # [N,4,h,w]
586
  image_latent = image_latent * self.vae.config.scaling_factor
587
  image_latent = image_latent.repeat_interleave(ensemble_size, dim=0) # [N*E,4,h,w]
588
- pred_latent = torch.randn_like(image_latent)
 
 
 
 
 
 
 
 
589
 
590
  return image_latent, pred_latent
591
 
@@ -714,4 +724,4 @@ def retrieve_timesteps(
714
  else:
715
  scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
716
  timesteps = scheduler.timesteps
717
- return timesteps, num_inference_steps
 
93
 
94
  prediction: Union[np.ndarray, torch.Tensor]
95
  latent: Union[None, torch.Tensor]
96
+ gaus_noise: Union[None, torch.Tensor]
97
 
98
 
99
  class YOSONormalsPipeline(StableDiffusionControlNetPipeline):
 
502
  # noise. This behavior can be achieved by setting the `output_latent` argument to `True`. The latent space
503
  # dimensions are `(h, w)`. Encoding into latent space happens in batches of size `batch_size`.
504
  # Model invocation: self.vae.encoder.
505
+ image_latent, pred_latent = self.prepare_latents(
506
  image, latents, generator, ensemble_size, batch_size
507
  ) # [N*E,4,h,w], [N*E,4,h,w]
508
 
509
+ gaus_noise = pred_latent.detach().clone()
510
  del image
511
 
512
 
 
524
 
525
  # 7. YOSO sampling
526
  latent_x_t = self.unet(
527
+ pred_latent,
528
  self.t_start,
529
  encoder_hidden_states=self.prompt_embeds,
530
  down_block_additional_residuals=down_block_res_samples,
 
534
 
535
 
536
  del (
537
+ pred_latent,
538
  image_latent,
539
  )
540
 
 
556
  return YosoNormalsOutput(
557
  prediction=prediction,
558
  latent=latent_x_t,
559
+ gaus_noise=gaus_noise,
560
  )
561
 
562
  # Copied from diffusers.pipelines.marigold.pipeline_marigold_depth.MarigoldDepthPipeline.prepare_latents
 
587
  ) # [N,4,h,w]
588
  image_latent = image_latent * self.vae.config.scaling_factor
589
  image_latent = image_latent.repeat_interleave(ensemble_size, dim=0) # [N*E,4,h,w]
590
+
591
+ pred_latent = latents
592
+ if pred_latent is None:
593
+ pred_latent = randn_tensor(
594
+ image_latent.shape,
595
+ generator=generator,
596
+ device=image_latent.device,
597
+ dtype=image_latent.dtype,
598
+ ) # [N*E,4,h,w]
599
 
600
  return image_latent, pred_latent
601
 
 
724
  else:
725
  scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
726
  timesteps = scheduler.timesteps
727
+ return timesteps, num_inference_steps
stablenormal/scheduler/heuristics_ddimsampler.py CHANGED
@@ -12,7 +12,7 @@ import pdb
12
 
13
  class HEURI_DDIMScheduler(DDIMScheduler, SchedulerMixin, ConfigMixin):
14
 
15
- def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
16
  """
17
  Sets the discrete timesteps used for the diffusion chain (to be run before inference).
18
 
@@ -56,8 +56,13 @@ class HEURI_DDIMScheduler(DDIMScheduler, SchedulerMixin, ConfigMixin):
56
  )
57
 
58
  timesteps = torch.from_numpy(timesteps).to(device)
 
 
59
  naive_sampling_step = num_inference_steps //2
60
 
 
 
 
61
  self.naive_sampling_step = naive_sampling_step
62
 
63
  timesteps[:naive_sampling_step] = timesteps[naive_sampling_step] # refine on step 5 for 5 steps, then backward from step 6
@@ -79,8 +84,8 @@ class HEURI_DDIMScheduler(DDIMScheduler, SchedulerMixin, ConfigMixin):
79
  use_clipped_model_output: bool = False,
80
  generator=None,
81
  cur_step=None,
82
- gauss_latent=None,
83
  variance_noise: Optional[torch.Tensor] = None,
 
84
  return_dict: bool = True,
85
  ) -> Union[DDIMSchedulerOutput, Tuple]:
86
  """
@@ -134,10 +139,12 @@ class HEURI_DDIMScheduler(DDIMScheduler, SchedulerMixin, ConfigMixin):
134
  # - pred_prev_sample -> "x_t-1"
135
 
136
  # 1. get previous step value (=t-1)
 
137
  # trick from heuri_sampling
138
  if cur_step == self.naive_sampling_step and timestep == prev_timestep:
139
  timestep += self.gap
140
 
 
141
  prev_timestep = prev_timestep # NOTE naive sampling
142
 
143
  # 2. compute alphas, betas
@@ -172,6 +179,7 @@ class HEURI_DDIMScheduler(DDIMScheduler, SchedulerMixin, ConfigMixin):
172
  variance = self._get_variance(timestep, prev_timestep)
173
  std_dev_t = eta * variance ** (0.5)
174
 
 
175
  if use_clipped_model_output:
176
  # the pred_epsilon is always re-derived from the clipped x_0 in Glide
177
  pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
@@ -180,8 +188,6 @@ class HEURI_DDIMScheduler(DDIMScheduler, SchedulerMixin, ConfigMixin):
180
  pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon
181
 
182
  # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
183
- if gauss_latent == None:
184
- gauss_latent = torch.randn_like(pred_original_sample)
185
  prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
186
 
187
  if eta > 0:
@@ -200,11 +206,12 @@ class HEURI_DDIMScheduler(DDIMScheduler, SchedulerMixin, ConfigMixin):
200
  prev_sample = prev_sample + variance
201
 
202
  if cur_step < self.naive_sampling_step:
203
- prev_sample = self.add_noise(pred_original_sample, gauss_latent, timestep)
204
 
205
  if not return_dict:
206
  return (prev_sample,)
207
 
 
208
  return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
209
 
210
 
 
12
 
13
  class HEURI_DDIMScheduler(DDIMScheduler, SchedulerMixin, ConfigMixin):
14
 
15
+ def set_timesteps(self, num_inference_steps: int, t_start: int, device: Union[str, torch.device] = None):
16
  """
17
  Sets the discrete timesteps used for the diffusion chain (to be run before inference).
18
 
 
56
  )
57
 
58
  timesteps = torch.from_numpy(timesteps).to(device)
59
+
60
+
61
  naive_sampling_step = num_inference_steps //2
62
 
63
+ # TODO for debug
64
+ # naive_sampling_step = 0
65
+
66
  self.naive_sampling_step = naive_sampling_step
67
 
68
  timesteps[:naive_sampling_step] = timesteps[naive_sampling_step] # refine on step 5 for 5 steps, then backward from step 6
 
84
  use_clipped_model_output: bool = False,
85
  generator=None,
86
  cur_step=None,
 
87
  variance_noise: Optional[torch.Tensor] = None,
88
+ gaus_noise: Optional[torch.Tensor] = None,
89
  return_dict: bool = True,
90
  ) -> Union[DDIMSchedulerOutput, Tuple]:
91
  """
 
139
  # - pred_prev_sample -> "x_t-1"
140
 
141
  # 1. get previous step value (=t-1)
142
+
143
  # trick from heuri_sampling
144
  if cur_step == self.naive_sampling_step and timestep == prev_timestep:
145
  timestep += self.gap
146
 
147
+
148
  prev_timestep = prev_timestep # NOTE naive sampling
149
 
150
  # 2. compute alphas, betas
 
179
  variance = self._get_variance(timestep, prev_timestep)
180
  std_dev_t = eta * variance ** (0.5)
181
 
182
+
183
  if use_clipped_model_output:
184
  # the pred_epsilon is always re-derived from the clipped x_0 in Glide
185
  pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
 
188
  pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon
189
 
190
  # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
 
 
191
  prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
192
 
193
  if eta > 0:
 
206
  prev_sample = prev_sample + variance
207
 
208
  if cur_step < self.naive_sampling_step:
209
+ prev_sample = self.add_noise(pred_original_sample, torch.randn_like(pred_original_sample), timestep)
210
 
211
  if not return_dict:
212
  return (prev_sample,)
213
 
214
+
215
  return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
216
 
217