# mast3r.py import torch from mini_dust3r.api import OptimizedResult, inferece_dust3r, log_optimized_result from mini_dust3r.model import AsymmetricCroCo3DStereo from mini_dust3r.utils.misc import ( transpose_to_landscape, ) from mini_dust3r.model import load_model from pathlib import Path import uuid from utils import create_image_grid # Corrected import import rerun as rr import rerun.blueprint as rrb import os DEVICE = "cuda" if torch.cuda.is_available() else "CPU" class AsymmetricMASt3R(AsymmetricCroCo3DStereo): def __init__(self, desc_mode=('norm'), two_confs=False, desc_conf_mode=None, **kwargs): self.desc_mode = desc_mode self.two_confs = two_confs self.desc_conf_mode = desc_conf_mode super().__init__(**kwargs) @classmethod def from_pretrained(cls, pretrained_model_name_or_path, **kw): if os.path.isfile(pretrained_model_name_or_path): return load_model(pretrained_model_name_or_path, device='cpu') else: return super(AsymmetricMASt3R, cls).from_pretrained(pretrained_model_name_or_path, **kw) def set_downstream_head(self, output_mode, head_type, landscape_only, depth_mode, conf_mode, patch_size, img_size, **kw): assert img_size[0] % patch_size == 0 and img_size[ 1] % patch_size == 0, f'{img_size=} must be multiple of {patch_size=}' self.output_mode = output_mode self.head_type = head_type self.depth_mode = depth_mode self.conf_mode = conf_mode if self.desc_conf_mode is None: self.desc_conf_mode = conf_mode # allocate heads from mini_dust3r.heads.linear_head import LinearPts3d from mini_dust3r.heads.dpt_head import create_dpt_head from catmlp_dpt_head import Cat_MLP_LocalFeatures_DPT_Pts3d, postprocess def head_factory(head_type, output_mode, net, has_conf=False): """" build a prediction head for the decoder """ if head_type == 'linear' and output_mode == 'pts3d': return LinearPts3d(net, has_conf) elif head_type == 'dpt' and output_mode == 'pts3d': return create_dpt_head(net, has_conf=has_conf) if head_type == 'catmlp+dpt' and output_mode.startswith('pts3d+desc'): local_feat_dim = int(output_mode[10:]) assert net.dec_depth > 9 l2 = net.dec_depth feature_dim = 256 last_dim = feature_dim // 2 out_nchan = 3 ed = net.enc_embed_dim dd = net.dec_embed_dim return Cat_MLP_LocalFeatures_DPT_Pts3d(net, local_feat_dim=local_feat_dim, has_conf=has_conf, num_channels=out_nchan + has_conf, feature_dim=feature_dim, last_dim=last_dim, hooks_idx=[0, l2 * 2 // 4, l2 * 3 // 4, l2], dim_tokens=[ed, dd, dd, dd], postprocess=postprocess, depth_mode=net.depth_mode, conf_mode=net.conf_mode, head_type='regression') else: raise NotImplementedError(f"unexpected {head_type=} and {output_mode=}") self.downstream_head1 = head_factory(head_type, output_mode, self, has_conf=bool(conf_mode)) self.downstream_head2 = head_factory(head_type, output_mode, self, has_conf=bool(conf_mode)) # magic wrapper self.head1 = transpose_to_landscape(self.downstream_head1, activate=landscape_only) self.head2 = transpose_to_landscape(self.downstream_head2, activate=landscape_only) class MASt3R: def __init__(self, device="cuda", model_id="naver/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric"): self.device = device self.model = AsymmetricMASt3R.from_pretrained(model_id).to(self.device) def create_blueprint(self, image_name_list: list[str], log_path: Path) -> rrb.Blueprint: # dont show 2d views if there are more than 4 images as to not clutter the view if len(image_name_list) > 4: blueprint = rrb.Blueprint( rrb.Horizontal( rrb.Spatial3DView(origin=f"{log_path}"), ), collapse_panels=True, ) else: blueprint = rrb.Blueprint( rrb.Horizontal( contents=[ rrb.Spatial3DView(origin=f"{log_path}"), rrb.Vertical( contents=[ rrb.Spatial2DView( origin=f"{log_path}/camera_{i}/pinhole/", contents=[ "+ $origin/**", ], ) for i in range(len(image_name_list)) ] ), ], column_shares=[3, 1], ), collapse_panels=True, ) return blueprint def generate_point_cloud(self, image_name_list): if not isinstance(image_name_list, list) and not isinstance(image_name_list, str): raise Exception( f"Input must be a list of strings or a string, got: {type(image_name_list)}" ) uuid_str = str(uuid.uuid4()) filename = Path(f"/tmp/gradio/{uuid_str}.rrd") rr.init(f"{uuid_str}") log_path = Path("world") if isinstance(image_name_list, str): image_name_list = [image_name_list] optimized_results: OptimizedResult = inferece_dust3r( image_dir_or_list=image_name_list, model=self.model, device=DEVICE, batch_size=1, ) blueprint: rrb.Blueprint = self.create_blueprint(image_name_list, log_path) rr.send_blueprint(blueprint) rr.set_time_sequence("sequence", 0) log_optimized_result(optimized_results, log_path) rr.save(filename.as_posix()) return filename.as_posix() if __name__ == "__main__": mast3r = MASt3R() images = ["examples/single_image/bench1.png"] point_cloud = mast3r.generate_point_cloud(images) print(point_cloud)