diff --git a/app.py b/app.py index 07938cf89dfdb31378296518585571ddcb025c4b..ef98df51d32956d9242e8cac3837980fc45731e2 100644 --- a/app.py +++ b/app.py @@ -4,68 +4,178 @@ import torch import numpy as np import gradio as gr +import trimesh +import sys +import os +sys.path.append('vggsfm_code/') +import shutil -def parse_video(video_file): - vs = cv2.VideoCapture(video_file) - - frames = [] - while True: - (gotit, frame) = vs.read() - if frame is not None: - frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) - frames.append(frame) - if not gotit: - break +from vggsfm_code.hf_demo import demo_fn +from omegaconf import DictConfig, OmegaConf +from viz_utils.viz_fn import add_camera - return np.stack(frames) +# +from scipy.spatial.transform import Rotation +import PIL +import spaces @spaces.GPU -def cotracker_demo( +def vggsfm_demo( + input_image, input_video, - grid_size: int = 10, - tracks_leave_trace: bool = False, + query_frame_num, + max_query_pts + # grid_size: int = 10, ): - load_video = parse_video(input_video) - load_video = torch.from_numpy(load_video).permute(0, 3, 1, 2)[None].float() + cfg_file = "vggsfm_code/cfgs/demo.yaml" + cfg = OmegaConf.load(cfg_file) + + max_input_image = 20 + + target_dir = f"input_images" + if os.path.exists(target_dir): + shutil.rmtree(target_dir) + + os.makedirs(target_dir) + target_dir_images = target_dir + "/images" + os.makedirs(target_dir_images) + + if input_image is not None: + if len(input_image)<3: + return None, "Please input at least three frames" + + input_image = sorted(input_image) + input_image = input_image[:max_input_image] + + # Copy files to the new directory + for file_name in input_image: + shutil.copy(file_name, target_dir_images) + elif input_video is not None: + vs = cv2.VideoCapture(input_video) + + fps = vs.get(cv2.CAP_PROP_FPS) + + frame_rate = 1 + frame_interval = int(fps * frame_rate) + + video_frame_num = 0 + count = 0 + + while video_frame_num<=max_input_image: + (gotit, frame) = vs.read() + count +=1 + + if count % frame_interval == 0: + cv2.imwrite(target_dir_images+"/"+f"{video_frame_num:06}.png", frame) + video_frame_num+=1 + if not gotit: + break + if video_frame_num<3: + return None, "Please input at least three frames" + else: + return None, "Input format incorrect" + + cfg.query_frame_num = query_frame_num + cfg.max_query_pts = max_query_pts + print(f"Files have been copied to {target_dir_images}") + cfg.SCENE_DIR = target_dir + + predictions = demo_fn(cfg) + + glbfile = vggsfm_predictions_to_glb(predictions) + + + print(input_image) + print(input_video) + return glbfile, "Success" - import time - def current_milli_time(): - return round(time.time() * 1000) - filename = str(current_milli_time()) + +def vggsfm_predictions_to_glb(predictions): + # learned from https://github.com/naver/dust3r/blob/main/dust3r/viz.py + points3D = predictions["points3D"].cpu().numpy() + points3D_rgb = predictions["points3D_rgb"].cpu().numpy() + points3D_rgb = (points3D_rgb*255).astype(np.uint8) - return os.path.join( - os.path.dirname(__file__), "results", f"{filename}.mp4" - ) + extrinsics_opencv = predictions["extrinsics_opencv"].cpu().numpy() + intrinsics_opencv = predictions["intrinsics_opencv"].cpu().numpy() + raw_image_paths = predictions["raw_image_paths"] + images = predictions["images"].permute(0,2,3,1).cpu().numpy() + images = (images*255).astype(np.uint8) + + glbscene = trimesh.Scene() + point_cloud = trimesh.PointCloud(points3D, colors=points3D_rgb) + glbscene.add_geometry(point_cloud) + + + camera_edge_colors = [(255, 0, 0), (0, 0, 255), (0, 255, 0), (255, 0, 255), (255, 204, 0), (0, 204, 204), + (128, 255, 255), (255, 128, 255), (255, 255, 128), (0, 0, 0), (128, 128, 128)] + + frame_num = len(extrinsics_opencv) + extrinsics_opencv_4x4 = np.zeros((frame_num, 4, 4)) + extrinsics_opencv_4x4[:, :3, :4] = extrinsics_opencv + extrinsics_opencv_4x4[:, 3, 3] = 1 + for idx in range(frame_num): + cam_from_world = extrinsics_opencv_4x4[idx] + cam_to_world = np.linalg.inv(cam_from_world) + cur_cam_color = camera_edge_colors[idx % len(camera_edge_colors)] + cur_focal = intrinsics_opencv[idx, 0, 0] + # cur_image_path = raw_image_paths[idx] + # cur_image = np.array(PIL.Image.open(cur_image_path)) + # add_camera(glbscene, cam_to_world, cur_cam_color, image=None, imsize=cur_image.shape[1::-1], + # focal=None,screen_width=0.3) + add_camera(glbscene, cam_to_world, cur_cam_color, image=None, imsize=(1024,1024), + focal=None,screen_width=0.35) -app = gr.Interface( - title="🎨 CoTracker: It is Better to Track Together", - description="
Welcome to CoTracker! This space demonstrates point (pixel) tracking in videos. \ - Points are sampled on a regular grid and are tracked jointly.
\ -To get started, simply upload your .mp4 video in landscape orientation or click on one of the example videos to load them. The shorter the video, the faster the processing. We recommend submitting short videos of length 2-7 seconds.
\ -For more details, check out our GitHub Repo ⭐
\ -Welcome to VGGSfM!", + fn=vggsfm_demo, + inputs=[ + gr.File(file_count="multiple", label="Input Images", interactive=True), + gr.Video(label="Input video", interactive=True), + gr.Slider(minimum=1, maximum=10, step=1, value=5, label="Number of query images"), + gr.Slider(minimum=512, maximum=4096, step=1, value=1024, label="Number of query points"), + ], + outputs=[gr.Model3D(label="Reconstruction"), gr.Textbox(label="Log")], + cache_examples=True, + allow_flagging=False, + ) + demo.queue(max_size=20, concurrency_count=1).launch(debug=True) + + # demo.launch(debug=True, share=True) +else: + import glob + files = glob.glob(f'vggsfm_code/examples/cake/images/*', recursive=True) + vggsfm_demo(files, None, None) + + +# demo.queue(max_size=20, concurrency_count=1).launch(debug=True, share=True) diff --git a/debug_demo.py b/debug_demo.py new file mode 100644 index 0000000000000000000000000000000000000000..f01818a7a72a1a541a5c7afd16577437fb6787b4 --- /dev/null +++ b/debug_demo.py @@ -0,0 +1,31 @@ +import gradio as gr + +def greet(name, intensity): + return "Hello, " + name + "!" * int(intensity) + +demo = gr.Interface( + fn=greet, + inputs=["text", "slider"], + outputs=["text"], +) + +demo.launch(share=True) + + +import sys +import os + +sys.path.append('vggsfm_code/') + +from vggsfm_code.hf_demo import demo_fn +from omegaconf import DictConfig, OmegaConf + +cfg_file = "vggsfm_code/cfgs/demo.yaml" +cfg = OmegaConf.load(cfg_file) +cfg.SCENE_DIR = "vggsfm_code/examples/cake" + +import pdb;pdb.set_trace() + +demo_fn(cfg) + + diff --git a/glbscene.glb b/glbscene.glb new file mode 100644 index 0000000000000000000000000000000000000000..fe844b11af0e1488fb3553020084e3c1e8294cf8 Binary files /dev/null and b/glbscene.glb differ diff --git a/requirements.txt b/requirements.txt index aee576be6d2837fb01156c7f352bf0ade0eb3dc5..49e0b91cda3c22a7ac0b016a76e512ed3f917291 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,4 +7,4 @@ git+https://github.com/cvg/LightGlue.git#egg=LightGlue numpy==1.26.3 pycolmap==0.6.1 https://huggingface.co./facebook/VGGSfM/resolve/main/poselib-2.0.2-cp310-cp310-linux_x86_64.whl - +trimesh diff --git a/vggsfm/.gitignore b/vggsfm_code/.gitignore similarity index 100% rename from vggsfm/.gitignore rename to vggsfm_code/.gitignore diff --git a/vggsfm/CHANGELOG.txt b/vggsfm_code/CHANGELOG.txt similarity index 100% rename from vggsfm/CHANGELOG.txt rename to vggsfm_code/CHANGELOG.txt diff --git a/vggsfm/CODE_OF_CONDUCT.md b/vggsfm_code/CODE_OF_CONDUCT.md similarity index 100% rename from vggsfm/CODE_OF_CONDUCT.md rename to vggsfm_code/CODE_OF_CONDUCT.md diff --git a/vggsfm/CONTRIBUTING.md b/vggsfm_code/CONTRIBUTING.md similarity index 100% rename from vggsfm/CONTRIBUTING.md rename to vggsfm_code/CONTRIBUTING.md diff --git a/vggsfm/LICENSE.txt b/vggsfm_code/LICENSE.txt similarity index 100% rename from vggsfm/LICENSE.txt rename to vggsfm_code/LICENSE.txt diff --git a/vggsfm/README.md b/vggsfm_code/README.md similarity index 100% rename from vggsfm/README.md rename to vggsfm_code/README.md diff --git a/vggsfm/assets/ui.png b/vggsfm_code/assets/ui.png similarity index 100% rename from vggsfm/assets/ui.png rename to vggsfm_code/assets/ui.png diff --git a/vggsfm/cfgs/demo.yaml b/vggsfm_code/cfgs/demo.yaml similarity index 100% rename from vggsfm/cfgs/demo.yaml rename to vggsfm_code/cfgs/demo.yaml diff --git a/vggsfm/demo.py b/vggsfm_code/demo.py similarity index 100% rename from vggsfm/demo.py rename to vggsfm_code/demo.py diff --git a/vggsfm/examples/apple/images/frame000007.jpg b/vggsfm_code/examples/apple/images/frame000007.jpg similarity index 100% rename from vggsfm/examples/apple/images/frame000007.jpg rename to vggsfm_code/examples/apple/images/frame000007.jpg diff --git a/vggsfm/examples/apple/images/frame000012.jpg b/vggsfm_code/examples/apple/images/frame000012.jpg similarity index 100% rename from vggsfm/examples/apple/images/frame000012.jpg rename to vggsfm_code/examples/apple/images/frame000012.jpg diff --git a/vggsfm/examples/apple/images/frame000017.jpg b/vggsfm_code/examples/apple/images/frame000017.jpg similarity index 100% rename from vggsfm/examples/apple/images/frame000017.jpg rename to vggsfm_code/examples/apple/images/frame000017.jpg diff --git a/vggsfm/examples/apple/images/frame000019.jpg b/vggsfm_code/examples/apple/images/frame000019.jpg similarity index 100% rename from vggsfm/examples/apple/images/frame000019.jpg rename to vggsfm_code/examples/apple/images/frame000019.jpg diff --git a/vggsfm/examples/apple/images/frame000024.jpg b/vggsfm_code/examples/apple/images/frame000024.jpg similarity index 100% rename from vggsfm/examples/apple/images/frame000024.jpg rename to vggsfm_code/examples/apple/images/frame000024.jpg diff --git a/vggsfm/examples/apple/images/frame000025.jpg b/vggsfm_code/examples/apple/images/frame000025.jpg similarity index 100% rename from vggsfm/examples/apple/images/frame000025.jpg rename to vggsfm_code/examples/apple/images/frame000025.jpg diff --git a/vggsfm/examples/apple/images/frame000043.jpg b/vggsfm_code/examples/apple/images/frame000043.jpg similarity index 100% rename from vggsfm/examples/apple/images/frame000043.jpg rename to vggsfm_code/examples/apple/images/frame000043.jpg diff --git a/vggsfm/examples/apple/images/frame000052.jpg b/vggsfm_code/examples/apple/images/frame000052.jpg similarity index 100% rename from vggsfm/examples/apple/images/frame000052.jpg rename to vggsfm_code/examples/apple/images/frame000052.jpg diff --git a/vggsfm/examples/apple/images/frame000070.jpg b/vggsfm_code/examples/apple/images/frame000070.jpg similarity index 100% rename from vggsfm/examples/apple/images/frame000070.jpg rename to vggsfm_code/examples/apple/images/frame000070.jpg diff --git a/vggsfm/examples/apple/images/frame000077.jpg b/vggsfm_code/examples/apple/images/frame000077.jpg similarity index 100% rename from vggsfm/examples/apple/images/frame000077.jpg rename to vggsfm_code/examples/apple/images/frame000077.jpg diff --git a/vggsfm/examples/apple/images/frame000085.jpg b/vggsfm_code/examples/apple/images/frame000085.jpg similarity index 100% rename from vggsfm/examples/apple/images/frame000085.jpg rename to vggsfm_code/examples/apple/images/frame000085.jpg diff --git a/vggsfm/examples/apple/images/frame000096.jpg b/vggsfm_code/examples/apple/images/frame000096.jpg similarity index 100% rename from vggsfm/examples/apple/images/frame000096.jpg rename to vggsfm_code/examples/apple/images/frame000096.jpg diff --git a/vggsfm/examples/apple/images/frame000128.jpg b/vggsfm_code/examples/apple/images/frame000128.jpg similarity index 100% rename from vggsfm/examples/apple/images/frame000128.jpg rename to vggsfm_code/examples/apple/images/frame000128.jpg diff --git a/vggsfm/examples/apple/images/frame000145.jpg b/vggsfm_code/examples/apple/images/frame000145.jpg similarity index 100% rename from vggsfm/examples/apple/images/frame000145.jpg rename to vggsfm_code/examples/apple/images/frame000145.jpg diff --git a/vggsfm/examples/apple/images/frame000160.jpg b/vggsfm_code/examples/apple/images/frame000160.jpg similarity index 100% rename from vggsfm/examples/apple/images/frame000160.jpg rename to vggsfm_code/examples/apple/images/frame000160.jpg diff --git a/vggsfm/examples/apple/images/frame000162.jpg b/vggsfm_code/examples/apple/images/frame000162.jpg similarity index 100% rename from vggsfm/examples/apple/images/frame000162.jpg rename to vggsfm_code/examples/apple/images/frame000162.jpg diff --git a/vggsfm/examples/apple/images/frame000168.jpg b/vggsfm_code/examples/apple/images/frame000168.jpg similarity index 100% rename from vggsfm/examples/apple/images/frame000168.jpg rename to vggsfm_code/examples/apple/images/frame000168.jpg diff --git a/vggsfm/examples/apple/images/frame000172.jpg b/vggsfm_code/examples/apple/images/frame000172.jpg similarity index 100% rename from vggsfm/examples/apple/images/frame000172.jpg rename to vggsfm_code/examples/apple/images/frame000172.jpg diff --git a/vggsfm/examples/apple/images/frame000191.jpg b/vggsfm_code/examples/apple/images/frame000191.jpg similarity index 100% rename from vggsfm/examples/apple/images/frame000191.jpg rename to vggsfm_code/examples/apple/images/frame000191.jpg diff --git a/vggsfm/examples/apple/images/frame000200.jpg b/vggsfm_code/examples/apple/images/frame000200.jpg similarity index 100% rename from vggsfm/examples/apple/images/frame000200.jpg rename to vggsfm_code/examples/apple/images/frame000200.jpg diff --git a/vggsfm/examples/british_museum/images/29057984_287139632.jpg b/vggsfm_code/examples/british_museum/images/29057984_287139632.jpg similarity index 100% rename from vggsfm/examples/british_museum/images/29057984_287139632.jpg rename to vggsfm_code/examples/british_museum/images/29057984_287139632.jpg diff --git a/vggsfm/examples/british_museum/images/32630292_7166579210.jpg b/vggsfm_code/examples/british_museum/images/32630292_7166579210.jpg similarity index 100% rename from vggsfm/examples/british_museum/images/32630292_7166579210.jpg rename to vggsfm_code/examples/british_museum/images/32630292_7166579210.jpg diff --git a/vggsfm/examples/british_museum/images/45839934_4117745134.jpg b/vggsfm_code/examples/british_museum/images/45839934_4117745134.jpg similarity index 100% rename from vggsfm/examples/british_museum/images/45839934_4117745134.jpg rename to vggsfm_code/examples/british_museum/images/45839934_4117745134.jpg diff --git a/vggsfm/examples/british_museum/images/51004432_567773767.jpg b/vggsfm_code/examples/british_museum/images/51004432_567773767.jpg similarity index 100% rename from vggsfm/examples/british_museum/images/51004432_567773767.jpg rename to vggsfm_code/examples/british_museum/images/51004432_567773767.jpg diff --git a/vggsfm/examples/british_museum/images/62620282_3728576515.jpg b/vggsfm_code/examples/british_museum/images/62620282_3728576515.jpg similarity index 100% rename from vggsfm/examples/british_museum/images/62620282_3728576515.jpg rename to vggsfm_code/examples/british_museum/images/62620282_3728576515.jpg diff --git a/vggsfm/examples/british_museum/images/71931631_7212707886.jpg b/vggsfm_code/examples/british_museum/images/71931631_7212707886.jpg similarity index 100% rename from vggsfm/examples/british_museum/images/71931631_7212707886.jpg rename to vggsfm_code/examples/british_museum/images/71931631_7212707886.jpg diff --git a/vggsfm/examples/british_museum/images/78600497_407639599.jpg b/vggsfm_code/examples/british_museum/images/78600497_407639599.jpg similarity index 100% rename from vggsfm/examples/british_museum/images/78600497_407639599.jpg rename to vggsfm_code/examples/british_museum/images/78600497_407639599.jpg diff --git a/vggsfm/examples/british_museum/images/80340357_5029510336.jpg b/vggsfm_code/examples/british_museum/images/80340357_5029510336.jpg similarity index 100% rename from vggsfm/examples/british_museum/images/80340357_5029510336.jpg rename to vggsfm_code/examples/british_museum/images/80340357_5029510336.jpg diff --git a/vggsfm/examples/british_museum/images/81272348_2712949069.jpg b/vggsfm_code/examples/british_museum/images/81272348_2712949069.jpg similarity index 100% rename from vggsfm/examples/british_museum/images/81272348_2712949069.jpg rename to vggsfm_code/examples/british_museum/images/81272348_2712949069.jpg diff --git a/vggsfm/examples/british_museum/images/93266801_2335569192.jpg b/vggsfm_code/examples/british_museum/images/93266801_2335569192.jpg similarity index 100% rename from vggsfm/examples/british_museum/images/93266801_2335569192.jpg rename to vggsfm_code/examples/british_museum/images/93266801_2335569192.jpg diff --git a/vggsfm/examples/cake/images/frame000020.jpg b/vggsfm_code/examples/cake/images/frame000020.jpg similarity index 100% rename from vggsfm/examples/cake/images/frame000020.jpg rename to vggsfm_code/examples/cake/images/frame000020.jpg diff --git a/vggsfm/examples/cake/images/frame000069.jpg b/vggsfm_code/examples/cake/images/frame000069.jpg similarity index 100% rename from vggsfm/examples/cake/images/frame000069.jpg rename to vggsfm_code/examples/cake/images/frame000069.jpg diff --git a/vggsfm/examples/cake/images/frame000096.jpg b/vggsfm_code/examples/cake/images/frame000096.jpg similarity index 100% rename from vggsfm/examples/cake/images/frame000096.jpg rename to vggsfm_code/examples/cake/images/frame000096.jpg diff --git a/vggsfm/examples/cake/images/frame000112.jpg b/vggsfm_code/examples/cake/images/frame000112.jpg similarity index 100% rename from vggsfm/examples/cake/images/frame000112.jpg rename to vggsfm_code/examples/cake/images/frame000112.jpg diff --git a/vggsfm/examples/cake/images/frame000146.jpg b/vggsfm_code/examples/cake/images/frame000146.jpg similarity index 100% rename from vggsfm/examples/cake/images/frame000146.jpg rename to vggsfm_code/examples/cake/images/frame000146.jpg diff --git a/vggsfm/examples/cake/images/frame000149.jpg b/vggsfm_code/examples/cake/images/frame000149.jpg similarity index 100% rename from vggsfm/examples/cake/images/frame000149.jpg rename to vggsfm_code/examples/cake/images/frame000149.jpg diff --git a/vggsfm/examples/cake/images/frame000166.jpg b/vggsfm_code/examples/cake/images/frame000166.jpg similarity index 100% rename from vggsfm/examples/cake/images/frame000166.jpg rename to vggsfm_code/examples/cake/images/frame000166.jpg diff --git a/vggsfm/examples/cake/images/frame000169.jpg b/vggsfm_code/examples/cake/images/frame000169.jpg similarity index 100% rename from vggsfm/examples/cake/images/frame000169.jpg rename to vggsfm_code/examples/cake/images/frame000169.jpg diff --git a/vggsfm_code/hf_demo.py b/vggsfm_code/hf_demo.py new file mode 100644 index 0000000000000000000000000000000000000000..a36d9cec308a0083c6afcb11b34bf3b159e868ee --- /dev/null +++ b/vggsfm_code/hf_demo.py @@ -0,0 +1,457 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os +import time +import random +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from torch.cuda.amp import autocast +import hydra + +from omegaconf import DictConfig, OmegaConf +from hydra.utils import instantiate + +from lightglue import LightGlue, SuperPoint, SIFT, ALIKED + +import pycolmap + +from visdom import Visdom + + +from vggsfm.datasets.demo_loader import DemoLoader +from vggsfm.two_view_geo.estimate_preliminary import estimate_preliminary_cameras + +try: + import poselib + from vggsfm.two_view_geo.estimate_preliminary import estimate_preliminary_cameras_poselib + + print("Poselib is available") +except: + print("Poselib is not installed. Please disable use_poselib") + +from vggsfm.utils.utils import ( + set_seed_and_print, + farthest_point_sampling, + calculate_index_mappings, + switch_tensor_order, +) + + +def demo_fn(cfg): + OmegaConf.set_struct(cfg, False) + + # Print configuration + print("Model Config:", OmegaConf.to_yaml(cfg)) + + torch.backends.cudnn.enabled = False + torch.backends.cudnn.benchmark = True + torch.backends.cudnn.deterministic = True + + # Set seed + seed_all_random_engines(cfg.seed) + + # Model instantiation + model = instantiate(cfg.MODEL, _recursive_=False, cfg=cfg) + + device = "cuda" if torch.cuda.is_available() else "cpu" + + model = model.to(device) + + # Prepare test dataset + test_dataset = DemoLoader( + SCENE_DIR=cfg.SCENE_DIR, img_size=cfg.img_size, normalize_cameras=False, load_gt=cfg.load_gt, cfg=cfg + ) + + # if cfg.resume_ckpt: + _VGGSFM_URL = "https://huggingface.co./facebook/VGGSfM/resolve/main/vggsfm_v2_0_0.bin" + + # Reload model + checkpoint = torch.hub.load_state_dict_from_url(_VGGSFM_URL) + model.load_state_dict(checkpoint, strict=True) + print(f"Successfully resumed from {_VGGSFM_URL}") + + + sequence_list = test_dataset.sequence_list + + for seq_name in sequence_list: + print("*" * 50 + f" Testing on Scene {seq_name} " + "*" * 50) + + # Load the data + batch, image_paths = test_dataset.get_data(sequence_name=seq_name, return_path=True) + + # Send to GPU + images = batch["image"].to(device) + crop_params = batch["crop_params"].to(device) + + + # Unsqueeze to have batch size = 1 + images = images.unsqueeze(0) + crop_params = crop_params.unsqueeze(0) + + batch_size = len(images) + + with torch.no_grad(): + # Run the model + assert cfg.mixed_precision in ("None", "bf16", "fp16") + if cfg.mixed_precision == "None": + dtype = torch.float32 + elif cfg.mixed_precision == "bf16": + dtype = torch.bfloat16 + elif cfg.mixed_precision == "fp16": + dtype = torch.float16 + else: + raise NotImplementedError(f"dtype {cfg.mixed_precision} is not supported now") + + predictions = run_one_scene( + model, + images, + crop_params=crop_params, + query_frame_num=cfg.query_frame_num, + image_paths=image_paths, + dtype=dtype, + cfg=cfg, + ) + + pred_cameras_PT3D = predictions["pred_cameras_PT3D"] + + return predictions + + +def run_one_scene(model, images, crop_params=None, query_frame_num=3, image_paths=None, dtype=None, cfg=None): + """ + images have been normalized to the range [0, 1] instead of [0, 255] + """ + batch_num, frame_num, image_dim, height, width = images.shape + device = images.device + reshaped_image = images.reshape(batch_num * frame_num, image_dim, height, width) + + predictions = {} + extra_dict = {} + + camera_predictor = model.camera_predictor + track_predictor = model.track_predictor + triangulator = model.triangulator + + # Find the query frames + # First use DINO to find the most common frame among all the input frames + # i.e., the one has highest (average) cosine similarity to all others + # Then use farthest_point_sampling to find the next ones + # The number of query frames is determined by query_frame_num + + with autocast(dtype=dtype): + query_frame_indexes = find_query_frame_indexes(reshaped_image, camera_predictor, frame_num) + + raw_image_paths = image_paths + image_paths = [os.path.basename(imgpath) for imgpath in image_paths] + + if cfg.center_order: + # The code below switchs the first frame (frame 0) to the most common frame + center_frame_index = query_frame_indexes[0] + center_order = calculate_index_mappings(center_frame_index, frame_num, device=device) + + images, crop_params = switch_tensor_order([images, crop_params], center_order, dim=1) + reshaped_image = switch_tensor_order([reshaped_image], center_order, dim=0)[0] + + image_paths = [image_paths[i] for i in center_order.cpu().numpy().tolist()] + + # Also update query_frame_indexes: + query_frame_indexes = [center_frame_index if x == 0 else x for x in query_frame_indexes] + query_frame_indexes[0] = 0 + + # only pick query_frame_num + query_frame_indexes = query_frame_indexes[:query_frame_num] + + # Prepare image feature maps for tracker + fmaps_for_tracker = track_predictor.process_images_to_fmaps(images) + + # Predict tracks + with autocast(dtype=dtype): + pred_track, pred_vis, pred_score = predict_tracks( + cfg.query_method, + cfg.max_query_pts, + track_predictor, + images, + fmaps_for_tracker, + query_frame_indexes, + frame_num, + device, + cfg, + ) + + if cfg.comple_nonvis: + pred_track, pred_vis, pred_score = comple_nonvis_frames( + track_predictor, + images, + fmaps_for_tracker, + frame_num, + device, + pred_track, + pred_vis, + pred_score, + 200, + cfg=cfg, + ) + + torch.cuda.empty_cache() + + # If necessary, force all the predictions at the padding areas as non-visible + if crop_params is not None: + boundaries = crop_params[:, :, -4:-2].abs().to(device) + boundaries = torch.cat([boundaries, reshaped_image.shape[-1] - boundaries], dim=-1) + hvis = torch.logical_and( + pred_track[..., 1] >= boundaries[:, :, 1:2], pred_track[..., 1] <= boundaries[:, :, 3:4] + ) + wvis = torch.logical_and( + pred_track[..., 0] >= boundaries[:, :, 0:1], pred_track[..., 0] <= boundaries[:, :, 2:3] + ) + force_vis = torch.logical_and(hvis, wvis) + pred_vis = pred_vis * force_vis.float() + + # TODO: plot 2D matches + if cfg.use_poselib: + estimate_preliminary_cameras_fn = estimate_preliminary_cameras_poselib + else: + estimate_preliminary_cameras_fn = estimate_preliminary_cameras + + # Estimate preliminary_cameras by recovering fundamental/essential/homography matrix from 2D matches + # By default, we use fundamental matrix estimation with 7p/8p+LORANSAC + # All the operations are batched and differentiable (if necessary) + # except when you enable use_poselib to save GPU memory + _, preliminary_dict = estimate_preliminary_cameras_fn( + pred_track, + pred_vis, + width, + height, + tracks_score=pred_score, + max_error=cfg.fmat_thres, + loopresidual=True, + # max_ransac_iters=cfg.max_ransac_iters, + ) + + pose_predictions = camera_predictor(reshaped_image, batch_size=batch_num) + + pred_cameras = pose_predictions["pred_cameras"] + + # Conduct Triangulation and Bundle Adjustment + ( + BA_cameras_PT3D, + extrinsics_opencv, + intrinsics_opencv, + points3D, + points3D_rgb, + reconstruction, + valid_frame_mask, + ) = triangulator( + pred_cameras, + pred_track, + pred_vis, + images, + preliminary_dict, + image_paths=image_paths, + crop_params=crop_params, + pred_score=pred_score, + fmat_thres=cfg.fmat_thres, + BA_iters=cfg.BA_iters, + max_reproj_error = cfg.max_reproj_error, + init_max_reproj_error=cfg.init_max_reproj_error, + cfg=cfg, + ) + + if cfg.center_order: + # NOTE we changed the image order previously, now we need to switch it back + BA_cameras_PT3D = BA_cameras_PT3D[center_order] + extrinsics_opencv = extrinsics_opencv[center_order] + intrinsics_opencv = intrinsics_opencv[center_order] + + if cfg.filter_invalid_frame: + raw_image_paths = np.array(raw_image_paths)[valid_frame_mask.cpu().numpy().tolist()].tolist() + images = images[0][valid_frame_mask] + + predictions["pred_cameras_PT3D"] = BA_cameras_PT3D + predictions["extrinsics_opencv"] = extrinsics_opencv + predictions["intrinsics_opencv"] = intrinsics_opencv + predictions["points3D"] = points3D + predictions["points3D_rgb"] = points3D_rgb + predictions["reconstruction"] = reconstruction + predictions["images"] = images + predictions["raw_image_paths"] = raw_image_paths + return predictions + + +def predict_tracks( + query_method, + max_query_pts, + track_predictor, + images, + fmaps_for_tracker, + query_frame_indexes, + frame_num, + device, + cfg=None, +): + pred_track_list = [] + pred_vis_list = [] + pred_score_list = [] + + for query_index in query_frame_indexes: + print(f"Predicting tracks with query_index = {query_index}") + + # Find query_points at the query frame + query_points = get_query_points(images[:, query_index], query_method, max_query_pts) + + # Switch so that query_index frame stays at the first frame + # This largely simplifies the code structure of tracker + new_order = calculate_index_mappings(query_index, frame_num, device=device) + images_feed, fmaps_feed = switch_tensor_order([images, fmaps_for_tracker], new_order) + + # Feed into track predictor + fine_pred_track, _, pred_vis, pred_score = track_predictor(images_feed, query_points, fmaps=fmaps_feed) + + # Switch back the predictions + fine_pred_track, pred_vis, pred_score = switch_tensor_order([fine_pred_track, pred_vis, pred_score], new_order) + + # Append predictions for different queries + pred_track_list.append(fine_pred_track) + pred_vis_list.append(pred_vis) + pred_score_list.append(pred_score) + + pred_track = torch.cat(pred_track_list, dim=2) + pred_vis = torch.cat(pred_vis_list, dim=2) + pred_score = torch.cat(pred_score_list, dim=2) + + return pred_track, pred_vis, pred_score + + +def comple_nonvis_frames( + track_predictor, + images, + fmaps_for_tracker, + frame_num, + device, + pred_track, + pred_vis, + pred_score, + min_vis=500, + cfg=None, +): + # if a frame has too few visible inlier, use it as a query + non_vis_frames = torch.nonzero((pred_vis.squeeze(0) > 0.05).sum(-1) < min_vis).squeeze(-1).tolist() + last_query = -1 + while len(non_vis_frames) > 0: + print("Processing non visible frames") + print(non_vis_frames) + if non_vis_frames[0] == last_query: + print("The non vis frame still does not has enough 2D matches") + pred_track_comple, pred_vis_comple, pred_score_comple = predict_tracks( + "sp+sift+aliked", + cfg.max_query_pts // 2, + track_predictor, + images, + fmaps_for_tracker, + non_vis_frames, + frame_num, + device, + cfg, + ) + # concat predictions + pred_track = torch.cat([pred_track, pred_track_comple], dim=2) + pred_vis = torch.cat([pred_vis, pred_vis_comple], dim=2) + pred_score = torch.cat([pred_score, pred_score_comple], dim=2) + break + + non_vis_query_list = [non_vis_frames[0]] + last_query = non_vis_frames[0] + pred_track_comple, pred_vis_comple, pred_score_comple = predict_tracks( + cfg.query_method, + cfg.max_query_pts, + track_predictor, + images, + fmaps_for_tracker, + non_vis_query_list, + frame_num, + device, + cfg, + ) + + # concat predictions + pred_track = torch.cat([pred_track, pred_track_comple], dim=2) + pred_vis = torch.cat([pred_vis, pred_vis_comple], dim=2) + pred_score = torch.cat([pred_score, pred_score_comple], dim=2) + non_vis_frames = torch.nonzero((pred_vis.squeeze(0) > 0.05).sum(-1) < min_vis).squeeze(-1).tolist() + return pred_track, pred_vis, pred_score + + +def find_query_frame_indexes(reshaped_image, camera_predictor, query_frame_num, image_size=336): + # Downsample image to image_size x image_size + # because we found it is unnecessary to use high resolution + rgbs = F.interpolate(reshaped_image, (image_size, image_size), mode="bilinear", align_corners=True) + rgbs = camera_predictor._resnet_normalize_image(rgbs) + + # Get the image features (patch level) + frame_feat = camera_predictor.backbone(rgbs, is_training=True) + frame_feat = frame_feat["x_norm_patchtokens"] + frame_feat_norm = F.normalize(frame_feat, p=2, dim=1) + + # Compute the similiarty matrix + frame_feat_norm = frame_feat_norm.permute(1, 0, 2) + similarity_matrix = torch.bmm(frame_feat_norm, frame_feat_norm.transpose(-1, -2)) + similarity_matrix = similarity_matrix.mean(dim=0) + distance_matrix = 100 - similarity_matrix.clone() + + # Ignore self-pairing + similarity_matrix.fill_diagonal_(-100) + + similarity_sum = similarity_matrix.sum(dim=1) + + # Find the most common frame + most_common_frame_index = torch.argmax(similarity_sum).item() + + # Conduct FPS sampling + # Starting from the most_common_frame_index, + # try to find the farthest frame, + # then the farthest to the last found frame + # (frames are not allowed to be found twice) + fps_idx = farthest_point_sampling(distance_matrix, query_frame_num, most_common_frame_index) + + return fps_idx + + +def get_query_points(query_image, query_method, max_query_num=4096, det_thres=0.005): + # Run superpoint and sift on the target frame + # Feel free to modify for your own + + methods = query_method.split("+") + pred_points = [] + + for method in methods: + if "sp" in method: + extractor = SuperPoint(max_num_keypoints=max_query_num, detection_threshold=det_thres).cuda().eval() + elif "sift" in method: + extractor = SIFT(max_num_keypoints=max_query_num).cuda().eval() + elif "aliked" in method: + extractor = ALIKED(max_num_keypoints=max_query_num, detection_threshold=det_thres).cuda().eval() + else: + raise NotImplementedError(f"query method {method} is not supprted now") + + query_points = extractor.extract(query_image)["keypoints"] + pred_points.append(query_points) + + query_points = torch.cat(pred_points, dim=1) + + if query_points.shape[1] > max_query_num: + random_point_indices = torch.randperm(query_points.shape[1])[:max_query_num] + query_points = query_points[:, random_point_indices, :] + + return query_points + + +def seed_all_random_engines(seed: int) -> None: + np.random.seed(seed) + torch.manual_seed(seed) + random.seed(seed) diff --git a/vggsfm/install.sh b/vggsfm_code/install.sh similarity index 100% rename from vggsfm/install.sh rename to vggsfm_code/install.sh diff --git a/vggsfm/minipytorch3d/__init__.py b/vggsfm_code/minipytorch3d/__init__.py similarity index 100% rename from vggsfm/minipytorch3d/__init__.py rename to vggsfm_code/minipytorch3d/__init__.py diff --git a/vggsfm/minipytorch3d/cameras.py b/vggsfm_code/minipytorch3d/cameras.py similarity index 100% rename from vggsfm/minipytorch3d/cameras.py rename to vggsfm_code/minipytorch3d/cameras.py diff --git a/vggsfm/minipytorch3d/device_utils.py b/vggsfm_code/minipytorch3d/device_utils.py similarity index 100% rename from vggsfm/minipytorch3d/device_utils.py rename to vggsfm_code/minipytorch3d/device_utils.py diff --git a/vggsfm/minipytorch3d/harmonic_embedding.py b/vggsfm_code/minipytorch3d/harmonic_embedding.py similarity index 100% rename from vggsfm/minipytorch3d/harmonic_embedding.py rename to vggsfm_code/minipytorch3d/harmonic_embedding.py diff --git a/vggsfm/minipytorch3d/renderer_utils.py b/vggsfm_code/minipytorch3d/renderer_utils.py similarity index 100% rename from vggsfm/minipytorch3d/renderer_utils.py rename to vggsfm_code/minipytorch3d/renderer_utils.py diff --git a/vggsfm/minipytorch3d/rotation_conversions.py b/vggsfm_code/minipytorch3d/rotation_conversions.py similarity index 100% rename from vggsfm/minipytorch3d/rotation_conversions.py rename to vggsfm_code/minipytorch3d/rotation_conversions.py diff --git a/vggsfm/minipytorch3d/transform3d.py b/vggsfm_code/minipytorch3d/transform3d.py similarity index 100% rename from vggsfm/minipytorch3d/transform3d.py rename to vggsfm_code/minipytorch3d/transform3d.py diff --git a/vggsfm/vggsfm/datasets/camera_transform.py b/vggsfm_code/vggsfm/datasets/camera_transform.py similarity index 100% rename from vggsfm/vggsfm/datasets/camera_transform.py rename to vggsfm_code/vggsfm/datasets/camera_transform.py diff --git a/vggsfm/vggsfm/datasets/demo_loader.py b/vggsfm_code/vggsfm/datasets/demo_loader.py similarity index 100% rename from vggsfm/vggsfm/datasets/demo_loader.py rename to vggsfm_code/vggsfm/datasets/demo_loader.py diff --git a/vggsfm/vggsfm/datasets/imc.py b/vggsfm_code/vggsfm/datasets/imc.py similarity index 100% rename from vggsfm/vggsfm/datasets/imc.py rename to vggsfm_code/vggsfm/datasets/imc.py diff --git a/vggsfm/vggsfm/datasets/imc_helper.py b/vggsfm_code/vggsfm/datasets/imc_helper.py similarity index 100% rename from vggsfm/vggsfm/datasets/imc_helper.py rename to vggsfm_code/vggsfm/datasets/imc_helper.py diff --git a/vggsfm/vggsfm/models/__init__.py b/vggsfm_code/vggsfm/models/__init__.py similarity index 100% rename from vggsfm/vggsfm/models/__init__.py rename to vggsfm_code/vggsfm/models/__init__.py diff --git a/vggsfm/vggsfm/models/camera_predictor.py b/vggsfm_code/vggsfm/models/camera_predictor.py similarity index 100% rename from vggsfm/vggsfm/models/camera_predictor.py rename to vggsfm_code/vggsfm/models/camera_predictor.py diff --git a/vggsfm/vggsfm/models/modules.py b/vggsfm_code/vggsfm/models/modules.py similarity index 100% rename from vggsfm/vggsfm/models/modules.py rename to vggsfm_code/vggsfm/models/modules.py diff --git a/vggsfm/vggsfm/models/track_modules/__init__.py b/vggsfm_code/vggsfm/models/track_modules/__init__.py similarity index 100% rename from vggsfm/vggsfm/models/track_modules/__init__.py rename to vggsfm_code/vggsfm/models/track_modules/__init__.py diff --git a/vggsfm/vggsfm/models/track_modules/base_track_predictor.py b/vggsfm_code/vggsfm/models/track_modules/base_track_predictor.py similarity index 100% rename from vggsfm/vggsfm/models/track_modules/base_track_predictor.py rename to vggsfm_code/vggsfm/models/track_modules/base_track_predictor.py diff --git a/vggsfm/vggsfm/models/track_modules/blocks.py b/vggsfm_code/vggsfm/models/track_modules/blocks.py similarity index 100% rename from vggsfm/vggsfm/models/track_modules/blocks.py rename to vggsfm_code/vggsfm/models/track_modules/blocks.py diff --git a/vggsfm/vggsfm/models/track_modules/refine_track.py b/vggsfm_code/vggsfm/models/track_modules/refine_track.py similarity index 100% rename from vggsfm/vggsfm/models/track_modules/refine_track.py rename to vggsfm_code/vggsfm/models/track_modules/refine_track.py diff --git a/vggsfm/vggsfm/models/track_predictor.py b/vggsfm_code/vggsfm/models/track_predictor.py similarity index 100% rename from vggsfm/vggsfm/models/track_predictor.py rename to vggsfm_code/vggsfm/models/track_predictor.py diff --git a/vggsfm/vggsfm/models/triangulator.py b/vggsfm_code/vggsfm/models/triangulator.py similarity index 100% rename from vggsfm/vggsfm/models/triangulator.py rename to vggsfm_code/vggsfm/models/triangulator.py diff --git a/vggsfm/vggsfm/models/utils.py b/vggsfm_code/vggsfm/models/utils.py similarity index 100% rename from vggsfm/vggsfm/models/utils.py rename to vggsfm_code/vggsfm/models/utils.py diff --git a/vggsfm/vggsfm/models/vggsfm.py b/vggsfm_code/vggsfm/models/vggsfm.py similarity index 100% rename from vggsfm/vggsfm/models/vggsfm.py rename to vggsfm_code/vggsfm/models/vggsfm.py diff --git a/vggsfm/vggsfm/two_view_geo/essential.py b/vggsfm_code/vggsfm/two_view_geo/essential.py similarity index 100% rename from vggsfm/vggsfm/two_view_geo/essential.py rename to vggsfm_code/vggsfm/two_view_geo/essential.py diff --git a/vggsfm/vggsfm/two_view_geo/estimate_preliminary.py b/vggsfm_code/vggsfm/two_view_geo/estimate_preliminary.py similarity index 100% rename from vggsfm/vggsfm/two_view_geo/estimate_preliminary.py rename to vggsfm_code/vggsfm/two_view_geo/estimate_preliminary.py diff --git a/vggsfm/vggsfm/two_view_geo/fundamental.py b/vggsfm_code/vggsfm/two_view_geo/fundamental.py similarity index 100% rename from vggsfm/vggsfm/two_view_geo/fundamental.py rename to vggsfm_code/vggsfm/two_view_geo/fundamental.py diff --git a/vggsfm/vggsfm/two_view_geo/homography.py b/vggsfm_code/vggsfm/two_view_geo/homography.py similarity index 100% rename from vggsfm/vggsfm/two_view_geo/homography.py rename to vggsfm_code/vggsfm/two_view_geo/homography.py diff --git a/vggsfm/vggsfm/two_view_geo/perspective_n_points.py b/vggsfm_code/vggsfm/two_view_geo/perspective_n_points.py similarity index 100% rename from vggsfm/vggsfm/two_view_geo/perspective_n_points.py rename to vggsfm_code/vggsfm/two_view_geo/perspective_n_points.py diff --git a/vggsfm/vggsfm/two_view_geo/pnp.py b/vggsfm_code/vggsfm/two_view_geo/pnp.py similarity index 100% rename from vggsfm/vggsfm/two_view_geo/pnp.py rename to vggsfm_code/vggsfm/two_view_geo/pnp.py diff --git a/vggsfm/vggsfm/two_view_geo/utils.py b/vggsfm_code/vggsfm/two_view_geo/utils.py similarity index 100% rename from vggsfm/vggsfm/two_view_geo/utils.py rename to vggsfm_code/vggsfm/two_view_geo/utils.py diff --git a/vggsfm/vggsfm/utils/metric.py b/vggsfm_code/vggsfm/utils/metric.py similarity index 100% rename from vggsfm/vggsfm/utils/metric.py rename to vggsfm_code/vggsfm/utils/metric.py diff --git a/vggsfm/vggsfm/utils/tensor_to_pycolmap.py b/vggsfm_code/vggsfm/utils/tensor_to_pycolmap.py similarity index 100% rename from vggsfm/vggsfm/utils/tensor_to_pycolmap.py rename to vggsfm_code/vggsfm/utils/tensor_to_pycolmap.py diff --git a/vggsfm/vggsfm/utils/triangulation.py b/vggsfm_code/vggsfm/utils/triangulation.py similarity index 100% rename from vggsfm/vggsfm/utils/triangulation.py rename to vggsfm_code/vggsfm/utils/triangulation.py diff --git a/vggsfm/vggsfm/utils/triangulation_helpers.py b/vggsfm_code/vggsfm/utils/triangulation_helpers.py similarity index 100% rename from vggsfm/vggsfm/utils/triangulation_helpers.py rename to vggsfm_code/vggsfm/utils/triangulation_helpers.py diff --git a/vggsfm/vggsfm/utils/utils.py b/vggsfm_code/vggsfm/utils/utils.py similarity index 100% rename from vggsfm/vggsfm/utils/utils.py rename to vggsfm_code/vggsfm/utils/utils.py diff --git a/viz_utils/__pycache__/viz_fn.cpython-310.pyc b/viz_utils/__pycache__/viz_fn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..275454792daea52b84a1b1cee8e2637219649c88 Binary files /dev/null and b/viz_utils/__pycache__/viz_fn.cpython-310.pyc differ diff --git a/viz_utils/viz_fn.py b/viz_utils/viz_fn.py new file mode 100644 index 0000000000000000000000000000000000000000..5143a6987cdbdfa459c5b3414c5487cacb4a17b5 --- /dev/null +++ b/viz_utils/viz_fn.py @@ -0,0 +1,148 @@ +import os +import cv2 +import torch +import numpy as np +import gradio as gr + +import trimesh +import sys +import os + +# sys.path.append('vggsfm_code/') +import shutil +from datetime import datetime + +# from vggsfm_code.hf_demo import demo_fn +# from omegaconf import DictConfig, OmegaConf +# from viz_utils.viz_fn import add_camera + +from scipy.spatial.transform import Rotation +import PIL + + +def add_camera(scene, pose_c2w, edge_color, image=None, + focal=None, imsize=None, + screen_width=0.03, marker=None): + # learned from https://github.com/naver/dust3r/blob/main/dust3r/viz.py + + opengl_mat = np.array([[1, 0, 0, 0], + [0, -1, 0, 0], + [0, 0, -1, 0], + [0, 0, 0, 1]]) + + if image is not None: + image = np.asarray(image) + H, W, THREE = image.shape + assert THREE == 3 + if image.dtype != np.uint8: + image = np.uint8(255*image) + elif imsize is not None: + W, H = imsize + elif focal is not None: + H = W = focal / 1.1 + else: + H = W = 1 + + + if isinstance(focal, np.ndarray): + focal = focal[0] + if not focal: + focal = min(H,W) * 1.1 # default value + + # create fake camera + height = max( screen_width/10, focal * screen_width / H ) + width = screen_width * 0.5**0.5 + rot45 = np.eye(4) + rot45[:3, :3] = Rotation.from_euler('z', np.deg2rad(45)).as_matrix() + rot45[2, 3] = -height # set the tip of the cone = optical center + aspect_ratio = np.eye(4) + aspect_ratio[0, 0] = W/H + transform = pose_c2w @ opengl_mat @ aspect_ratio @ rot45 + cam = trimesh.creation.cone(width, height, sections=4) # , transform=transform) + + # this is the image + if image is not None: + vertices = geotrf(transform, cam.vertices[[4, 5, 1, 3]]) + faces = np.array([[0, 1, 2], [0, 2, 3], [2, 1, 0], [3, 2, 0]]) + img = trimesh.Trimesh(vertices=vertices, faces=faces) + uv_coords = np.float32([[0, 0], [1, 0], [1, 1], [0, 1]]) + img.visual = trimesh.visual.TextureVisuals(uv_coords, image=PIL.Image.fromarray(image)) + scene.add_geometry(img) + + # this is the camera mesh + rot2 = np.eye(4) + rot2[:3, :3] = Rotation.from_euler('z', np.deg2rad(2)).as_matrix() + vertices = np.r_[cam.vertices, 0.95*cam.vertices, geotrf(rot2, cam.vertices)] + vertices = geotrf(transform, vertices) + faces = [] + for face in cam.faces: + if 0 in face: + continue + a, b, c = face + a2, b2, c2 = face + len(cam.vertices) + a3, b3, c3 = face + 2*len(cam.vertices) + + # add 3 pseudo-edges + faces.append((a, b, b2)) + faces.append((a, a2, c)) + faces.append((c2, b, c)) + + faces.append((a, b, b3)) + faces.append((a, a3, c)) + faces.append((c3, b, c)) + + # no culling + faces += [(c, b, a) for a, b, c in faces] + + cam = trimesh.Trimesh(vertices=vertices, faces=faces) + cam.visual.face_colors[:, :3] = edge_color + scene.add_geometry(cam) + + if marker == 'o': + marker = trimesh.creation.icosphere(3, radius=screen_width/4) + marker.vertices += pose_c2w[:3,3] + marker.visual.face_colors[:,:3] = edge_color + scene.add_geometry(marker) + +def geotrf(Trf, pts, ncol=None, norm=False): + # learned from https://github.com/naver/dust3r/blob/main/dust3r/ + + assert Trf.ndim >= 2 + pts = np.asarray(pts) + + # adapt shape if necessary + output_reshape = pts.shape[:-1] + ncol = ncol or pts.shape[-1] + + if Trf.ndim >= 3: + n = Trf.ndim - 2 + assert Trf.shape[:n] == pts.shape[:n], 'batch size does not match' + Trf = Trf.reshape(-1, Trf.shape[-2], Trf.shape[-1]) + + if pts.ndim > Trf.ndim: + # Trf == (B,d,d) & pts == (B,H,W,d) --> (B, H*W, d) + pts = pts.reshape(Trf.shape[0], -1, pts.shape[-1]) + elif pts.ndim == 2: + # Trf == (B,d,d) & pts == (B,d) --> (B, 1, d) + pts = pts[:, None, :] + + if pts.shape[-1] + 1 == Trf.shape[-1]: + Trf = Trf.swapaxes(-1, -2) # transpose Trf + pts = pts @ Trf[..., :-1, :] + Trf[..., -1:, :] + elif pts.shape[-1] == Trf.shape[-1]: + Trf = Trf.swapaxes(-1, -2) # transpose Trf + pts = pts @ Trf + else: + pts = Trf @ pts.T + if pts.ndim >= 2: + pts = pts.swapaxes(-1, -2) + + if norm: + pts = pts / pts[..., -1:] + if norm != 1: + pts *= norm + + res = pts[..., :ncol].reshape(*output_reshape, ncol) + return res + +