Spaces:
zino36
/
Runtime error

File size: 7,674 Bytes
6151e97
 
 
 
627e15a
6151e97
2a90110
6151e97
 
 
 
 
7494687
6d2638a
0c19783
6151e97
 
 
 
 
0c19783
6151e97
 
 
627e15a
 
 
6151e97
 
 
18276d5
 
6151e97
627e15a
6151e97
627e15a
e83109b
6151e97
 
 
6d2638a
7494687
 
 
 
 
0c19783
 
 
7494687
0c19783
 
 
 
7494687
 
314f8f9
0c19783
 
 
 
 
 
 
 
 
 
 
 
7494687
0c19783
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1d74503
d8503f1
 
0c19783
 
 
 
2f74cb2
71e14c1
d8503f1
a6b79e4
d8503f1
7ec4c89
 
7494687
ac6fc75
d8503f1
7494687
2f74cb2
0c19783
 
7494687
 
 
 
 
 
0c19783
 
 
7494687
 
2f74cb2
 
 
d8503f1
1d74503
2f74cb2
 
 
 
 
 
 
 
 
 
 
1812a4b
2f74cb2
 
7b04511
2f74cb2
d8503f1
2f74cb2
 
 
7494687
0c19783
 
 
 
 
 
d8503f1
6d2638a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# mast3r demo
# --------------------------------------------------------
import spaces
import os
import sys
import os.path as path
import torch
import tempfile
import gradio
import shutil
import math

HERE_PATH = path.normpath(path.dirname(__file__))  # noqa
MASt3R_REPO_PATH = path.normpath(path.join(HERE_PATH, './mast3r'))  # noqa
sys.path.insert(0, MASt3R_REPO_PATH)  # noqa

from mast3r.demo import get_reconstructed_scene
from mast3r.model import AsymmetricMASt3R
from mast3r.utils.misc import hash_md5

import mast3r.utils.path_to_dust3r  # noqa
from dust3r.demo import set_print_with_timestamp

import matplotlib.pyplot as pl
pl.ion()

# for gpu >= Ampere and pytorch >= 1.12
torch.backends.cuda.matmul.allow_tf32 = True
batch_size = 1
set_print_with_timestamp()

weights_path = "naver/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric"
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = AsymmetricMASt3R.from_pretrained(weights_path).to(device)
chkpt_tag = hash_md5(weights_path)

tmpdirname = tempfile.mkdtemp(suffix='_mast3r_gradio_demo')
image_size = 512
silent = True
gradio_delete_cache = 7200


class FileState:
    def __init__(self, outfile_name=None):
        self.outfile_name = outfile_name

    def __del__(self):
        if self.outfile_name is not None and os.path.isfile(self.outfile_name):
            os.remove(self.outfile_name)
        self.outfile_name = None


@spaces.GPU(duration=180)
def local_get_reconstructed_scene(filelist, min_conf_thr, matching_conf_thr,
                                  as_pointcloud, cam_size,
                                  shared_intrinsics, **kw):
    lr1 = 0.07
    niter1 = 500
    lr2 = 0.014
    niter2 = 200
    optim_level = 'refine'
    mask_sky, clean_depth, transparent_cams = False, True, False
    if len(filelist) < 5:
        scenegraph_type = 'complete'
        winsize = 1
    else:
        scenegraph_type = 'logwin'
        half_size = math.ceil((len(filelist) - 1) / 2)
        max_winsize = max(1, math.ceil(math.log(half_size, 2)))
        winsize = min(5, max_winsize)
    refid = 0
    win_cyclic = False
    scene_state, outfile = get_reconstructed_scene(tmpdirname, gradio_delete_cache, model, device, silent, image_size, None,
                                                   filelist, optim_level, lr1, niter1, lr2, niter2, min_conf_thr, matching_conf_thr,
                                                   as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size, scenegraph_type, winsize,
                                                   win_cyclic, refid, TSDF_thresh=0, shared_intrinsics=shared_intrinsics, **kw)
    filestate = FileState(scene_state.outfile_name)
    scene_state.outfile_name = None
    del scene_state
    return filestate, outfile


def run_example(snapshot, matching_conf_thr, min_conf_thr, cam_size, as_pointcloud, shared_intrinsics, filelist, **kw):
    return local_get_reconstructed_scene(filelist, min_conf_thr, matching_conf_thr, as_pointcloud, cam_size, shared_intrinsics, **kw)

