|
|
|
import torch |
|
from mini_dust3r.api import OptimizedResult, inferece_dust3r, log_optimized_result |
|
from mini_dust3r.model import AsymmetricCroCo3DStereo |
|
from mini_dust3r.utils.misc import ( |
|
transpose_to_landscape, |
|
) |
|
from mini_dust3r.model import load_model |
|
from pathlib import Path |
|
import uuid |
|
from utils import create_image_grid |
|
import rerun as rr |
|
import rerun.blueprint as rrb |
|
import os |
|
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "CPU" |
|
class AsymmetricMASt3R(AsymmetricCroCo3DStereo): |
|
def __init__(self, desc_mode=('norm'), two_confs=False, desc_conf_mode=None, **kwargs): |
|
self.desc_mode = desc_mode |
|
self.two_confs = two_confs |
|
self.desc_conf_mode = desc_conf_mode |
|
super().__init__(**kwargs) |
|
|
|
@classmethod |
|
def from_pretrained(cls, pretrained_model_name_or_path, **kw): |
|
if os.path.isfile(pretrained_model_name_or_path): |
|
return load_model(pretrained_model_name_or_path, device='cpu') |
|
else: |
|
return super(AsymmetricMASt3R, cls).from_pretrained(pretrained_model_name_or_path, **kw) |
|
|
|
def set_downstream_head(self, output_mode, head_type, landscape_only, depth_mode, conf_mode, patch_size, img_size, **kw): |
|
assert img_size[0] % patch_size == 0 and img_size[ |
|
1] % patch_size == 0, f'{img_size=} must be multiple of {patch_size=}' |
|
self.output_mode = output_mode |
|
self.head_type = head_type |
|
self.depth_mode = depth_mode |
|
self.conf_mode = conf_mode |
|
if self.desc_conf_mode is None: |
|
self.desc_conf_mode = conf_mode |
|
|
|
from mini_dust3r.heads.linear_head import LinearPts3d |
|
from mini_dust3r.heads.dpt_head import create_dpt_head |
|
from catmlp_dpt_head import Cat_MLP_LocalFeatures_DPT_Pts3d, postprocess |
|
|
|
def head_factory(head_type, output_mode, net, has_conf=False): |
|
"""" build a prediction head for the decoder |
|
""" |
|
if head_type == 'linear' and output_mode == 'pts3d': |
|
return LinearPts3d(net, has_conf) |
|
elif head_type == 'dpt' and output_mode == 'pts3d': |
|
return create_dpt_head(net, has_conf=has_conf) |
|
if head_type == 'catmlp+dpt' and output_mode.startswith('pts3d+desc'): |
|
local_feat_dim = int(output_mode[10:]) |
|
assert net.dec_depth > 9 |
|
l2 = net.dec_depth |
|
feature_dim = 256 |
|
last_dim = feature_dim // 2 |
|
out_nchan = 3 |
|
ed = net.enc_embed_dim |
|
dd = net.dec_embed_dim |
|
return Cat_MLP_LocalFeatures_DPT_Pts3d(net, local_feat_dim=local_feat_dim, has_conf=has_conf, |
|
num_channels=out_nchan + has_conf, |
|
feature_dim=feature_dim, |
|
last_dim=last_dim, |
|
hooks_idx=[0, l2 * 2 // 4, l2 * 3 // 4, l2], |
|
dim_tokens=[ed, dd, dd, dd], |
|
postprocess=postprocess, |
|
depth_mode=net.depth_mode, |
|
conf_mode=net.conf_mode, |
|
head_type='regression') |
|
else: |
|
raise NotImplementedError(f"unexpected {head_type=} and {output_mode=}") |
|
self.downstream_head1 = head_factory(head_type, output_mode, self, has_conf=bool(conf_mode)) |
|
self.downstream_head2 = head_factory(head_type, output_mode, self, has_conf=bool(conf_mode)) |
|
|
|
self.head1 = transpose_to_landscape(self.downstream_head1, activate=landscape_only) |
|
self.head2 = transpose_to_landscape(self.downstream_head2, activate=landscape_only) |
|
|
|
class MASt3R: |
|
def __init__(self, device="cuda", model_id="naver/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric"): |
|
self.device = device |
|
self.model = AsymmetricMASt3R.from_pretrained(model_id).to(self.device) |
|
|
|
def create_blueprint(self, image_name_list: list[str], log_path: Path) -> rrb.Blueprint: |
|
|
|
if len(image_name_list) > 4: |
|
blueprint = rrb.Blueprint( |
|
rrb.Horizontal( |
|
rrb.Spatial3DView(origin=f"{log_path}"), |
|
), |
|
collapse_panels=True, |
|
) |
|
else: |
|
blueprint = rrb.Blueprint( |
|
rrb.Horizontal( |
|
contents=[ |
|
rrb.Spatial3DView(origin=f"{log_path}"), |
|
rrb.Vertical( |
|
contents=[ |
|
rrb.Spatial2DView( |
|
origin=f"{log_path}/camera_{i}/pinhole/", |
|
contents=[ |
|
"+ $origin/**", |
|
], |
|
) |
|
for i in range(len(image_name_list)) |
|
] |
|
), |
|
], |
|
column_shares=[3, 1], |
|
), |
|
collapse_panels=True, |
|
) |
|
return blueprint |
|
|
|
def generate_point_cloud(self, image_name_list): |
|
if not isinstance(image_name_list, list) and not isinstance(image_name_list, str): |
|
raise Exception( |
|
f"Input must be a list of strings or a string, got: {type(image_name_list)}" |
|
) |
|
uuid_str = str(uuid.uuid4()) |
|
filename = Path(f"/tmp/gradio/{uuid_str}.rrd") |
|
rr.init(f"{uuid_str}") |
|
log_path = Path("world") |
|
|
|
if isinstance(image_name_list, str): |
|
image_name_list = [image_name_list] |
|
|
|
optimized_results: OptimizedResult = inferece_dust3r( |
|
image_dir_or_list=image_name_list, |
|
model=self.model, |
|
device=DEVICE, |
|
batch_size=1, |
|
) |
|
|
|
blueprint: rrb.Blueprint = self.create_blueprint(image_name_list, log_path) |
|
rr.send_blueprint(blueprint) |
|
|
|
rr.set_time_sequence("sequence", 0) |
|
log_optimized_result(optimized_results, log_path) |
|
rr.save(filename.as_posix()) |
|
return filename.as_posix() |
|
if __name__ == "__main__": |
|
mast3r = MASt3R() |
|
images = ["examples/single_image/bench1.png"] |
|
point_cloud = mast3r.generate_point_cloud(images) |
|
print(point_cloud) |