byeongjun-park commited on
Commit
f679b0c
1 Parent(s): 0c0d385

HarmonyView update

Browse files
Files changed (2) hide show
  1. app.py +2 -2
  2. ldm/models/diffusion/sync_dreamer.py +25 -47
app.py CHANGED
@@ -225,8 +225,8 @@ def run_demo():
225
  input_block = gr.Image(type='pil', image_mode='RGBA', label="Input to SyncDreamer", height=256, interactive=False)
226
  elevation.render()
227
  with gr.Accordion('Advanced options', open=False):
228
- cfg_scale_1 = gr.Slider(1.0, 5.0, 2.0, step=0.1, label='Classifier free guidance', interactive=True)
229
- cfg_scale_2 = gr.Slider(0.5, 1.5, 1.0, step=0.1, label='Classifier free guidance', interactive=True)
230
  sample_num = gr.Slider(1, 2, 1, step=1, label='Sample num', interactive=False, info='How many instance (16 images per instance)')
231
  sample_steps = gr.Slider(10, 300, 50, step=10, label='Sample steps', interactive=False)
232
  batch_view_num = gr.Slider(1, 16, 16, step=1, label='Batch num', interactive=True)
 
225
  input_block = gr.Image(type='pil', image_mode='RGBA', label="Input to SyncDreamer", height=256, interactive=False)
226
  elevation.render()
227
  with gr.Accordion('Advanced options', open=False):
228
+ cfg_scale_1 = gr.Slider(1.0, 5.0, 2.0, step=0.1, label='Classifier free guidance 1', interactive=True)
229
+ cfg_scale_2 = gr.Slider(0.5, 1.5, 1.0, step=0.1, label='Classifier free guidance 2', interactive=True)
230
  sample_num = gr.Slider(1, 2, 1, step=1, label='Sample num', interactive=False, info='How many instance (16 images per instance)')
231
  sample_steps = gr.Slider(10, 300, 50, step=10, label='Sample steps', interactive=False)
232
  batch_view_num = gr.Slider(1, 16, 16, step=1, label='Batch num', interactive=True)
ldm/models/diffusion/sync_dreamer.py CHANGED
@@ -100,26 +100,6 @@ class UNetWrapper(nn.Module):
100
  pred = self.diffusion_model(x, t, clip_embed, source_dict=volume_feats)
101
  return pred
102
 
103
- def predict_with_unconditional_scale(self, x, t, clip_embed, volume_feats, x_concat, unconditional_scale):
104
- x_ = torch.cat([x] * 2, 0)
105
- t_ = torch.cat([t] * 2, 0)
106
- clip_embed_ = torch.cat([clip_embed, torch.zeros_like(clip_embed)], 0)
107
-
108
- v_ = {}
109
- for k, v in volume_feats.items():
110
- v_[k] = torch.cat([v, torch.zeros_like(v)], 0)
111
-
112
- x_concat_ = torch.cat([x_concat, torch.zeros_like(x_concat)], 0)
113
-
114
- if self.use_zero_123:
115
- # zero123 does not multiply this when encoding, maybe a bug for zero123
116
- first_stage_scale_factor = 0.18215
117
- x_concat_[:, :4] = x_concat_[:, :4] / first_stage_scale_factor
118
- x_ = torch.cat([x_, x_concat_], 1)
119
- s, s_uc = self.diffusion_model(x_, t_, clip_embed_, source_dict=v_).chunk(2)
120
- s = s_uc + unconditional_scale * (s - s_uc)
121
- return s
122
-
123
  def predict_with_decomposed_unconditional_scales(self, x, t, clip_embed, volume_feats, x_concat, unconditional_scales):
124
  x_ = torch.cat([x] * 3, 0)
125
  t_ = torch.cat([t] * 3, 0)
@@ -139,6 +119,7 @@ class UNetWrapper(nn.Module):
139
  s = s + unconditional_scales[0] * (s - s_uc1) + unconditional_scales[1] * (s - s_uc2)
140
  return s
141
 
 
142
  class SpatialVolumeNet(nn.Module):
