"""Visualization utilities for 3D reconstruction results using Viser. Provides tools to visualize predicted camera poses, 3D point clouds, and confidence thresholding through an interactive web interface. """ import time from pathlib import Path from typing import List, Optional import numpy as np import tyro from tqdm.auto import tqdm import cv2 import viser import viser.transforms as tf import glob import os from scipy.spatial.transform import Rotation as R # from camera import closed_form_inverse_se3 import torch import threading def viser_wrapper( pred_dict: dict, port: int = None, init_conf_threshold: float = 3.0, ) -> None: """Visualize Args: pred_dict: Dictionary containing predictions port: Optional port number for the viser server. If None, a random port will be used. """ print(f"Starting viser server on port {port}") # Debug print server = viser.ViserServer(host="0.0.0.0", port=port) # server = viser.ViserServer(port=port) server.gui.configure_theme(titlebar_content=None, control_layout="collapsible") # Unpack and preprocess inputs images = pred_dict["images"] world_points = pred_dict["pred_world_points"] conf = pred_dict["pred_world_points_conf"] extrinsics = pred_dict["last_pred_extrinsic"] # Handle batch dimension if present if len(images.shape) > 4: images = images[0] world_points = world_points[0] conf = conf[0] extrinsics = extrinsics[0] colors = images.transpose(0, 2, 3, 1) # Convert to (B, H, W, C) # Reshape for visualization S, H, W, _ = world_points.shape colors = (colors.reshape(-1, 3) * 255).astype(np.uint8) # Convert to 0-255 range conf = conf.reshape(-1) world_points = world_points.reshape(-1, 3) # Calculate camera poses in world coordinates cam_to_world = closed_form_inverse_se3(extrinsics) extrinsics = cam_to_world[:, :3, :] # Center scene for better visualization scene_center = np.mean(world_points, axis=0) world_points -= scene_center extrinsics[..., -1] -= scene_center # set points3d as world_points points = world_points # frame_mask frame_indices = np.arange(S) frame_indices = frame_indices[:, None, None] # Shape: (S, 1, 1, 1) frame_indices = np.tile(frame_indices, (1, H, W)) # Shape: (S, H, W, 3) frame_indices = frame_indices.reshape(-1) ############################################################ ############################################################ gui_points_conf = server.gui.add_slider( "Confidence Thres", min=0.1, max=20, step=0.05, initial_value=init_conf_threshold, ) gui_point_size = server.gui.add_slider( "Point size", min=0.00001, max=0.01, step=0.0001, initial_value=0.00001 ) # Change from "Frame Selector" to more descriptive name gui_frame_selector = server.gui.add_dropdown( "Filter by Frame", # More action-oriented name options=["All"] + [str(i) for i in range(S)], initial_value="All", ) # Initial mask shows all points passing confidence threshold init_conf_mask = conf > init_conf_threshold point_cloud = server.scene.add_point_cloud( name="viser_pcd", points=points[init_conf_mask], colors=colors[init_conf_mask], point_size=gui_point_size.value, point_shape="circle", ) frames: List[viser.FrameHandle] = [] def visualize_frames(extrinsics: np.ndarray, intrinsics: np.ndarray, images: np.ndarray) -> None: """Send all COLMAP elements to viser for visualization. This could be optimized a ton!""" extrinsics = np.copy(extrinsics) # Remove existing image frames. for frame in frames: frame.remove() frames.clear() def attach_callback( frustum: viser.CameraFrustumHandle, frame: viser.FrameHandle ) -> None: @frustum.on_click def _(_) -> None: for client in server.get_clients().values(): client.camera.wxyz = frame.wxyz client.camera.position = frame.position img_ids = sorted(range(S)) for img_id in tqdm(img_ids): cam_to_world = extrinsics[img_id] T_world_camera = tf.SE3.from_matrix(cam_to_world) ratio = 1 frame = server.scene.add_frame( f"frame_{img_id}", wxyz=T_world_camera.rotation().wxyz, position=T_world_camera.translation(), axes_length=0.05/ratio, axes_radius=0.002/ratio, origin_radius = 0.002/ratio ) frames.append(frame) img = images[img_id] img = (img.transpose(1, 2, 0) * 255).astype(np.uint8) # import pdb;pdb.set_trace() H, W = img.shape[:2] # fy = intrinsics[img_id, 1, 1] * H fy = 1.1 * H image = img # image = image[::downsample_factor, ::downsample_factor] frustum = server.scene.add_camera_frustum( f"frame_{img_id}/frustum", fov=2 * np.arctan2(H / 2, fy), aspect=W / H, scale=0.05/ratio, image=image, line_width=1.0, # line_thickness=0.01, ) attach_callback(frustum, frame) @gui_points_conf.on_update def _(_) -> None: conf_mask = conf > gui_points_conf.value frame_mask = np.ones_like(conf_mask) # Default to all frames if gui_frame_selector.value != "All": selected_idx = int(gui_frame_selector.value) frame_mask = (frame_indices == selected_idx) combined_mask = conf_mask & frame_mask point_cloud.points = points[combined_mask] point_cloud.colors = colors[combined_mask] @gui_point_size.on_update def _(_) -> None: point_cloud.point_size = gui_point_size.value @gui_frame_selector.on_update def _(_) -> None: """Update points based on frame selection.""" conf_mask = conf > gui_points_conf.value if gui_frame_selector.value == "All": # Show all points passing confidence threshold point_cloud.points = points[conf_mask] point_cloud.colors = colors[conf_mask] else: # Show only selected frame's points selected_idx = int(gui_frame_selector.value) frame_mask = (frame_indices == selected_idx) combined_mask = conf_mask & frame_mask point_cloud.points = points[combined_mask] point_cloud.colors = colors[combined_mask] # Move camera to selected frame # if 0 <= selected_idx < len(frames): # selected_frame = frames[selected_idx] # for client in server.get_clients().values(): # client.camera.wxyz = selected_frame.wxyz # client.camera.position = selected_frame.position # Initial visualization visualize_frames(extrinsics, None, images) # # Start server update loop in a background thread def server_loop(): while True: time.sleep(1e-3) # Small sleep to prevent CPU hogging thread = threading.Thread(target=server_loop, daemon=True) thread.start() def closed_form_inverse_se3(se3, R=None, T=None): """ Compute the inverse of each 4x4 (or 3x4) SE3 matrix in a batch. If `R` and `T` are provided, they must correspond to the rotation and translation components of `se3`. Otherwise, they will be extracted from `se3`. Args: se3: Nx4x4 or Nx3x4 array or tensor of SE3 matrices. R (optional): Nx3x3 array or tensor of rotation matrices. T (optional): Nx3x1 array or tensor of translation vectors. Returns: Inverted SE3 matrices with the same type and device as `se3`. Shapes: se3: (N, 4, 4) R: (N, 3, 3) T: (N, 3, 1) """ # Check if se3 is a numpy array or a torch tensor is_numpy = isinstance(se3, np.ndarray) # Validate shapes if se3.shape[-2:] != (4, 4) and se3.shape[-2:] != (3, 4): raise ValueError(f"se3 must be of shape (N,4,4), got {se3.shape}.") # Extract R and T if not provided if R is None: R = se3[:, :3, :3] # (N,3,3) if T is None: T = se3[:, :3, 3:] # (N,3,1) # Transpose R if is_numpy: # Compute the transpose of the rotation for NumPy R_transposed = np.transpose(R, (0, 2, 1)) # -R^T t for NumPy top_right = -np.matmul(R_transposed, T) inverted_matrix = np.tile(np.eye(4), (len(R), 1, 1)) else: R_transposed = R.transpose(1, 2) # (N,3,3) top_right = -torch.bmm(R_transposed, T) # (N,3,1) inverted_matrix = torch.eye(4, 4)[None].repeat(len(R), 1, 1) inverted_matrix = inverted_matrix.to(R.dtype).to(R.device) inverted_matrix[:, :3, :3] = R_transposed inverted_matrix[:, :3, 3:] = top_right return inverted_matrix