wzhouxiff commited on
Commit
65aedb0
·
2 Parent(s): 6ed0ad9 73a3d97

Track all PNG and MP4 files in assets/examples with Git LFS

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -3
  2. app copy.py +740 -0
  3. assets/examples/00010/generated_video.mp4 +0 -0
  4. assets/examples/00010/mask.png +0 -0
  5. assets/examples/00010/raw_image.png +0 -0
  6. assets/examples/00010/seg_image.png +0 -0
  7. assets/examples/00011/generated_video.mp4 +0 -0
  8. assets/examples/00011/mask.png +0 -0
  9. assets/examples/00011/raw_image.png +0 -0
  10. assets/examples/00011/seg_image.png +0 -0
  11. assets/examples/00012/generated_video.mp4 +0 -0
  12. assets/examples/00012/mask.png +0 -0
  13. assets/examples/00012/raw_image.png +0 -0
  14. assets/examples/00012/seg_image.png +0 -0
  15. assets/examples/00013/generated_video.mp4 +0 -0
  16. assets/examples/00013/mask.png +0 -0
  17. assets/examples/00013/raw_image.png +0 -0
  18. assets/examples/00013/seg_image.png +0 -0
  19. assets/examples/00014/generated_video.mp4 +0 -0
  20. assets/examples/00014/mask.png +0 -0
  21. assets/examples/00014/raw_image.png +0 -0
  22. assets/examples/00014/seg_image.png +0 -0
  23. assets/examples/00015/generated_video.mp4 +0 -0
  24. assets/examples/00015/mask.png +0 -0
  25. assets/examples/00015/raw_image.png +0 -0
  26. assets/examples/00015/seg_image.png +0 -0
  27. assets/examples/00016/generated_video.mp4 +0 -0
  28. assets/examples/00016/mask.png +0 -0
  29. assets/examples/00016/raw_image.png +0 -0
  30. assets/examples/00016/seg_image.png +0 -0
  31. assets/examples/00017/generated_video.mp4 +0 -0
  32. assets/examples/00017/mask.png +0 -0
  33. assets/examples/00017/raw_image.png +0 -0
  34. assets/examples/00017/seg_image.png +0 -0
  35. assets/examples/00018/generated_video.mp4 +0 -0
  36. assets/examples/00018/mask.png +0 -0
  37. assets/examples/00018/raw_image.png +0 -0
  38. assets/examples/00018/seg_image.png +0 -0
  39. assets/examples/00019/generated_video.mp4 +0 -0
  40. assets/examples/00019/mask.png +0 -0
  41. assets/examples/00019/raw_image.png +0 -0
  42. assets/examples/00019/seg_image.png +0 -0
  43. assets/examples/00020/generated_video.mp4 +0 -0
  44. assets/examples/00020/mask.png +0 -0
  45. assets/examples/00020/raw_image.png +0 -0
  46. assets/examples/00020/seg_image.png +0 -0
  47. assets/examples/00021/generated_video.mp4 +0 -0
  48. assets/examples/00021/mask.png +0 -0
  49. assets/examples/00021/raw_image.png +0 -0
  50. assets/examples/00021/seg_image.png +0 -0