143
  def __init__(self, time_dim, view_dim, view_num,
144
  input_image_size=256, frustum_volume_depth=48,
@@ -175,12 +156,13 @@ class SpatialVolumeNet(nn.Module):
175
  device = x.device
176
 
177
  spatial_volume_verts = torch.linspace(-self.spatial_volume_length, self.spatial_volume_length, V, dtype=torch.float32, device=device)
178
- spatial_volume_verts = torch.stack(torch.meshgrid(spatial_volume_verts, spatial_volume_verts, spatial_volume_verts, indexing='ij'), -1)
179
  spatial_volume_verts = spatial_volume_verts.reshape(1, V ** 3, 3)[:, :, (2, 1, 0)]
180
  spatial_volume_verts = spatial_volume_verts.view(1, V, V, V, 3).permute(0, 4, 1, 2, 3).repeat(B, 1, 1, 1, 1)
181
 
182
  # encode source features
183
  t_embed_ = t_embed.view(B, 1, self.time_dim).repeat(1, N, 1).view(B, N, self.time_dim)
 
184
  v_embed_ = v_embed
185
  target_Ks = target_Ks.unsqueeze(0).repeat(B, 1, 1, 1)
186
  target_poses = target_poses.unsqueeze(0).repeat(B, 1, 1, 1)
@@ -245,8 +227,7 @@ class SyncMultiviewDiffusion(pl.LightningModule):
245
  view_num=16, image_size=256,
246
  cfg_scale=3.0, output_num=8, batch_view_num=4,
247
  drop_conditions=False, drop_scheme='default',
248
- clip_image_encoder_path="/apdcephfs/private_rondyliu/projects/clip/ViT-L-14.pt",
249
- sample_type='ddim', sample_steps=200):
250
  super().__init__()
251
 
252
  self.finetune_unet = finetune_unet
@@ -274,10 +255,7 @@ class SyncMultiviewDiffusion(pl.LightningModule):
274
  self.scheduler_config = scheduler_config
275
 
276
  latent_size = image_size//8
277
- if sample_type=='ddim':
278
- self.sampler = SyncDDIMSampler(self, sample_steps , "uniform", 1.0, latent_size=latent_size)
279
- else:
280
- raise NotImplementedError
281
 
282
  def _init_clip_projection(self):
283
  self.cc_projection = nn.Linear(772, 768)
@@ -490,9 +468,9 @@ class SyncMultiviewDiffusion(pl.LightningModule):
490
  x_noisy = sqrt_alphas_cumprod_ * x_start + sqrt_one_minus_alphas_cumprod_ * noise
491
  return x_noisy, noise
492
 
493
- def sample(self, sampler, batch, cfg_scale, return_inter_results=False, inter_interval=50, inter_view_interval=2):
494
  _, clip_embed, input_info = self.prepare(batch)
495
- x_sample, inter = sampler.sample(input_info, clip_embed, unconditional_scale=cfg_scale, log_every_t=inter_interval)
496
 
497
  N = x_sample.shape[1]
498
  x_sample = torch.stack([self.decode_first_stage(x_sample[:, ni]) for ni in range(N)], 1)
@@ -531,7 +509,7 @@ class SyncMultiviewDiffusion(pl.LightningModule):
531
  step = self.global_step
532
  batch_ = {}
533
  for k, v in batch.items(): batch_[k] = v[:self.output_num]
534
- x_sample = self.sample(self.sampler, batch_, self.cfg_scale)
535
  output_dir = Path(self.image_dir) / 'images' / 'val'
536
  output_dir.mkdir(exist_ok=True, parents=True)
537
  self.log_image(x_sample, batch, step, output_dir=output_dir)
@@ -610,7 +588,7 @@ class SyncDDIMSampler:
610
  return x_prev
611
 
612
  @torch.no_grad()
613
- def denoise_apply(self, x_target_noisy, input_info, clip_embed, time_steps, index, unconditional_scale, is_step0=False):
614
  """
615
  @param x_target_noisy: B,N,4,H,W
616
  @param input_info:
@@ -618,6 +596,7 @@ class SyncDDIMSampler:
618
  @param time_steps: B,
619
  @param index: int
620
  @param unconditional_scale:
 
621
  @param is_step0: bool
622
  @return:
623
  """
@@ -629,34 +608,33 @@ class SyncDDIMSampler:
629
  t_embed = self.model.embed_time(time_steps) # B,t_dim
630
  spatial_volume = self.model.spatial_volume.construct_spatial_volume(x_target_noisy, t_embed, v_embed, self.model.poses, self.model.Ks)
631
 
632
- target_indices_ = torch.arange(N).unsqueeze(0).repeat(B, 1)
633
- x_target_noisy_ = x_target_noisy.reshape(B*N,C,H,W)
634
-
635
- time_steps_ = repeat_to_batch(time_steps, B, N)
636
- clip_embed_, volume_feats_, x_concat_ = self.model.get_target_view_feats(x_input, spatial_volume, clip_embed, t_embed, v_embed, target_indices_)
 
