Ryukijano commited on
Commit
f7427dd
Β·
verified Β·
1 Parent(s): 7a18ab3

Create mast3r.py

Browse files
Files changed (1) hide show
  1. mast3r.py +148 -0
mast3r.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # timeforge/mast3r.py
2
+ import torch
3
+ from mini_dust3r.api import OptimizedResult, inferece_dust3r, log_optimized_result
4
+ from mini_dust3r.model import AsymmetricCroCo3DStereo
5
+ from mini_dust3r.utils.misc import (
6
+ transpose_to_landscape,
7
+ )
8
+ from mini_dust3r.model import load_model
9
+ from pathlib import Path
10
+ import uuid
11
+ from timeforge.utils import create_image_grid
12
+ import rerun as rr
13
+ import rerun.blueprint as rrb
14
+ import os
15
+
16
+ DEVICE = "cuda" if torch.cuda.is_available() else "CPU"
17
+ class AsymmetricMASt3R(AsymmetricCroCo3DStereo):
18
+ def __init__(self, desc_mode=('norm'), two_confs=False, desc_conf_mode=None, **kwargs):
19
+ self.desc_mode = desc_mode
20
+ self.two_confs = two_confs
21
+ self.desc_conf_mode = desc_conf_mode
22
+ super().__init__(**kwargs)
23
+
24
+ @classmethod
25
+ def from_pretrained(cls, pretrained_model_name_or_path, **kw):
26
+ if os.path.isfile(pretrained_model_name_or_path):
27
+ return load_model(pretrained_model_name_or_path, device='cpu')
28
+ else:
29
+ return super(AsymmetricMASt3R, cls).from_pretrained(pretrained_model_name_or_path, **kw)
30
+
31
+ def set_downstream_head(self, output_mode, head_type, landscape_only, depth_mode, conf_mode, patch_size, img_size, **kw):
32
+ assert img_size[0] % patch_size == 0 and img_size[
33
+ 1] % patch_size == 0, f'{img_size=} must be multiple of {patch_size=}'
34
+ self.output_mode = output_mode
35
+ self.head_type = head_type
36
+ self.depth_mode = depth_mode
37
+ self.conf_mode = conf_mode
38
+ if self.desc_conf_mode is None:
39
+ self.desc_conf_mode = conf_mode
40
+ # allocate heads
41
+ from mini_dust3r.heads.linear_head import LinearPts3d
42
+ from mini_dust3r.heads.dpt_head import create_dpt_head
43
+ from catmlp_dpt_head import Cat_MLP_LocalFeatures_DPT_Pts3d, postprocess
44
+
45
+ def head_factory(head_type, output_mode, net, has_conf=False):
46
+ """" build a prediction head for the decoder
47
+ """
48
+ if head_type == 'linear' and output_mode == 'pts3d':
49
+ return LinearPts3d(net, has_conf)
50
+ elif head_type == 'dpt' and output_mode == 'pts3d':
51
+ return create_dpt_head(net, has_conf=has_conf)
52
+ if head_type == 'catmlp+dpt' and output_mode.startswith('pts3d+desc'):
53
+ local_feat_dim = int(output_mode[10:])
54
+ assert net.dec_depth > 9
55
+ l2 = net.dec_depth
56
+ feature_dim = 256
57
+ last_dim = feature_dim // 2
58
+ out_nchan = 3
59
+ ed = net.enc_embed_dim
60
+ dd = net.dec_embed_dim
61
+ return Cat_MLP_LocalFeatures_DPT_Pts3d(net, local_feat_dim=local_feat_dim, has_conf=has_conf,
62
+ num_channels=out_nchan + has_conf,
63
+ feature_dim=feature_dim,
64
+ last_dim=last_dim,
65
+ hooks_idx=[0, l2 * 2 // 4, l2 * 3 // 4, l2],
66
+ dim_tokens=[ed, dd, dd, dd],
67
+ postprocess=postprocess,
68
+ depth_mode=net.depth_mode,
69
+ conf_mode=net.conf_mode,
70
+ head_type='regression')
71
+ else:
72
+ raise NotImplementedError(f"unexpected {head_type=} and {output_mode=}")
73
+ self.downstream_head1 = head_factory(head_type, output_mode, self, has_conf=bool(conf_mode))
74
+ self.downstream_head2 = head_factory(head_type, output_mode, self, has_conf=bool(conf_mode))
75
+ # magic wrapper
76
+ self.head1 = transpose_to_landscape(self.downstream_head1, activate=landscape_only)
77
+ self.head2 = transpose_to_landscape(self.downstream_head2, activate=landscape_only)
78
+
79
+ class MASt3R:
80
+ def __init__(self, device="cuda", model_id="naver/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric"):
81
+ self.device = device
82
+ self.model = AsymmetricMASt3R.from_pretrained(model_id).to(self.device)
83
+
84
+ def create_blueprint(self, image_name_list: list[str], log_path: Path) -> rrb.Blueprint:
85
+ # dont show 2d views if there are more than 4 images as to not clutter the view
86
+ if len(image_name_list) > 4:
87
+ blueprint = rrb.Blueprint(
88
+ rrb.Horizontal(
89
+ rrb.Spatial3DView(origin=f"{log_path}"),
90
+ ),
91
+ collapse_panels=True,
92
+ )
93
+ else:
94
+ blueprint = rrb.Blueprint(
95
+ rrb.Horizontal(
96
+ contents=[
97
+ rrb.Spatial3DView(origin=f"{log_path}"),
98
+ rrb.Vertical(
99
+ contents=[
100
+ rrb.Spatial2DView(
101
+ origin=f"{log_path}/camera_{i}/pinhole/",
102
+ contents=[
103
+ "+ $origin/**",
104
+ ],
105
+ )
106
+ for i in range(len(image_name_list))
107
+ ]
108
+ ),
109
+ ],
110
+ column_shares=[3, 1],
111
+ ),
112
+ collapse_panels=True,
113
+ )
114
+ return blueprint
115
+
116
+
117
+ def generate_point_cloud(self, image_name_list):
118
+ if not isinstance(image_name_list, list) and not isinstance(image_name_list, str):
119
+ raise Exception(
120
+ f"Input must be a list of strings or a string, got: {type(image_name_list)}"
121
+ )
122
+ uuid_str = str(uuid.uuid4())
123
+ filename = Path(f"/tmp/gradio/{uuid_str}.rrd")
124
+ rr.init(f"{uuid_str}")
125
+ log_path = Path("world")
126
+
127
+ if isinstance(image_name_list, str):
128
+ image_name_list = [image_name_list]
129
+
130
+ optimized_results: OptimizedResult = inferece_dust3r(
131
+ image_dir_or_list=image_name_list,
132
+ model=self.model,
133
+ device=DEVICE,
134
+ batch_size=1,
135
+ )
136
+
137
+ blueprint: rrb.Blueprint = self.create_blueprint(image_name_list, log_path)
138
+ rr.send_blueprint(blueprint)
139
+
140
+ rr.set_time_sequence("sequence", 0)
141
+ log_optimized_result(optimized_results, log_path)
142
+ rr.save(filename.as_posix())
143
+ return filename.as_posix()
144
+ if __name__ == "__main__":
145
+ mast3r = MASt3R()
146
+ images = ["examples/single_image/bench1.png"]
147
+ point_cloud = mast3r.generate_point_cloud(images)
148
+ print(point_cloud)