Spaces:
Running
on
Zero
Running
on
Zero
Add demo file. Change sdk to gradio. Add wild-gaussian-splatting submodule
Browse files- .gitmodules +4 -0
- README.md +2 -2
- app.py +38 -0
- demo/__init__.py +0 -0
- demo/gs_demo.py +148 -0
- demo/gs_train.py +289 -0
- demo/mast3r_demo.py +382 -0
- requirements.txt +1 -0
- wild-gaussian-splatting +1 -0
.gitmodules
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[submodule "wild-gaussian-splatting"]
|
2 |
+
path = wild-gaussian-splatting
|
3 |
+
url = https://github.com/ostapagon/wild-gaussian-splatting.git
|
4 |
+
branch = mast3r_3dgs
|
README.md
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
emoji: 😻
|
4 |
colorFrom: gray
|
5 |
colorTo: indigo
|
6 |
-
sdk:
|
7 |
pinned: false
|
8 |
---
|
9 |
|
|
|
1 |
---
|
2 |
+
title: MASt3r+3DGS
|
3 |
emoji: 😻
|
4 |
colorFrom: gray
|
5 |
colorTo: indigo
|
6 |
+
sdk: gradio
|
7 |
pinned: false
|
8 |
---
|
9 |
|
app.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
sys.path.append('wild-gaussian-splatting/mast3r/')
|
3 |
+
sys.path.append('demo/')
|
4 |
+
|
5 |
+
import os
|
6 |
+
import tempfile
|
7 |
+
import gradio as gr
|
8 |
+
from mast3r.demo import get_args_parser
|
9 |
+
from mast3r.utils.misc import hash_md5
|
10 |
+
from mast3r_demo import mast3r_demo_tab
|
11 |
+
from gs_demo import gs_demo_tab
|
12 |
+
|
13 |
+
if __name__ == '__main__':
|
14 |
+
parser = get_args_parser()
|
15 |
+
args = parser.parse_args()
|
16 |
+
|
17 |
+
if args.server_name is not None:
|
18 |
+
server_name = args.server_name
|
19 |
+
else:
|
20 |
+
server_name = '0.0.0.0' if args.local_network else '127.0.0.1'
|
21 |
+
|
22 |
+
weights_path = args.weights if args.weights is not None else "naver/" + args.model_name
|
23 |
+
chkpt_tag = hash_md5(weights_path)
|
24 |
+
|
25 |
+
with tempfile.TemporaryDirectory(suffix='demo') as tmpdirname:
|
26 |
+
cache_path = os.path.join(tmpdirname, chkpt_tag)
|
27 |
+
os.makedirs(cache_path, exist_ok=True)
|
28 |
+
|
29 |
+
with gr.Blocks() as demo:
|
30 |
+
with gr.Tabs():
|
31 |
+
with gr.Tab("MASt3R Demo"):
|
32 |
+
mast3r_demo_tab(cache_path, weights_path, args.device)
|
33 |
+
with gr.Tab("Gaussian Splatting Demo"):
|
34 |
+
gs_demo_tab(cache_path)
|
35 |
+
|
36 |
+
demo.launch(server_name=server_name, server_port=args.server_port)
|
37 |
+
|
38 |
+
# python3 demo.py --weights "/app/mast3r/checkpoints/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth" --device "cuda" --server_port 3334 --local_network "$@"
|
demo/__init__.py
ADDED
File without changes
|
demo/gs_demo.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from gs_train import train
|
3 |
+
import os
|
4 |
+
|
5 |
+
DATASET_DIR = "colmap_data"
|
6 |
+
|
7 |
+
def get_dataset_folders(datasets_path):
|
8 |
+
try:
|
9 |
+
return [f for f in os.listdir(datasets_path) if os.path.isdir(os.path.join(datasets_path, f))]
|
10 |
+
except FileNotFoundError:
|
11 |
+
return []
|
12 |
+
|
13 |
+
def gs_demo_tab(cache_path):
|
14 |
+
datasets_path = "/app/data/scenes/"
|
15 |
+
# dataset_path = os.path.join(cache_path, DATASET_DIR)
|
16 |
+
def start_training(selected_folder, *args):
|
17 |
+
selected_data_path = os.path.join(datasets_path, selected_folder)
|
18 |
+
return train(selected_data_path, *args)
|
19 |
+
|
20 |
+
def get_context():
|
21 |
+
return gr.Blocks(delete_cache=(True, True))
|
22 |
+
|
23 |
+
with get_context() as gs_demo:
|
24 |
+
gr.Markdown("""
|
25 |
+
<style>
|
26 |
+
.fixed-size-video video {
|
27 |
+
max-height: 400px !important;
|
28 |
+
height: 400px !important;
|
29 |
+
object-fit: contain;
|
30 |
+
}
|
31 |
+
</style>
|
32 |
+
""")
|
33 |
+
gr.Markdown("# Gaussian Splatting Training Demo")
|
34 |
+
|
35 |
+
refresh_button = gr.Button("Refresh Datasets", elem_classes="refresh-button")
|
36 |
+
dataset_dropdown = gr.Dropdown(label="Select Dataset", choices=[], value="")
|
37 |
+
|
38 |
+
def update_dataset_dropdown():
|
39 |
+
print("update_dataset_dropdown, cache_path", cache_path)
|
40 |
+
# Update the dataset folders list
|
41 |
+
dataset_folders = get_dataset_folders(datasets_path)
|
42 |
+
# dataset_folders = "/app/data/scenes/"
|
43 |
+
print("dataset_folders", dataset_folders)
|
44 |
+
# Only set a default value if there are folders available
|
45 |
+
default_value = dataset_folders[0] if dataset_folders else None
|
46 |
+
return gr.Dropdown(label="Select Dataset", choices=dataset_folders, value=default_value)
|
47 |
+
|
48 |
+
# Set the update function to be called when the refresh button is clicked
|
49 |
+
refresh_button.click(fn=update_dataset_dropdown, inputs=None, outputs=dataset_dropdown)
|
50 |
+
|
51 |
+
with gr.Accordion("Model Parameters", open=False):
|
52 |
+
with gr.Row():
|
53 |
+
with gr.Column():
|
54 |
+
sh_degree = gr.Number(label="SH Degree", value=3)
|
55 |
+
model_path = gr.Textbox(label="Model Path", value="")
|
56 |
+
images = gr.Textbox(label="Images", value="images")
|
57 |
+
resolution = gr.Number(label="Resolution", value=-1)
|
58 |
+
white_background = gr.Checkbox(label="White Background", value=True)
|
59 |
+
data_device = gr.Dropdown(label="Data Device", choices=["cuda", "cpu"], value="cuda")
|
60 |
+
eval = gr.Checkbox(label="Eval", value=False)
|
61 |
+
|
62 |
+
with gr.Accordion("Pipeline Parameters", open=False):
|
63 |
+
with gr.Row():
|
64 |
+
with gr.Column():
|
65 |
+
convert_SHs_python = gr.Checkbox(label="Convert SHs Python", value=False)
|
66 |
+
compute_cov3D_python = gr.Checkbox(label="Compute Cov3D Python", value=False)
|
67 |
+
debug = gr.Checkbox(label="Debug", value=False)
|
68 |
+
|
69 |
+
with gr.Accordion("Optimization Parameters", open=False):
|
70 |
+
with gr.Row():
|
71 |
+
with gr.Column():
|
72 |
+
iterations = gr.Number(label="Iterations", value=1000)
|
73 |
+
position_lr_init = gr.Number(label="Position LR Init", value=0.00016)
|
74 |
+
position_lr_final = gr.Number(label="Position LR Final", value=0.0000016)
|
75 |
+
position_lr_delay_mult = gr.Number(label="Position LR Delay Mult", value=0.01)
|
76 |
+
position_lr_max_steps = gr.Number(label="Position LR Max Steps", value=30000)
|
77 |
+
with gr.Column():
|
78 |
+
feature_lr = gr.Number(label="Feature LR", value=0.0025)
|
79 |
+
opacity_lr = gr.Number(label="Opacity LR", value=0.05)
|
80 |
+
scaling_lr = gr.Number(label="Scaling LR", value=0.005)
|
81 |
+
rotation_lr = gr.Number(label="Rotation LR", value=0.001)
|
82 |
+
percent_dense = gr.Number(label="Percent Dense", value=0.01)
|
83 |
+
with gr.Column():
|
84 |
+
lambda_dssim = gr.Number(label="Lambda DSSIM", value=0.2)
|
85 |
+
densification_interval = gr.Number(label="Densification Interval", value=100)
|
86 |
+
opacity_reset_interval = gr.Number(label="Opacity Reset Interval", value=3000)
|
87 |
+
densify_from_iter = gr.Number(label="Densify From Iter", value=500)
|
88 |
+
densify_until_iter = gr.Number(label="Densify Until Iter", value=15000)
|
89 |
+
densify_grad_threshold = gr.Number(label="Densify Grad Threshold", value=0.0002)
|
90 |
+
random_background = gr.Checkbox(label="Random Background", value=False)
|
91 |
+
|
92 |
+
start_button = gr.Button("Start Training")
|
93 |
+
|
94 |
+
# Add state variable to store model path
|
95 |
+
model_path_state = gr.State()
|
96 |
+
|
97 |
+
# Add video output and load model button with fixed scale
|
98 |
+
video_output = gr.Video(
|
99 |
+
label="Training Progress",
|
100 |
+
height=400, # Fixed height
|
101 |
+
width="100%", # Full width of container
|
102 |
+
autoplay=False, # Prevent autoplay
|
103 |
+
show_label=True,
|
104 |
+
container=True,
|
105 |
+
elem_classes="fixed-size-video" # Add custom class for potential CSS
|
106 |
+
)
|
107 |
+
load_model_button = gr.Button("Load 3D Model", interactive=False)
|
108 |
+
output = gr.Model3D(label="3D Model Output", visible=False)
|
109 |
+
|
110 |
+
def handle_training_complete(selected_folder, *args):
|
111 |
+
# Construct the full path to the selected dataset
|
112 |
+
selected_data_path = os.path.join(datasets_path, selected_folder)
|
113 |
+
# Call the training function with the full path
|
114 |
+
video_path, model_path = train(selected_data_path, *args)
|
115 |
+
# Then return all required outputs
|
116 |
+
return [
|
117 |
+
video_path, # video output
|
118 |
+
gr.Button(value="Load 3D Model", interactive=True), # Return new button with updated properties
|
119 |
+
gr.Model3D(visible=False), # keep 3D model hidden
|
120 |
+
model_path # store model path in state
|
121 |
+
]
|
122 |
+
|
123 |
+
def load_model(model_path):
|
124 |
+
if not model_path:
|
125 |
+
return gr.Model3D(visible=False)
|
126 |
+
return gr.Model3D(value=model_path, visible=True)
|
127 |
+
|
128 |
+
# Connect the start training button
|
129 |
+
start_button.click(
|
130 |
+
fn=handle_training_complete,
|
131 |
+
inputs=[
|
132 |
+
dataset_dropdown, sh_degree, model_path, images, resolution, white_background, data_device, eval,
|
133 |
+
convert_SHs_python, compute_cov3D_python, debug,
|
134 |
+
iterations, position_lr_init, position_lr_final, position_lr_delay_mult,
|
135 |
+
position_lr_max_steps, feature_lr, opacity_lr, scaling_lr, rotation_lr,
|
136 |
+
percent_dense, lambda_dssim, densification_interval, opacity_reset_interval,
|
137 |
+
densify_from_iter, densify_until_iter, densify_grad_threshold, random_background
|
138 |
+
],
|
139 |
+
outputs=[video_output, load_model_button, output, model_path_state]
|
140 |
+
)
|
141 |
+
|
142 |
+
# Connect the load model button
|
143 |
+
load_model_button.click(
|
144 |
+
fn=load_model,
|
145 |
+
inputs=[model_path_state],
|
146 |
+
outputs=output
|
147 |
+
)
|
148 |
+
return gs_demo
|
demo/gs_train.py
ADDED
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
from random import randint
|
5 |
+
import uuid
|
6 |
+
from tqdm.auto import tqdm
|
7 |
+
import gradio as gr
|
8 |
+
import importlib.util
|
9 |
+
|
10 |
+
# Add the path to the gaussian-splatting repository
|
11 |
+
gaussian_splatting_path = 'wild-gaussian-splatting/gaussian-splatting/'
|
12 |
+
sys.path.append(gaussian_splatting_path)
|
13 |
+
|
14 |
+
# Import necessary modules from the gaussian-splatting directory
|
15 |
+
from utils.loss_utils import l1_loss, ssim
|
16 |
+
from gaussian_renderer import render, network_gui
|
17 |
+
from scene import Scene, GaussianModel
|
18 |
+
from utils.general_utils import safe_state
|
19 |
+
from utils.image_utils import psnr
|
20 |
+
|
21 |
+
# Dynamically import the train module from the gaussian-splatting directory
|
22 |
+
train_spec = importlib.util.spec_from_file_location("gaussian_splatting_train", os.path.join(gaussian_splatting_path, "train.py"))
|
23 |
+
gaussian_splatting_train = importlib.util.module_from_spec(train_spec)
|
24 |
+
train_spec.loader.exec_module(gaussian_splatting_train)
|
25 |
+
|
26 |
+
# Import the necessary functions from the dynamically loaded module
|
27 |
+
prepare_output_and_logger = gaussian_splatting_train.prepare_output_and_logger
|
28 |
+
training_report = gaussian_splatting_train.training_report
|
29 |
+
|
30 |
+
from dataclasses import dataclass, field
|
31 |
+
|
32 |
+
@dataclass
|
33 |
+
class PipelineParams:
|
34 |
+
convert_SHs_python: bool = False
|
35 |
+
compute_cov3D_python: bool = False
|
36 |
+
debug: bool = False
|
37 |
+
|
38 |
+
@dataclass
|
39 |
+
class OptimizationParams:
|
40 |
+
iterations: int = 7000
|
41 |
+
position_lr_init: float = 0.00016
|
42 |
+
position_lr_final: float = 0.0000016
|
43 |
+
position_lr_delay_mult: float = 0.01
|
44 |
+
position_lr_max_steps: int = 30_000
|
45 |
+
feature_lr: float = 0.0025
|
46 |
+
opacity_lr: float = 0.05
|
47 |
+
scaling_lr: float = 0.005
|
48 |
+
rotation_lr: float = 0.001
|
49 |
+
percent_dense: float = 0.01
|
50 |
+
lambda_dssim: float = 0.2
|
51 |
+
densification_interval: int = 100
|
52 |
+
opacity_reset_interval: int = 3000
|
53 |
+
densify_from_iter: int = 500
|
54 |
+
densify_until_iter: int = 15_000
|
55 |
+
densify_grad_threshold: float = 0.0002
|
56 |
+
random_background: bool = False
|
57 |
+
|
58 |
+
@dataclass
|
59 |
+
class ModelParams:
|
60 |
+
sh_degree: int = 3
|
61 |
+
source_path: str = "../data/scenes/turtle/" # Default path, adjust as needed
|
62 |
+
model_path: str = ""
|
63 |
+
images: str = "images"
|
64 |
+
resolution: int = -1
|
65 |
+
white_background: bool = True
|
66 |
+
data_device: str = "cuda"
|
67 |
+
eval: bool = False
|
68 |
+
|
69 |
+
@dataclass
|
70 |
+
class TrainingArgs:
|
71 |
+
ip: str = "0.0.0.0"
|
72 |
+
port: int = 6007
|
73 |
+
debug_from: int = -1
|
74 |
+
detect_anomaly: bool = False
|
75 |
+
test_iterations: list[int] = field(default_factory=lambda: [7_000, 30_000])
|
76 |
+
save_iterations: list[int] = field(default_factory=lambda: [7_000, 30_000])
|
77 |
+
quiet: bool = False
|
78 |
+
checkpoint_iterations: list[int] = field(default_factory=lambda: [7_000, 15_000, 30_000])
|
79 |
+
start_checkpoint: str = None
|
80 |
+
|
81 |
+
def train(
|
82 |
+
data_source_path, sh_degree, model_path, images, resolution, white_background, data_device, eval,
|
83 |
+
convert_SHs_python, compute_cov3D_python, debug,
|
84 |
+
iterations, position_lr_init, position_lr_final, position_lr_delay_mult,
|
85 |
+
position_lr_max_steps, feature_lr, opacity_lr, scaling_lr, rotation_lr,
|
86 |
+
percent_dense, lambda_dssim, densification_interval, opacity_reset_interval,
|
87 |
+
densify_from_iter, densify_until_iter, densify_grad_threshold, random_background
|
88 |
+
):
|
89 |
+
print(data_source_path)
|
90 |
+
# Create instances of the parameter dataclasses
|
91 |
+
dataset = ModelParams(
|
92 |
+
sh_degree=sh_degree,
|
93 |
+
source_path=data_source_path,
|
94 |
+
model_path=model_path,
|
95 |
+
images=images,
|
96 |
+
resolution=resolution,
|
97 |
+
white_background=white_background,
|
98 |
+
data_device=data_device,
|
99 |
+
eval=eval
|
100 |
+
)
|
101 |
+
|
102 |
+
pipe = PipelineParams(
|
103 |
+
convert_SHs_python=convert_SHs_python,
|
104 |
+
compute_cov3D_python=compute_cov3D_python,
|
105 |
+
debug=debug
|
106 |
+
)
|
107 |
+
|
108 |
+
opt = OptimizationParams(
|
109 |
+
iterations=iterations,
|
110 |
+
position_lr_init=position_lr_init,
|
111 |
+
position_lr_final=position_lr_final,
|
112 |
+
position_lr_delay_mult=position_lr_delay_mult,
|
113 |
+
position_lr_max_steps=position_lr_max_steps,
|
114 |
+
feature_lr=feature_lr,
|
115 |
+
opacity_lr=opacity_lr,
|
116 |
+
scaling_lr=scaling_lr,
|
117 |
+
rotation_lr=rotation_lr,
|
118 |
+
percent_dense=percent_dense,
|
119 |
+
lambda_dssim=lambda_dssim,
|
120 |
+
densification_interval=densification_interval,
|
121 |
+
opacity_reset_interval=opacity_reset_interval,
|
122 |
+
densify_from_iter=densify_from_iter,
|
123 |
+
densify_until_iter=densify_until_iter,
|
124 |
+
densify_grad_threshold=densify_grad_threshold,
|
125 |
+
random_background=random_background
|
126 |
+
)
|
127 |
+
|
128 |
+
args = TrainingArgs()
|
129 |
+
|
130 |
+
testing_iterations = args.test_iterations
|
131 |
+
saving_iterations = args.save_iterations
|
132 |
+
checkpoint_iterations = args.checkpoint_iterations
|
133 |
+
debug_from = args.debug_from
|
134 |
+
|
135 |
+
tb_writer = prepare_output_and_logger(dataset)
|
136 |
+
|
137 |
+
gaussians = GaussianModel(dataset.sh_degree)
|
138 |
+
scene = Scene(dataset, gaussians)
|
139 |
+
gaussians.training_setup(opt)
|
140 |
+
|
141 |
+
bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]
|
142 |
+
background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
|
143 |
+
|
144 |
+
iter_start = torch.cuda.Event(enable_timing = True)
|
145 |
+
iter_end = torch.cuda.Event(enable_timing = True)
|
146 |
+
|
147 |
+
viewpoint_stack = None
|
148 |
+
ema_loss_for_log = 0.0
|
149 |
+
first_iter = 0
|
150 |
+
progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress")
|
151 |
+
first_iter += 1
|
152 |
+
|
153 |
+
point_cloud_path = ""
|
154 |
+
progress = gr.Progress() # Initialize the progress bar
|
155 |
+
for iteration in range(first_iter, opt.iterations + 1):
|
156 |
+
iter_start.record()
|
157 |
+
gaussians.update_learning_rate(iteration)
|
158 |
+
|
159 |
+
# Every 1000 its we increase the levels of SH up to a maximum degree
|
160 |
+
if iteration % 1000 == 0:
|
161 |
+
gaussians.oneupSHdegree()
|
162 |
+
|
163 |
+
# Pick a random Camera
|
164 |
+
if not viewpoint_stack:
|
165 |
+
viewpoint_stack = scene.getTrainCameras().copy()
|
166 |
+
viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack)-1))
|
167 |
+
|
168 |
+
# Render
|
169 |
+
if (iteration - 1) == debug_from:
|
170 |
+
pipe.debug = True
|
171 |
+
bg = torch.rand((3), device="cuda") if opt.random_background else background
|
172 |
+
|
173 |
+
render_pkg = render(viewpoint_cam, gaussians, pipe, bg)
|
174 |
+
image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"]
|
175 |
+
|
176 |
+
# Loss
|
177 |
+
gt_image = viewpoint_cam.original_image.cuda()
|
178 |
+
Ll1 = l1_loss(image, gt_image)
|
179 |
+
loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim(image, gt_image))
|
180 |
+
loss.backward()
|
181 |
+
iter_end.record()
|
182 |
+
|
183 |
+
with torch.no_grad():
|
184 |
+
# Progress bar
|
185 |
+
ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log
|
186 |
+
if iteration % 10 == 0:
|
187 |
+
progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}"})
|
188 |
+
progress_bar.update(10)
|
189 |
+
progress(iteration / opt.iterations) # Update Gradio progress bar
|
190 |
+
if iteration == opt.iterations:
|
191 |
+
progress_bar.close()
|
192 |
+
|
193 |
+
# Log and save
|
194 |
+
training_report(tb_writer, iteration, Ll1, loss, l1_loss, iter_start.elapsed_time(iter_end), testing_iterations, scene, render, (pipe, background))
|
195 |
+
if (iteration == opt.iterations):
|
196 |
+
point_cloud_path = os.path.join(os.path.join(dataset.model_path, "point_cloud/iteration_{}".format(iteration)), "point_cloud.ply")
|
197 |
+
print("\n[ITER {}] Saving Gaussians to {}".format(iteration, point_cloud_path))
|
198 |
+
scene.save(iteration)
|
199 |
+
|
200 |
+
# Densification
|
201 |
+
if iteration < opt.densify_until_iter:
|
202 |
+
# Keep track of max radii in image-space for pruning
|
203 |
+
gaussians.max_radii2D[visibility_filter] = torch.max(gaussians.max_radii2D[visibility_filter], radii[visibility_filter])
|
204 |
+
gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter)
|
205 |
+
|
206 |
+
if iteration > opt.densify_from_iter and iteration % opt.densification_interval == 0:
|
207 |
+
size_threshold = 20 if iteration > opt.opacity_reset_interval else None
|
208 |
+
gaussians.densify_and_prune(opt.densify_grad_threshold, 0.005, scene.cameras_extent, size_threshold)
|
209 |
+
|
210 |
+
if iteration % opt.opacity_reset_interval == 0 or (dataset.white_background and iteration == opt.densify_from_iter):
|
211 |
+
gaussians.reset_opacity()
|
212 |
+
|
213 |
+
# Optimizer step
|
214 |
+
if iteration < opt.iterations:
|
215 |
+
gaussians.optimizer.step()
|
216 |
+
gaussians.optimizer.zero_grad(set_to_none = True)
|
217 |
+
|
218 |
+
if (iteration == opt.iterations):
|
219 |
+
print("\n[ITER {}] Saving Checkpoint".format(iteration))
|
220 |
+
torch.save((gaussians.capture(), iteration), scene.model_path + "/chkpnt" + str(iteration) + ".pth")
|
221 |
+
|
222 |
+
|
223 |
+
from os import makedirs
|
224 |
+
from utils.graphics_utils import focal2fov, fov2focal, getProjectionMatrix
|
225 |
+
import torchvision
|
226 |
+
import subprocess
|
227 |
+
|
228 |
+
@torch.no_grad()
|
229 |
+
def render_path(dataset : ModelParams, iteration : int, pipeline : PipelineParams, render_resize_method='crop'):
|
230 |
+
"""
|
231 |
+
render_resize_method: crop, pad
|
232 |
+
"""
|
233 |
+
gaussians = GaussianModel(dataset.sh_degree)
|
234 |
+
scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False)
|
235 |
+
|
236 |
+
iteration = scene.loaded_iter
|
237 |
+
|
238 |
+
bg_color = [1,1,1] if dataset.white_background else [0, 0, 0]
|
239 |
+
background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
|
240 |
+
|
241 |
+
model_path = dataset.model_path
|
242 |
+
name = "render"
|
243 |
+
|
244 |
+
views = scene.getRenderCameras()
|
245 |
+
|
246 |
+
# print(len(views))
|
247 |
+
render_path = os.path.join(model_path, name, "ours_{}".format(iteration), "renders")
|
248 |
+
|
249 |
+
makedirs(render_path, exist_ok=True)
|
250 |
+
|
251 |
+
for idx, view in enumerate(tqdm(views, desc="Rendering progress")):
|
252 |
+
if render_resize_method == 'crop':
|
253 |
+
image_size = 256
|
254 |
+
elif render_resize_method == 'pad':
|
255 |
+
image_size = max(view.image_width, view.image_height)
|
256 |
+
else:
|
257 |
+
raise NotImplementedError
|
258 |
+
view.original_image = torch.zeros((3, image_size, image_size), device=view.original_image.device)
|
259 |
+
focal_length_x = fov2focal(view.FoVx, view.image_width)
|
260 |
+
focal_length_y = fov2focal(view.FoVy, view.image_height)
|
261 |
+
view.image_width = image_size
|
262 |
+
view.image_height = image_size
|
263 |
+
view.FoVx = focal2fov(focal_length_x, image_size)
|
264 |
+
view.FoVy = focal2fov(focal_length_y, image_size)
|
265 |
+
view.projection_matrix = getProjectionMatrix(znear=view.znear, zfar=view.zfar, fovX=view.FoVx, fovY=view.FoVy).transpose(0,1).cuda().float()
|
266 |
+
view.full_proj_transform = (view.world_view_transform.unsqueeze(0).bmm(view.projection_matrix.unsqueeze(0))).squeeze(0)
|
267 |
+
|
268 |
+
render_pkg = render(view, gaussians, pipeline, background)
|
269 |
+
rendering = render_pkg["render"]
|
270 |
+
torchvision.utils.save_image(rendering, os.path.join(render_path, '{0:05d}'.format(idx) + ".png"))
|
271 |
+
|
272 |
+
# Use ffmpeg to output video
|
273 |
+
renders_path = os.path.join(model_path, name, "ours_{}".format(iteration), "renders.mp4")
|
274 |
+
# Use ffmpeg to output video
|
275 |
+
subprocess.run(["ffmpeg", "-y",
|
276 |
+
"-framerate", "24",
|
277 |
+
"-i", os.path.join(render_path, "%05d.png"),
|
278 |
+
"-vf", "pad=ceil(iw/2)*2:ceil(ih/2)*2",
|
279 |
+
"-c:v", "libx264",
|
280 |
+
"-pix_fmt", "yuv420p",
|
281 |
+
"-crf", "23",
|
282 |
+
# "-pix_fmt", "yuv420p", # Set pixel format for compatibility
|
283 |
+
renders_path], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL
|
284 |
+
)
|
285 |
+
return renders_path
|
286 |
+
|
287 |
+
renders_path = render_path(dataset, opt.iterations, pipe, render_resize_method='crop')
|
288 |
+
|
289 |
+
return renders_path, point_cloud_path
|
demo/mast3r_demo.py
ADDED
@@ -0,0 +1,382 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
3 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
4 |
+
#
|
5 |
+
# --------------------------------------------------------
|
6 |
+
# sparse gradio demo functions
|
7 |
+
# --------------------------------------------------------
|
8 |
+
import sys
|
9 |
+
|
10 |
+
import math
|
11 |
+
import gradio
|
12 |
+
import os
|
13 |
+
import numpy as np
|
14 |
+
import functools
|
15 |
+
import trimesh
|
16 |
+
import copy
|
17 |
+
from scipy.spatial.transform import Rotation
|
18 |
+
import tempfile
|
19 |
+
import shutil
|
20 |
+
|
21 |
+
from mast3r.cloud_opt.sparse_ga import sparse_global_alignment
|
22 |
+
from mast3r.cloud_opt.tsdf_optimizer import TSDFPostProcess
|
23 |
+
|
24 |
+
from mast3r.model import AsymmetricMASt3R
|
25 |
+
from dust3r.image_pairs import make_pairs
|
26 |
+
from dust3r.utils.image import load_images
|
27 |
+
from dust3r.utils.device import to_numpy
|
28 |
+
from dust3r.viz import add_scene_cam, CAM_COLORS, OPENGL, pts3d_to_trimesh, cat_meshes
|
29 |
+
from dust3r.demo import get_args_parser as dust3r_get_args_parser
|
30 |
+
|
31 |
+
|
32 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), '../wild-gaussian-splatting/gaussian-splatting'))
|
33 |
+
from src.colmap_dataset_utils import (
|
34 |
+
inv,
|
35 |
+
init_filestructure,
|
36 |
+
save_images_masks,
|
37 |
+
save_cameras,
|
38 |
+
save_imagestxt,
|
39 |
+
save_pointcloud,
|
40 |
+
save_pointcloud_with_normals
|
41 |
+
)
|
42 |
+
|
43 |
+
import matplotlib.pyplot as pl
|
44 |
+
|
45 |
+
import torch
|
46 |
+
|
47 |
+
|
48 |
+
class SparseGAState():
|
49 |
+
def __init__(self, sparse_ga, cache_dir=None, outfile_name=None):
|
50 |
+
self.sparse_ga = sparse_ga
|
51 |
+
self.cache_dir = cache_dir
|
52 |
+
self.outfile_name = outfile_name
|
53 |
+
|
54 |
+
def __del__(self):
|
55 |
+
if self.cache_dir is not None and os.path.isdir(self.cache_dir):
|
56 |
+
shutil.rmtree(self.cache_dir)
|
57 |
+
self.cache_dir = None
|
58 |
+
if self.outfile_name is not None and os.path.isfile(self.outfile_name):
|
59 |
+
os.remove(self.outfile_name)
|
60 |
+
self.outfile_name = None
|
61 |
+
|
62 |
+
|
63 |
+
def get_args_parser():
|
64 |
+
parser = dust3r_get_args_parser()
|
65 |
+
parser.add_argument('--share', action='store_true')
|
66 |
+
parser.add_argument('--gradio_delete_cache', default=None, type=int,
|
67 |
+
help='age/frequency at which gradio removes the file. If >0, matching cache is purged')
|
68 |
+
|
69 |
+
actions = parser._actions
|
70 |
+
for action in actions:
|
71 |
+
if action.dest == 'model_name':
|
72 |
+
action.choices = ["MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric"]
|
73 |
+
# change defaults
|
74 |
+
parser.prog = 'mast3r demo'
|
75 |
+
return parser
|
76 |
+
|
77 |
+
|
78 |
+
def _convert_scene_output_to_glb(outfile, imgs, pts3d, mask, focals, cams2world, cam_size=0.05,
|
79 |
+
cam_color=None, as_pointcloud=False,
|
80 |
+
transparent_cams=False, silent=False):
|
81 |
+
assert len(pts3d) == len(mask) <= len(imgs) <= len(cams2world) == len(focals)
|
82 |
+
pts3d = to_numpy(pts3d)
|
83 |
+
imgs = to_numpy(imgs)
|
84 |
+
focals = to_numpy(focals)
|
85 |
+
cams2world = to_numpy(cams2world)
|
86 |
+
|
87 |
+
scene = trimesh.Scene()
|
88 |
+
|
89 |
+
# full pointcloud
|
90 |
+
if as_pointcloud:
|
91 |
+
pts = np.concatenate([p[m.ravel()] for p, m in zip(pts3d, mask)]).reshape(-1, 3)
|
92 |
+
col = np.concatenate([p[m] for p, m in zip(imgs, mask)]).reshape(-1, 3)
|
93 |
+
valid_msk = np.isfinite(pts.sum(axis=1))
|
94 |
+
pct = trimesh.PointCloud(pts[valid_msk], colors=col[valid_msk])
|
95 |
+
scene.add_geometry(pct)
|
96 |
+
else:
|
97 |
+
meshes = []
|
98 |
+
for i in range(len(imgs)):
|
99 |
+
pts3d_i = pts3d[i].reshape(imgs[i].shape)
|
100 |
+
msk_i = mask[i] & np.isfinite(pts3d_i.sum(axis=-1))
|
101 |
+
meshes.append(pts3d_to_trimesh(imgs[i], pts3d_i, msk_i))
|
102 |
+
mesh = trimesh.Trimesh(**cat_meshes(meshes))
|
103 |
+
scene.add_geometry(mesh)
|
104 |
+
|
105 |
+
# add each camera
|
106 |
+
for i, pose_c2w in enumerate(cams2world):
|
107 |
+
if isinstance(cam_color, list):
|
108 |
+
camera_edge_color = cam_color[i]
|
109 |
+
else:
|
110 |
+
camera_edge_color = cam_color or CAM_COLORS[i % len(CAM_COLORS)]
|
111 |
+
add_scene_cam(scene, pose_c2w, camera_edge_color,
|
112 |
+
None if transparent_cams else imgs[i], focals[i],
|
113 |
+
imsize=imgs[i].shape[1::-1], screen_width=cam_size)
|
114 |
+
|
115 |
+
rot = np.eye(4)
|
116 |
+
rot[:3, :3] = Rotation.from_euler('y', np.deg2rad(180)).as_matrix()
|
117 |
+
scene.apply_transform(np.linalg.inv(cams2world[0] @ OPENGL @ rot))
|
118 |
+
if not silent:
|
119 |
+
print('(exporting 3D scene to', outfile, ')')
|
120 |
+
scene.export(file_obj=outfile)
|
121 |
+
return outfile
|
122 |
+
|
123 |
+
|
124 |
+
def get_3D_model_from_scene(silent, scene_state, min_conf_thr=2, as_pointcloud=False, mask_sky=False,
|
125 |
+
clean_depth=False, transparent_cams=False, cam_size=0.05, TSDF_thresh=0):
|
126 |
+
"""
|
127 |
+
extract 3D_model (glb file) from a reconstructed scene
|
128 |
+
"""
|
129 |
+
if scene_state is None:
|
130 |
+
return None
|
131 |
+
outfile = scene_state.outfile_name
|
132 |
+
if outfile is None:
|
133 |
+
return None
|
134 |
+
|
135 |
+
# get optimized values from scene
|
136 |
+
scene = scene_state.sparse_ga
|
137 |
+
rgbimg = scene.imgs
|
138 |
+
focals = scene.get_focals().cpu()
|
139 |
+
cams2world = scene.get_im_poses().cpu()
|
140 |
+
|
141 |
+
# 3D pointcloud from depthmap, poses and intrinsics
|
142 |
+
if TSDF_thresh > 0:
|
143 |
+
tsdf = TSDFPostProcess(scene, TSDF_thresh=TSDF_thresh)
|
144 |
+
pts3d, _, confs = to_numpy(tsdf.get_dense_pts3d(clean_depth=clean_depth))
|
145 |
+
else:
|
146 |
+
pts3d, _, confs = to_numpy(scene.get_dense_pts3d(clean_depth=clean_depth))
|
147 |
+
|
148 |
+
torch.save(confs, '/app/data/confs.pt')
|
149 |
+
msk = to_numpy([c > min_conf_thr for c in confs])
|
150 |
+
return _convert_scene_output_to_glb(outfile, rgbimg, pts3d, msk, focals, cams2world, as_pointcloud=as_pointcloud,
|
151 |
+
transparent_cams=transparent_cams, cam_size=cam_size, silent=silent)
|
152 |
+
|
153 |
+
def save_colmap_scene(scene, save_dir, min_conf_thr=2, clean_depth=False):
|
154 |
+
cam2world = scene.get_im_poses().detach().cpu().numpy()
|
155 |
+
world2cam = inv(cam2world) #
|
156 |
+
principal_points = scene.get_principal_points().detach().cpu().numpy()
|
157 |
+
focals = scene.get_focals().detach().cpu().numpy()[..., None]
|
158 |
+
imgs = np.array(scene.imgs)
|
159 |
+
|
160 |
+
pts3d, _, confs = scene.get_dense_pts3d(clean_depth=clean_depth)
|
161 |
+
pts3d = [i.detach().reshape(imgs[0].shape) for i in pts3d] #
|
162 |
+
|
163 |
+
masks = to_numpy([c > min_conf_thr for c in to_numpy(confs)])
|
164 |
+
|
165 |
+
# move
|
166 |
+
mask_images = True
|
167 |
+
|
168 |
+
save_path, images_path, masks_path, sparse_path = init_filestructure(save_dir)
|
169 |
+
save_images_masks(imgs, masks, images_path, masks_path, mask_images)
|
170 |
+
save_cameras(focals, principal_points, sparse_path, imgs_shape=imgs.shape)
|
171 |
+
save_imagestxt(world2cam, sparse_path)
|
172 |
+
save_pointcloud_with_normals(imgs, pts3d, masks, sparse_path)
|
173 |
+
return save_path
|
174 |
+
|
175 |
+
def get_reconstructed_scene(outdir, model, device, silent, image_size, current_scene_state,
|
176 |
+
filelist, optim_level, lr1, niter1, lr2, niter2, min_conf_thr, matching_conf_thr,
|
177 |
+
as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size, scenegraph_type, winsize,
|
178 |
+
win_cyclic, refid, TSDF_thresh, shared_intrinsics, **kw):
|
179 |
+
"""
|
180 |
+
from a list of images, run mast3r inference, sparse global aligner.
|
181 |
+
then run get_3D_model_from_scene
|
182 |
+
"""
|
183 |
+
imgs = load_images(filelist, size=image_size, verbose=not silent)
|
184 |
+
if len(imgs) == 1:
|
185 |
+
imgs = [imgs[0], copy.deepcopy(imgs[0])]
|
186 |
+
imgs[1]['idx'] = 1
|
187 |
+
filelist = [filelist[0], filelist[0] + '_2']
|
188 |
+
|
189 |
+
scene_graph_params = [scenegraph_type]
|
190 |
+
if scenegraph_type in ["swin", "logwin"]:
|
191 |
+
scene_graph_params.append(str(winsize))
|
192 |
+
elif scenegraph_type == "oneref":
|
193 |
+
scene_graph_params.append(str(refid))
|
194 |
+
if scenegraph_type in ["swin", "logwin"] and not win_cyclic:
|
195 |
+
scene_graph_params.append('noncyclic')
|
196 |
+
scene_graph = '-'.join(scene_graph_params)
|
197 |
+
pairs = make_pairs(imgs, scene_graph=scene_graph, prefilter=None, symmetrize=True)
|
198 |
+
if optim_level == 'coarse':
|
199 |
+
niter2 = 0
|
200 |
+
|
201 |
+
base_cache_dir = os.path.join(outdir, 'cache')
|
202 |
+
os.makedirs(base_cache_dir, exist_ok=True)
|
203 |
+
def get_next_dir(base_dir):
|
204 |
+
run_counter = 0
|
205 |
+
while True:
|
206 |
+
run_cache_dir = os.path.join(base_dir, f"run_{run_counter}")
|
207 |
+
if not os.path.exists(run_cache_dir):
|
208 |
+
os.makedirs(run_cache_dir)
|
209 |
+
break
|
210 |
+
run_counter += 1
|
211 |
+
return run_cache_dir
|
212 |
+
|
213 |
+
cache_dir = get_next_dir(base_cache_dir)
|
214 |
+
scene = sparse_global_alignment(filelist, pairs, cache_dir,
|
215 |
+
model, lr1=lr1, niter1=niter1, lr2=lr2, niter2=niter2, device=device,
|
216 |
+
opt_depth='depth' in optim_level, shared_intrinsics=shared_intrinsics,
|
217 |
+
matching_conf_thr=matching_conf_thr, **kw)
|
218 |
+
|
219 |
+
base_colmapdata_dir = os.path.join(outdir, 'colmap_data')
|
220 |
+
os.makedirs(base_colmapdata_dir, exist_ok=True)
|
221 |
+
colmap_data_dir = get_next_dir(base_colmapdata_dir)
|
222 |
+
#
|
223 |
+
save_colmap_scene(scene, colmap_data_dir, min_conf_thr, clean_depth)
|
224 |
+
|
225 |
+
if current_scene_state is not None and \
|
226 |
+
current_scene_state.outfile_name is not None:
|
227 |
+
outfile_name = current_scene_state.outfile_name
|
228 |
+
else:
|
229 |
+
outfile_name = tempfile.mktemp(suffix='_scene.glb', dir=outdir)
|
230 |
+
|
231 |
+
scene_state = SparseGAState(scene, cache_dir, outfile_name)
|
232 |
+
outfile = get_3D_model_from_scene(silent, scene_state, min_conf_thr, as_pointcloud, mask_sky,
|
233 |
+
clean_depth, transparent_cams, cam_size, TSDF_thresh)
|
234 |
+
print(f"colmap_data_dir: {colmap_data_dir}")
|
235 |
+
print(f"outfile_name: {outfile_name}")
|
236 |
+
print(f"cache_dir: {cache_dir}")
|
237 |
+
return scene_state, outfile
|
238 |
+
|
239 |
+
|
240 |
+
def set_scenegraph_options(inputfiles, win_cyclic, refid, scenegraph_type):
|
241 |
+
num_files = len(inputfiles) if inputfiles is not None else 1
|
242 |
+
show_win_controls = scenegraph_type in ["swin", "logwin"]
|
243 |
+
show_winsize = scenegraph_type in ["swin", "logwin"]
|
244 |
+
show_cyclic = scenegraph_type in ["swin", "logwin"]
|
245 |
+
max_winsize, min_winsize = 1, 1
|
246 |
+
if scenegraph_type == "swin":
|
247 |
+
if win_cyclic:
|
248 |
+
max_winsize = max(1, math.ceil((num_files - 1) / 2))
|
249 |
+
else:
|
250 |
+
max_winsize = num_files - 1
|
251 |
+
elif scenegraph_type == "logwin":
|
252 |
+
if win_cyclic:
|
253 |
+
half_size = math.ceil((num_files - 1) / 2)
|
254 |
+
max_winsize = max(1, math.ceil(math.log(half_size, 2)))
|
255 |
+
else:
|
256 |
+
max_winsize = max(1, math.ceil(math.log(num_files, 2)))
|
257 |
+
winsize = gradio.Slider(label="Scene Graph: Window Size", value=max_winsize,
|
258 |
+
minimum=min_winsize, maximum=max_winsize, step=1, visible=show_winsize)
|
259 |
+
win_cyclic = gradio.Checkbox(value=win_cyclic, label="Cyclic sequence", visible=show_cyclic)
|
260 |
+
win_col = gradio.Column(visible=show_win_controls)
|
261 |
+
refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0,
|
262 |
+
maximum=num_files - 1, step=1, visible=scenegraph_type == 'oneref')
|
263 |
+
return win_col, winsize, win_cyclic, refid
|
264 |
+
|
265 |
+
|
266 |
+
|
267 |
+
def mast3r_demo_tab(cache_path, weights_path, device, silent=False):
|
268 |
+
model = AsymmetricMASt3R.from_pretrained(weights_path).to(device)
|
269 |
+
|
270 |
+
if not silent:
|
271 |
+
print('Outputing stuff in', cache_path)
|
272 |
+
|
273 |
+
recon_fun = functools.partial(get_reconstructed_scene, cache_path, model, device,
|
274 |
+
silent)
|
275 |
+
model_from_scene_fun = functools.partial(get_3D_model_from_scene, silent)
|
276 |
+
|
277 |
+
def get_context():
|
278 |
+
css = """.gradio-container {margin: 0 !important; min-width: 100%};"""
|
279 |
+
title = "MASt3R Demo"
|
280 |
+
return gradio.Blocks(css=css, title=title, delete_cache=(True, True))
|
281 |
+
|
282 |
+
with get_context() as demo:
|
283 |
+
# scene state is save so that you can change conf_thr, cam_size... without rerunning the inference
|
284 |
+
scene = gradio.State(None)
|
285 |
+
gradio.HTML('<h2 style="text-align: center;">MASt3R Demo</h2>')
|
286 |
+
with gradio.Column():
|
287 |
+
inputfiles = gradio.File(file_count="multiple")
|
288 |
+
with gradio.Row():
|
289 |
+
with gradio.Column():
|
290 |
+
with gradio.Row():
|
291 |
+
lr1 = gradio.Slider(label="Coarse LR", value=0.07, minimum=0.01, maximum=0.2, step=0.01)
|
292 |
+
niter1 = gradio.Number(value=500, precision=0, minimum=0, maximum=10_000,
|
293 |
+
label="num_iterations", info="For coarse alignment!")
|
294 |
+
lr2 = gradio.Slider(label="Fine LR", value=0.014, minimum=0.005, maximum=0.05, step=0.001)
|
295 |
+
niter2 = gradio.Number(value=200, precision=0, minimum=0, maximum=100_000,
|
296 |
+
label="num_iterations", info="For refinement!")
|
297 |
+
optim_level = gradio.Dropdown(["coarse", "refine", "refine+depth"],
|
298 |
+
value='refine+depth', label="OptLevel",
|
299 |
+
info="Optimization level")
|
300 |
+
image_size = gradio.Dropdown(choices=[512, 224], label="Image Size", value=512)
|
301 |
+
with gradio.Row():
|
302 |
+
matching_conf_thr = gradio.Slider(label="Matching Confidence Thr", value=5.,
|
303 |
+
minimum=0., maximum=30., step=0.1,
|
304 |
+
info="Before Fallback to Regr3D!")
|
305 |
+
shared_intrinsics = gradio.Checkbox(value=False, label="Shared intrinsics",
|
306 |
+
info="Only optimize one set of intrinsics for all views")
|
307 |
+
scenegraph_type = gradio.Dropdown([("complete: all possible image pairs", "complete"),
|
308 |
+
("swin: sliding window", "swin"),
|
309 |
+
("logwin: sliding window with long range", "logwin"),
|
310 |
+
("oneref: match one image with all", "oneref")],
|
311 |
+
value='complete', label="Scenegraph",
|
312 |
+
info="Define how to make pairs",
|
313 |
+
interactive=True)
|
314 |
+
with gradio.Column(visible=False) as win_col:
|
315 |
+
winsize = gradio.Slider(label="Scene Graph: Window Size", value=1,
|
316 |
+
minimum=1, maximum=1, step=1)
|
317 |
+
win_cyclic = gradio.Checkbox(value=False, label="Cyclic sequence")
|
318 |
+
refid = gradio.Slider(label="Scene Graph: Id", value=0,
|
319 |
+
minimum=0, maximum=0, step=1, visible=False)
|
320 |
+
run_btn = gradio.Button("Run")
|
321 |
+
|
322 |
+
with gradio.Row():
|
323 |
+
# adjust the confidence threshold
|
324 |
+
min_conf_thr = gradio.Slider(label="min_conf_thr", value=1.5, minimum=0.0, maximum=10, step=0.1)
|
325 |
+
# adjust the camera size in the output pointcloud
|
326 |
+
cam_size = gradio.Slider(label="cam_size", value=0.2, minimum=0.001, maximum=1.0, step=0.001)
|
327 |
+
TSDF_thresh = gradio.Slider(label="TSDF Threshold", value=0., minimum=0., maximum=1., step=0.01)
|
328 |
+
with gradio.Row():
|
329 |
+
as_pointcloud = gradio.Checkbox(value=True, label="As pointcloud")
|
330 |
+
# two post process implemented
|
331 |
+
mask_sky = gradio.Checkbox(value=False, label="Mask sky")
|
332 |
+
clean_depth = gradio.Checkbox(value=True, label="Clean-up depthmaps")
|
333 |
+
transparent_cams = gradio.Checkbox(value=False, label="Transparent cameras")
|
334 |
+
|
335 |
+
outmodel = gradio.Model3D()
|
336 |
+
|
337 |
+
# events
|
338 |
+
scenegraph_type.change(set_scenegraph_options,
|
339 |
+
inputs=[inputfiles, win_cyclic, refid, scenegraph_type],
|
340 |
+
outputs=[win_col, winsize, win_cyclic, refid])
|
341 |
+
inputfiles.change(set_scenegraph_options,
|
342 |
+
inputs=[inputfiles, win_cyclic, refid, scenegraph_type],
|
343 |
+
outputs=[win_col, winsize, win_cyclic, refid])
|
344 |
+
win_cyclic.change(set_scenegraph_options,
|
345 |
+
inputs=[inputfiles, win_cyclic, refid, scenegraph_type],
|
346 |
+
outputs=[win_col, winsize, win_cyclic, refid])
|
347 |
+
run_btn.click(fn=recon_fun,
|
348 |
+
inputs=[image_size, scene, inputfiles, optim_level, lr1, niter1, lr2, niter2, min_conf_thr, matching_conf_thr,
|
349 |
+
as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size,
|
350 |
+
scenegraph_type, winsize, win_cyclic, refid, TSDF_thresh, shared_intrinsics],
|
351 |
+
outputs=[scene, outmodel])
|
352 |
+
min_conf_thr.release(fn=model_from_scene_fun,
|
353 |
+
inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
|
354 |
+
clean_depth, transparent_cams, cam_size, TSDF_thresh],
|
355 |
+
outputs=outmodel)
|
356 |
+
cam_size.change(fn=model_from_scene_fun,
|
357 |
+
inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
|
358 |
+
clean_depth, transparent_cams, cam_size, TSDF_thresh],
|
359 |
+
outputs=outmodel)
|
360 |
+
TSDF_thresh.change(fn=model_from_scene_fun,
|
361 |
+
inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
|
362 |
+
clean_depth, transparent_cams, cam_size, TSDF_thresh],
|
363 |
+
outputs=outmodel)
|
364 |
+
as_pointcloud.change(fn=model_from_scene_fun,
|
365 |
+
inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
|
366 |
+
clean_depth, transparent_cams, cam_size, TSDF_thresh],
|
367 |
+
outputs=outmodel)
|
368 |
+
mask_sky.change(fn=model_from_scene_fun,
|
369 |
+
inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
|
370 |
+
clean_depth, transparent_cams, cam_size, TSDF_thresh],
|
371 |
+
outputs=outmodel)
|
372 |
+
clean_depth.change(fn=model_from_scene_fun,
|
373 |
+
inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
|
374 |
+
clean_depth, transparent_cams, cam_size, TSDF_thresh],
|
375 |
+
outputs=outmodel)
|
376 |
+
transparent_cams.change(model_from_scene_fun,
|
377 |
+
inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
|
378 |
+
clean_depth, transparent_cams, cam_size, TSDF_thresh],
|
379 |
+
outputs=outmodel)
|
380 |
+
|
381 |
+
return demo
|
382 |
+
|
requirements.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
-e wild-gaussian-splatting
|
wild-gaussian-splatting
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
Subproject commit fe8a9f389cdc583864f34a9e3ae32899c674229a
|