TimeForge / mast3r.py
Ryukijano's picture
Update mast3r.py
8a533f5 verified
raw
history blame
6.7 kB
# mast3r.py
import torch
from mini_dust3r.api import OptimizedResult, inferece_dust3r, log_optimized_result
from mini_dust3r.model import AsymmetricCroCo3DStereo
from mini_dust3r.utils.misc import (
transpose_to_landscape,
)
from mini_dust3r.model import load_model
from pathlib import Path
import uuid
from utils import create_image_grid # Corrected import
import rerun as rr
import rerun.blueprint as rrb
import os
DEVICE = "cuda" if torch.cuda.is_available() else "CPU"
class AsymmetricMASt3R(AsymmetricCroCo3DStereo):
def __init__(self, desc_mode=('norm'), two_confs=False, desc_conf_mode=None, **kwargs):
self.desc_mode = desc_mode
self.two_confs = two_confs
self.desc_conf_mode = desc_conf_mode
super().__init__(**kwargs)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kw):
if os.path.isfile(pretrained_model_name_or_path):
return load_model(pretrained_model_name_or_path, device='cpu')
else:
return super(AsymmetricMASt3R, cls).from_pretrained(pretrained_model_name_or_path, **kw)
def set_downstream_head(self, output_mode, head_type, landscape_only, depth_mode, conf_mode, patch_size, img_size, **kw):
assert img_size[0] % patch_size == 0 and img_size[
1] % patch_size == 0, f'{img_size=} must be multiple of {patch_size=}'
self.output_mode = output_mode
self.head_type = head_type
self.depth_mode = depth_mode
self.conf_mode = conf_mode
if self.desc_conf_mode is None:
self.desc_conf_mode = conf_mode
# allocate heads
from mini_dust3r.heads.linear_head import LinearPts3d
from mini_dust3r.heads.dpt_head import create_dpt_head
from catmlp_dpt_head import Cat_MLP_LocalFeatures_DPT_Pts3d, postprocess
def head_factory(head_type, output_mode, net, has_conf=False):
"""" build a prediction head for the decoder
"""
if head_type == 'linear' and output_mode == 'pts3d':
return LinearPts3d(net, has_conf)
elif head_type == 'dpt' and output_mode == 'pts3d':
return create_dpt_head(net, has_conf=has_conf)
if head_type == 'catmlp+dpt' and output_mode.startswith('pts3d+desc'):
local_feat_dim = int(output_mode[10:])
assert net.dec_depth > 9
l2 = net.dec_depth
feature_dim = 256
last_dim = feature_dim // 2
out_nchan = 3
ed = net.enc_embed_dim
dd = net.dec_embed_dim
return Cat_MLP_LocalFeatures_DPT_Pts3d(net, local_feat_dim=local_feat_dim, has_conf=has_conf,
num_channels=out_nchan + has_conf,
feature_dim=feature_dim,
last_dim=last_dim,
hooks_idx=[0, l2 * 2 // 4, l2 * 3 // 4, l2],
dim_tokens=[ed, dd, dd, dd],
postprocess=postprocess,
depth_mode=net.depth_mode,
conf_mode=net.conf_mode,
head_type='regression')
else:
raise NotImplementedError(f"unexpected {head_type=} and {output_mode=}")
self.downstream_head1 = head_factory(head_type, output_mode, self, has_conf=bool(conf_mode))
self.downstream_head2 = head_factory(head_type, output_mode, self, has_conf=bool(conf_mode))
# magic wrapper
self.head1 = transpose_to_landscape(self.downstream_head1, activate=landscape_only)
self.head2 = transpose_to_landscape(self.downstream_head2, activate=landscape_only)
class MASt3R:
def __init__(self, device="cuda", model_id="naver/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric"):
self.device = device
self.model = AsymmetricMASt3R.from_pretrained(model_id).to(self.device)
def create_blueprint(self, image_name_list: list[str], log_path: Path) -> rrb.Blueprint:
# dont show 2d views if there are more than 4 images as to not clutter the view
if len(image_name_list) > 4:
blueprint = rrb.Blueprint(
rrb.Horizontal(
rrb.Spatial3DView(origin=f"{log_path}"),
),
collapse_panels=True,
)
else:
blueprint = rrb.Blueprint(
rrb.Horizontal(
contents=[
rrb.Spatial3DView(origin=f"{log_path}"),
rrb.Vertical(
contents=[
rrb.Spatial2DView(
origin=f"{log_path}/camera_{i}/pinhole/",
contents=[
"+ $origin/**",
],
)
for i in range(len(image_name_list))
]
),
],
column_shares=[3, 1],
),
collapse_panels=True,
)
return blueprint
def generate_point_cloud(self, image_name_list):
if not isinstance(image_name_list, list) and not isinstance(image_name_list, str):
raise Exception(
f"Input must be a list of strings or a string, got: {type(image_name_list)}"
)
uuid_str = str(uuid.uuid4())
filename = Path(f"/tmp/gradio/{uuid_str}.rrd")
rr.init(f"{uuid_str}")
log_path = Path("world")
if isinstance(image_name_list, str):
image_name_list = [image_name_list]
optimized_results: OptimizedResult = inferece_dust3r(
image_dir_or_list=image_name_list,
model=self.model,
device=DEVICE,
batch_size=1,
)
blueprint: rrb.Blueprint = self.create_blueprint(image_name_list, log_path)
rr.send_blueprint(blueprint)
rr.set_time_sequence("sequence", 0)
log_optimized_result(optimized_results, log_path)
rr.save(filename.as_posix())
return filename.as_posix()
if __name__ == "__main__":
mast3r = MASt3R()
images = ["examples/single_image/bench1.png"]
point_cloud = mast3r.generate_point_cloud(images)
print(point_cloud)