import math import numpy as np import torch import torch.nn.functional as F from typing import Tuple from utils.stepfun import sample_np, sample import scipy def quad2rotation(q): """ Convert quaternion to rotation in batch. Since all operation in pytorch, support gradient passing. Args: quad (tensor, batch_size*4): quaternion. Returns: rot_mat (tensor, batch_size*3*3): rotation. """ # bs = quad.shape[0] # qr, qi, qj, qk = quad[:, 0], quad[:, 1], quad[:, 2], quad[:, 3] # two_s = 2.0 / (quad * quad).sum(-1) # rot_mat = torch.zeros(bs, 3, 3).to(quad.get_device()) # rot_mat[:, 0, 0] = 1 - two_s * (qj**2 + qk**2) # rot_mat[:, 0, 1] = two_s * (qi * qj - qk * qr) # rot_mat[:, 0, 2] = two_s * (qi * qk + qj * qr) # rot_mat[:, 1, 0] = two_s * (qi * qj + qk * qr) # rot_mat[:, 1, 1] = 1 - two_s * (qi**2 + qk**2) # rot_mat[:, 1, 2] = two_s * (qj * qk - qi * qr) # rot_mat[:, 2, 0] = two_s * (qi * qk - qj * qr) # rot_mat[:, 2, 1] = two_s * (qj * qk + qi * qr) # rot_mat[:, 2, 2] = 1 - two_s * (qi**2 + qj**2) # return rot_mat if not isinstance(q, torch.Tensor): q = torch.tensor(q).cuda() norm = torch.sqrt( q[:, 0] * q[:, 0] + q[:, 1] * q[:, 1] + q[:, 2] * q[:, 2] + q[:, 3] * q[:, 3] ) q = q / norm[:, None] rot = torch.zeros((q.size(0), 3, 3)).to(q) r = q[:, 0] x = q[:, 1] y = q[:, 2] z = q[:, 3] rot[:, 0, 0] = 1 - 2 * (y * y + z * z) rot[:, 0, 1] = 2 * (x * y - r * z) rot[:, 0, 2] = 2 * (x * z + r * y) rot[:, 1, 0] = 2 * (x * y + r * z) rot[:, 1, 1] = 1 - 2 * (x * x + z * z) rot[:, 1, 2] = 2 * (y * z - r * x) rot[:, 2, 0] = 2 * (x * z - r * y) rot[:, 2, 1] = 2 * (y * z + r * x) rot[:, 2, 2] = 1 - 2 * (x * x + y * y) return rot def get_camera_from_tensor(inputs): """ Convert quaternion and translation to transformation matrix. """ if not isinstance(inputs, torch.Tensor): inputs = torch.tensor(inputs).cuda() N = len(inputs.shape) if N == 1: inputs = inputs.unsqueeze(0) # quad, T = inputs[:, :4], inputs[:, 4:] # # normalize quad # quad = F.normalize(quad) # R = quad2rotation(quad) # RT = torch.cat([R, T[:, :, None]], 2) # # Add homogenous row # homogenous_row = torch.tensor([0, 0, 0, 1]).cuda() # RT = torch.cat([RT, homogenous_row[None, None, :].repeat(N, 1, 1)], 1) # if N == 1: # RT = RT[0] # return RT quad, T = inputs[:, :4], inputs[:, 4:] w2c = torch.eye(4).to(inputs).float() w2c[:3, :3] = quad2rotation(quad) w2c[:3, 3] = T return w2c def quadmultiply(q1, q2): """ Multiply two quaternions together using quaternion arithmetic """ # Extract scalar and vector parts of the quaternions w1, x1, y1, z1 = q1.unbind(dim=-1) w2, x2, y2, z2 = q2.unbind(dim=-1) # Calculate the quaternion product result_quaternion = torch.stack( [ w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2, w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2, w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2, w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2, ], dim=-1, ) return result_quaternion def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor: """ Returns torch.sqrt(torch.max(0, x)) but with a zero subgradient where x is 0. Source: https://pytorch3d.readthedocs.io/en/latest/_modules/pytorch3d/transforms/rotation_conversions.html#matrix_to_quaternion """ ret = torch.zeros_like(x) positive_mask = x > 0 ret[positive_mask] = torch.sqrt(x[positive_mask]) return ret def rotation2quad(matrix: torch.Tensor) -> torch.Tensor: """ Convert rotations given as rotation matrices to quaternions. Args: matrix: Rotation matrices as tensor of shape (..., 3, 3). Returns: quaternions with real part first, as tensor of shape (..., 4). Source: https://pytorch3d.readthedocs.io/en/latest/_modules/pytorch3d/transforms/rotation_conversions.html#matrix_to_quaternion """ if matrix.size(-1) != 3 or matrix.size(-2) != 3: raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") if not isinstance(matrix, torch.Tensor): matrix = torch.tensor(matrix).cuda() batch_dim = matrix.shape[:-2] m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind( matrix.reshape(batch_dim + (9,)), dim=-1 ) q_abs = _sqrt_positive_part( torch.stack( [ 1.0 + m00 + m11 + m22, 1.0 + m00 - m11 - m22, 1.0 - m00 + m11 - m22, 1.0 - m00 - m11 + m22, ], dim=-1, ) ) # we produce the desired quaternion multiplied by each of r, i, j, k quat_by_rijk = torch.stack( [ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and # `int`. torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1), # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and # `int`. torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1), # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and # `int`. torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1), # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and # `int`. torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1), ], dim=-2, ) # We floor here at 0.1 but the exact level is not important; if q_abs is small, # the candidate won't be picked. flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device) quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr)) # if not for numerical problems, quat_candidates[i] should be same (up to a sign), # forall i; we pick the best-conditioned one (with the largest denominator) return quat_candidates[ F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, : ].reshape(batch_dim + (4,)) def get_tensor_from_camera(RT, Tquad=False): """ Convert transformation matrix to quaternion and translation. """ # gpu_id = -1 # if type(RT) == torch.Tensor: # if RT.get_device() != -1: # gpu_id = RT.get_device() # RT = RT.detach().cpu() # RT = RT.numpy() # from mathutils import Matrix # # R, T = RT[:3, :3], RT[:3, 3] # rot = Matrix(R) # quad = rot.to_quaternion() # if Tquad: # tensor = np.concatenate([T, quad], 0) # else: # tensor = np.concatenate([quad, T], 0) # tensor = torch.from_numpy(tensor).float() # if gpu_id != -1: # tensor = tensor.to(gpu_id) # return tensor if not isinstance(RT, torch.Tensor): RT = torch.tensor(RT).cuda() rot = RT[:3, :3].unsqueeze(0).detach() quat = rotation2quad(rot).squeeze() tran = RT[:3, 3].detach() return torch.cat([quat, tran]) def normalize(x): return x / np.linalg.norm(x) def viewmatrix(lookdir, up, position, subtract_position=False): """Construct lookat view matrix.""" vec2 = normalize((lookdir - position) if subtract_position else lookdir) vec0 = normalize(np.cross(up, vec2)) vec1 = normalize(np.cross(vec2, vec0)) m = np.stack([vec0, vec1, vec2, position], axis=1) return m def poses_avg(poses): """New pose using average position, z-axis, and up vector of input poses.""" position = poses[:, :3, 3].mean(0) z_axis = poses[:, :3, 2].mean(0) up = poses[:, :3, 1].mean(0) cam2world = viewmatrix(z_axis, up, position) return cam2world def focus_point_fn(poses): """Calculate nearest point to all focal axes in poses.""" directions, origins = poses[:, :3, 2:3], poses[:, :3, 3:4] m = np.eye(3) - directions * np.transpose(directions, [0, 2, 1]) mt_m = np.transpose(m, [0, 2, 1]) @ m focus_pt = np.linalg.inv(mt_m.mean(0)) @ (mt_m @ origins).mean(0)[:, 0] return focus_pt def pad_poses(p): """Pad [..., 3, 4] pose matrices with a homogeneous bottom row [0,0,0,1].""" bottom = np.broadcast_to([0, 0, 0, 1.], p[..., :1, :4].shape) return np.concatenate([p[..., :3, :4], bottom], axis=-2) def unpad_poses(p): """Remove the homogeneous bottom row from [..., 4, 4] pose matrices.""" return p[..., :3, :4] def transform_poses_pca(poses): """Transforms poses so principal components lie on XYZ axes. Args: poses: a (N, 3, 4) array containing the cameras' camera to world transforms. Returns: A tuple (poses, transform), with the transformed poses and the applied camera_to_world transforms. """ t = poses[:, :3, 3] t_mean = t.mean(axis=0) t = t - t_mean eigval, eigvec = np.linalg.eig(t.T @ t) # Sort eigenvectors in order of largest to smallest eigenvalue. inds = np.argsort(eigval)[::-1] eigvec = eigvec[:, inds] rot = eigvec.T if np.linalg.det(rot) < 0: rot = np.diag(np.array([1, 1, -1])) @ rot transform = np.concatenate([rot, rot @ -t_mean[:, None]], -1) poses_recentered = unpad_poses(transform @ pad_poses(poses)) transform = np.concatenate([transform, np.eye(4)[3:]], axis=0) # Flip coordinate system if z component of y-axis is negative if poses_recentered.mean(axis=0)[2, 1] < 0: poses_recentered = np.diag(np.array([1, -1, -1])) @ poses_recentered transform = np.diag(np.array([1, -1, -1, 1])) @ transform # Just make sure it's it in the [-1, 1]^3 cube scale_factor = 1. / np.max(np.abs(poses_recentered[:, :3, 3])) poses_recentered[:, :3, 3] *= scale_factor transform = np.diag(np.array([scale_factor] * 3 + [1])) @ transform return poses_recentered, transform def recenter_poses(poses: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: """Recenter poses around the origin.""" cam2world = poses_avg(poses) transform = np.linalg.inv(pad_poses(cam2world)) poses = transform @ pad_poses(poses) return unpad_poses(poses), transform def generate_ellipse_path(views, n_frames=600, const_speed=True, z_variation=0., z_phase=0.): poses = [] for view in views: tmp_view = np.eye(4) tmp_view[:3] = np.concatenate([view.R.T, view.T[:, None]], 1) tmp_view = np.linalg.inv(tmp_view) tmp_view[:, 1:3] *= -1 poses.append(tmp_view) poses = np.stack(poses, 0) poses, transform = transform_poses_pca(poses) # Calculate the focal point for the path (cameras point toward this). center = focus_point_fn(poses) # Path height sits at z=0 (in middle of zero-mean capture pattern). offset = np.array([center[0] , center[1], 0 ]) # Calculate scaling for ellipse axes based on input camera positions. sc = np.percentile(np.abs(poses[:, :3, 3] - offset), 90, axis=0) # Use ellipse that is symmetric about the focal point in xy. low = -sc + offset high = sc + offset # Optional height variation need not be symmetric z_low = np.percentile((poses[:, :3, 3]), 10, axis=0) z_high = np.percentile((poses[:, :3, 3]), 90, axis=0) def get_positions(theta): # Interpolate between bounds with trig functions to get ellipse in x-y. # Optionally also interpolate in z to change camera height along path. return np.stack([ (low[0] + (high - low)[0] * (np.cos(theta) * .5 + .5)), (low[1] + (high - low)[1] * (np.sin(theta) * .5 + .5)), z_variation * (z_low[2] + (z_high - z_low)[2] * (np.cos(theta + 2 * np.pi * z_phase) * .5 + .5)), ], -1) theta = np.linspace(0, 2. * np.pi, n_frames + 1, endpoint=True) positions = get_positions(theta) if const_speed: # Resample theta angles so that the velocity is closer to constant. lengths = np.linalg.norm(positions[1:] - positions[:-1], axis=-1) theta = sample_np(None, theta, np.log(lengths), n_frames + 1) positions = get_positions(theta) # Throw away duplicated last position. positions = positions[:-1] # Set path's up vector to axis closest to average of input pose up vectors. avg_up = poses[:, :3, 1].mean(0) avg_up = avg_up / np.linalg.norm(avg_up) ind_up = np.argmax(np.abs(avg_up)) up = np.eye(3)[ind_up] * np.sign(avg_up[ind_up]) # up = normalize(poses[:, :3, 1].sum(0)) render_poses = [] for p in positions: render_pose = np.eye(4) render_pose[:3] = viewmatrix(p - center, up, p) render_pose = np.linalg.inv(transform) @ render_pose render_pose[:3, 1:3] *= -1 render_poses.append(np.linalg.inv(render_pose)) return render_poses def generate_spiral_path(poses_arr, n_frames: int = 180, n_rots: int = 2, zrate: float = .5) -> np.ndarray: """Calculates a forward facing spiral path for rendering.""" poses = poses_arr[:, :-2].reshape([-1, 3, 5]) bounds = poses_arr[:, -2:] fix_rotation = np.array([ [0, -1, 0, 0], [1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1], ], dtype=np.float32) poses = poses[:, :3, :4] @ fix_rotation scale = 1. / (bounds.min() * .75) poses[:, :3, 3] *= scale bounds *= scale poses, transform = recenter_poses(poses) close_depth, inf_depth = bounds.min() * .9, bounds.max() * 5. dt = .75 focal = 1 / (((1 - dt) / close_depth + dt / inf_depth)) # Get radii for spiral path using 90th percentile of camera positions. positions = poses[:, :3, 3] radii = np.percentile(np.abs(positions), 90, 0) radii = np.concatenate([radii, [1.]]) # Generate poses for spiral path. render_poses = [] cam2world = poses_avg(poses) up = poses[:, :3, 1].mean(0) for theta in np.linspace(0., 2. * np.pi * n_rots, n_frames, endpoint=False): t = radii * [np.cos(theta), -np.sin(theta), -np.sin(theta * zrate), 1.] position = cam2world @ t lookat = cam2world @ [0, 0, -focal, 1.] z_axis = position - lookat render_pose = np.eye(4) render_pose[:3] = viewmatrix(z_axis, up, position) render_pose = np.linalg.inv(transform) @ render_pose render_pose[:3, 1:3] *= -1 render_pose[:3, 3] /= scale render_poses.append(np.linalg.inv(render_pose)) render_poses = np.stack(render_poses, axis=0) return render_poses def generate_interpolated_path( views, n_interp, spline_degree = 5, smoothness = 0.03, rot_weight = 0.1, lock_up = False, fixed_up_vector = None, lookahead_i = None, frames_per_colmap = None, const_speed = False, n_buffer = None, periodic = False, n_interp_as_total = False, ): """Creates a smooth spline path between input keyframe camera poses. Spline is calculated with poses in format (position, lookat-point, up-point). Args: poses: (n, 3, 4) array of input pose keyframes. n_interp: returned path will have n_interp * (n - 1) total poses. spline_degree: polynomial degree of B-spline. smoothness: parameter for spline smoothing, 0 forces exact interpolation. rot_weight: relative weighting of rotation/translation in spline solve. lock_up: if True, forced to use given Up and allow Lookat to vary. fixed_up_vector: replace the interpolated `up` with a fixed vector. lookahead_i: force the look direction to look at the pose `i` frames ahead. frames_per_colmap: conversion factor for the desired average velocity. const_speed: renormalize spline to have constant delta between each pose. n_buffer: Number of buffer frames to insert at the start and end of the path. Helps keep the ends of a spline path straight. periodic: make the spline path periodic (perfect loop). n_interp_as_total: use n_interp as total number of poses in path rather than the number of poses to interpolate between each input. Returns: Array of new camera poses with shape (n_interp * (n - 1), 3, 4), or (n_interp, 3, 4) if n_interp_as_total is set. """ poses = [] for view in views: tmp_view = np.eye(4) tmp_view[:3] = np.concatenate([view.R.T, view.T[:, None]], 1) tmp_view = np.linalg.inv(tmp_view) tmp_view[:, 1:3] *= -1 poses.append(tmp_view) poses = np.stack(poses, 0) def poses_to_points(poses, dist): """Converts from pose matrices to (position, lookat, up) format.""" pos = poses[:, :3, -1] lookat = poses[:, :3, -1] - dist * poses[:, :3, 2] up = poses[:, :3, -1] + dist * poses[:, :3, 1] return np.stack([pos, lookat, up], 1) def points_to_poses(points): """Converts from (position, lookat, up) format to pose matrices.""" poses = [] for i in range(len(points)): pos, lookat_point, up_point = points[i] if lookahead_i is not None: if i + lookahead_i < len(points): lookat = pos - points[i + lookahead_i][0] else: lookat = pos - lookat_point up = (up_point - pos) if fixed_up_vector is None else fixed_up_vector poses.append(viewmatrix(lookat, up, pos)) return np.array(poses) def insert_buffer_poses(poses, n_buffer): """Insert extra poses at the start and end of the path.""" def average_distance(points): distances = np.linalg.norm(points[1:] - points[0:-1], axis=-1) return np.mean(distances) def shift(pose, dz): result = np.copy(pose) z = result[:3, 2] z /= np.linalg.norm(z) # Move along forward-backward axis. -z is forward. result[:3, 3] += z * dz return result dz = average_distance(poses[:, :3, 3]) prefix = np.stack([shift(poses[0], (i + 1) * dz) for i in range(n_buffer)]) prefix = prefix[::-1] # reverse order suffix = np.stack( [shift(poses[-1], -(i + 1) * dz) for i in range(n_buffer)] ) result = np.concatenate([prefix, poses, suffix]) return result def remove_buffer_poses(poses, u, n_frames, u_keyframes, n_buffer): u_keyframes = u_keyframes[n_buffer:-n_buffer] mask = (u >= u_keyframes[0]) & (u <= u_keyframes[-1]) poses = poses[mask] u = u[mask] n_frames = len(poses) return poses, u, n_frames, u_keyframes def interp(points, u, k, s): """Runs multidimensional B-spline interpolation on the input points.""" sh = points.shape pts = np.reshape(points, (sh[0], -1)) k = min(k, sh[0] - 1) tck, u_keyframes = scipy.interpolate.splprep(pts.T, k=k, s=s, per=periodic) new_points = np.array(scipy.interpolate.splev(u, tck)) new_points = np.reshape(new_points.T, (len(u), sh[1], sh[2])) return new_points, u_keyframes if n_buffer is not None: poses = insert_buffer_poses(poses, n_buffer) points = poses_to_points(poses, dist=rot_weight) if n_interp_as_total: n_frames = n_interp + 1 # Add extra since final pose is discarded. else: n_frames = n_interp * (points.shape[0] - 1) u = np.linspace(0, 1, n_frames, endpoint=True) new_points, u_keyframes = interp(points, u=u, k=spline_degree, s=smoothness) poses = points_to_poses(new_points) if n_buffer is not None: poses, u, n_frames, u_keyframes = remove_buffer_poses( poses, u, n_frames, u_keyframes, n_buffer ) # poses, transform = transform_poses_pca(poses) if frames_per_colmap is not None: # Recalculate the number of frames to achieve desired average velocity. positions = poses[:, :3, -1] lengths = np.linalg.norm(positions[1:] - positions[:-1], axis=-1) total_length_colmap = lengths.sum() print('old n_frames:', n_frames) print('total_length_colmap:', total_length_colmap) n_frames = int(total_length_colmap * frames_per_colmap) print('new n_frames:', n_frames) u = np.linspace( np.min(u_keyframes), np.max(u_keyframes), n_frames, endpoint=True ) new_points, _ = interp(points, u=u, k=spline_degree, s=smoothness) poses = points_to_poses(new_points) if const_speed: # Resample timesteps so that the velocity is nearly constant. positions = poses[:, :3, -1] lengths = np.linalg.norm(positions[1:] - positions[:-1], axis=-1) u = sample(None, u, np.log(lengths), n_frames + 1) new_points, _ = interp(points, u=u, k=spline_degree, s=smoothness) poses = points_to_poses(new_points) # return poses[:-1], u[:-1], u_keyframes return poses[:-1] def depth_to_pts3d(K, pose, W, H, depth): # Get depths and projection params if not provided assert (K[:, 0, 0] == K[:, 1, 1]).all() focals = K[:, 0, 0] pp = K[:, :2, 2] im_poses = pose grid = torch.tensor([xy_grid(W, H) for _ in range(len(depth))] , device=depth.device) # get pointmaps in camera frame rel_ptmaps = _fast_depthmap_to_pts3d(depth, grid, focals, pp=pp) # project to world frame return geotrf(im_poses, rel_ptmaps) def xy_grid(W, H, device=None, origin=(0, 0), unsqueeze=None, cat_dim=-1, homogeneous=False, **arange_kw): """ Output a (H,W,2) array of int32 with output[j,i,0] = i + origin[0] output[j,i,1] = j + origin[1] """ if device is None: # numpy arange, meshgrid, stack, ones = np.arange, np.meshgrid, np.stack, np.ones else: # torch arange = lambda *a, **kw: torch.arange(*a, device=device, **kw) meshgrid, stack = torch.meshgrid, torch.stack ones = lambda *a: torch.ones(*a, device=device) tw, th = [arange(o, o+s, **arange_kw) for s, o in zip((W, H), origin)] grid = meshgrid(tw, th, indexing='xy') if homogeneous: grid = grid + (ones((H, W)),) if unsqueeze is not None: grid = (grid[0].unsqueeze(unsqueeze), grid[1].unsqueeze(unsqueeze)) if cat_dim is not None: grid = stack(grid, cat_dim) return grid def geotrf(Trf, pts, ncol=None, norm=False): """ Apply a geometric transformation to a list of 3-D points. H: 3x3 or 4x4 projection matrix (typically a Homography) p: numpy/torch/tuple of coordinates. Shape must be (...,2) or (...,3) ncol: int. number of columns of the result (2 or 3) norm: float. if != 0, the resut is projected on the z=norm plane. Returns an array of projected 2d points. """ assert Trf.ndim >= 2 if isinstance(Trf, np.ndarray): pts = np.asarray(pts) elif isinstance(Trf, torch.Tensor): pts = torch.as_tensor(pts, dtype=Trf.dtype) # adapt shape if necessary output_reshape = pts.shape[:-1] ncol = ncol or pts.shape[-1] # optimized code if (isinstance(Trf, torch.Tensor) and isinstance(pts, torch.Tensor) and Trf.ndim == 3 and pts.ndim == 4): d = pts.shape[3] if Trf.shape[-1] == d: pts = torch.einsum("bij, bhwj -> bhwi", Trf, pts) elif Trf.shape[-1] == d+1: pts = torch.einsum("bij, bhwj -> bhwi", Trf[:, :d, :d], pts) + Trf[:, None, None, :d, d] else: raise ValueError(f'bad shape, not ending with 3 or 4, for {pts.shape=}') else: if Trf.ndim >= 3: n = Trf.ndim-2 assert Trf.shape[:n] == pts.shape[:n], 'batch size does not match' Trf = Trf.reshape(-1, Trf.shape[-2], Trf.shape[-1]) if pts.ndim > Trf.ndim: # Trf == (B,d,d) & pts == (B,H,W,d) --> (B, H*W, d) pts = pts.reshape(Trf.shape[0], -1, pts.shape[-1]) elif pts.ndim == 2: # Trf == (B,d,d) & pts == (B,d) --> (B, 1, d) pts = pts[:, None, :] if pts.shape[-1]+1 == Trf.shape[-1]: Trf = Trf.swapaxes(-1, -2) # transpose Trf pts = pts @ Trf[..., :-1, :] + Trf[..., -1:, :] elif pts.shape[-1] == Trf.shape[-1]: Trf = Trf.swapaxes(-1, -2) # transpose Trf pts = pts @ Trf else: pts = Trf @ pts.T if pts.ndim >= 2: pts = pts.swapaxes(-1, -2) if norm: pts = pts / pts[..., -1:] # DONT DO /= BECAUSE OF WEIRD PYTORCH BUG if norm != 1: pts *= norm res = pts[..., :ncol].reshape(*output_reshape, ncol) return res def _fast_depthmap_to_pts3d(depth, pixel_grid, focal, pp): pp = pp.unsqueeze(1) focal = focal[:, None, None] assert focal.shape == (len(depth), 1, 1) assert pp.shape == (len(depth), 1, 2) assert pixel_grid.shape == depth.shape + (2,) depth = depth.unsqueeze(-1) pixel_grid = pixel_grid.reshape(len(depth), -1, 2) depth = depth.reshape(len(depth), -1, 1) return torch.cat((depth * (pixel_grid - pp) / focal, depth), dim=-1)