.gitattributes CHANGED
@@ -33,6 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
- *.png filter=lfs diff=lfs merge=lfs -text
37
- *.mp4 filter=lfs diff=lfs merge=lfs -text
38
- assets/examples/ filter=lfs diff=lfs merge=lfs -text
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/examples/*.png filter=lfs diff=lfs merge=lfs -text
37
+ assets/examples/*.mp4 filter=lfs diff=lfs merge=lfs -text
 
app copy.py ADDED
@@ -0,0 +1,740 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ try:
2
+ import spaces
3
+ except:
4
+ pass
5
+
6
+ import os
7
+ import gradio as gr
8
+
9
+ import torch
10
+ from gradio_image_prompter import ImagePrompter
11
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
12
+ from omegaconf import OmegaConf
13
+ from PIL import Image
14
+ import numpy as np
15
+ from copy import deepcopy
16
+ import cv2
17
+
18
+ import torch.nn.functional as F
19
+ import torchvision
20
+ from einops import rearrange
21
+ import tempfile
22
+
23
+ from objctrl_2_5d.utils.ui_utils import process_image, get_camera_pose, get_subject_points, get_points, undo_points, mask_image
24
+ from ZoeDepth.zoedepth.utils.misc import colorize
25
+
26
+ from cameractrl.inference import get_pipeline
27
+ from objctrl_2_5d.utils.examples import examples, sync_points
28
+
29
+ from objctrl_2_5d.utils.objmask_util import RT2Plucker, Unprojected, roll_with_ignore_multidim, dilate_mask_pytorch
30
+ from objctrl_2_5d.utils.filter_utils import get_freq_filter, freq_mix_3d
31
+
32
+
33
+ ### Title and Description ###
34
+ #### Description ####
35
+ title = r"""<h1 align="center">ObjCtrl-2.5D: Training-free Object Control with Camera Poses</h1>"""
36
+ # subtitle = r"""<h2 align="center">Deployed on SVD Generation</h2>"""
37
+ important_link = r"""
38
+ <div align='center'>
39
+ <a href='https://wzhouxiff.github.io/projects/MotionCtrl/assets/paper/MotionCtrl.pdf'>[Paper]</a>
40
+ &ensp; <a href='https://wzhouxiff.github.io/projects/MotionCtrl/'>[Project Page]</a>
41
+ &ensp; <a href='https://github.com/TencentARC/MotionCtrl'>[Code]</a>
42
+ </div>
43
+ """
44
+
45
+ authors = r"""
46
+ <div align='center'>
47
+ <a href='https://wzhouxiff.github.io/'>Zhouxia Wang</a>
48
+ &ensp; <a href='https://nirvanalan.github.io/'>Yushi Lan</a>
49
+ &ensp; <a href='https://shangchenzhou.com/'>Shanchen Zhou</a>
50
+ &ensp; <a href='https://www.mmlab-ntu.com/person/ccloy/index.html'>Chen Change Loy</a>
51
+ </div>
52
+ """
53
+
54
+ affiliation = r"""
55
+ <div align='center'>
56
+ <a href='https://www.mmlab-ntu.com/'>S-Lab, NTU Singapore</a>
57
+ </div>
58
+ """
59
+
60
+ description = r"""
61
+ <b>Official Gradio demo</b> for <a href='https://github.com/TencentARC/MotionCtrl' target='_blank'><b>ObjCtrl-2.5D: Training-free Object Control with Camera Poses</b></a>.<br>
62
+ 🔥 ObjCtrl2.5D enables object motion control in a I2V generated video via transforming 2D trajectories to 3D using depth, subsequently converting them into camera poses,
63
+ thereby leveraging the exisitng camera motion control module for object motion control without requiring additional training.<br>
64
+ """
65
+
66
+ article = r"""
67
+ If ObjCtrl2.5D is helpful, please help to ⭐ the <a href='https://github.com/TencentARC/MotionCtrl' target='_blank'>Github Repo</a>. Thanks!
68
+ [![GitHub Stars](https://img.shields.io/github/stars/TencentARC%2FMotionCtrl
69
+ )](https://github.com/TencentARC/MotionCtrl)
70
+
71
+ ---
72
+
73
+ 📝 **Citation**
74
+ <br>
75
+ If our work is useful for your research, please consider citing:
76
+ ```bibtex
77
+ @inproceedings{wang2024motionctrl,
78
+ title={Motionctrl: A unified and flexible motion controller for video generation},
79
+ author={Wang, Zhouxia and Yuan, Ziyang and Wang, Xintao and Li, Yaowei and Chen, Tianshui and Xia, Menghan and Luo, Ping and Shan, Ying},
80
+ booktitle={ACM SIGGRAPH 2024 Conference Papers},
81
+ pages={1--11},
82
+ year={2024}
83
+ }
84
+ ```
85
+
86
+ 📧 **Contact**
87
+ <br>
88
+ If you have any questions, please feel free to reach me out at <b>[email protected]</b>.
89
+
90
+ """
91
+
92
+ # -------------- initialization --------------
93
+
94
+ CAMERA_MODE = ["Traj2Cam", "Rotate", "Clockwise", "Translate"]
95
+
96
+ # select the device for computation
97
+ if torch.cuda.is_available():
98
+ device = torch.device("cuda")
99
+ elif torch.backends.mps.is_available():
100
+ device = torch.device("mps")
101
+ else:
102
+ device = torch.device("cpu")
103
+ print(f"using device: {device}")
104
+
105
+ # segmentation model
106
+ segmentor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-tiny", cache_dir="ckpt", device=device)
107
+
108
+ # depth model
109
+ d_model_NK = torch.hub.load('./ZoeDepth', 'ZoeD_NK', source='local', pretrained=True).to(device)
110
+
111
+ # cameractrl model
112
+ config = "configs/svd_320_576_cameractrl.yaml"
113
+ model_id = "stabilityai/stable-video-diffusion-img2vid"
114
+ ckpt = "checkpoints/CameraCtrl_svd.ckpt"
115
+ if not os.path.exists(ckpt):
116
+ os.makedirs("checkpoints", exist_ok=True)
117
+ os.system("wget -c https://huggingface.co/hehao13/CameraCtrl_SVD_ckpts/resolve/main/CameraCtrl_svd.ckpt?download=true")
118
+ os.system("mv CameraCtrl_svd.ckpt?download=true checkpoints/CameraCtrl_svd.ckpt")
119
+ model_config = OmegaConf.load(config)
120
+
121
+
122
+ pipeline = get_pipeline(model_id, "unet", model_config['down_block_types'], model_config['up_block_types'],
123
+ model_config['pose_encoder_kwargs'], model_config['attention_processor_kwargs'],
124
+ ckpt, True, device)
125
+
126
+ # segmentor = None
127
+ # d_model_NK = None
128
+ # pipeline = None
129
+
130
+ ### run the demo ##
131
+ # @spaces.GPU(duration=5)
132
+ def segment(canvas, image, logits):
133
+ if logits is not None:
134
+ logits *= 32.0
135
+ _, points = get_subject_points(canvas)
136
+ image = np.array(image)
137
+
138
+ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
139
+ segmentor.set_image(image)
140
+ input_points = []
141
+ input_boxes = []
142
+ for p in points:
143
+ [x1, y1, _, x2, y2, _] = p
144
+ if x2==0 and y2==0:
145
+ input_points.append([x1, y1])
146
+ else:
147
+ input_boxes.append([x1, y1, x2, y2])
148
+ if len(input_points) == 0:
149
+ input_points = None
150
+ input_labels = None
151
+ else:
152
+ input_points = np.array(input_points)
153
+ input_labels = np.ones(len(input_points))
154
+ if len(input_boxes) == 0:
155
+ input_boxes = None
156
+ else:
157
+ input_boxes = np.array(input_boxes)
158
+ masks, _, logits = segmentor.predict(
159
+ point_coords=input_points,
160
+ point_labels=input_labels,
161
+ box=input_boxes,
162
+ multimask_output=False,
163
+ return_logits=True,
164
+ mask_input=logits,
165
+ )
166
+ mask = masks > 0
167
+ masked_img = mask_image(image, mask[0], color=[252, 140, 90], alpha=0.9)
168
+ masked_img = Image.fromarray(masked_img)
169
+
170
+ return mask[0], masked_img, masked_img, logits / 32.0
171
+
172
+ # @spaces.GPU(duration=5)
173
+ def get_depth(image, points):
174
+
175
+ depth = d_model_NK.infer_pil(image)
176
+ colored_depth = colorize(depth, cmap='gray_r') # [h, w, 4] 0-255
177
+
178
+ depth_img = deepcopy(colored_depth[:, :, :3])
179
+ if len(points) > 0:
180
+ for idx, point in enumerate(points):
181
+ if idx % 2 == 0:
182
+ cv2.circle(depth_img, tuple(point), 10, (255, 0, 0), -1)
183
+ else:
184
+ cv2.circle(depth_img, tuple(point), 10, (0, 0, 255), -1)
185
+ if idx > 0:
186
+ cv2.arrowedLine(depth_img, points[idx-1], points[idx], (255, 255, 255), 4, tipLength=0.5)
187
+
188
+ return depth, depth_img, colored_depth[:, :, :3]
189
+
190
+
191
+ # @spaces.GPU(duration=80)
192
+ def run_objctrl_2_5d(condition_image,
193
+ mask,
194
+ depth,
195
+ RTs,
196
+ bg_mode,
197
+ shared_wapring_latents,
198
+ scale_wise_masks,
199
+ rescale,
200
+ seed,
201
+ ds, dt,
202
+ num_inference_steps=25):
203
+
204
+ DEBUG = False
205
+
206
+ if DEBUG:
207
+ cur_OUTPUT_PATH = 'outputs/tmp'
208
+ os.makedirs(cur_OUTPUT_PATH, exist_ok=True)
209
+
210
+ # num_inference_steps=25
211
+ min_guidance_scale = 1.0
212
+ max_guidance_scale = 3.0
213
+
214
+ area_ratio = 0.3
215
+ depth_scale_ = 5.2
216
+ center_margin = 10
217
+
218
+ height, width = 320, 576
219
+ num_frames = 14
220
+
221
+ intrinsics = np.array([[float(width), float(width), float(width) / 2, float(height) / 2]])
222
+ intrinsics = np.repeat(intrinsics, num_frames, axis=0) # [n_frame, 4]
223
+ fx = intrinsics[0, 0] / width
224
+ fy = intrinsics[0, 1] / height
225
+ cx = intrinsics[0, 2] / width
226
+ cy = intrinsics[0, 3] / height
227
+
228
+ down_scale = 8
229
+ H, W = height // down_scale, width // down_scale
230
+ K = np.array([[width / down_scale, 0, W / 2], [0, width / down_scale, H / 2], [0, 0, 1]])
231
+
232
+ seed = int(seed)
233
+
234
+ center_h_margin, center_w_margin = center_margin, center_margin
235
+ depth_center = np.mean(depth[height//2-center_h_margin:height//2+center_h_margin, width//2-center_w_margin:width//2+center_w_margin])
236
+
237
+ if rescale > 0:
238
+ depth_rescale = round(depth_scale_ * rescale / depth_center, 2)
239
+ else:
240
+ depth_rescale = 1.0
241
+
242
+ depth = depth * depth_rescale
243
+
244
+ depth_down = F.interpolate(torch.tensor(depth).unsqueeze(0).unsqueeze(0),
245
+ (H, W), mode='bilinear', align_corners=False).squeeze().numpy() # [H, W]
246
+
247
+ ## latent
248
+ generator = torch.Generator()
249
+ generator.manual_seed(seed)
250
+
251
+ latents_org = pipeline.prepare_latents(
252
+ 1,
253
+ 14,
254
+ 8,
255
+ height,
256
+ width,
257
+ pipeline.dtype,
258
+ device,
259
+ generator,
260
+ None,
261
+ )
262
+ latents_org = latents_org / pipeline.scheduler.init_noise_sigma
263
+
264
+ cur_plucker_embedding, _, _ = RT2Plucker(RTs, RTs.shape[0], (height, width), fx, fy, cx, cy) # 6, V, H, W
265
+ cur_plucker_embedding = cur_plucker_embedding.to(device)
266
+ cur_plucker_embedding = cur_plucker_embedding[None, ...] # b 6 f h w
267
+ cur_plucker_embedding = cur_plucker_embedding.permute(0, 2, 1, 3, 4) # b f 6 h w
268
+ cur_plucker_embedding = cur_plucker_embedding[:, :num_frames, ...]
269
+ cur_pose_features = pipeline.pose_encoder(cur_plucker_embedding)
270
+
271
+ # bg_mode = ["Fixed", "Reverse", "Free"]
272
+ if bg_mode == "Fixed":
273
+ fix_RTs = np.repeat(RTs[0][None, ...], num_frames, axis=0) # [n_frame, 4, 3]
274
+ fix_plucker_embedding, _, _ = RT2Plucker(fix_RTs, num_frames, (height, width), fx, fy, cx, cy) # 6, V, H, W
275
+ fix_plucker_embedding = fix_plucker_embedding.to(device)
276
+ fix_plucker_embedding = fix_plucker_embedding[None, ...] # b 6 f h w
277
+ fix_plucker_embedding = fix_plucker_embedding.permute(0, 2, 1, 3, 4) # b f 6 h w
278
+ fix_plucker_embedding = fix_plucker_embedding[:, :num_frames, ...]
279
+ fix_pose_features = pipeline.pose_encoder(fix_plucker_embedding)
280
+
281
+ elif bg_mode == "Reverse":
282
+ bg_plucker_embedding, _, _ = RT2Plucker(RTs[::-1], RTs.shape[0], (height, width), fx, fy, cx, cy) # 6, V, H, W
283
+ bg_plucker_embedding = bg_plucker_embedding.to(device)
284
+ bg_plucker_embedding = bg_plucker_embedding[None, ...] # b 6 f h w
285
+ bg_plucker_embedding = bg_plucker_embedding.permute(0, 2, 1, 3, 4) # b f 6 h w
286
+ bg_plucker_embedding = bg_plucker_embedding[:, :num_frames, ...]
287
+ fix_pose_features = pipeline.pose_encoder(bg_plucker_embedding)
288
+
289
+ else:
290
+ fix_pose_features = None
291
+
292
+ #### preparing mask
293
+
294
+ mask = Image.fromarray(mask)
295
+ mask = mask.resize((W, H))
296
+ mask = np.array(mask).astype(np.float32)
297
+ mask = np.expand_dims(mask, axis=-1)
298
+
299
+ # visulize mask
300
+ if DEBUG:
301
+ mask_sum_vis = mask[..., 0]
302
+ mask_sum_vis = (mask_sum_vis * 255.0).astype(np.uint8)
303
+ mask_sum_vis = Image.fromarray(mask_sum_vis)
304
+
305
+ mask_sum_vis.save(f'{cur_OUTPUT_PATH}/org_mask.png')
306
+
307
+ try:
308
+ warped_masks = Unprojected(mask, depth_down, RTs, H=H, W=W, K=K)
309
+
310
+ warped_masks.insert(0, mask)
311
+
312
+ except:
313
+ # mask to bbox
314
+ print(f'!!! Mask is too small to warp; mask to bbox')
315
+ mask = mask[:, :, 0]
316
+ coords = cv2.findNonZero(mask)
317
+ x, y, w, h = cv2.boundingRect(coords)
318
+ # mask[y:y+h, x:x+w] = 1.0
319
+
320
+ center_x, center_y = x + w // 2, y + h // 2
321
+ center_z = depth_down[center_y, center_x]
322
+
323
+ # RTs [n_frame, 3, 4] to [n_frame, 4, 4] , add [0, 0, 0, 1]
324
+ RTs = np.concatenate([RTs, np.array([[[0, 0, 0, 1]]] * num_frames)], axis=1)
325
+
326
+ # RTs: world to camera
327
+ P0 = np.array([center_x, center_y, 1])
328
+ Pc0 = np.linalg.inv(K) @ P0 * center_z
329
+ pw = np.linalg.inv(RTs[0]) @ np.array([Pc0[0], Pc0[1], center_z, 1]) # [4]
330
+
331
+ P = [np.array([center_x, center_y])]
332
+ for i in range(1, num_frames):
333
+ Pci = RTs[i] @ pw
334
+ Pi = K @ Pci[:3] / Pci[2]
335
+ P.append(Pi[:2])
336
+
337
+ warped_masks = [mask]
338
+ for i in range(1, num_frames):
339
+ shift_x = int(round(P[i][0] - P[0][0]))
340
+ shift_y = int(round(P[i][1] - P[0][1]))
341
+
342
+ cur_mask = roll_with_ignore_multidim(mask, [shift_y, shift_x])
343
+ warped_masks.append(cur_mask)
344
+
345
+
346
+ warped_masks = [v[..., None] for v in warped_masks]
347
+
348
+ warped_masks = np.stack(warped_masks, axis=0) # [f, h, w]
349
+ warped_masks = np.repeat(warped_masks, 3, axis=-1) # [f, h, w, 3]
350
+
351
+ mask_sum = np.sum(warped_masks, axis=0, keepdims=True) # [1, H, W, 3]
352
+ mask_sum[mask_sum > 1.0] = 1.0
353
+ mask_sum = mask_sum[0,:,:, 0]
354
+
355
+ if DEBUG:
356
+ ## visulize warp mask
357
+ warp_masks_vis = torch.tensor(warped_masks)
358
+ warp_masks_vis = (warp_masks_vis * 255.0).to(torch.uint8)
359
+ torchvision.io.write_video(f'{cur_OUTPUT_PATH}/warped_masks.mp4', warp_masks_vis, fps=10, video_codec='h264', options={'crf': '10'})
360
+
361
+ # visulize mask
362
+ mask_sum_vis = mask_sum
363
+ mask_sum_vis = (mask_sum_vis * 255.0).astype(np.uint8)
364
+ mask_sum_vis = Image.fromarray(mask_sum_vis)
365
+
366
+ mask_sum_vis.save(f'{cur_OUTPUT_PATH}/merged_mask.png')
367
+
368
+ if scale_wise_masks:
369
+ min_area = H * W * area_ratio # cal in downscale
370
+ non_zero_len = mask_sum.sum()
371
+
372
+ print(f'non_zero_len: {non_zero_len}, min_area: {min_area}')
373
+
374
+ if non_zero_len > min_area:
375
+ kernel_sizes = [1, 1, 1, 3]
376
+ elif non_zero_len > min_area * 0.5:
377
+ kernel_sizes = [3, 1, 1, 5]
378
+ else:
379
+ kernel_sizes = [5, 3, 3, 7]
380
+ else:
381
+ kernel_sizes = [1, 1, 1, 1]
382
+
383
+ mask = torch.from_numpy(mask_sum) # [h, w]
384
+ mask = mask[None, None, ...] # [1, 1, h, w]
385
+ mask = F.interpolate(mask, (height, width), mode='bilinear', align_corners=False) # [1, 1, H, W]
386
+ # mask = mask.repeat(1, num_frames, 1, 1) # [1, f, H, W]
387
+ mask = mask.to(pipeline.dtype).to(device)
388
+
389
+ ##### Mask End ######
390
+
391
+ ### Got blending pose features Start ###
392
+
393
+ pose_features = []
394
+ for i in range(0, len(cur_pose_features)):
395
+ kernel_size = kernel_sizes[i]
396
+ h, w = cur_pose_features[i].shape[-2:]
397
+
398
+ if fix_pose_features is None:
399
+ pose_features.append(torch.zeros_like(cur_pose_features[i]))
400
+ else:
401
+ pose_features.append(fix_pose_features[i])
402
+
403
+ cur_mask = F.interpolate(mask, (h, w), mode='bilinear', align_corners=False)
404
+ cur_mask = dilate_mask_pytorch(cur_mask, kernel_size=kernel_size) # [1, 1, H, W]
405
+ cur_mask = cur_mask.repeat(1, num_frames, 1, 1) # [1, f, H, W]
406
+
407
+ if DEBUG:
408
+ # visulize mask
409
+ mask_vis = cur_mask[0, 0].cpu().numpy() * 255.0
410
+ mask_vis = Image.fromarray(mask_vis.astype(np.uint8))
411
+ mask_vis.save(f'{cur_OUTPUT_PATH}/mask_k{kernel_size}_scale{i}.png')
412
+
413
+ cur_mask = cur_mask[None, ...] # [1, 1, f, H, W]
414
+ pose_features[-1] = cur_pose_features[i] * cur_mask + pose_features[-1] * (1 - cur_mask)
415
+
416
+ ### Got blending pose features End ###
417
+
418
+ ##### Warp Noise Start ######
419
+
420
+ if shared_wapring_latents:
421
+ noise = latents_org[0, 0].data.cpu().numpy().copy() #[14, 4, 40, 72]
422
+ noise = np.transpose(noise, (1, 2, 0)) # [40, 72, 4]
423
+
424
+ try:
425
+ warp_noise = Unprojected(noise, depth_down, RTs, H=H, W=W, K=K)
426
+ warp_noise.insert(0, noise)
427
+ except:
428
+ print(f'!!! Noise is too small to warp; mask to bbox')
429
+
430
+ warp_noise = [noise]
431
+ for i in range(1, num_frames):
432
+ shift_x = int(round(P[i][0] - P[0][0]))
433
+ shift_y = int(round(P[i][1] - P[0][1]))
434
+
435
+ cur_noise= roll_with_ignore_multidim(noise, [shift_y, shift_x])
436
+ warp_noise.append(cur_noise)
437
+
438
+ warp_noise = np.stack(warp_noise, axis=0) # [f, h, w, 4]
439
+
440
+ if DEBUG:
441
+ ## visulize warp noise
442
+ warp_noise_vis = torch.tensor(warp_noise)[..., :3] * torch.tensor(warped_masks)
443
+ warp_noise_vis = (warp_noise_vis - warp_noise_vis.min()) / (warp_noise_vis.max() - warp_noise_vis.min())
444
+ warp_noise_vis = (warp_noise_vis * 255.0).to(torch.uint8)
445
+
446
+ torchvision.io.write_video(f'{cur_OUTPUT_PATH}/warp_noise.mp4', warp_noise_vis, fps=10, video_codec='h264', options={'crf': '10'})
447
+
448
+
449
+ warp_latents = torch.tensor(warp_noise).permute(0, 3, 1, 2).to(latents_org.device).to(latents_org.dtype) # [frame, 4, H, W]
450
+ warp_latents = warp_latents.unsqueeze(0) # [1, frame, 4, H, W]
451
+
452
+ warped_masks = torch.tensor(warped_masks).permute(0, 3, 1, 2).unsqueeze(0) # [1, frame, 3, H, W]
453
+ mask_extend = torch.concat([warped_masks, warped_masks[:,:,0:1]], dim=2) # [1, frame, 4, H, W]
454
+ mask_extend = mask_extend.to(latents_org.device).to(latents_org.dtype)
455
+
456
+ warp_latents = warp_latents * mask_extend + latents_org * (1 - mask_extend)
457
+ warp_latents = warp_latents.permute(0, 2, 1, 3, 4)
458
+ random_noise = latents_org.clone().permute(0, 2, 1, 3, 4)
459
+
460
+ filter_shape = warp_latents.shape
461
+
462
+ freq_filter = get_freq_filter(
463
+ filter_shape,
464
+ device = device,
465
+ filter_type='butterworth',
466
+ n=4,
467
+ d_s=ds,
468
+ d_t=dt
469
+ )
470
+
471
+ warp_latents = freq_mix_3d(warp_latents, random_noise, freq_filter)
472
+ warp_latents = warp_latents.permute(0, 2, 1, 3, 4)
473
+
474
+ else:
475
+ warp_latents = latents_org.clone()
476
+
477
+ generator.manual_seed(42)
478
+
479
+ with torch.no_grad():
480
+ result = pipeline(
481
+ image=condition_image,
482
+ pose_embedding=cur_plucker_embedding,
483
+ height=height,
484
+ width=width,
485
+ num_frames=num_frames,
486
+ num_inference_steps=num_inference_steps,
487
+ min_guidance_scale=min_guidance_scale,
488
+ max_guidance_scale=max_guidance_scale,
489
+ do_image_process=True,
490
+ generator=generator,
491
+ output_type='pt',
492
+ pose_features= pose_features,
493
+ latents = warp_latents
494
+ ).frames[0].cpu() #[f, c, h, w]
495
+
496
+
497
+ result = rearrange(result, 'f c h w -> f h w c')
498
+ result = (result * 255.0).to(torch.uint8)
499
+
500
+ video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name
501
+ torchvision.io.write_video(video_path, result, fps=10, video_codec='h264', options={'crf': '8'})
502
+
503
+ return video_path
504
+
505
+ # -------------- UI definition --------------
506
+ with gr.Blocks() as demo:
507
+ # layout definition
508
+ gr.Markdown(title)
509
+ gr.Markdown(authors)
510
+ gr.Markdown(affiliation)
511
+ gr.Markdown(important_link)
512
+ gr.Markdown(description)
513
+
514
+
515
+ # with gr.Row():
516
+ # gr.Markdown("""# <center>Repositioning the Subject within Image </center>""")
517
+ mask = gr.State(value=None) # store mask
518
+ removal_mask = gr.State(value=None) # store removal mask
519
+ selected_points = gr.State([]) # store points
520
+ selected_points_text = gr.Textbox(label="Selected Points", visible=False)
521
+
522
+ original_image = gr.State(value=None) # store original input image
523
+ masked_original_image = gr.State(value=None) # store masked input image
524
+ mask_logits = gr.State(value=None) # store mask logits
525
+
526
+ depth = gr.State(value=None) # store depth
527
+ org_depth_image = gr.State(value=None) # store original depth image
528
+
529
+ camera_pose = gr.State(value=None) # store camera pose
530
+
531
+ with gr.Column():
532
+
533
+ outlines = """
534
+ <font size="5"><b>There are total 5 steps to complete the task.</b></font>
535
+ - Step 1: Input an image and Crop it to a suitable size;
536
+ - Step 2: Attain the subject mask;
537
+ - Step 3: Get depth and Draw Trajectory;
538
+ - Step 4: Get camera pose from trajectory or customize it;
539
+ - Step 5: Generate the final video.
540
+ """
541
+
542
+ gr.Markdown(outlines)
543
+
544
+
545
+ with gr.Row():
546
+ with gr.Column():
547
+ # Step 1: Input Image
548
+ step1_dec = """
549
+ <font size="4"><b>Step 1: Input Image</b></font>
550
+ - Select the region using a <mark>bounding box</mark>, aiming for a ratio close to </mark>320:576</mark> (height:width).
551
+ - All provided images in `Examples` are in 320 x 576 resolution. Simply press `Process` to proceed.
552
+ """
553
+ step1 = gr.Markdown(step1_dec)
554
+ raw_input = ImagePrompter(type="pil", label="Raw Image", show_label=True, interactive=True)
555
+ # left_up_point = gr.Textbox(value = "-1 -1", label="Left Up Point", interactive=True)
556
+ process_button = gr.Button("Process")
557
+
558
+ with gr.Column():
559
+ # Step 2: Get Subject Mask
560
+ step2_dec = """
561
+ <font size="4"><b>Step 2: Get Subject Mask</b></font>
562
+ - Use the <mark>bounding boxes</mark> or <mark>paints</mark> to select the subject.
563
+ - Press `Segment Subject` to get the mask. <mark>Can be refined iteratively by updating points<mark>.
564
+ """
565
+ step2 = gr.Markdown(step2_dec)
566
+ canvas = ImagePrompter(type="pil", label="Input Image", show_label=True, interactive=True) # for mask painting
567
+
568
+ select_button = gr.Button("Segment Subject")
569
+
570
+ with gr.Row():
571
+ with gr.Column():
572
+ mask_dec = """
573
+ <font size="4"><b>Mask Result</b></font>
574
+ - Just for visualization purpose. No need to interact.
575
+ """
576
+ mask_vis = gr.Markdown(mask_dec)
577
+ mask_output = gr.Image(type="pil", label="Mask", show_label=True, interactive=False)
578
+ with gr.Column():
579
+ # Step 3: Get Depth and Draw Trajectory
580
+ step3_dec = """
581
+ <font size="4"><b>Step 3: Get Depth and Draw Trajectory</b></font>
582
+ - Press `Get Depth` to get the depth image.
583
+ - Draw the trajectory by selecting points on the depth image. <mark>No more than 14 points</mark>.
584
+ - Press `Undo point` to remove all points.
585
+ """
586
+ step3 = gr.Markdown(step3_dec)
587
+ depth_image = gr.Image(type="pil", label="Depth Image", show_label=True, interactive=False)
588
+ with gr.Row():
589
+ depth_button = gr.Button("Get Depth")
590
+ undo_button = gr.Button("Undo point")
591
+
592
+ with gr.Row():
593
+ with gr.Column():
594
+ # Step 4: Trajectory to Camera Pose or Get Camera Pose
595
+ step4_dec = """
596
+ <font size="4"><b>Step 4: Get camera pose from trajectory or customize it</b></font>
597
+ - Option 1: Transform the 2D trajectory to camera poses with depth. <mark>`Rescale` is used for depth alignment. Larger value can speed up the object motion.</mark>
598
+ - Option 2: Rotate the camera with a specific `Angle`.
599
+ - Option 3: Rotate the camera clockwise or counterclockwise with a specific `Angle`.
600
+ - Option 4: Translate the camera with `Tx` (<mark>Pan Left/Right</mark>), `Ty` (<mark>Pan Up/Down</mark>), `Tz` (<mark>Zoom In/Out</mark>) and `Speed`.
601
+ """
602
+ step4 = gr.Markdown(step4_dec)
603
+ camera_pose_vis = gr.Plot(None, label='Camera Pose')
604
+ with gr.Row():
605
+ with gr.Column():
606
+ speed = gr.Slider(minimum=0.1, maximum=10, step=0.1, value=1.0, label="Speed", interactive=True)
607
+ rescale = gr.Slider(minimum=0.0, maximum=10, step=0.1, value=1.0, label="Rescale", interactive=True)
608
+ # traj2pose_button = gr.Button("Option1: Trajectory to Camera Pose")
609
+
610
+ angle = gr.Slider(minimum=-360, maximum=360, step=1, value=60, label="Angle", interactive=True)
611
+ # rotation_button = gr.Button("Option2: Rotate")
612
+ # clockwise_button = gr.Button("Option3: Clockwise")
613
+ with gr.Column():
614
+
615
+ Tx = gr.Slider(minimum=-1, maximum=1, step=1, value=0, label="Tx", interactive=True)
616
+ Ty = gr.Slider(minimum=-1, maximum=1, step=1, value=0, label="Ty", interactive=True)
617
+ Tz = gr.Slider(minimum=-1, maximum=1, step=1, value=0, label="Tz", interactive=True)
618
+ # translation_button = gr.Button("Option4: Translate")
619
+ with gr.Row():
620
+ camera_option = gr.Radio(choices = CAMERA_MODE, label='Camera Options', value=CAMERA_MODE[0], interactive=True)
621
+ with gr.Row():
622
+ get_camera_pose_button = gr.Button("Get Camera Pose")
623
+
624
+ with gr.Column():
625
+ # Step 5: Get the final generated video
626
+ step5_dec = """
627
+ <font size="4"><b>Step 5: Get the final generated video</b></font>
628
+ - 3 modes for background: <mark>Fixed</mark>, <mark>Reverse</mark>, <mark>Free</mark>.
629
+ - Enable <mark>Scale-wise Masks</mark> for better object control.
630
+ - Option to enable <mark>Shared Warping Latents</mark> and set <mark>stop frequency</mark> for spatial (`ds`) and temporal (`dt`) dimensions. Larger stop frequency will lead to artifacts.
631
+ """
632
+ step5 = gr.Markdown(step5_dec)
633
+ generated_video = gr.Video(None, label='Generated Video')
634
+
635
+ with gr.Row():
636
+ seed = gr.Textbox(value = "42", label="Seed", interactive=True)
637
+ # num_inference_steps = gr.Slider(minimum=1, maximum=100, step=1, value=25, label="Number of Inference Steps", interactive=True)
638
+ bg_mode = gr.Radio(choices = ["Fixed", "Reverse", "Free"], label="Background Mode", value="Fixed", interactive=True)
639
+ # swl_mode = gr.Radio(choices = ["Enable SWL", "Disable SWL"], label="Shared Warping Latent", value="Disable SWL", interactive=True)
640
+ scale_wise_masks = gr.Checkbox(label="Enable Scale-wise Masks", interactive=True, value=True)
641
+ with gr.Row():
642
+ with gr.Column():
643
+ shared_wapring_latents = gr.Checkbox(label="Enable Shared Warping Latents", interactive=True)
644
+ with gr.Column():
645
+ ds = gr.Slider(minimum=0.0, maximum=1, step=0.1, value=0.5, label="ds", interactive=True)
646
+ dt = gr.Slider(minimum=0.0, maximum=1, step=0.1, value=0.5, label="dt", interactive=True)
647
+
648
+ generated_button = gr.Button("Generate")
649
+
650
+
651
+
652
+ # # event definition
653
+ process_button.click(
654
+ fn = process_image,
655
+ inputs = [raw_input],
656
+ outputs = [original_image, canvas]
657
+ )
658
+
659
+ select_button.click(
660
+ segment,
661
+ [canvas, original_image, mask_logits],
662
+ [mask, mask_output, masked_original_image, mask_logits]
663
+ )
664
+
665
+ depth_button.click(
666
+ get_depth,
667
+ [original_image, selected_points],
668
+ [depth, depth_image, org_depth_image]
669
+ )
670
+
671
+ depth_image.select(
672
+ get_points,
673
+ [depth_image, selected_points],
674
+ [depth_image, selected_points],
675
+ )
676
+ undo_button.click(
677
+ undo_points,
678
+ [org_depth_image],
679
+ [depth_image, selected_points]
680
+ )
681
+
682
+ get_camera_pose_button.click(
683
+ get_camera_pose(CAMERA_MODE),
684
+ [camera_option, selected_points, depth, mask, rescale, angle, Tx, Ty, Tz, speed],
685
+ [camera_pose, camera_pose_vis, rescale]
686
+ )
687
+
688
+ generated_button.click(
689
+ run_objctrl_2_5d,
690
+ [
691
+ original_image,
692
+ mask,
693
+ depth,
694
+ camera_pose,
695
+ bg_mode,
696
+ shared_wapring_latents,
697
+ scale_wise_masks,
698
+ rescale,
699
+ seed,
700
+ ds,
701
+ dt,
702
+ # num_inference_steps
703
+ ],
704
+ [generated_video],
705
+ )
706
+
707
+ gr.Examples(
708
+ examples=examples,
709
+ inputs=[
710
+ raw_input,
711
+ rescale,
712
+ speed,
713
+ angle,
714
+ Tx,
715
+ Ty,
716
+ Tz,
717
+ camera_option,
718
+ bg_mode,
719
+ shared_wapring_latents,
720
+ scale_wise_masks,
721
+ ds,
722
+ dt,
723
+ seed,
724
+ selected_points_text # selected_points
725
+ ],
726
+ outputs=[generated_video],
727
+ examples_per_page=10
728
+ )
729
+
730
+ selected_points_text.change(
731
+ sync_points,
732
+ inputs=[selected_points_text],
733
+ outputs=[selected_points]
734
+ )
735
+
736
+
737
+ gr.Markdown(article)
738
+
739
+
740
+ demo.queue().launch(share=True)
assets/examples/00010/generated_video.mp4 ADDED
Binary file (483 kB). View file
 
assets/examples/00010/mask.png ADDED
assets/examples/00010/raw_image.png ADDED
assets/examples/00010/seg_image.png ADDED
assets/examples/00011/generated_video.mp4 ADDED
Binary file (557 kB). View file
 
assets/examples/00011/mask.png ADDED
assets/examples/00011/raw_image.png ADDED
assets/examples/00011/seg_image.png ADDED
assets/examples/00012/generated_video.mp4 ADDED
Binary file (348 kB). View file
 
assets/examples/00012/mask.png ADDED
assets/examples/00012/raw_image.png ADDED
assets/examples/00012/seg_image.png ADDED
assets/examples/00013/generated_video.mp4 ADDED
Binary file (298 kB). View file
 
assets/examples/00013/mask.png ADDED
assets/examples/00013/raw_image.png ADDED
assets/examples/00013/seg_image.png ADDED
assets/examples/00014/generated_video.mp4 ADDED
Binary file (417 kB). View file
 
assets/examples/00014/mask.png ADDED
assets/examples/00014/raw_image.png ADDED
assets/examples/00014/seg_image.png ADDED
assets/examples/00015/generated_video.mp4 ADDED
Binary file (854 kB). View file
 
assets/examples/00015/mask.png ADDED
assets/examples/00015/raw_image.png ADDED
assets/examples/00015/seg_image.png ADDED
assets/examples/00016/generated_video.mp4 ADDED
Binary file (848 kB). View file
 
assets/examples/00016/mask.png ADDED
assets/examples/00016/raw_image.png ADDED
assets/examples/00016/seg_image.png ADDED
assets/examples/00017/generated_video.mp4 ADDED
Binary file (796 kB). View file
 
assets/examples/00017/mask.png ADDED
assets/examples/00017/raw_image.png ADDED
assets/examples/00017/seg_image.png ADDED
assets/examples/00018/generated_video.mp4 ADDED
Binary file (410 kB). View file
 
assets/examples/00018/mask.png ADDED
assets/examples/00018/raw_image.png ADDED
assets/examples/00018/seg_image.png ADDED
assets/examples/00019/generated_video.mp4 ADDED
Binary file (283 kB). View file
 
assets/examples/00019/mask.png ADDED
assets/examples/00019/raw_image.png ADDED
assets/examples/00019/seg_image.png ADDED
assets/examples/00020/generated_video.mp4 ADDED
Binary file (282 kB). View file
 
assets/examples/00020/mask.png ADDED
assets/examples/00020/raw_image.png ADDED
assets/examples/00020/seg_image.png ADDED
assets/examples/00021/generated_video.mp4 ADDED
Binary file (263 kB). View file
 
assets/examples/00021/mask.png ADDED
assets/examples/00021/raw_image.png ADDED
assets/examples/00021/seg_image.png ADDED