File size: 6,701 Bytes
f7427dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
# timeforge/mast3r.py
import torch
from mini_dust3r.api import OptimizedResult, inferece_dust3r, log_optimized_result
from mini_dust3r.model import AsymmetricCroCo3DStereo
from mini_dust3r.utils.misc import (
    transpose_to_landscape,
)
from mini_dust3r.model import load_model
from pathlib import Path
import uuid
from timeforge.utils import create_image_grid
import rerun as rr
import rerun.blueprint as rrb
import os

DEVICE = "cuda" if torch.cuda.is_available() else "CPU"
class AsymmetricMASt3R(AsymmetricCroCo3DStereo):
    def __init__(self, desc_mode=('norm'), two_confs=False, desc_conf_mode=None, **kwargs):
        self.desc_mode = desc_mode
        self.two_confs = two_confs
        self.desc_conf_mode = desc_conf_mode
        super().__init__(**kwargs)

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, **kw):
        if os.path.isfile(pretrained_model_name_or_path):
            return load_model(pretrained_model_name_or_path, device='cpu')
        else:
            return super(AsymmetricMASt3R, cls).from_pretrained(pretrained_model_name_or_path, **kw)

    def set_downstream_head(self, output_mode, head_type, landscape_only, depth_mode, conf_mode, patch_size, img_size, **kw):
        assert img_size[0] % patch_size == 0 and img_size[
            1] % patch_size == 0, f'{img_size=} must be multiple of {patch_size=}'
        self.output_mode = output_mode
        self.head_type = head_type
        self.depth_mode = depth_mode
        self.conf_mode = conf_mode
        if self.desc_conf_mode is None:
            self.desc_conf_mode = conf_mode
        # allocate heads
        from mini_dust3r.heads.linear_head import LinearPts3d
        from mini_dust3r.heads.dpt_head import create_dpt_head
        from catmlp_dpt_head import Cat_MLP_LocalFeatures_DPT_Pts3d, postprocess

        def head_factory(head_type, output_mode, net, has_conf=False):
            """" build a prediction head for the decoder
            """
            if head_type == 'linear' and output_mode == 'pts3d':
                return LinearPts3d(net, has_conf)
            elif head_type == 'dpt' and output_mode == 'pts3d':
                return create_dpt_head(net, has_conf=has_conf)
            if head_type == 'catmlp+dpt' and output_mode.startswith('pts3d+desc'):
                local_feat_dim = int(output_mode[10:])
                assert net.dec_depth > 9
                l2 = net.dec_depth
                feature_dim = 256
                last_dim = feature_dim // 2
                out_nchan = 3
                ed = net.enc_embed_dim
                dd = net.dec_embed_dim
                return Cat_MLP_LocalFeatures_DPT_Pts3d(net, local_feat_dim=local_feat_dim, has_conf=has_conf,
                                                       num_channels=out_nchan + has_conf,
                                                       feature_dim=feature_dim,
                                                       last_dim=last_dim,
                                                       hooks_idx=[0, l2 * 2 // 4, l2 * 3 // 4, l2],
                                                       dim_tokens=[ed, dd, dd, dd],
                                                       postprocess=postprocess,
                                                       depth_mode=net.depth_mode,
                                                       conf_mode=net.conf_mode,
                                                       head_type='regression')
            else:
                raise NotImplementedError(f"unexpected {head_type=} and {output_mode=}")
        self.downstream_head1 = head_factory(head_type, output_mode, self, has_conf=bool(conf_mode))
        self.downstream_head2 = head_factory(head_type, output_mode, self, has_conf=bool(conf_mode))
        # magic wrapper
        self.head1 = transpose_to_landscape(self.downstream_head1, activate=landscape_only)
        self.head2 = transpose_to_landscape(self.downstream_head2, activate=landscape_only)

class MASt3R:
    def __init__(self, device="cuda", model_id="naver/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric"):
        self.device = device
        self.model = AsymmetricMASt3R.from_pretrained(model_id).to(self.device)

    def create_blueprint(self, image_name_list: list[str], log_path: Path) -> rrb.Blueprint:
        # dont show 2d views if there are more than 4 images as to not clutter the view
        if len(image_name_list) > 4:
            blueprint = rrb.Blueprint(
                rrb.Horizontal(
                    rrb.Spatial3DView(origin=f"{log_path}"),
                ),
                collapse_panels=True,
            )
        else:
            blueprint = rrb.Blueprint(
                rrb.Horizontal(
                    contents=[
                        rrb.Spatial3DView(origin=f"{log_path}"),
                        rrb.Vertical(
                            contents=[
                                rrb.Spatial2DView(
                                    origin=f"{log_path}/camera_{i}/pinhole/",
                                    contents=[
                                        "+ $origin/**",
                                    ],
                                )
                                for i in range(len(image_name_list))
                            ]
                        ),
                    ],
                    column_shares=[3, 1],
                ),
                collapse_panels=True,
            )
        return blueprint


    def generate_point_cloud(self, image_name_list):
        if not isinstance(image_name_list, list) and not isinstance(image_name_list, str):
            raise Exception(
            f"Input must be a list of strings or a string, got: {type(image_name_list)}"
            )
        uuid_str = str(uuid.uuid4())
        filename = Path(f"/tmp/gradio/{uuid_str}.rrd")
        rr.init(f"{uuid_str}")
        log_path = Path("world")

        if isinstance(image_name_list, str):
            image_name_list = [image_name_list]

        optimized_results: OptimizedResult = inferece_dust3r(
            image_dir_or_list=image_name_list,
            model=self.model,
            device=DEVICE,
            batch_size=1,
        )

        blueprint: rrb.Blueprint = self.create_blueprint(image_name_list, log_path)
        rr.send_blueprint(blueprint)

        rr.set_time_sequence("sequence", 0)
        log_optimized_result(optimized_results, log_path)
        rr.save(filename.as_posix())
        return filename.as_posix()
if __name__ == "__main__":
    mast3r = MASt3R()
    images = ["examples/single_image/bench1.png"]
    point_cloud = mast3r.generate_point_cloud(images)
    print(point_cloud)