css = """.gradio-container {margin: 0 !important; min-width: 100%};"""
title = "MASt3R Demo"
with gradio.Blocks(css=css, title=title, delete_cache=(gradio_delete_cache, gradio_delete_cache)) as demo:
    filestate = gradio.State(None)
    gradio.HTML('<h2 style="text-align: center;">3D Reconstruction with MASt3R</h2>')
    gradio.HTML('<p>Upload one or multiple images (wait for them to be fully uploaded before hitting the run button). '
                'We tested with up to 18 images before running into the allocation timeout - set at 3 minutes but your mileage may vary. '
                'At the very bottom of this page, you will find an example. If you click on it, it will pull the 3D reconstruction from 7 images of the small Naver Labs Europe tower from cache. '
                'If you want to try larger image collections, you can find the more complete version of this demo that you can run locally '
                'and more details about the method at <a href="https://github.com/naver/mast3r">github.com/naver/mast3r</a>. '
                'The checkpoint used in this demo is available at <a href="https://huggingface.co./naver/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric">huggingface.co/naver/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric</a>.</p>')
    with gradio.Column():
        inputfiles = gradio.File(file_count="multiple", file_types= ['image'])
        snapshot = gradio.Image(None, visible=False)
        with gradio.Row():
            matching_conf_thr = gradio.Slider(label="Matching Confidence Thr", value=2.,
                                              minimum=0., maximum=30., step=0.1,
                                              info="Before Fallback to Regr3D!")
            # adjust the confidence threshold
            min_conf_thr = gradio.Slider(label="min_conf_thr", value=1.5, minimum=0.0, maximum=10, step=0.1)
            # adjust the camera size in the output pointcloud
            cam_size = gradio.Slider(label="cam_size", value=0.2, minimum=0.001, maximum=1.0, step=0.001)
        with gradio.Row():
            as_pointcloud = gradio.Checkbox(value=True, label="As pointcloud")
            shared_intrinsics = gradio.Checkbox(value=False, label="Shared intrinsics",
                                                info="Only optimize one set of intrinsics for all views")
        run_btn = gradio.Button("Run")
        outmodel = gradio.Model3D()

        examples = gradio.Examples(
            examples=[
                [
                    os.path.join(HERE_PATH, 'mast3r/assets/NLE_tower/FF5599FD-768B-431A-AB83-BDA5FB44CB9D-83120-000041DADDE35483.jpg'),
                    0.0, 1.5, 0.2, True, False,
                     [os.path.join(HERE_PATH, 'mast3r/assets/NLE_tower/01D90321-69C8-439F-B0B0-E87E7634741C-83120-000041DAE419D7AE.jpg'),
                      os.path.join(
                          HERE_PATH, 'mast3r/assets/NLE_tower/1AD85EF5-B651-4291-A5C0-7BDB7D966384-83120-000041DADF639E09.jpg'),
                      os.path.join(
                          HERE_PATH, 'mast3r/assets/NLE_tower/28EDBB63-B9F9-42FB-AC86-4852A33ED71B-83120-000041DAF22407A1.jpg'),
                      os.path.join(
                          HERE_PATH, 'mast3r/assets/NLE_tower/91E9B685-7A7D-42D7-B933-23A800EE4129-83120-000041DAE12C8176.jpg'),
                      os.path.join(
                          HERE_PATH, 'mast3r/assets/NLE_tower/2679C386-1DC0-4443-81B5-93D7EDE4AB37-83120-000041DADB2EA917.jpg'),
                      os.path.join(
                          HERE_PATH, 'mast3r/assets/NLE_tower/CDBBD885-54C3-4EB4-9181-226059A60EE0-83120-000041DAE0C3D612.jpg'),
                      os.path.join(HERE_PATH, 'mast3r/assets/NLE_tower/FF5599FD-768B-431A-AB83-BDA5FB44CB9D-83120-000041DADDE35483.jpg')]
                ]
            ],
            inputs=[snapshot, matching_conf_thr, min_conf_thr, cam_size, as_pointcloud, shared_intrinsics, inputfiles],
            outputs=[filestate, outmodel],
            fn=run_example,
            cache_examples="lazy",
        )

        # events
        run_btn.click(fn=local_get_reconstructed_scene,
                      inputs=[inputfiles, min_conf_thr, matching_conf_thr,
                              as_pointcloud,
                              cam_size, shared_intrinsics],
                      outputs=[filestate, outmodel])

demo.launch(show_error=True, share=None, server_name=None, server_port=None)
shutil.rmtree(tmpdirname)