JianyuanWang commited on
Commit
2529861
·
1 Parent(s): febf487
Files changed (3) hide show
  1. app.py +129 -88
  2. demo_hf.py +15 -12
  3. gradio_util.py +127 -129
app.py CHANGED
@@ -3,7 +3,6 @@ import cv2
3
  import torch
4
  import numpy as np
5
  import gradio as gr
6
- import spaces
7
  import sys
8
  import os
9
  import socket
@@ -11,42 +10,64 @@ import webbrowser
11
  sys.path.append('vggt/')
12
  import shutil
13
  from datetime import datetime
14
- from demo_hf import demo_fn
15
  from omegaconf import DictConfig, OmegaConf
16
  import glob
17
  import gc
18
  import time
19
  from viser_fn import viser_wrapper
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
 
22
- def get_free_port():
23
- """Get a free port using socket."""
24
- # return 80
25
- # return 8080
26
- # return 10088 # for debugging
27
- # return 7860
28
- # return 7888
29
- with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
30
- s.bind(('', 0))
31
- port = s.getsockname()[1]
32
- return port
33
 
 
 
 
 
 
 
34
 
 
 
 
 
 
35
 
36
 
 
 
37
  @spaces.GPU(duration=240)
38
  def vggt_demo(
39
  input_video,
40
  input_image,
 
 
 
41
  ):
42
  start_time = time.time()
43
  gc.collect()
44
  torch.cuda.empty_cache()
45
 
46
-
47
- debug = False
48
 
49
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
50
  target_dir = f"input_images_{timestamp}"
51
  if os.path.exists(target_dir):
52
  shutil.rmtree(target_dir)
