splatt3r / demo.py
brandonsmart's picture
Initial commit
5ed9923
raw
history blame
No virus
7.09 kB
#!/usr/bin/env python3
# The MASt3R Gradio demo, modified for predicting 3D Gaussian Splats
# --- Original License ---
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
import functools
import os
import sys
import tempfile
import gradio
import torch
sys.path.append('src/mast3r_src')
sys.path.append('src/mast3r_src/dust3r')
sys.path.append('src/pixelsplat_src')
from dust3r.utils.image import load_images
from mast3r.utils.misc import hash_md5
import main
import utils.export as export
def get_reconstructed_scene(outdir, model, device, silent, image_size, ios_mode, filelist):
if ios_mode:
filelist = [f[0] for f in filelist]
if len(filelist) == 1:
filelist = [filelist[0], filelist[0]]
assert len(filelist) == 2, "Please provide two images"
imgs = load_images(filelist, size=image_size, verbose=not silent)
for img in imgs:
img['img'] = img['img'].to(device)
img['original_img'] = img['original_img'].to(device)
img['true_shape'] = torch.from_numpy(img['true_shape'])
output = model(imgs[0], imgs[1])
pred1, pred2 = output
plyfile = os.path.join(outdir, 'gaussians.ply')
export.save_as_ply(pred1, pred2, plyfile)
return plyfile
if __name__ == '__main__':
weights_path = sys.argv[1]
image_size = 512
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
server_name = '127.0.0.1'
server_port = None
share = True
silent = False
ios_mode = True
model = main.MAST3RGaussians.load_from_checkpoint(weights_path, device)
chkpt_tag = hash_md5(weights_path)
# Define example inputs and their corresponding precalculated outputs
examples = [
["assets/demo_examples/scannet++_1_img_1.jpg", "assets/demo_examples/scannet++_1_img_2.jpg", "assets/demo_examples/scannet++_1.ply"],
["assets/demo_examples/scannet++_2_img_1.jpg", "assets/demo_examples/scannet++_2_img_2.jpg", "assets/demo_examples/scannet++_2.ply"],
["assets/demo_examples/scannet++_3_img_1.jpg", "assets/demo_examples/scannet++_3_img_2.jpg", "assets/demo_examples/scannet++_3.ply"],
["assets/demo_examples/scannet++_4_img_1.jpg", "assets/demo_examples/scannet++_4_img_2.jpg", "assets/demo_examples/scannet++_4.ply"],
["assets/demo_examples/scannet++_5_img_1.jpg", "assets/demo_examples/scannet++_5_img_2.jpg", "assets/demo_examples/scannet++_5.ply"],
["assets/demo_examples/scannet++_6_img_1.jpg", "assets/demo_examples/scannet++_6_img_2.jpg", "assets/demo_examples/scannet++_6.ply"],
["assets/demo_examples/scannet++_7_img_1.jpg", "assets/demo_examples/scannet++_7_img_2.jpg", "assets/demo_examples/scannet++_7.ply"],
["assets/demo_examples/scannet++_8_img_1.jpg", "assets/demo_examples/scannet++_8_img_2.jpg", "assets/demo_examples/scannet++_8.ply"],
["assets/demo_examples/in_the_wild_1_img_1.jpg", "assets/demo_examples/in_the_wild_1_img_2.jpg", "assets/demo_examples/in_the_wild_1.ply"],
["assets/demo_examples/in_the_wild_2_img_1.jpg", "assets/demo_examples/in_the_wild_2_img_2.jpg", "assets/demo_examples/in_the_wild_2.ply"],
["assets/demo_examples/in_the_wild_3_img_1.jpg", "assets/demo_examples/in_the_wild_3_img_2.jpg", "assets/demo_examples/in_the_wild_3.ply"],
["assets/demo_examples/in_the_wild_4_img_1.jpg", "assets/demo_examples/in_the_wild_4_img_2.jpg", "assets/demo_examples/in_the_wild_4.ply"],
["assets/demo_examples/in_the_wild_5_img_1.jpg", "assets/demo_examples/in_the_wild_5_img_2.jpg", "assets/demo_examples/in_the_wild_5.ply"],
["assets/demo_examples/in_the_wild_6_img_1.jpg", "assets/demo_examples/in_the_wild_6_img_2.jpg", "assets/demo_examples/in_the_wild_6.ply"],
["assets/demo_examples/in_the_wild_7_img_1.jpg", "assets/demo_examples/in_the_wild_7_img_2.jpg", "assets/demo_examples/in_the_wild_7.ply"],
["assets/demo_examples/in_the_wild_8_img_1.jpg", "assets/demo_examples/in_the_wild_8_img_2.jpg", "assets/demo_examples/in_the_wild_8.ply"],
]
with tempfile.TemporaryDirectory(suffix='_mast3r_gradio_demo') as tmpdirname:
cache_path = os.path.join(tmpdirname, chkpt_tag)
os.makedirs(cache_path, exist_ok=True)
recon_fun = functools.partial(get_reconstructed_scene, tmpdirname, model, device, silent, image_size, ios_mode)
if not ios_mode:
for i in range(len(examples)):
examples[i].insert(2, (examples[i][0], examples[i][1]))
css = """.gradio-container {margin: 0 !important; min-width: 100%};"""
with gradio.Blocks(css=css, title="Splatt3R Demo") as demo:
gradio.HTML('<h2 style="text-align: center;">Splatt3R Demo</h2>')
with gradio.Column():
gradio.Markdown('''
Please upload exactly one or two images below to be used for reconstruction.
If non-square images are uploaded, they will be cropped to squares for reconstruction.
''')
if ios_mode:
inputfiles = gradio.Gallery(type="filepath")
else:
inputfiles = gradio.File(file_count="multiple")
run_btn = gradio.Button("Run")
gradio.Markdown('''
## Output
Below we show the generated 3D Gaussian Splat.
There may be a short delay as the reconstruction needs to be downloaded before rendering.
The arrow in the top right of the window below can be used to download the .ply for rendering with other viewers,
such as [here](https://projects.markkellogg.org/threejs/demo_gaussian_splats_3d.php?art=1&cu=0,-1,0&cp=0,1,0&cla=1,0,0&aa=false&2d=false&sh=0) or [here](https://playcanvas.com/supersplat/editor)
''')
outmodel = gradio.Model3D(
clear_color=[1.0, 1.0, 1.0, 0.0],
)
run_btn.click(fn=recon_fun, inputs=[inputfiles], outputs=[outmodel])
gradio.Markdown('''
## Examples
A gallery of examples generated from ScanNet++ and from 'in the wild' images taken with a mobile phone.
''')
snapshot_1 = gradio.Image(None, visible=False)
snapshot_2 = gradio.Image(None, visible=False)
if ios_mode:
gradio.Examples(
examples=examples,
inputs=[snapshot_1, snapshot_2, outmodel],
examples_per_page=5
)
else:
gradio.Examples(
examples=examples,
inputs=[snapshot_1, snapshot_2, inputfiles, outmodel],
examples_per_page=5
)
demo.launch(share=share, server_name=server_name, server_port=server_port)