637
 
638
- if type(unconditional_scale) == float: ## CFG
639
- if unconditional_scale != 1.0:
640
- noise = self.model.model.predict_with_unconditional_scale(x_target_noisy_, time_steps_, clip_embed_, volume_feats_, x_concat_, unconditional_scale)
641
- else:
642
- noise = self.model.model(x_target_noisy_, time_steps_, clip_embed_, volume_feats_, x_concat_, is_train=False)
643
- else: ## DG
644
  noise = self.model.model.predict_with_decomposed_unconditional_scales(x_target_noisy_, time_steps_, clip_embed_, volume_feats_, x_concat_, unconditional_scale)
 
645
 
646
- noise = noise.reshape(B, N, 4, H, W)
647
- x_prev = self.denoise_apply_impl(x_target_noisy, index, noise, is_step0)
648
  return x_prev
649
 
650
  @torch.no_grad()
651
- def sample(self, input_info, clip_embed, unconditional_scale, log_every_t=50):
652
  """
653
  @param input_info: x, elevation
654
  @param clip_embed: B,M,768
655
  @param unconditional_scale:
656
  @param log_every_t:
 
657
  @return:
658
  """
659
-
660
  C, H, W = 4, self.latent_size, self.latent_size
661
  B = clip_embed.shape[0]
662
  N = self.model.view_num
@@ -672,7 +650,7 @@ class SyncDDIMSampler:
672
  for i, step in enumerate(iterator):
673
  index = total_steps - i - 1 # index in ddim state
674
  time_steps = torch.full((B,), step, device=device, dtype=torch.long)
675
- x_target_noisy = self.denoise_apply(x_target_noisy, input_info, clip_embed, time_steps, index, unconditional_scale, is_step0=index==0)
676
  if index % log_every_t == 0 or index == total_steps - 1:
677
  intermediates['x_inter'].append(x_target_noisy)
678
 
 
100
  pred = self.diffusion_model(x, t, clip_embed, source_dict=volume_feats)
101
  return pred
102
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  def predict_with_decomposed_unconditional_scales(self, x, t, clip_embed, volume_feats, x_concat, unconditional_scales):
104
  x_ = torch.cat([x] * 3, 0)
105
  t_ = torch.cat([t] * 3, 0)
 
119
  s = s + unconditional_scales[0] * (s - s_uc1) + unconditional_scales[1] * (s - s_uc2)
120
  return s
121
 
122
+
123
  class SpatialVolumeNet(nn.Module):
124
  def __init__(self, time_dim, view_dim, view_num,
125
  input_image_size=256, frustum_volume_depth=48,
 
156
  device = x.device
157
 
158
  spatial_volume_verts = torch.linspace(-self.spatial_volume_length, self.spatial_volume_length, V, dtype=torch.float32, device=device)
159
+ spatial_volume_verts = torch.stack(torch.meshgrid(spatial_volume_verts, spatial_volume_verts, spatial_volume_verts), -1)
160
  spatial_volume_verts = spatial_volume_verts.reshape(1, V ** 3, 3)[:, :, (2, 1, 0)]
161
  spatial_volume_verts = spatial_volume_verts.view(1, V, V, V, 3).permute(0, 4, 1, 2, 3).repeat(B, 1, 1, 1, 1)
162
 
163
  # encode source features
164
  t_embed_ = t_embed.view(B, 1, self.time_dim).repeat(1, N, 1).view(B, N, self.time_dim)
165
+ # v_embed_ = v_embed.view(1, N, self.view_dim).repeat(B, 1, 1).view(B, N, self.view_dim)
166
  v_embed_ = v_embed
167
  target_Ks = target_Ks.unsqueeze(0).repeat(B, 1, 1, 1)
168
  target_poses = target_poses.unsqueeze(0).repeat(B, 1, 1, 1)
 
227
  view_num=16, image_size=256,
228
  cfg_scale=3.0, output_num=8, batch_view_num=4,
229
  drop_conditions=False, drop_scheme='default',
230
+ clip_image_encoder_path="/apdcephfs/private_rondyliu/projects/clip/ViT-L-14.pt"):
 
231
  super().__init__()
232
 
233
  self.finetune_unet = finetune_unet
 
255
  self.scheduler_config = scheduler_config
256
 
257
  latent_size = image_size//8
