Spaces:
Sleeping
Sleeping
byeongjun-park
commited on
Commit
•
f679b0c
1
Parent(s):
0c0d385
HarmonyView update
Browse files- app.py +2 -2
- ldm/models/diffusion/sync_dreamer.py +25 -47
app.py
CHANGED
@@ -225,8 +225,8 @@ def run_demo():
|
|
225 |
input_block = gr.Image(type='pil', image_mode='RGBA', label="Input to SyncDreamer", height=256, interactive=False)
|
226 |
elevation.render()
|
227 |
with gr.Accordion('Advanced options', open=False):
|
228 |
-
cfg_scale_1 = gr.Slider(1.0, 5.0, 2.0, step=0.1, label='Classifier free guidance', interactive=True)
|
229 |
-
cfg_scale_2 = gr.Slider(0.5, 1.5, 1.0, step=0.1, label='Classifier free guidance', interactive=True)
|
230 |
sample_num = gr.Slider(1, 2, 1, step=1, label='Sample num', interactive=False, info='How many instance (16 images per instance)')
|
231 |
sample_steps = gr.Slider(10, 300, 50, step=10, label='Sample steps', interactive=False)
|
232 |
batch_view_num = gr.Slider(1, 16, 16, step=1, label='Batch num', interactive=True)
|
|
|
225 |
input_block = gr.Image(type='pil', image_mode='RGBA', label="Input to SyncDreamer", height=256, interactive=False)
|
226 |
elevation.render()
|
227 |
with gr.Accordion('Advanced options', open=False):
|
228 |
+
cfg_scale_1 = gr.Slider(1.0, 5.0, 2.0, step=0.1, label='Classifier free guidance 1', interactive=True)
|
229 |
+
cfg_scale_2 = gr.Slider(0.5, 1.5, 1.0, step=0.1, label='Classifier free guidance 2', interactive=True)
|
230 |
sample_num = gr.Slider(1, 2, 1, step=1, label='Sample num', interactive=False, info='How many instance (16 images per instance)')
|
231 |
sample_steps = gr.Slider(10, 300, 50, step=10, label='Sample steps', interactive=False)
|
232 |
batch_view_num = gr.Slider(1, 16, 16, step=1, label='Batch num', interactive=True)
|
ldm/models/diffusion/sync_dreamer.py
CHANGED
@@ -100,26 +100,6 @@ class UNetWrapper(nn.Module):
|
|
100 |
pred = self.diffusion_model(x, t, clip_embed, source_dict=volume_feats)
|
101 |
return pred
|
102 |
|
103 |
-
def predict_with_unconditional_scale(self, x, t, clip_embed, volume_feats, x_concat, unconditional_scale):
|
104 |
-
x_ = torch.cat([x] * 2, 0)
|
105 |
-
t_ = torch.cat([t] * 2, 0)
|
106 |
-
clip_embed_ = torch.cat([clip_embed, torch.zeros_like(clip_embed)], 0)
|
107 |
-
|
108 |
-
v_ = {}
|
109 |
-
for k, v in volume_feats.items():
|
110 |
-
v_[k] = torch.cat([v, torch.zeros_like(v)], 0)
|
111 |
-
|
112 |
-
x_concat_ = torch.cat([x_concat, torch.zeros_like(x_concat)], 0)
|
113 |
-
|
114 |
-
if self.use_zero_123:
|
115 |
-
# zero123 does not multiply this when encoding, maybe a bug for zero123
|
116 |
-
first_stage_scale_factor = 0.18215
|
117 |
-
x_concat_[:, :4] = x_concat_[:, :4] / first_stage_scale_factor
|
118 |
-
x_ = torch.cat([x_, x_concat_], 1)
|
119 |
-
s, s_uc = self.diffusion_model(x_, t_, clip_embed_, source_dict=v_).chunk(2)
|
120 |
-
s = s_uc + unconditional_scale * (s - s_uc)
|
121 |
-
return s
|
122 |
-
|
123 |
def predict_with_decomposed_unconditional_scales(self, x, t, clip_embed, volume_feats, x_concat, unconditional_scales):
|
124 |
x_ = torch.cat([x] * 3, 0)
|
125 |
t_ = torch.cat([t] * 3, 0)
|
@@ -139,6 +119,7 @@ class UNetWrapper(nn.Module):
|
|
139 |
s = s + unconditional_scales[0] * (s - s_uc1) + unconditional_scales[1] * (s - s_uc2)
|
140 |
return s
|
141 |
|
|
|
142 |
class SpatialVolumeNet(nn.Module):
|
143 |
def __init__(self, time_dim, view_dim, view_num,
|
144 |
input_image_size=256, frustum_volume_depth=48,
|
@@ -175,12 +156,13 @@ class SpatialVolumeNet(nn.Module):
|
|
175 |
device = x.device
|
176 |
|
177 |
spatial_volume_verts = torch.linspace(-self.spatial_volume_length, self.spatial_volume_length, V, dtype=torch.float32, device=device)
|
178 |
-
spatial_volume_verts = torch.stack(torch.meshgrid(spatial_volume_verts, spatial_volume_verts, spatial_volume_verts
|
179 |
spatial_volume_verts = spatial_volume_verts.reshape(1, V ** 3, 3)[:, :, (2, 1, 0)]
|
180 |
spatial_volume_verts = spatial_volume_verts.view(1, V, V, V, 3).permute(0, 4, 1, 2, 3).repeat(B, 1, 1, 1, 1)
|
181 |
|
182 |
# encode source features
|
183 |
t_embed_ = t_embed.view(B, 1, self.time_dim).repeat(1, N, 1).view(B, N, self.time_dim)
|
|
|
184 |
v_embed_ = v_embed
|
185 |
target_Ks = target_Ks.unsqueeze(0).repeat(B, 1, 1, 1)
|
186 |
target_poses = target_poses.unsqueeze(0).repeat(B, 1, 1, 1)
|
@@ -245,8 +227,7 @@ class SyncMultiviewDiffusion(pl.LightningModule):
|
|
245 |
view_num=16, image_size=256,
|
246 |
cfg_scale=3.0, output_num=8, batch_view_num=4,
|
247 |
drop_conditions=False, drop_scheme='default',
|
248 |
-
clip_image_encoder_path="/apdcephfs/private_rondyliu/projects/clip/ViT-L-14.pt"
|
249 |
-
sample_type='ddim', sample_steps=200):
|
250 |
super().__init__()
|
251 |
|
252 |
self.finetune_unet = finetune_unet
|
@@ -274,10 +255,7 @@ class SyncMultiviewDiffusion(pl.LightningModule):
|
|
274 |
self.scheduler_config = scheduler_config
|
275 |
|
276 |
latent_size = image_size//8
|
277 |
-
|
278 |
-
self.sampler = SyncDDIMSampler(self, sample_steps , "uniform", 1.0, latent_size=latent_size)
|
279 |
-
else:
|
280 |
-
raise NotImplementedError
|
281 |
|
282 |
def _init_clip_projection(self):
|
283 |
self.cc_projection = nn.Linear(772, 768)
|
@@ -490,9 +468,9 @@ class SyncMultiviewDiffusion(pl.LightningModule):
|
|
490 |
x_noisy = sqrt_alphas_cumprod_ * x_start + sqrt_one_minus_alphas_cumprod_ * noise
|
491 |
return x_noisy, noise
|
492 |
|
493 |
-
def sample(self, sampler, batch,
|
494 |
_, clip_embed, input_info = self.prepare(batch)
|
495 |
-
x_sample, inter = sampler.sample(input_info, clip_embed, unconditional_scale=cfg_scale, log_every_t=inter_interval)
|
496 |
|
497 |
N = x_sample.shape[1]
|
498 |
x_sample = torch.stack([self.decode_first_stage(x_sample[:, ni]) for ni in range(N)], 1)
|
@@ -531,7 +509,7 @@ class SyncMultiviewDiffusion(pl.LightningModule):
|
|
531 |
step = self.global_step
|
532 |
batch_ = {}
|
533 |
for k, v in batch.items(): batch_[k] = v[:self.output_num]
|
534 |
-
x_sample = self.sample(self.
|
535 |
output_dir = Path(self.image_dir) / 'images' / 'val'
|
536 |
output_dir.mkdir(exist_ok=True, parents=True)
|
537 |
self.log_image(x_sample, batch, step, output_dir=output_dir)
|
@@ -610,7 +588,7 @@ class SyncDDIMSampler:
|
|
610 |
return x_prev
|
611 |
|
612 |
@torch.no_grad()
|
613 |
-
def denoise_apply(self, x_target_noisy, input_info, clip_embed, time_steps, index, unconditional_scale, is_step0=False):
|
614 |
"""
|
615 |
@param x_target_noisy: B,N,4,H,W
|
616 |
@param input_info:
|
@@ -618,6 +596,7 @@ class SyncDDIMSampler:
|
|
618 |
@param time_steps: B,
|
619 |
@param index: int
|
620 |
@param unconditional_scale:
|
|
|
621 |
@param is_step0: bool
|
622 |
@return:
|
623 |
"""
|
@@ -629,34 +608,33 @@ class SyncDDIMSampler:
|
|
629 |
t_embed = self.model.embed_time(time_steps) # B,t_dim
|
630 |
spatial_volume = self.model.spatial_volume.construct_spatial_volume(x_target_noisy, t_embed, v_embed, self.model.poses, self.model.Ks)
|
631 |
|
632 |
-
|
633 |
-
|
634 |
-
|
635 |
-
|
636 |
-
|
|
|
637 |
|
638 |
-
|
639 |
-
|
640 |
-
|
641 |
-
else:
|
642 |
-
noise = self.model.model(x_target_noisy_, time_steps_, clip_embed_, volume_feats_, x_concat_, is_train=False)
|
643 |
-
else: ## DG
|
644 |
noise = self.model.model.predict_with_decomposed_unconditional_scales(x_target_noisy_, time_steps_, clip_embed_, volume_feats_, x_concat_, unconditional_scale)
|
|
|
645 |
|
646 |
-
|
647 |
-
x_prev = self.denoise_apply_impl(x_target_noisy, index,
|
648 |
return x_prev
|
649 |
|
650 |
@torch.no_grad()
|
651 |
-
def sample(self, input_info, clip_embed, unconditional_scale, log_every_t=50):
|
652 |
"""
|
653 |
@param input_info: x, elevation
|
654 |
@param clip_embed: B,M,768
|
655 |
@param unconditional_scale:
|
656 |
@param log_every_t:
|
|
|
657 |
@return:
|
658 |
"""
|
659 |
-
|
660 |
C, H, W = 4, self.latent_size, self.latent_size
|
661 |
B = clip_embed.shape[0]
|
662 |
N = self.model.view_num
|
@@ -672,7 +650,7 @@ class SyncDDIMSampler:
|
|
672 |
for i, step in enumerate(iterator):
|
673 |
index = total_steps - i - 1 # index in ddim state
|
674 |
time_steps = torch.full((B,), step, device=device, dtype=torch.long)
|
675 |
-
x_target_noisy = self.denoise_apply(x_target_noisy, input_info, clip_embed, time_steps, index, unconditional_scale, is_step0=index==0)
|
676 |
if index % log_every_t == 0 or index == total_steps - 1:
|
677 |
intermediates['x_inter'].append(x_target_noisy)
|
678 |
|
|
|
100 |
pred = self.diffusion_model(x, t, clip_embed, source_dict=volume_feats)
|
101 |
return pred
|
102 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
def predict_with_decomposed_unconditional_scales(self, x, t, clip_embed, volume_feats, x_concat, unconditional_scales):
|
104 |
x_ = torch.cat([x] * 3, 0)
|
105 |
t_ = torch.cat([t] * 3, 0)
|
|
|
119 |
s = s + unconditional_scales[0] * (s - s_uc1) + unconditional_scales[1] * (s - s_uc2)
|
120 |
return s
|
121 |
|
122 |
+
|
123 |
class SpatialVolumeNet(nn.Module):
|
124 |
def __init__(self, time_dim, view_dim, view_num,
|
125 |
input_image_size=256, frustum_volume_depth=48,
|
|
|
156 |
device = x.device
|
157 |
|
158 |
spatial_volume_verts = torch.linspace(-self.spatial_volume_length, self.spatial_volume_length, V, dtype=torch.float32, device=device)
|
159 |
+
spatial_volume_verts = torch.stack(torch.meshgrid(spatial_volume_verts, spatial_volume_verts, spatial_volume_verts), -1)
|
160 |
spatial_volume_verts = spatial_volume_verts.reshape(1, V ** 3, 3)[:, :, (2, 1, 0)]
|
161 |
spatial_volume_verts = spatial_volume_verts.view(1, V, V, V, 3).permute(0, 4, 1, 2, 3).repeat(B, 1, 1, 1, 1)
|
162 |
|
163 |
# encode source features
|
164 |
t_embed_ = t_embed.view(B, 1, self.time_dim).repeat(1, N, 1).view(B, N, self.time_dim)
|
165 |
+
# v_embed_ = v_embed.view(1, N, self.view_dim).repeat(B, 1, 1).view(B, N, self.view_dim)
|
166 |
v_embed_ = v_embed
|
167 |
target_Ks = target_Ks.unsqueeze(0).repeat(B, 1, 1, 1)
|
168 |
target_poses = target_poses.unsqueeze(0).repeat(B, 1, 1, 1)
|
|
|
227 |
view_num=16, image_size=256,
|
228 |
cfg_scale=3.0, output_num=8, batch_view_num=4,
|
229 |
drop_conditions=False, drop_scheme='default',
|
230 |
+
clip_image_encoder_path="/apdcephfs/private_rondyliu/projects/clip/ViT-L-14.pt"):
|
|
|
231 |
super().__init__()
|
232 |
|
233 |
self.finetune_unet = finetune_unet
|
|
|
255 |
self.scheduler_config = scheduler_config
|
256 |
|
257 |
latent_size = image_size//8
|
258 |
+
self.ddim = SyncDDIMSampler(self, 200, "uniform", 1.0, latent_size=latent_size)
|
|
|
|
|
|
|
259 |
|
260 |
def _init_clip_projection(self):
|
261 |
self.cc_projection = nn.Linear(772, 768)
|
|
|
468 |
x_noisy = sqrt_alphas_cumprod_ * x_start + sqrt_one_minus_alphas_cumprod_ * noise
|
469 |
return x_noisy, noise
|
470 |
|
471 |
+
def sample(self, sampler, batch, cfg_scale_1, cfg_scale_2, batch_view_num, return_inter_results=False, inter_interval=50, inter_view_interval=2):
|
472 |
_, clip_embed, input_info = self.prepare(batch)
|
473 |
+
x_sample, inter = sampler.sample(input_info, clip_embed, unconditional_scale=cfg_scale, log_every_t=inter_interval, batch_view_num=batch_view_num)
|
474 |
|
475 |
N = x_sample.shape[1]
|
476 |
x_sample = torch.stack([self.decode_first_stage(x_sample[:, ni]) for ni in range(N)], 1)
|
|
|
509 |
step = self.global_step
|
510 |
batch_ = {}
|
511 |
for k, v in batch.items(): batch_[k] = v[:self.output_num]
|
512 |
+
x_sample = self.sample(batch_, self.cfg_scale, self.batch_view_num)
|
513 |
output_dir = Path(self.image_dir) / 'images' / 'val'
|
514 |
output_dir.mkdir(exist_ok=True, parents=True)
|
515 |
self.log_image(x_sample, batch, step, output_dir=output_dir)
|
|
|
588 |
return x_prev
|
589 |
|
590 |
@torch.no_grad()
|
591 |
+
def denoise_apply(self, x_target_noisy, input_info, clip_embed, time_steps, index, unconditional_scale, batch_view_num=1, is_step0=False):
|
592 |
"""
|
593 |
@param x_target_noisy: B,N,4,H,W
|
594 |
@param input_info:
|
|
|
596 |
@param time_steps: B,
|
597 |
@param index: int
|
598 |
@param unconditional_scale:
|
599 |
+
@param batch_view_num: int
|
600 |
@param is_step0: bool
|
601 |
@return:
|
602 |
"""
|
|
|
608 |
t_embed = self.model.embed_time(time_steps) # B,t_dim
|
609 |
spatial_volume = self.model.spatial_volume.construct_spatial_volume(x_target_noisy, t_embed, v_embed, self.model.poses, self.model.Ks)
|
610 |
|
611 |
+
e_t = []
|
612 |
+
target_indices = torch.arange(N) # N
|
613 |
+
for ni in range(0, N, batch_view_num):
|
614 |
+
x_target_noisy_ = x_target_noisy[:, ni:ni + batch_view_num]
|
615 |
+
VN = x_target_noisy_.shape[1]
|
616 |
+
x_target_noisy_ = x_target_noisy_.reshape(B*VN,C,H,W)
|
617 |
|
618 |
+
time_steps_ = repeat_to_batch(time_steps, B, VN)
|
619 |
+
target_indices_ = target_indices[ni:ni+batch_view_num].unsqueeze(0).repeat(B,1)
|
620 |
+
clip_embed_, volume_feats_, x_concat_ = self.model.get_target_view_feats(x_input, spatial_volume, clip_embed, t_embed, v_embed, target_indices_)
|
|
|
|
|
|
|
621 |
noise = self.model.model.predict_with_decomposed_unconditional_scales(x_target_noisy_, time_steps_, clip_embed_, volume_feats_, x_concat_, unconditional_scale)
|
622 |
+
e_t.append(noise.view(B,VN,4,H,W))
|
623 |
|
624 |
+
e_t = torch.cat(e_t, 1)
|
625 |
+
x_prev = self.denoise_apply_impl(x_target_noisy, index, e_t, is_step0)
|
626 |
return x_prev
|
627 |
|
628 |
@torch.no_grad()
|
629 |
+
def sample(self, input_info, clip_embed, unconditional_scale=1.0, log_every_t=50, batch_view_num=1):
|
630 |
"""
|
631 |
@param input_info: x, elevation
|
632 |
@param clip_embed: B,M,768
|
633 |
@param unconditional_scale:
|
634 |
@param log_every_t:
|
635 |
+
@param batch_view_num:
|
636 |
@return:
|
637 |
"""
|
|
|
638 |
C, H, W = 4, self.latent_size, self.latent_size
|
639 |
B = clip_embed.shape[0]
|
640 |
N = self.model.view_num
|
|
|
650 |
for i, step in enumerate(iterator):
|
651 |
index = total_steps - i - 1 # index in ddim state
|
652 |
time_steps = torch.full((B,), step, device=device, dtype=torch.long)
|
653 |
+
x_target_noisy = self.denoise_apply(x_target_noisy, input_info, clip_embed, time_steps, index, unconditional_scale, batch_view_num=batch_view_num, is_step0=index==0)
|
654 |
if index % log_every_t == 0 or index == total_steps - 1:
|
655 |
intermediates['x_inter'].append(x_target_noisy)
|
656 |
|