@@ -65,9 +86,6 @@ def vggt_demo(
65
 
66
  if input_image is not None:
67
  input_image = sorted(input_image)
68
- # recon_num = len(input_image)
69
-
70
- # Copy files to the new directory
71
  for file_name in input_image:
72
  shutil.copy(file_name, target_dir_images)
73
  elif input_video is not None:
@@ -90,26 +108,37 @@ def vggt_demo(
90
 
91
  if count % frame_interval == 0:
92
  cv2.imwrite(target_dir_images+"/"+f"{video_frame_num:06}.png", frame)
93
- video_frame_num+=1
94
-
95
- # recon_num = video_frame_num
96
- # if recon_num<3:
97
- # return None, "Please input at least three frames"
98
  else:
99
- return None, "Uploading not finished or Incorrect input format"
100
-
101
 
 
 
 
 
 
 
 
102
  print(f"Files have been copied to {target_dir_images}")
103
  cfg.SCENE_DIR = target_dir
104
 
105
- predictions = demo_fn(cfg)
106
-
107
- # Get a free port for viser
108
- viser_port = get_free_port()
 
 
 
109
 
110
- # Start viser visualization in a separate thread/process
111
- viser_wrapper(predictions, port=viser_port)
 
 
 
112
 
 
 
 
113
  del predictions
114
  gc.collect()
115
  torch.cuda.empty_cache()
@@ -120,10 +149,31 @@ def vggt_demo(
120
  execution_time = end_time - start_time
121
  print(f"Execution time: {execution_time} seconds")
122
 
123
- # Return None for the 3D model (since we're using viser) and the viser URL
124
  # viser_url = f"Viser visualization is ready at: http://localhost:{viser_port}"
125
  # print(viser_url) # Debug print
126
- return None, viser_port
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
 
129
 
@@ -177,8 +227,17 @@ with gr.Blocks() as demo:
177
  gr.Markdown("""
178
  # 🏛️ VGGT: Visual Geometry Grounded Transformer
179
 
180
- <div style="font-size: 16px; line-height: 1.2;">
181
- Alpha version (testing).
 
 
 
 
 
 
 
 
 
182
  </div>
183
  """)
184
 
@@ -186,87 +245,69 @@ with gr.Blocks() as demo:
186
  with gr.Column(scale=1):
187
  input_video = gr.Video(label="Upload Video", interactive=True)
188
  input_images = gr.File(file_count="multiple", label="Upload Images", interactive=True)
189
-
190
-
191
  with gr.Column(scale=3):
192
- viser_output = gr.HTML(
193
- label="Viser Visualization",
194
- value='''<div style="height: 520px; border: 1px solid #e0e0e0;
195
- border-radius: 4px; padding: 16px;
196
- display: flex; align-items: center;
197
- justify-content: center">
198
- 3D Reconstruction (Point Cloud and Camera Poses; Zoom in to see details)
199
- </div>'''
200
- )
201
-
 
202
  log_output = gr.Textbox(label="Log")
 
 
203
 
204
  with gr.Row():
205
  submit_btn = gr.Button("Reconstruct", scale=1)
206
- clear_btn = gr.ClearButton([input_video, input_images, viser_output, log_output], scale=1) #Modified viser_output
 
207
 
208
 
209
 
210
 
211
  examples = [
212
- [flower_video, flower_images],
213
- [kitchen_video, kitchen_images],
 
 
214
  # [person_video, person_images],
215
  # [statue_video, statue_images],
216
  # [drums_video, drums_images],
217
- [counter_video, counter_images],
218
- [fern_video, fern_images],
219
- [horns_video, horns_images],
220
  # [apple_video, apple_images],
221
  # [bonsai_video, bonsai_images],
222
  ]
223
 
224
- def process_example(video, images):
225
- """Wrapper function to ensure outputs are properly captured"""
226
- model_output, log = vggt_demo(video, images)
227
-
228
- # viser_wrapper(predictions, port=log)
229
- # Get the hostname - use the actual hostname or IP where the server is running
230
- # hostname = socket.gethostname()
231
-
232
- # Extract port from log
233
- port = log
234
-
235
- # Create the viser URL using the hostname
236
- # viser_url = f"http://{hostname}:{port}"
237
-
238
- viser_url = f"http://localhost:{log}"
239
- print(f"Viser URL: {viser_url}")
240
-
241
- # Create the iframe HTML code. Set width and height appropriately.
242
- iframe_code = f'<iframe src="{viser_url}" width="100%" height="520px"></iframe>'
243
-
244
-
245
- # Return the iframe code to update the gr.HTML component
246
- return iframe_code, f"Visualization running at {viser_url}"
247
-
248
-
249
- # TODO: move the selection of port outside of the demo function
250
- # so that we can cache examples
251
-
252
  gr.Examples(examples=examples,
253
- inputs=[input_video, input_images],
254
- outputs=[viser_output, log_output], # Output to viser_output
255
- fn=process_example, # Use our wrapper function
256
  cache_examples=False,
257
  examples_per_page=50,
258
  )
259
 
260
  submit_btn.click(
261
- process_example, # Use the same wrapper function
262
- [input_video, input_images],
263
- [viser_output, log_output], # Output to viser_output
264
  # concurrency_limit=1
265
  )
 
 
 
 
 
 
266
 
267
  # demo.launch(debug=True, share=True)
268
  # demo.launch(server_name="0.0.0.0", server_port=8082, debug=True, share=False)
269
  # demo.queue(max_size=20).launch(show_error=True, share=True)
270
  demo.queue(max_size=20).launch(show_error=True) #, share=True, server_port=7888, server_name="0.0.0.0")
 
271
  # demo.queue(max_size=20, concurrency_count=1).launch(debug=True, share=True)
272
  ########################################################################################################################
 
3
  import torch
4
  import numpy as np
5
  import gradio as gr
 
6
  import sys
7
  import os
8
  import socket
 
10
  sys.path.append('vggt/')
11
  import shutil
12
  from datetime import datetime
13
+ from demo_hf import demo_fn #, initialize_model
14
  from omegaconf import DictConfig, OmegaConf
15
  import glob
16
  import gc
17
  import time
18
  from viser_fn import viser_wrapper
19
+ from gradio_util import demo_predictions_to_glb
20
+ from hydra.utils import instantiate
21
+ import spaces
22
+
23
+
24
+
25
+
26
+ # def get_free_port():
27
+ # """Get a free port using socket."""
28
+ # # return 80
29
+ # # return 8080
30
+ # # return 10088 # for debugging
31
+ # # return 7860
32
+ # # return 7888
33
+ # with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
34
+ # s.bind(('', 0))
35
+ # port = s.getsockname()[1]
36
+ # return port
37
 
38
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
+ cfg_file = "config/base.yaml"
41
+ cfg = OmegaConf.load(cfg_file)
42
+ vggt_model = instantiate(cfg, _recursive_=False)
43
+ _VGGT_URL = "https://huggingface.co/facebook/vggt_alpha/resolve/main/vggt_alpha_v0.pt"
44
+ # Reload vggt_model
45
+ pretrain_model = torch.hub.load_state_dict_from_url(_VGGT_URL)
46
 
47
+ if "vggt_model" in pretrain_model:
48
+ model_dict = pretrain_model["vggt_model"]
49
+ vggt_model.load_state_dict(model_dict, strict=False)
50
+ else:
51
+ vggt_model.load_state_dict(pretrain_model, strict=True)
52
 
53
 
54
+ # @torch.inference_mode()
55
+
56
  @spaces.GPU(duration=240)
57
  def vggt_demo(
58
  input_video,
59
  input_image,
60
+ conf_thres=3.0,
61
+ frame_filter="all",
62
+ mask_black_bg=False,
63
  ):
64
  start_time = time.time()
65
  gc.collect()
66
  torch.cuda.empty_cache()
67
 
68
+ # timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
69
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
70
 
 
71
  target_dir = f"input_images_{timestamp}"
72
  if os.path.exists(target_dir):
73
  shutil.rmtree(target_dir)
 
86
 
87
  if input_image is not None:
88
  input_image = sorted(input_image)
 
 
 
89
  for file_name in input_image:
90
  shutil.copy(file_name, target_dir_images)
91
  elif input_video is not None:
 
108
 
109
  if count % frame_interval == 0:
110
  cv2.imwrite(target_dir_images+"/"+f"{video_frame_num:06}.png", frame)
111
+ video_frame_num+=1
 
 
 
 
112
  else:
113
+ return None, "Uploading not finished or Incorrect input format", None, None
 
114
 
115
+ all_files = sorted(os.listdir(target_dir_images))
116
+
117
+ all_files = [f"{i}: {filename}" for i, filename in enumerate(all_files)]
118
+
119
+ # Update frame_filter choices
120
+ frame_filter_choices = ["All"] + all_files
121
+
122
  print(f"Files have been copied to {target_dir_images}")
123
  cfg.SCENE_DIR = target_dir
124
 
125
+ print("Running demo_fn")
126
+ with torch.no_grad():
127
+ predictions = demo_fn(cfg, vggt_model)
128
+ predictions["pred_extrinsic_list"] = None
129
+ print("Saving predictions")
130
+
131
+ prediction_save_path = f"{target_dir}/predictions.npz"
132
 
133
+ np.savez(prediction_save_path, **predictions)
134
+
135
+
136
+ glbfile = target_dir + f"/glbscene_{conf_thres}_{frame_filter.replace('.', '_')}_mask{mask_black_bg}.glb"
137
+
138
 
139
+ glbscene = demo_predictions_to_glb(predictions, conf_thres=conf_thres, filter_by_frames=frame_filter, mask_black_bg=mask_black_bg)
140
+ glbscene.export(file_obj=glbfile)
141
+
142
  del predictions
143
  gc.collect()
144
  torch.cuda.empty_cache()
 
149
  execution_time = end_time - start_time
150
  print(f"Execution time: {execution_time} seconds")
151
 
152
+ # Return None for the 3D vggt_model (since we're using viser) and the viser URL
153
  # viser_url = f"Viser visualization is ready at: http://localhost:{viser_port}"
154
  # print(viser_url) # Debug print
155
+ log = "Success. Waiting for visualization."
156
+ return glbfile, log, target_dir, gr.Dropdown(choices=frame_filter_choices, value=frame_filter, interactive=True)
157
+
158
+
159
+
160
+ def update_visualization(target_dir, conf_thres, frame_filter, mask_black_bg):
161
+
162
+ loaded = np.load(f"{target_dir}/predictions.npz", allow_pickle=True)
163
+ # predictions = np.load(f"{target_dir}/predictions.npz", allow_pickle=True)
164
+ # predictions["arr_0"]
165
+ # for key in predictions.files: print(key)
166
+ predictions = {key: loaded[key] for key in loaded.keys()}
167
+
168
+ glbfile = target_dir + f"/glbscene_{conf_thres}_{frame_filter.replace('.', '_')}_mask{mask_black_bg}.glb"
169
+
170
+ if not os.path.exists(glbfile):
171
+ glbscene = demo_predictions_to_glb(predictions, conf_thres=conf_thres, filter_by_frames=frame_filter, mask_black_bg=mask_black_bg)
172
+ glbscene.export(file_obj=glbfile)
173
+ return glbfile, "Updating Visualization", target_dir
174
+
175
+
176
+
177
 
178
 
179
 
 
227
  gr.Markdown("""
228
  # 🏛️ VGGT: Visual Geometry Grounded Transformer
229
 
230
+ <div style="font-size: 16px; line-height: 1.5;">
231
+ <p><strong>Alpha version</strong> (under active development)</p>
232
+
233
+ <p>Upload a video or images to create a 3D reconstruction. Once your media appears in the left panel, click the "Reconstruct" button to begin processing.</p>
234
+
235
+ <h3>Usage Tips:</h3>
236
+ <ol>
237
+ <li>After reconstruction, you can fine-tune the visualization by adjusting the confidence threshold or selecting specific frames to display, then click "Update Visualization".</li>
238
+ <li>Performance note: While the model itself processes quickly (~0.2 seconds), initial setup and visualization may take longer. First-time use requires downloading model weights, and rendering dense point clouds can be resource-intensive.</li>
239
+ <li>Known limitation: The model currently exhibits inconsistent behavior with videos centered around human subjects. This issue is being addressed in upcoming updates.</li>
240
+ </ol>
241
  </div>
242
  """)
243
 
 
245
  with gr.Column(scale=1):
246
  input_video = gr.Video(label="Upload Video", interactive=True)
247
  input_images = gr.File(file_count="multiple", label="Upload Images", interactive=True)
248
+
 
249
  with gr.Column(scale=3):
250
+ with gr.Column():
251
+ gr.Markdown("**3D Reconstruction (Point Cloud and Camera Poses; Zoom in to see details)**")
252
+ reconstruction_output = gr.Model3D(height=520, zoom_speed=0.5, pan_speed=0.5)
253
+ # reconstruction_output = gr.Model3D(label="3D Reconstruction (Point Cloud and Camera Poses; Zoom in to see details)", height=520, zoom_speed=0.5, pan_speed=0.5)
254
+
255
+ # Move these controls to a new row above the log output
256
+ with gr.Row():
257
+ conf_thres = gr.Slider(minimum=0.1, maximum=20.0, value=3.0, step=0.1, label="Conf Thres")
258
+ frame_filter = gr.Dropdown(choices=["All"], value="All", label="Show Points from Frame")
259
+ mask_black_bg = gr.Checkbox(label="Filter Black Background", value=False)
260
+
261
  log_output = gr.Textbox(label="Log")
262
+ # Add a hidden textbox for target_dir
263
+ target_dir_output = gr.Textbox(label="Target Dir", visible=False)
264
 
265
  with gr.Row():
266
  submit_btn = gr.Button("Reconstruct", scale=1)
267
+ revisual_btn = gr.Button("Update Visualization", scale=1)
268
+ clear_btn = gr.ClearButton([input_video, input_images, reconstruction_output, log_output, target_dir_output], scale=1) #Modified reconstruction_output
269
 
270
 
271
 
272
 
273
  examples = [
274
+ [counter_video, counter_images, 1.5, "All", False],
275
+ [flower_video, flower_images, 1.5, "All", False],
276
+ [kitchen_video, kitchen_images, 3, "All", False],
277
+ [fern_video, fern_images, 1.5, "All", False],
278
  # [person_video, person_images],
279
  # [statue_video, statue_images],
280
  # [drums_video, drums_images],
281
+ # [horns_video, horns_images, 1.5, "All", False],
 
 
282
  # [apple_video, apple_images],
283
  # [bonsai_video, bonsai_images],
284
  ]
285
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
286
  gr.Examples(examples=examples,
287
+ inputs=[input_video, input_images, conf_thres, frame_filter, mask_black_bg],
288
+ outputs=[reconstruction_output, log_output, target_dir_output, frame_filter], # Added frame_filter
289
+ fn=vggt_demo, # Use our wrapper function
290
  cache_examples=False,
291
  examples_per_page=50,
292
  )
293
 
294
  submit_btn.click(
295
+ vggt_demo, # Use the same wrapper function
296
+ [input_video, input_images, conf_thres, frame_filter, mask_black_bg],
297
+ [reconstruction_output, log_output, target_dir_output, frame_filter], # Added frame_filter to outputs
298
  # concurrency_limit=1
299
  )
300
+
301
+ revisual_btn.click(
302
+ update_visualization,
303
+ [target_dir_output, conf_thres, frame_filter, mask_black_bg],
304
+ [reconstruction_output, log_output, target_dir_output],
305
+ )
306
 
307
  # demo.launch(debug=True, share=True)
308
  # demo.launch(server_name="0.0.0.0", server_port=8082, debug=True, share=False)
309
  # demo.queue(max_size=20).launch(show_error=True, share=True)
310
  demo.queue(max_size=20).launch(show_error=True) #, share=True, server_port=7888, server_name="0.0.0.0")
311
+ # share=True
312
  # demo.queue(max_size=20, concurrency_count=1).launch(debug=True, share=True)
313
  ########################################################################################################################
demo_hf.py CHANGED
@@ -11,26 +11,29 @@ from viser_fn import viser_wrapper
11
 
12
 
13
  # @hydra.main(config_path="config", config_name="base")
14
- def demo_fn(cfg: DictConfig) -> None:
15
- print(cfg)
16
- model = instantiate(cfg, _recursive_=False)
17
 
18
  if not torch.cuda.is_available():
19
  raise ValueError("CUDA is not available. Check your environment.")
20
 
21
- device = "cuda"
 
 
 
 
22
  model = model.to(device)
23
 
24
- _VGGT_URL = "https://huggingface.co/facebook/vggt_alpha/resolve/main/vggt_alpha_v0.pt"
25
 
26
- # Reload model
27
- pretrain_model = torch.hub.load_state_dict_from_url(_VGGT_URL)
28
 
29
- if "model" in pretrain_model:
30
- model_dict = pretrain_model["model"]
31
- model.load_state_dict(model_dict, strict=False)
32
- else:
33
- model.load_state_dict(pretrain_model, strict=True)
34
 
35
 
36
  # batch = torch.load("/fsx-repligen/jianyuan/cvpr2025_ckpts/batch.pth")
 
11
 
12
 
13
  # @hydra.main(config_path="config", config_name="base")
14
+ def demo_fn(cfg: DictConfig, model) -> None:
15
+ print(cfg.SCENE_DIR)
 
16
 
17
  if not torch.cuda.is_available():
18
  raise ValueError("CUDA is not available. Check your environment.")
19
 
20
+ if torch.cuda.is_available():
21
+ device = "cuda"
22
+ else:
23
+ device = "cpu"
24
+
25
  model = model.to(device)
26
 
27
+ # _VGGT_URL = "https://huggingface.co/facebook/vggt_alpha/resolve/main/vggt_alpha_v0.pt"
28
 
29
+ # # Reload model
30
+ # pretrain_model = torch.hub.load_state_dict_from_url(_VGGT_URL)
31
 
32
+ # if "model" in pretrain_model:
33
+ # model_dict = pretrain_model["model"]
34
+ # model.load_state_dict(model_dict, strict=False)
35
+ # else:
36
+ # model.load_state_dict(pretrain_model, strict=True)
37
 
38
 
39
  # batch = torch.load("/fsx-repligen/jianyuan/cvpr2025_ckpts/batch.pth")
gradio_util.py CHANGED
@@ -1,56 +1,22 @@
1
- try:
2
- import os
3
-
4
- import trimesh
5
- import open3d as o3d
6
-
7
- import gradio as gr
8
- import numpy as np
9
- import matplotlib
10
- from scipy.spatial.transform import Rotation
11
-
12
- print("Successfully imported the packages for Gradio visualization")
13
- except:
14
- print(
15
- f"Failed to import packages for Gradio visualization. Please disable gradio visualization"
16
- )
17
-
18
-
19
- def visualize_by_gradio(glbfile):
20
- """
21
- Set up and launch a Gradio interface to visualize a GLB file.
22
 
23
- Args:
24
- glbfile (str): Path to the GLB file to be visualized.
25
- """
26
-
27
- def load_glb_file(glb_path):
28
- # Check if the file exists and return the path or error message
29
- if os.path.exists(glb_path):
30
- return glb_path, "3D Model Loaded Successfully"
31
- else:
32
- return None, "File not found"
33
 
34
- # Load the GLB file initially to check if it's valid
35
- initial_model, log_message = load_glb_file(glbfile)
 
 
36
 
37
- # Create the Gradio interface
38
- with gr.Blocks() as demo:
39
- gr.Markdown("# GLB File Viewer")
40
-
41
- # 3D Model viewer component
42
- model_viewer = gr.Model3D(
43
- label="3D Model Viewer", height=600, value=initial_model
44
- )
45
 
46
- # Textbox for log output
47
- log_output = gr.Textbox(label="Log", lines=2, value=log_message)
48
 
49
- # Launch the Gradio interface
50
- demo.launch(share=True)
51
 
52
 
53
- def vggsfm_predictions_to_glb(predictions) -> trimesh.Scene:
54
  """
55
  Converts VGG SFM predictions to a 3D scene represented as a GLB.
56
 
@@ -61,27 +27,51 @@ def vggsfm_predictions_to_glb(predictions) -> trimesh.Scene:
61
  trimesh.Scene: A 3D scene object.
62
  """
63
  # Convert predictions to numpy arrays
64
- vertices_3d = predictions["points3D"].cpu().numpy()
65
- colors_rgb = (predictions["points3D_rgb"].cpu().numpy() * 255).astype(
66
- np.uint8
67
- )
68
-
69
-
70
- if True:
71
- pcd = o3d.geometry.PointCloud()
72
- pcd.points = o3d.utility.Vector3dVector(vertices_3d)
73
- pcd.colors = o3d.utility.Vector3dVector(colors_rgb)
74
-
75
- cl, ind = pcd.remove_statistical_outlier(nb_neighbors=20, std_ratio=1.0)
76
- filtered_pcd = pcd.select_by_index(ind)
77
-
78
- print(f"Filter out {len(vertices_3d) - len(filtered_pcd.points)} 3D points")
79
- vertices_3d = np.asarray(filtered_pcd.points)
80
- colors_rgb = np.asarray(filtered_pcd.colors).astype(np.uint8)
81
-
82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
- camera_matrices = predictions["extrinsics_opencv"].cpu().numpy()
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
  # Calculate the 5th and 95th percentiles along each axis
87
  lower_percentile = np.percentile(vertices_3d, 5, axis=0)
@@ -122,39 +112,10 @@ def vggsfm_predictions_to_glb(predictions) -> trimesh.Scene:
122
  # Align scene to the observation of the first camera
123
  scene_3d = apply_scene_alignment(scene_3d, extrinsics_matrices)
124
 
 
125
  return scene_3d
126
 
127
 
128
- def apply_scene_alignment(
129
- scene_3d: trimesh.Scene, extrinsics_matrices: np.ndarray
130
- ) -> trimesh.Scene:
131
- """
132
- Aligns the 3D scene based on the extrinsics of the first camera.
133
-
134
- Args:
135
- scene_3d (trimesh.Scene): The 3D scene to be aligned.
136
- extrinsics_matrices (np.ndarray): Camera extrinsic matrices.
137
-
138
- Returns:
139
- trimesh.Scene: Aligned 3D scene.
140
- """
141
- # Set transformations for scene alignment
142
- opengl_conversion_matrix = get_opengl_conversion_matrix()
143
-
144
- # Rotation matrix for alignment (180 degrees around the y-axis)
145
- align_rotation = np.eye(4)
146
- align_rotation[:3, :3] = Rotation.from_euler(
147
- "y", 180, degrees=True
148
- ).as_matrix()
149
-
150
- # Apply transformation
151
- initial_transformation = (
152
- np.linalg.inv(extrinsics_matrices[0])
153
- @ opengl_conversion_matrix
154
- @ align_rotation
155
- )
156
- scene_3d.apply_transform(initial_transformation)
157
- return scene_3d
158
 
159
 
160
  def integrate_camera_into_scene(
@@ -215,40 +176,57 @@ def integrate_camera_into_scene(
215
  scene.add_geometry(camera_mesh)
216
 
217
 
218
- def compute_camera_faces(cone_shape: trimesh.Trimesh) -> np.ndarray:
 
 
 
219
  """
220
- Computes the faces for the camera mesh.
221
 
222
  Args:
223
- cone_shape (trimesh.Trimesh): The shape of the camera cone.
 
224
 
225
  Returns:
226
- np.ndarray: Array of faces for the camera mesh.
227
  """
228
- # Create pseudo cameras
229
- faces_list = []
230
- num_vertices_cone = len(cone_shape.vertices)
231
 
232
- for face in cone_shape.faces:
233
- if 0 in face:
234
- continue
235
- v1, v2, v3 = face
236
- v1_offset, v2_offset, v3_offset = face + num_vertices_cone
237
- v1_offset_2, v2_offset_2, v3_offset_2 = face + 2 * num_vertices_cone
238
 
239
- faces_list.extend(
240
- [
241
- (v1, v2, v2_offset),
242
- (v1, v1_offset, v3),
243
- (v3_offset, v2, v3),
244
- (v1, v2, v2_offset_2),
245
- (v1, v1_offset_2, v3),
246
- (v3_offset_2, v2, v3),
247
- ]
248
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
 
250
- faces_list += [(v3, v2, v1) for v1, v2, v3 in faces_list]
251
- return np.array(faces_list)
252
 
253
 
254
  def transform_points(
@@ -280,18 +258,38 @@ def transform_points(
280
  return result
281
 
282
 
283
- def get_opengl_conversion_matrix() -> np.ndarray:
 
284
  """
285
- Constructs and returns the OpenGL conversion matrix.
 
 
 
286
 
287
  Returns:
288
- numpy.ndarray: A 4x4 OpenGL conversion matrix.
289
  """
290
- # Create an identity matrix
291
- matrix = np.identity(4)
 
292
 
293
- # Flip the y and z axes
294
- matrix[1, 1] = -1
295
- matrix[2, 2] = -1
 
 
 
296
 
297
- return matrix
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
+ import trimesh
4
+ # import open3d as o3d
 
 
 
 
 
 
 
 
5
 
6
+ import gradio as gr
7
+ import numpy as np
8
+ import matplotlib
9
+ from scipy.spatial.transform import Rotation
10
 
11
+ # except:
12
+ # print(
13
+ # f"Failed to import packages for Gradio visualization. Please disable gradio visualization"
14
+ # )
 
 
 
 
15
 
 
 
16
 
 
 
17
 
18
 
19
+ def demo_predictions_to_glb(predictions, conf_thres=3.0, filter_by_frames="all", mask_black_bg=False) -> trimesh.Scene:
20
  """
21
  Converts VGG SFM predictions to a 3D scene represented as a GLB.
22
 
 
27
  trimesh.Scene: A 3D scene object.
28
  """
29
  # Convert predictions to numpy arrays
30
+ # pred_extrinsic_list', 'pred_world_points', 'pred_world_points_conf', 'images', 'last_pred_extrinsic
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
+ print("Building GLB scene")
33
+ selected_frame_idx = None
34
+ if filter_by_frames != "all":
35
+ try:
36
+ # Extract the index part before the colon
37
+ selected_frame_idx = int(filter_by_frames.split(":")[0])
38
+ except (ValueError, IndexError):
39
+ # If parsing fails, default to using all frames
40
+ pass
41
+
42
+ pred_world_points = predictions["pred_world_points"][0] # remove batch dimension
43
+ pred_world_points_conf = predictions["pred_world_points_conf"][0]
44
+ images = predictions["images"][0]
45
+ last_pred_extrinsic = predictions["last_pred_extrinsic"][0]
46
+
47
+
48
+ if selected_frame_idx is not None:
49
+ pred_world_points = pred_world_points[selected_frame_idx][None]
50
+ pred_world_points_conf = pred_world_points_conf[selected_frame_idx][None]
51
+ images = images[selected_frame_idx][None]
52
+ last_pred_extrinsic = last_pred_extrinsic[selected_frame_idx][None]
53
+
54
+ vertices_3d = pred_world_points.reshape(-1, 3)
55
+ colors_rgb = np.transpose(images, (0, 2, 3, 1)) #images.permute(0, 3, 1, 2)
56
+ colors_rgb = (colors_rgb.reshape(-1, 3) * 255).astype(np.uint8)
57
+ camera_matrices = last_pred_extrinsic
58
+
59
+ conf = pred_world_points_conf.reshape(-1)
60
+ conf_mask = conf > conf_thres
61
 
62
+ if mask_black_bg:
63
+ black_bg_mask = colors_rgb.sum(axis=1) >= 16
64
+ conf_mask = conf_mask & black_bg_mask
65
+
66
+ vertices_3d = vertices_3d[conf_mask]
67
+ colors_rgb = colors_rgb[conf_mask]
68
+
69
+
70
+ # vertices_3d = predictions["points3D"].cpu().numpy()
71
+ # colors_rgb = (predictions["points3D_rgb"].cpu().numpy() * 255).astype(
72
+ # np.uint8
73
+ # )
74
+ # camera_matrices = predictions["extrinsics_opencv"].cpu().numpy()
75
 
76
  # Calculate the 5th and 95th percentiles along each axis
77
  lower_percentile = np.percentile(vertices_3d, 5, axis=0)
 
112
  # Align scene to the observation of the first camera
113
  scene_3d = apply_scene_alignment(scene_3d, extrinsics_matrices)
114
 
115
+ print("GLB Scene built")
116
  return scene_3d
117
 
118
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
120
 
121
  def integrate_camera_into_scene(
 
176
  scene.add_geometry(camera_mesh)
177
 
178
 
179
+
180
+ def apply_scene_alignment(
181
+ scene_3d: trimesh.Scene, extrinsics_matrices: np.ndarray
182
+ ) -> trimesh.Scene:
183
  """
184
+ Aligns the 3D scene based on the extrinsics of the first camera.
185
 
186
  Args:
187
+ scene_3d (trimesh.Scene): The 3D scene to be aligned.
188
+ extrinsics_matrices (np.ndarray): Camera extrinsic matrices.
189
 
190
  Returns:
191
+ trimesh.Scene: Aligned 3D scene.
192
  """
193
+ # Set transformations for scene alignment
194
+ opengl_conversion_matrix = get_opengl_conversion_matrix()
 
195
 
196
+ # Rotation matrix for alignment (180 degrees around the y-axis)
197
+ align_rotation = np.eye(4)
198
+ align_rotation[:3, :3] = Rotation.from_euler(
199
+ "y", 180, degrees=True
200
+ ).as_matrix()
 
201
 
202
+ # Apply transformation
203
+ initial_transformation = (
204
+ np.linalg.inv(extrinsics_matrices[0])
205
+ @ opengl_conversion_matrix
206
+ @ align_rotation
207
+ )
208
+ scene_3d.apply_transform(initial_transformation)
209
+ return scene_3d
210
+
211
+
212
+
213
+
214
+ def get_opengl_conversion_matrix() -> np.ndarray:
215
+ """
216
+ Constructs and returns the OpenGL conversion matrix.
217
+
218
+ Returns:
219
+ numpy.ndarray: A 4x4 OpenGL conversion matrix.
220
+ """
221
+ # Create an identity matrix
222
+ matrix = np.identity(4)
223
+
224
+ # Flip the y and z axes
225
+ matrix[1, 1] = -1
226
+ matrix[2, 2] = -1
227
+
228
+ return matrix
229
 
 
 
230
 
231
 
232
  def transform_points(
 
258
  return result
259
 
260
 
261
+
262
+ def compute_camera_faces(cone_shape: trimesh.Trimesh) -> np.ndarray:
263
  """
264
+ Computes the faces for the camera mesh.
265
+
266
+ Args:
267
+ cone_shape (trimesh.Trimesh): The shape of the camera cone.
268
 
269
  Returns:
270
+ np.ndarray: Array of faces for the camera mesh.
271
  """
272
+ # Create pseudo cameras
273
+ faces_list = []
274
+ num_vertices_cone = len(cone_shape.vertices)
275
 
276
+ for face in cone_shape.faces:
277
+ if 0 in face:
278
+ continue
279
+ v1, v2, v3 = face
280
+ v1_offset, v2_offset, v3_offset = face + num_vertices_cone
281
+ v1_offset_2, v2_offset_2, v3_offset_2 = face + 2 * num_vertices_cone
282
 
283
+ faces_list.extend(
284
+ [
285
+ (v1, v2, v2_offset),
286
+ (v1, v1_offset, v3),
287
+ (v3_offset, v2, v3),
288
+ (v1, v2, v2_offset_2),
289
+ (v1, v1_offset_2, v3),
290
+ (v3_offset_2, v2, v3),
291
+ ]
292
+ )
293
+
294
+ faces_list += [(v3, v2, v1) for v1, v2, v3 in faces_list]
295
+ return np.array(faces_list)