258
+ self.ddim = SyncDDIMSampler(self, 200, "uniform", 1.0, latent_size=latent_size)
 
 
 
259
 
260
  def _init_clip_projection(self):
261
  self.cc_projection = nn.Linear(772, 768)
 
468
  x_noisy = sqrt_alphas_cumprod_ * x_start + sqrt_one_minus_alphas_cumprod_ * noise
469
  return x_noisy, noise
470
 
471
+ def sample(self, sampler, batch, cfg_scale_1, cfg_scale_2, batch_view_num, return_inter_results=False, inter_interval=50, inter_view_interval=2):
472
  _, clip_embed, input_info = self.prepare(batch)
473
+ x_sample, inter = sampler.sample(input_info, clip_embed, unconditional_scale=cfg_scale, log_every_t=inter_interval, batch_view_num=batch_view_num)
474
 
475
  N = x_sample.shape[1]
476
  x_sample = torch.stack([self.decode_first_stage(x_sample[:, ni]) for ni in range(N)], 1)
 
509
  step = self.global_step
510
  batch_ = {}
511
  for k, v in batch.items(): batch_[k] = v[:self.output_num]
512
+ x_sample = self.sample(batch_, self.cfg_scale, self.batch_view_num)
513
  output_dir = Path(self.image_dir) / 'images' / 'val'
514
  output_dir.mkdir(exist_ok=True, parents=True)
515
  self.log_image(x_sample, batch, step, output_dir=output_dir)
 
588
  return x_prev
589
 
590
  @torch.no_grad()
591
+ def denoise_apply(self, x_target_noisy, input_info, clip_embed, time_steps, index, unconditional_scale, batch_view_num=1, is_step0=False):
592
  """
593
  @param x_target_noisy: B,N,4,H,W
594
  @param input_info:
 
596
  @param time_steps: B,
597
  @param index: int
598
  @param unconditional_scale:
599
+ @param batch_view_num: int
600
  @param is_step0: bool
601
  @return:
602
  """
 
608
  t_embed = self.model.embed_time(time_steps) # B,t_dim
609
  spatial_volume = self.model.spatial_volume.construct_spatial_volume(x_target_noisy, t_embed, v_embed, self.model.poses, self.model.Ks)
610
 
611
+ e_t = []
612
+ target_indices = torch.arange(N) # N
613
+ for ni in range(0, N, batch_view_num):
614
+ x_target_noisy_ = x_target_noisy[:, ni:ni + batch_view_num]
615
+ VN = x_target_noisy_.shape[1]
616
+ x_target_noisy_ = x_target_noisy_.reshape(B*VN,C,H,W)
617
 
618
+ time_steps_ = repeat_to_batch(time_steps, B, VN)
619
+ target_indices_ = target_indices[ni:ni+batch_view_num].unsqueeze(0).repeat(B,1)
620
+ clip_embed_, volume_feats_, x_concat_ = self.model.get_target_view_feats(x_input, spatial_volume, clip_embed, t_embed, v_embed, target_indices_)
 
 
 
621
  noise = self.model.model.predict_with_decomposed_unconditional_scales(x_target_noisy_, time_steps_, clip_embed_, volume_feats_, x_concat_, unconditional_scale)
622
+ e_t.append(noise.view(B,VN,4,H,W))
623
 
624
+ e_t = torch.cat(e_t, 1)
625
+ x_prev = self.denoise_apply_impl(x_target_noisy, index, e_t, is_step0)
626
  return x_prev
627
 
628
  @torch.no_grad()
629
+ def sample(self, input_info, clip_embed, unconditional_scale=1.0, log_every_t=50, batch_view_num=1):
630
  """
631
  @param input_info: x, elevation
632
  @param clip_embed: B,M,768
633
  @param unconditional_scale:
634
  @param log_every_t:
635
+ @param batch_view_num:
636
  @return:
637
  """
 
638
  C, H, W = 4, self.latent_size, self.latent_size
639
  B = clip_embed.shape[0]
640
  N = self.model.view_num
 
650
  for i, step in enumerate(iterator):
651
  index = total_steps - i - 1 # index in ddim state
652
  time_steps = torch.full((B,), step, device=device, dtype=torch.long)
653
+ x_target_noisy = self.denoise_apply(x_target_noisy, input_info, clip_embed, time_steps, index, unconditional_scale, batch_view_num=batch_view_num, is_step0=index==0)
654
  if index % log_every_t == 0 or index == total_steps - 1:
655
  intermediates['x_inter'].append(x_target_noisy)
656