ostapagon commited on
Commit
6db5fd9
·
1 Parent(s): 036b7d1

Add demo file. Change sdk to gradio. Add wild-gaussian-splatting submodule

Browse files
.gitmodules ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ [submodule "wild-gaussian-splatting"]
2
+ path = wild-gaussian-splatting
3
+ url = https://github.com/ostapagon/wild-gaussian-splatting.git
4
+ branch = mast3r_3dgs
README.md CHANGED
@@ -1,9 +1,9 @@
1
  ---
2
- title: Mast3r 3dgs
3
  emoji: 😻
4
  colorFrom: gray
5
  colorTo: indigo
6
- sdk: docker
7
  pinned: false
8
  ---
9
 
 
1
  ---
2
+ title: MASt3r+3DGS
3
  emoji: 😻
4
  colorFrom: gray
5
  colorTo: indigo
6
+ sdk: gradio
7
  pinned: false
8
  ---
9
 
app.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append('wild-gaussian-splatting/mast3r/')
3
+ sys.path.append('demo/')
4
+
5
+ import os
6
+ import tempfile
7
+ import gradio as gr
8
+ from mast3r.demo import get_args_parser
9
+ from mast3r.utils.misc import hash_md5
10
+ from mast3r_demo import mast3r_demo_tab
11
+ from gs_demo import gs_demo_tab
12
+
13
+ if __name__ == '__main__':
14
+ parser = get_args_parser()
15
+ args = parser.parse_args()
16
+
17
+ if args.server_name is not None:
18
+ server_name = args.server_name
19
+ else:
20
+ server_name = '0.0.0.0' if args.local_network else '127.0.0.1'
21
+
22
+ weights_path = args.weights if args.weights is not None else "naver/" + args.model_name
23
+ chkpt_tag = hash_md5(weights_path)
24
+
25
+ with tempfile.TemporaryDirectory(suffix='demo') as tmpdirname:
26
+ cache_path = os.path.join(tmpdirname, chkpt_tag)
27
+ os.makedirs(cache_path, exist_ok=True)
28
+
29
+ with gr.Blocks() as demo:
30
+ with gr.Tabs():
31
+ with gr.Tab("MASt3R Demo"):
32
+ mast3r_demo_tab(cache_path, weights_path, args.device)
33
+ with gr.Tab("Gaussian Splatting Demo"):
34
+ gs_demo_tab(cache_path)
35
+
36
+ demo.launch(server_name=server_name, server_port=args.server_port)
37
+
38
+ # python3 demo.py --weights "/app/mast3r/checkpoints/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth" --device "cuda" --server_port 3334 --local_network "$@"
demo/__init__.py ADDED
File without changes
demo/gs_demo.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from gs_train import train
3
+ import os
4
+
5
+ DATASET_DIR = "colmap_data"
6
+
7
+ def get_dataset_folders(datasets_path):
8
+ try:
9
+ return [f for f in os.listdir(datasets_path) if os.path.isdir(os.path.join(datasets_path, f))]
10
+ except FileNotFoundError:
11
+ return []
12
+
13
+ def gs_demo_tab(cache_path):
14
+ datasets_path = "/app/data/scenes/"
15
+ # dataset_path = os.path.join(cache_path, DATASET_DIR)
16
+ def start_training(selected_folder, *args):
17
+ selected_data_path = os.path.join(datasets_path, selected_folder)
18
+ return train(selected_data_path, *args)
19
+
20
+ def get_context():
21
+ return gr.Blocks(delete_cache=(True, True))
22
+
23
+ with get_context() as gs_demo:
24
+ gr.Markdown("""
25
+ <style>
26
+ .fixed-size-video video {
27
+ max-height: 400px !important;
28
+ height: 400px !important;
29
+ object-fit: contain;
30
+ }
31
+ </style>
32
+ """)
33
+ gr.Markdown("# Gaussian Splatting Training Demo")
34
+
35
+ refresh_button = gr.Button("Refresh Datasets", elem_classes="refresh-button")
36
+ dataset_dropdown = gr.Dropdown(label="Select Dataset", choices=[], value="")
37
+
38
+ def update_dataset_dropdown():
39
+ print("update_dataset_dropdown, cache_path", cache_path)
40
+ # Update the dataset folders list
41
+ dataset_folders = get_dataset_folders(datasets_path)
42
+ # dataset_folders = "/app/data/scenes/"
43
+ print("dataset_folders", dataset_folders)
44
+ # Only set a default value if there are folders available
45
+ default_value = dataset_folders[0] if dataset_folders else None
46
+ return gr.Dropdown(label="Select Dataset", choices=dataset_folders, value=default_value)
47
+
48
+ # Set the update function to be called when the refresh button is clicked
49
+ refresh_button.click(fn=update_dataset_dropdown, inputs=None, outputs=dataset_dropdown)
50
+
51
+ with gr.Accordion("Model Parameters", open=False):
52
+ with gr.Row():
53
+ with gr.Column():
54
+ sh_degree = gr.Number(label="SH Degree", value=3)
55
+ model_path = gr.Textbox(label="Model Path", value="")
56
+ images = gr.Textbox(label="Images", value="images")
57
+ resolution = gr.Number(label="Resolution", value=-1)
58
+ white_background = gr.Checkbox(label="White Background", value=True)
59
+ data_device = gr.Dropdown(label="Data Device", choices=["cuda", "cpu"], value="cuda")
60
+ eval = gr.Checkbox(label="Eval", value=False)
61
+
62
+ with gr.Accordion("Pipeline Parameters", open=False):
63
+ with gr.Row():
64
+ with gr.Column():
65
+ convert_SHs_python = gr.Checkbox(label="Convert SHs Python", value=False)
66
+ compute_cov3D_python = gr.Checkbox(label="Compute Cov3D Python", value=False)
67
+ debug = gr.Checkbox(label="Debug", value=False)
68
+
69
+ with gr.Accordion("Optimization Parameters", open=False):
70
+ with gr.Row():
71
+ with gr.Column():
72
+ iterations = gr.Number(label="Iterations", value=1000)
73
+ position_lr_init = gr.Number(label="Position LR Init", value=0.00016)
74
+ position_lr_final = gr.Number(label="Position LR Final", value=0.0000016)
75
+ position_lr_delay_mult = gr.Number(label="Position LR Delay Mult", value=0.01)
76
+ position_lr_max_steps = gr.Number(label="Position LR Max Steps", value=30000)
77
+ with gr.Column():
78
+ feature_lr = gr.Number(label="Feature LR", value=0.0025)
79
+ opacity_lr = gr.Number(label="Opacity LR", value=0.05)
80
+ scaling_lr = gr.Number(label="Scaling LR", value=0.005)
81
+ rotation_lr = gr.Number(label="Rotation LR", value=0.001)
82
+ percent_dense = gr.Number(label="Percent Dense", value=0.01)
83
+ with gr.Column():
84
+ lambda_dssim = gr.Number(label="Lambda DSSIM", value=0.2)
85
+ densification_interval = gr.Number(label="Densification Interval", value=100)
86
+ opacity_reset_interval = gr.Number(label="Opacity Reset Interval", value=3000)
87
+ densify_from_iter = gr.Number(label="Densify From Iter", value=500)
88
+ densify_until_iter = gr.Number(label="Densify Until Iter", value=15000)
89
+ densify_grad_threshold = gr.Number(label="Densify Grad Threshold", value=0.0002)
90
+ random_background = gr.Checkbox(label="Random Background", value=False)
91
+
92
+ start_button = gr.Button("Start Training")
93
+
94
+ # Add state variable to store model path
95
+ model_path_state = gr.State()
96
+
97
+ # Add video output and load model button with fixed scale
98
+ video_output = gr.Video(
99
+ label="Training Progress",
100
+ height=400, # Fixed height
101
+ width="100%", # Full width of container
102
+ autoplay=False, # Prevent autoplay
103
+ show_label=True,
104
+ container=True,
105
+ elem_classes="fixed-size-video" # Add custom class for potential CSS
106
+ )
107
+ load_model_button = gr.Button("Load 3D Model", interactive=False)
108
+ output = gr.Model3D(label="3D Model Output", visible=False)
109
+
110
+ def handle_training_complete(selected_folder, *args):
111
+ # Construct the full path to the selected dataset
112
+ selected_data_path = os.path.join(datasets_path, selected_folder)
113
+ # Call the training function with the full path
114
+ video_path, model_path = train(selected_data_path, *args)
115
+ # Then return all required outputs
116
+ return [
117
+ video_path, # video output
118
+ gr.Button(value="Load 3D Model", interactive=True), # Return new button with updated properties
119
+ gr.Model3D(visible=False), # keep 3D model hidden
120
+ model_path # store model path in state
121
+ ]
122
+
123
+ def load_model(model_path):
124
+ if not model_path:
125
+ return gr.Model3D(visible=False)
126
+ return gr.Model3D(value=model_path, visible=True)
127
+
128
+ # Connect the start training button
129
+ start_button.click(
130
+ fn=handle_training_complete,
131
+ inputs=[
132
+ dataset_dropdown, sh_degree, model_path, images, resolution, white_background, data_device, eval,
133
+ convert_SHs_python, compute_cov3D_python, debug,
134
+ iterations, position_lr_init, position_lr_final, position_lr_delay_mult,
135
+ position_lr_max_steps, feature_lr, opacity_lr, scaling_lr, rotation_lr,
136
+ percent_dense, lambda_dssim, densification_interval, opacity_reset_interval,
137
+ densify_from_iter, densify_until_iter, densify_grad_threshold, random_background
138
+ ],
139
+ outputs=[video_output, load_model_button, output, model_path_state]
140
+ )
141
+
142
+ # Connect the load model button
143
+ load_model_button.click(
144
+ fn=load_model,
145
+ inputs=[model_path_state],
146
+ outputs=output
147
+ )
148
+ return gs_demo
demo/gs_train.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import torch
4
+ from random import randint
5
+ import uuid
6
+ from tqdm.auto import tqdm
7
+ import gradio as gr
8
+ import importlib.util
9
+
10
+ # Add the path to the gaussian-splatting repository
11
+ gaussian_splatting_path = 'wild-gaussian-splatting/gaussian-splatting/'
12
+ sys.path.append(gaussian_splatting_path)
13
+
14
+ # Import necessary modules from the gaussian-splatting directory
15
+ from utils.loss_utils import l1_loss, ssim
16
+ from gaussian_renderer import render, network_gui
17
+ from scene import Scene, GaussianModel
18
+ from utils.general_utils import safe_state
19
+ from utils.image_utils import psnr
20
+
21
+ # Dynamically import the train module from the gaussian-splatting directory
22
+ train_spec = importlib.util.spec_from_file_location("gaussian_splatting_train", os.path.join(gaussian_splatting_path, "train.py"))
23
+ gaussian_splatting_train = importlib.util.module_from_spec(train_spec)
24
+ train_spec.loader.exec_module(gaussian_splatting_train)
25
+
26
+ # Import the necessary functions from the dynamically loaded module
27
+ prepare_output_and_logger = gaussian_splatting_train.prepare_output_and_logger
28
+ training_report = gaussian_splatting_train.training_report
29
+
30
+ from dataclasses import dataclass, field
31
+
32
+ @dataclass
33
+ class PipelineParams:
34
+ convert_SHs_python: bool = False
35
+ compute_cov3D_python: bool = False
36
+ debug: bool = False
37
+
38
+ @dataclass
39
+ class OptimizationParams:
40
+ iterations: int = 7000
41
+ position_lr_init: float = 0.00016
42
+ position_lr_final: float = 0.0000016
43
+ position_lr_delay_mult: float = 0.01
44
+ position_lr_max_steps: int = 30_000
45
+ feature_lr: float = 0.0025
46
+ opacity_lr: float = 0.05
47
+ scaling_lr: float = 0.005
48
+ rotation_lr: float = 0.001
49
+ percent_dense: float = 0.01
50
+ lambda_dssim: float = 0.2
51
+ densification_interval: int = 100
52
+ opacity_reset_interval: int = 3000
53
+ densify_from_iter: int = 500
54
+ densify_until_iter: int = 15_000
55
+ densify_grad_threshold: float = 0.0002
56
+ random_background: bool = False
57
+
58
+ @dataclass
59
+ class ModelParams:
60
+ sh_degree: int = 3
61
+ source_path: str = "../data/scenes/turtle/" # Default path, adjust as needed
62
+ model_path: str = ""
63
+ images: str = "images"
64
+ resolution: int = -1
65
+ white_background: bool = True
66
+ data_device: str = "cuda"
67
+ eval: bool = False
68
+
69
+ @dataclass
70
+ class TrainingArgs:
71
+ ip: str = "0.0.0.0"
72
+ port: int = 6007
73
+ debug_from: int = -1
74
+ detect_anomaly: bool = False
75
+ test_iterations: list[int] = field(default_factory=lambda: [7_000, 30_000])
76
+ save_iterations: list[int] = field(default_factory=lambda: [7_000, 30_000])
77
+ quiet: bool = False
78
+ checkpoint_iterations: list[int] = field(default_factory=lambda: [7_000, 15_000, 30_000])
79
+ start_checkpoint: str = None
80
+
81
+ def train(
82
+ data_source_path, sh_degree, model_path, images, resolution, white_background, data_device, eval,
83
+ convert_SHs_python, compute_cov3D_python, debug,
84
+ iterations, position_lr_init, position_lr_final, position_lr_delay_mult,
85
+ position_lr_max_steps, feature_lr, opacity_lr, scaling_lr, rotation_lr,
86
+ percent_dense, lambda_dssim, densification_interval, opacity_reset_interval,
87
+ densify_from_iter, densify_until_iter, densify_grad_threshold, random_background
88
+ ):
89
+ print(data_source_path)
90
+ # Create instances of the parameter dataclasses
91
+ dataset = ModelParams(
92
+ sh_degree=sh_degree,
93
+ source_path=data_source_path,
94
+ model_path=model_path,
95
+ images=images,
96
+ resolution=resolution,
97
+ white_background=white_background,
98
+ data_device=data_device,
99
+ eval=eval
100
+ )
101
+
102
+ pipe = PipelineParams(
103
+ convert_SHs_python=convert_SHs_python,
104
+ compute_cov3D_python=compute_cov3D_python,
105
+ debug=debug
106
+ )
107
+
108
+ opt = OptimizationParams(
109
+ iterations=iterations,
110
+ position_lr_init=position_lr_init,
111
+ position_lr_final=position_lr_final,
112
+ position_lr_delay_mult=position_lr_delay_mult,
113
+ position_lr_max_steps=position_lr_max_steps,
114
+ feature_lr=feature_lr,
115
+ opacity_lr=opacity_lr,
116
+ scaling_lr=scaling_lr,
117
+ rotation_lr=rotation_lr,
118
+ percent_dense=percent_dense,
119
+ lambda_dssim=lambda_dssim,
120
+ densification_interval=densification_interval,
121
+ opacity_reset_interval=opacity_reset_interval,
122
+ densify_from_iter=densify_from_iter,
123
+ densify_until_iter=densify_until_iter,
124
+ densify_grad_threshold=densify_grad_threshold,
125
+ random_background=random_background
126
+ )
127
+
128
+ args = TrainingArgs()
129
+
130
+ testing_iterations = args.test_iterations
131
+ saving_iterations = args.save_iterations
132
+ checkpoint_iterations = args.checkpoint_iterations
133
+ debug_from = args.debug_from
134
+
135
+ tb_writer = prepare_output_and_logger(dataset)
136
+
137
+ gaussians = GaussianModel(dataset.sh_degree)
138
+ scene = Scene(dataset, gaussians)
139
+ gaussians.training_setup(opt)
140
+
141
+ bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]
142
+ background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
143
+
144
+ iter_start = torch.cuda.Event(enable_timing = True)
145
+ iter_end = torch.cuda.Event(enable_timing = True)
146
+
147
+ viewpoint_stack = None
148
+ ema_loss_for_log = 0.0
149
+ first_iter = 0
150
+ progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress")
151
+ first_iter += 1
152
+
153
+ point_cloud_path = ""
154
+ progress = gr.Progress() # Initialize the progress bar
155
+ for iteration in range(first_iter, opt.iterations + 1):
156
+ iter_start.record()
157
+ gaussians.update_learning_rate(iteration)
158
+
159
+ # Every 1000 its we increase the levels of SH up to a maximum degree
160
+ if iteration % 1000 == 0:
161
+ gaussians.oneupSHdegree()
162
+
163
+ # Pick a random Camera
164
+ if not viewpoint_stack:
165
+ viewpoint_stack = scene.getTrainCameras().copy()
166
+ viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack)-1))
167
+
168
+ # Render
169
+ if (iteration - 1) == debug_from:
170
+ pipe.debug = True
171
+ bg = torch.rand((3), device="cuda") if opt.random_background else background
172
+
173
+ render_pkg = render(viewpoint_cam, gaussians, pipe, bg)
174
+ image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"]
175
+
176
+ # Loss
177
+ gt_image = viewpoint_cam.original_image.cuda()
178
+ Ll1 = l1_loss(image, gt_image)
179
+ loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim(image, gt_image))
180
+ loss.backward()
181
+ iter_end.record()
182
+
183
+ with torch.no_grad():
184
+ # Progress bar
185
+ ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log
186
+ if iteration % 10 == 0:
187
+ progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}"})
188
+ progress_bar.update(10)
189
+ progress(iteration / opt.iterations) # Update Gradio progress bar
190
+ if iteration == opt.iterations:
191
+ progress_bar.close()
192
+
193
+ # Log and save
194
+ training_report(tb_writer, iteration, Ll1, loss, l1_loss, iter_start.elapsed_time(iter_end), testing_iterations, scene, render, (pipe, background))
195
+ if (iteration == opt.iterations):
196
+ point_cloud_path = os.path.join(os.path.join(dataset.model_path, "point_cloud/iteration_{}".format(iteration)), "point_cloud.ply")
197
+ print("\n[ITER {}] Saving Gaussians to {}".format(iteration, point_cloud_path))
198
+ scene.save(iteration)
199
+
200
+ # Densification
201
+ if iteration < opt.densify_until_iter:
202
+ # Keep track of max radii in image-space for pruning
203
+ gaussians.max_radii2D[visibility_filter] = torch.max(gaussians.max_radii2D[visibility_filter], radii[visibility_filter])
204
+ gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter)
205
+
206
+ if iteration > opt.densify_from_iter and iteration % opt.densification_interval == 0:
207
+ size_threshold = 20 if iteration > opt.opacity_reset_interval else None
208
+ gaussians.densify_and_prune(opt.densify_grad_threshold, 0.005, scene.cameras_extent, size_threshold)
209
+
210
+ if iteration % opt.opacity_reset_interval == 0 or (dataset.white_background and iteration == opt.densify_from_iter):
211
+ gaussians.reset_opacity()
212
+
213
+ # Optimizer step
214
+ if iteration < opt.iterations:
215
+ gaussians.optimizer.step()
216
+ gaussians.optimizer.zero_grad(set_to_none = True)
217
+
218
+ if (iteration == opt.iterations):
219
+ print("\n[ITER {}] Saving Checkpoint".format(iteration))
220
+ torch.save((gaussians.capture(), iteration), scene.model_path + "/chkpnt" + str(iteration) + ".pth")
221
+
222
+
223
+ from os import makedirs
224
+ from utils.graphics_utils import focal2fov, fov2focal, getProjectionMatrix
225
+ import torchvision
226
+ import subprocess
227
+
228
+ @torch.no_grad()
229
+ def render_path(dataset : ModelParams, iteration : int, pipeline : PipelineParams, render_resize_method='crop'):
230
+ """
231
+ render_resize_method: crop, pad
232
+ """
233
+ gaussians = GaussianModel(dataset.sh_degree)
234
+ scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False)
235
+
236
+ iteration = scene.loaded_iter
237
+
238
+ bg_color = [1,1,1] if dataset.white_background else [0, 0, 0]
239
+ background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
240
+
241
+ model_path = dataset.model_path
242
+ name = "render"
243
+
244
+ views = scene.getRenderCameras()
245
+
246
+ # print(len(views))
247
+ render_path = os.path.join(model_path, name, "ours_{}".format(iteration), "renders")
248
+
249
+ makedirs(render_path, exist_ok=True)
250
+
251
+ for idx, view in enumerate(tqdm(views, desc="Rendering progress")):
252
+ if render_resize_method == 'crop':
253
+ image_size = 256
254
+ elif render_resize_method == 'pad':
255
+ image_size = max(view.image_width, view.image_height)
256
+ else:
257
+ raise NotImplementedError
258
+ view.original_image = torch.zeros((3, image_size, image_size), device=view.original_image.device)
259
+ focal_length_x = fov2focal(view.FoVx, view.image_width)
260
+ focal_length_y = fov2focal(view.FoVy, view.image_height)
261
+ view.image_width = image_size
262
+ view.image_height = image_size
263
+ view.FoVx = focal2fov(focal_length_x, image_size)
264
+ view.FoVy = focal2fov(focal_length_y, image_size)
265
+ view.projection_matrix = getProjectionMatrix(znear=view.znear, zfar=view.zfar, fovX=view.FoVx, fovY=view.FoVy).transpose(0,1).cuda().float()
266
+ view.full_proj_transform = (view.world_view_transform.unsqueeze(0).bmm(view.projection_matrix.unsqueeze(0))).squeeze(0)
267
+
268
+ render_pkg = render(view, gaussians, pipeline, background)
269
+ rendering = render_pkg["render"]
270
+ torchvision.utils.save_image(rendering, os.path.join(render_path, '{0:05d}'.format(idx) + ".png"))
271
+
272
+ # Use ffmpeg to output video
273
+ renders_path = os.path.join(model_path, name, "ours_{}".format(iteration), "renders.mp4")
274
+ # Use ffmpeg to output video
275
+ subprocess.run(["ffmpeg", "-y",
276
+ "-framerate", "24",
277
+ "-i", os.path.join(render_path, "%05d.png"),
278
+ "-vf", "pad=ceil(iw/2)*2:ceil(ih/2)*2",
279
+ "-c:v", "libx264",
280
+ "-pix_fmt", "yuv420p",
281
+ "-crf", "23",
282
+ # "-pix_fmt", "yuv420p", # Set pixel format for compatibility
283
+ renders_path], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL
284
+ )
285
+ return renders_path
286
+
287
+ renders_path = render_path(dataset, opt.iterations, pipe, render_resize_method='crop')
288
+
289
+ return renders_path, point_cloud_path
demo/mast3r_demo.py ADDED
@@ -0,0 +1,382 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
3
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
4
+ #
5
+ # --------------------------------------------------------
6
+ # sparse gradio demo functions
7
+ # --------------------------------------------------------
8
+ import sys
9
+
10
+ import math
11
+ import gradio
12
+ import os
13
+ import numpy as np
14
+ import functools
15
+ import trimesh
16
+ import copy
17
+ from scipy.spatial.transform import Rotation
18
+ import tempfile
19
+ import shutil
20
+
21
+ from mast3r.cloud_opt.sparse_ga import sparse_global_alignment
22
+ from mast3r.cloud_opt.tsdf_optimizer import TSDFPostProcess
23
+
24
+ from mast3r.model import AsymmetricMASt3R
25
+ from dust3r.image_pairs import make_pairs
26
+ from dust3r.utils.image import load_images
27
+ from dust3r.utils.device import to_numpy
28
+ from dust3r.viz import add_scene_cam, CAM_COLORS, OPENGL, pts3d_to_trimesh, cat_meshes
29
+ from dust3r.demo import get_args_parser as dust3r_get_args_parser
30
+
31
+
32
+ sys.path.append(os.path.join(os.path.dirname(__file__), '../wild-gaussian-splatting/gaussian-splatting'))
33
+ from src.colmap_dataset_utils import (
34
+ inv,
35
+ init_filestructure,
36
+ save_images_masks,
37
+ save_cameras,
38
+ save_imagestxt,
39
+ save_pointcloud,
40
+ save_pointcloud_with_normals
41
+ )
42
+
43
+ import matplotlib.pyplot as pl
44
+
45
+ import torch
46
+
47
+
48
+ class SparseGAState():
49
+ def __init__(self, sparse_ga, cache_dir=None, outfile_name=None):
50
+ self.sparse_ga = sparse_ga
51
+ self.cache_dir = cache_dir
52
+ self.outfile_name = outfile_name
53
+
54
+ def __del__(self):
55
+ if self.cache_dir is not None and os.path.isdir(self.cache_dir):
56
+ shutil.rmtree(self.cache_dir)
57
+ self.cache_dir = None
58
+ if self.outfile_name is not None and os.path.isfile(self.outfile_name):
59
+ os.remove(self.outfile_name)
60
+ self.outfile_name = None
61
+
62
+
63
+ def get_args_parser():
64
+ parser = dust3r_get_args_parser()
65
+ parser.add_argument('--share', action='store_true')
66
+ parser.add_argument('--gradio_delete_cache', default=None, type=int,
67
+ help='age/frequency at which gradio removes the file. If >0, matching cache is purged')
68
+
69
+ actions = parser._actions
70
+ for action in actions:
71
+ if action.dest == 'model_name':
72
+ action.choices = ["MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric"]
73
+ # change defaults
74
+ parser.prog = 'mast3r demo'
75
+ return parser
76
+
77
+
78
+ def _convert_scene_output_to_glb(outfile, imgs, pts3d, mask, focals, cams2world, cam_size=0.05,
79
+ cam_color=None, as_pointcloud=False,
80
+ transparent_cams=False, silent=False):
81
+ assert len(pts3d) == len(mask) <= len(imgs) <= len(cams2world) == len(focals)
82
+ pts3d = to_numpy(pts3d)
83
+ imgs = to_numpy(imgs)
84
+ focals = to_numpy(focals)
85
+ cams2world = to_numpy(cams2world)
86
+
87
+ scene = trimesh.Scene()
88
+
89
+ # full pointcloud
90
+ if as_pointcloud:
91
+ pts = np.concatenate([p[m.ravel()] for p, m in zip(pts3d, mask)]).reshape(-1, 3)
92
+ col = np.concatenate([p[m] for p, m in zip(imgs, mask)]).reshape(-1, 3)
93
+ valid_msk = np.isfinite(pts.sum(axis=1))
94
+ pct = trimesh.PointCloud(pts[valid_msk], colors=col[valid_msk])
95
+ scene.add_geometry(pct)
96
+ else:
97
+ meshes = []
98
+ for i in range(len(imgs)):
99
+ pts3d_i = pts3d[i].reshape(imgs[i].shape)
100
+ msk_i = mask[i] & np.isfinite(pts3d_i.sum(axis=-1))
101
+ meshes.append(pts3d_to_trimesh(imgs[i], pts3d_i, msk_i))
102
+ mesh = trimesh.Trimesh(**cat_meshes(meshes))
103
+ scene.add_geometry(mesh)
104
+
105
+ # add each camera
106
+ for i, pose_c2w in enumerate(cams2world):
107
+ if isinstance(cam_color, list):
108
+ camera_edge_color = cam_color[i]
109
+ else:
110
+ camera_edge_color = cam_color or CAM_COLORS[i % len(CAM_COLORS)]
111
+ add_scene_cam(scene, pose_c2w, camera_edge_color,
112
+ None if transparent_cams else imgs[i], focals[i],
113
+ imsize=imgs[i].shape[1::-1], screen_width=cam_size)
114
+
115
+ rot = np.eye(4)
116
+ rot[:3, :3] = Rotation.from_euler('y', np.deg2rad(180)).as_matrix()
117
+ scene.apply_transform(np.linalg.inv(cams2world[0] @ OPENGL @ rot))
118
+ if not silent:
119
+ print('(exporting 3D scene to', outfile, ')')
120
+ scene.export(file_obj=outfile)
121
+ return outfile
122
+
123
+
124
+ def get_3D_model_from_scene(silent, scene_state, min_conf_thr=2, as_pointcloud=False, mask_sky=False,
125
+ clean_depth=False, transparent_cams=False, cam_size=0.05, TSDF_thresh=0):
126
+ """
127
+ extract 3D_model (glb file) from a reconstructed scene
128
+ """
129
+ if scene_state is None:
130
+ return None
131
+ outfile = scene_state.outfile_name
132
+ if outfile is None:
133
+ return None
134
+
135
+ # get optimized values from scene
136
+ scene = scene_state.sparse_ga
137
+ rgbimg = scene.imgs
138
+ focals = scene.get_focals().cpu()
139
+ cams2world = scene.get_im_poses().cpu()
140
+
141
+ # 3D pointcloud from depthmap, poses and intrinsics
142
+ if TSDF_thresh > 0:
143
+ tsdf = TSDFPostProcess(scene, TSDF_thresh=TSDF_thresh)
144
+ pts3d, _, confs = to_numpy(tsdf.get_dense_pts3d(clean_depth=clean_depth))
145
+ else:
146
+ pts3d, _, confs = to_numpy(scene.get_dense_pts3d(clean_depth=clean_depth))
147
+
148
+ torch.save(confs, '/app/data/confs.pt')
149
+ msk = to_numpy([c > min_conf_thr for c in confs])
150
+ return _convert_scene_output_to_glb(outfile, rgbimg, pts3d, msk, focals, cams2world, as_pointcloud=as_pointcloud,
151
+ transparent_cams=transparent_cams, cam_size=cam_size, silent=silent)
152
+
153
+ def save_colmap_scene(scene, save_dir, min_conf_thr=2, clean_depth=False):
154
+ cam2world = scene.get_im_poses().detach().cpu().numpy()
155
+ world2cam = inv(cam2world) #
156
+ principal_points = scene.get_principal_points().detach().cpu().numpy()
157
+ focals = scene.get_focals().detach().cpu().numpy()[..., None]
158
+ imgs = np.array(scene.imgs)
159
+
160
+ pts3d, _, confs = scene.get_dense_pts3d(clean_depth=clean_depth)
161
+ pts3d = [i.detach().reshape(imgs[0].shape) for i in pts3d] #
162
+
163
+ masks = to_numpy([c > min_conf_thr for c in to_numpy(confs)])
164
+
165
+ # move
166
+ mask_images = True
167
+
168
+ save_path, images_path, masks_path, sparse_path = init_filestructure(save_dir)
169
+ save_images_masks(imgs, masks, images_path, masks_path, mask_images)
170
+ save_cameras(focals, principal_points, sparse_path, imgs_shape=imgs.shape)
171
+ save_imagestxt(world2cam, sparse_path)
172
+ save_pointcloud_with_normals(imgs, pts3d, masks, sparse_path)
173
+ return save_path
174
+
175
+ def get_reconstructed_scene(outdir, model, device, silent, image_size, current_scene_state,
176
+ filelist, optim_level, lr1, niter1, lr2, niter2, min_conf_thr, matching_conf_thr,
177
+ as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size, scenegraph_type, winsize,
178
+ win_cyclic, refid, TSDF_thresh, shared_intrinsics, **kw):
179
+ """
180
+ from a list of images, run mast3r inference, sparse global aligner.
181
+ then run get_3D_model_from_scene
182
+ """
183
+ imgs = load_images(filelist, size=image_size, verbose=not silent)
184
+ if len(imgs) == 1:
185
+ imgs = [imgs[0], copy.deepcopy(imgs[0])]
186
+ imgs[1]['idx'] = 1
187
+ filelist = [filelist[0], filelist[0] + '_2']
188
+
189
+ scene_graph_params = [scenegraph_type]
190
+ if scenegraph_type in ["swin", "logwin"]:
191
+ scene_graph_params.append(str(winsize))
192
+ elif scenegraph_type == "oneref":
193
+ scene_graph_params.append(str(refid))
194
+ if scenegraph_type in ["swin", "logwin"] and not win_cyclic:
195
+ scene_graph_params.append('noncyclic')
196
+ scene_graph = '-'.join(scene_graph_params)
197
+ pairs = make_pairs(imgs, scene_graph=scene_graph, prefilter=None, symmetrize=True)
198
+ if optim_level == 'coarse':
199
+ niter2 = 0
200
+
201
+ base_cache_dir = os.path.join(outdir, 'cache')
202
+ os.makedirs(base_cache_dir, exist_ok=True)
203
+ def get_next_dir(base_dir):
204
+ run_counter = 0
205
+ while True:
206
+ run_cache_dir = os.path.join(base_dir, f"run_{run_counter}")
207
+ if not os.path.exists(run_cache_dir):
208
+ os.makedirs(run_cache_dir)
209
+ break
210
+ run_counter += 1
211
+ return run_cache_dir
212
+
213
+ cache_dir = get_next_dir(base_cache_dir)
214
+ scene = sparse_global_alignment(filelist, pairs, cache_dir,
215
+ model, lr1=lr1, niter1=niter1, lr2=lr2, niter2=niter2, device=device,
216
+ opt_depth='depth' in optim_level, shared_intrinsics=shared_intrinsics,
217
+ matching_conf_thr=matching_conf_thr, **kw)
218
+
219
+ base_colmapdata_dir = os.path.join(outdir, 'colmap_data')
220
+ os.makedirs(base_colmapdata_dir, exist_ok=True)
221
+ colmap_data_dir = get_next_dir(base_colmapdata_dir)
222
+ #
223
+ save_colmap_scene(scene, colmap_data_dir, min_conf_thr, clean_depth)
224
+
225
+ if current_scene_state is not None and \
226
+ current_scene_state.outfile_name is not None:
227
+ outfile_name = current_scene_state.outfile_name
228
+ else:
229
+ outfile_name = tempfile.mktemp(suffix='_scene.glb', dir=outdir)
230
+
231
+ scene_state = SparseGAState(scene, cache_dir, outfile_name)
232
+ outfile = get_3D_model_from_scene(silent, scene_state, min_conf_thr, as_pointcloud, mask_sky,
233
+ clean_depth, transparent_cams, cam_size, TSDF_thresh)
234
+ print(f"colmap_data_dir: {colmap_data_dir}")
235
+ print(f"outfile_name: {outfile_name}")
236
+ print(f"cache_dir: {cache_dir}")
237
+ return scene_state, outfile
238
+
239
+
240
+ def set_scenegraph_options(inputfiles, win_cyclic, refid, scenegraph_type):
241
+ num_files = len(inputfiles) if inputfiles is not None else 1
242
+ show_win_controls = scenegraph_type in ["swin", "logwin"]
243
+ show_winsize = scenegraph_type in ["swin", "logwin"]
244
+ show_cyclic = scenegraph_type in ["swin", "logwin"]
245
+ max_winsize, min_winsize = 1, 1
246
+ if scenegraph_type == "swin":
247
+ if win_cyclic:
248
+ max_winsize = max(1, math.ceil((num_files - 1) / 2))
249
+ else:
250
+ max_winsize = num_files - 1
251
+ elif scenegraph_type == "logwin":
252
+ if win_cyclic:
253
+ half_size = math.ceil((num_files - 1) / 2)
254
+ max_winsize = max(1, math.ceil(math.log(half_size, 2)))
255
+ else:
256
+ max_winsize = max(1, math.ceil(math.log(num_files, 2)))
257
+ winsize = gradio.Slider(label="Scene Graph: Window Size", value=max_winsize,
258
+ minimum=min_winsize, maximum=max_winsize, step=1, visible=show_winsize)
259
+ win_cyclic = gradio.Checkbox(value=win_cyclic, label="Cyclic sequence", visible=show_cyclic)
260
+ win_col = gradio.Column(visible=show_win_controls)
261
+ refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0,
262
+ maximum=num_files - 1, step=1, visible=scenegraph_type == 'oneref')
263
+ return win_col, winsize, win_cyclic, refid
264
+
265
+
266
+
267
+ def mast3r_demo_tab(cache_path, weights_path, device, silent=False):
268
+ model = AsymmetricMASt3R.from_pretrained(weights_path).to(device)
269
+
270
+ if not silent:
271
+ print('Outputing stuff in', cache_path)
272
+
273
+ recon_fun = functools.partial(get_reconstructed_scene, cache_path, model, device,
274
+ silent)
275
+ model_from_scene_fun = functools.partial(get_3D_model_from_scene, silent)
276
+
277
+ def get_context():
278
+ css = """.gradio-container {margin: 0 !important; min-width: 100%};"""
279
+ title = "MASt3R Demo"
280
+ return gradio.Blocks(css=css, title=title, delete_cache=(True, True))
281
+
282
+ with get_context() as demo:
283
+ # scene state is save so that you can change conf_thr, cam_size... without rerunning the inference
284
+ scene = gradio.State(None)
285
+ gradio.HTML('<h2 style="text-align: center;">MASt3R Demo</h2>')
286
+ with gradio.Column():
287
+ inputfiles = gradio.File(file_count="multiple")
288
+ with gradio.Row():
289
+ with gradio.Column():
290
+ with gradio.Row():
291
+ lr1 = gradio.Slider(label="Coarse LR", value=0.07, minimum=0.01, maximum=0.2, step=0.01)
292
+ niter1 = gradio.Number(value=500, precision=0, minimum=0, maximum=10_000,
293
+ label="num_iterations", info="For coarse alignment!")
294
+ lr2 = gradio.Slider(label="Fine LR", value=0.014, minimum=0.005, maximum=0.05, step=0.001)
295
+ niter2 = gradio.Number(value=200, precision=0, minimum=0, maximum=100_000,
296
+ label="num_iterations", info="For refinement!")
297
+ optim_level = gradio.Dropdown(["coarse", "refine", "refine+depth"],
298
+ value='refine+depth', label="OptLevel",
299
+ info="Optimization level")
300
+ image_size = gradio.Dropdown(choices=[512, 224], label="Image Size", value=512)
301
+ with gradio.Row():
302
+ matching_conf_thr = gradio.Slider(label="Matching Confidence Thr", value=5.,
303
+ minimum=0., maximum=30., step=0.1,
304
+ info="Before Fallback to Regr3D!")
305
+ shared_intrinsics = gradio.Checkbox(value=False, label="Shared intrinsics",
306
+ info="Only optimize one set of intrinsics for all views")
307
+ scenegraph_type = gradio.Dropdown([("complete: all possible image pairs", "complete"),
308
+ ("swin: sliding window", "swin"),
309
+ ("logwin: sliding window with long range", "logwin"),
310
+ ("oneref: match one image with all", "oneref")],
311
+ value='complete', label="Scenegraph",
312
+ info="Define how to make pairs",
313
+ interactive=True)
314
+ with gradio.Column(visible=False) as win_col:
315
+ winsize = gradio.Slider(label="Scene Graph: Window Size", value=1,
316
+ minimum=1, maximum=1, step=1)
317
+ win_cyclic = gradio.Checkbox(value=False, label="Cyclic sequence")
318
+ refid = gradio.Slider(label="Scene Graph: Id", value=0,
319
+ minimum=0, maximum=0, step=1, visible=False)
320
+ run_btn = gradio.Button("Run")
321
+
322
+ with gradio.Row():
323
+ # adjust the confidence threshold
324
+ min_conf_thr = gradio.Slider(label="min_conf_thr", value=1.5, minimum=0.0, maximum=10, step=0.1)
325
+ # adjust the camera size in the output pointcloud
326
+ cam_size = gradio.Slider(label="cam_size", value=0.2, minimum=0.001, maximum=1.0, step=0.001)
327
+ TSDF_thresh = gradio.Slider(label="TSDF Threshold", value=0., minimum=0., maximum=1., step=0.01)
328
+ with gradio.Row():
329
+ as_pointcloud = gradio.Checkbox(value=True, label="As pointcloud")
330
+ # two post process implemented
331
+ mask_sky = gradio.Checkbox(value=False, label="Mask sky")
332
+ clean_depth = gradio.Checkbox(value=True, label="Clean-up depthmaps")
333
+ transparent_cams = gradio.Checkbox(value=False, label="Transparent cameras")
334
+
335
+ outmodel = gradio.Model3D()
336
+
337
+ # events
338
+ scenegraph_type.change(set_scenegraph_options,
339
+ inputs=[inputfiles, win_cyclic, refid, scenegraph_type],
340
+ outputs=[win_col, winsize, win_cyclic, refid])
341
+ inputfiles.change(set_scenegraph_options,
342
+ inputs=[inputfiles, win_cyclic, refid, scenegraph_type],
343
+ outputs=[win_col, winsize, win_cyclic, refid])
344
+ win_cyclic.change(set_scenegraph_options,
345
+ inputs=[inputfiles, win_cyclic, refid, scenegraph_type],
346
+ outputs=[win_col, winsize, win_cyclic, refid])
347
+ run_btn.click(fn=recon_fun,
348
+ inputs=[image_size, scene, inputfiles, optim_level, lr1, niter1, lr2, niter2, min_conf_thr, matching_conf_thr,
349
+ as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size,
350
+ scenegraph_type, winsize, win_cyclic, refid, TSDF_thresh, shared_intrinsics],
351
+ outputs=[scene, outmodel])
352
+ min_conf_thr.release(fn=model_from_scene_fun,
353
+ inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
354
+ clean_depth, transparent_cams, cam_size, TSDF_thresh],
355
+ outputs=outmodel)
356
+ cam_size.change(fn=model_from_scene_fun,
357
+ inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
358
+ clean_depth, transparent_cams, cam_size, TSDF_thresh],
359
+ outputs=outmodel)
360
+ TSDF_thresh.change(fn=model_from_scene_fun,
361
+ inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
362
+ clean_depth, transparent_cams, cam_size, TSDF_thresh],
363
+ outputs=outmodel)
364
+ as_pointcloud.change(fn=model_from_scene_fun,
365
+ inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
366
+ clean_depth, transparent_cams, cam_size, TSDF_thresh],
367
+ outputs=outmodel)
368
+ mask_sky.change(fn=model_from_scene_fun,
369
+ inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
370
+ clean_depth, transparent_cams, cam_size, TSDF_thresh],
371
+ outputs=outmodel)
372
+ clean_depth.change(fn=model_from_scene_fun,
373
+ inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
374
+ clean_depth, transparent_cams, cam_size, TSDF_thresh],
375
+ outputs=outmodel)
376
+ transparent_cams.change(model_from_scene_fun,
377
+ inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
378
+ clean_depth, transparent_cams, cam_size, TSDF_thresh],
379
+ outputs=outmodel)
380
+
381
+ return demo
382
+
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ -e wild-gaussian-splatting
wild-gaussian-splatting ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit fe8a9f389cdc583864f34a9e3ae32899c674229a