import os import matplotlib import matplotlib.pyplot as plt import copy from evo.core.trajectory import PosePath3D, PoseTrajectory3D from evo.main_ape import ape from evo.tools import plot from evo.core import sync from evo.tools import file_interface from evo.core import metrics import evo import torch import numpy as np from scipy.spatial.transform import Slerp from scipy.spatial.transform import Rotation as R import scipy.interpolate as si def interp_poses(c2ws, N_views): N_inputs = c2ws.shape[0] trans = c2ws[:, :3, 3:].permute(2, 1, 0) rots = c2ws[:, :3, :3] render_poses = [] rots = R.from_matrix(rots) slerp = Slerp(np.linspace(0, 1, N_inputs), rots) interp_rots = torch.tensor( slerp(np.linspace(0, 1, N_views)).as_matrix().astype(np.float32)) interp_trans = torch.nn.functional.interpolate( trans, size=N_views, mode='linear').permute(2, 1, 0) render_poses = torch.cat([interp_rots, interp_trans], dim=2) render_poses = convert3x4_4x4(render_poses) return render_poses def interp_poses_bspline(c2ws, N_novel_imgs, input_times, degree): target_trans = torch.tensor(scipy_bspline( c2ws[:, :3, 3], n=N_novel_imgs, degree=degree, periodic=False).astype(np.float32)).unsqueeze(2) rots = R.from_matrix(c2ws[:, :3, :3]) slerp = Slerp(input_times, rots) target_times = np.linspace(input_times[0], input_times[-1], N_novel_imgs) target_rots = torch.tensor( slerp(target_times).as_matrix().astype(np.float32)) target_poses = torch.cat([target_rots, target_trans], dim=2) target_poses = convert3x4_4x4(target_poses) return target_poses def poses_avg(poses): hwf = poses[0, :3, -1:] center = poses[:, :3, 3].mean(0) vec2 = normalize(poses[:, :3, 2].sum(0)) up = poses[:, :3, 1].sum(0) c2w = np.concatenate([viewmatrix(vec2, up, center), hwf], 1) return c2w def normalize(v): """Normalize a vector.""" return v / np.linalg.norm(v) def viewmatrix(z, up, pos): vec2 = normalize(z) vec1_avg = up vec0 = normalize(np.cross(vec1_avg, vec2)) vec1 = normalize(np.cross(vec2, vec0)) m = np.stack([vec0, vec1, vec2, pos], 1) return m def render_path_spiral(c2w, up, rads, focal, zdelta, zrate, rots, N): render_poses = [] rads = np.array(list(rads) + [1.]) hwf = c2w[:, 4:5] for theta in np.linspace(0., 2. * np.pi * rots, N+1)[:-1]: # c = np.dot(c2w[:3,:4], np.array([0.7*np.cos(theta) , -0.3*np.sin(theta) , -np.sin(theta*zrate) *0.1, 1.]) * rads) # c = np.dot(c2w[:3,:4], np.array([0.3*np.cos(theta) , -0.3*np.sin(theta) , -np.sin(theta*zrate) *0.01, 1.]) * rads) c = np.dot(c2w[:3, :4], np.array( [0.2*np.cos(theta), -0.2*np.sin(theta), -np.sin(theta*zrate) * 0.1, 1.]) * rads) z = normalize(c - np.dot(c2w[:3, :4], np.array([0, 0, -focal, 1.]))) render_poses.append(np.concatenate([viewmatrix(z, up, c), hwf], 1)) return render_poses def scipy_bspline(cv, n=100, degree=3, periodic=False): """ Calculate n samples on a bspline cv : Array ov control vertices n : Number of samples to return degree: Curve degree periodic: True - Curve is closed """ cv = np.asarray(cv) count = cv.shape[0] # Closed curve if periodic: kv = np.arange(-degree, count+degree+1) factor, fraction = divmod(count+degree+1, count) cv = np.roll(np.concatenate( (cv,) * factor + (cv[:fraction],)), -1, axis=0) degree = np.clip(degree, 1, degree) # Opened curve else: degree = np.clip(degree, 1, count-1) kv = np.clip(np.arange(count+degree+1)-degree, 0, count-degree) # Return samples max_param = count - (degree * (1-periodic)) spl = si.BSpline(kv, cv, degree) return spl(np.linspace(0, max_param, n)) def generate_spiral_nerf(learned_poses, bds, N_novel_views, hwf): learned_poses_ = np.concatenate((learned_poses[:, :3, :4].detach( ).cpu().numpy(), hwf[:len(learned_poses)]), axis=-1) c2w = poses_avg(learned_poses_) print('recentered', c2w.shape) # Get spiral # Get average pose up = normalize(learned_poses_[:, :3, 1].sum(0)) # Find a reasonable "focus depth" for this dataset close_depth, inf_depth = bds.min()*.9, bds.max()*5. dt = .75 mean_dz = 1./(((1.-dt)/close_depth + dt/inf_depth)) focal = mean_dz # Get radii for spiral path shrink_factor = .8 zdelta = close_depth * .2 tt = learned_poses_[:, :3, 3] # ptstocam(poses[:3,3,:].T, c2w).T rads = np.percentile(np.abs(tt), 90, 0) c2w_path = c2w N_rots = 2 c2ws = render_path_spiral( c2w_path, up, rads, focal, zdelta, zrate=.5, rots=N_rots, N=N_novel_views) c2ws = torch.tensor(np.stack(c2ws).astype(np.float32)) c2ws = c2ws[:, :3, :4] c2ws = convert3x4_4x4(c2ws) return c2ws def convert3x4_4x4(input): """ :param input: (N, 3, 4) or (3, 4) torch or np :return: (N, 4, 4) or (4, 4) torch or np """ if torch.is_tensor(input): if len(input.shape) == 3: output = torch.cat([input, torch.zeros_like( input[:, 0:1])], dim=1) # (N, 4, 4) output[:, 3, 3] = 1.0 else: output = torch.cat([input, torch.tensor( [[0, 0, 0, 1]], dtype=input.dtype, device=input.device)], dim=0) # (4, 4) else: if len(input.shape) == 3: output = np.concatenate( [input, np.zeros_like(input[:, 0:1])], axis=1) # (N, 4, 4) output[:, 3, 3] = 1.0 else: output = np.concatenate( [input, np.array([[0, 0, 0, 1]], dtype=input.dtype)], axis=0) # (4, 4) output[3, 3] = 1.0 return output plt.rc('legend', fontsize=20) # using a named size def plot_pose(ref_poses, est_poses, output_path, args, vid=False): ref_poses = [pose for pose in ref_poses] if isinstance(est_poses, dict): est_poses = [pose for k, pose in est_poses.items()] else: est_poses = [pose for pose in est_poses] traj_ref = PosePath3D(poses_se3=ref_poses) traj_est = PosePath3D(poses_se3=est_poses) traj_est_aligned = copy.deepcopy(traj_est) traj_est_aligned.align(traj_ref, correct_scale=True, correct_only_scale=False) if vid: for p_idx in range(len(ref_poses)): fig = plt.figure() current_est_aligned = traj_est_aligned.poses_se3[:p_idx+1] current_ref = traj_ref.poses_se3[:p_idx+1] current_est_aligned = PosePath3D(poses_se3=current_est_aligned) current_ref = PosePath3D(poses_se3=current_ref) traj_by_label = { # "estimate (not aligned)": traj_est, "Ours (aligned)": current_est_aligned, "Ground-truth": current_ref } plot_mode = plot.PlotMode.xyz # ax = plot.prepare_axis(fig, plot_mode, 111) ax = fig.add_subplot(111, projection="3d") ax.xaxis.set_tick_params(labelbottom=False) ax.yaxis.set_tick_params(labelleft=False) ax.zaxis.set_tick_params(labelleft=False) colors = ['r', 'b'] styles = ['-', '--'] for idx, (label, traj) in enumerate(traj_by_label.items()): plot.traj(ax, plot_mode, traj, styles[idx], colors[idx], label) # break # plot.trajectories(fig, traj_by_label, plot.PlotMode.xyz) ax.view_init(elev=10., azim=45) plt.tight_layout() os.makedirs(os.path.join(os.path.dirname( output_path), 'pose_vid'), exist_ok=True) pose_vis_path = os.path.join(os.path.dirname( output_path), 'pose_vid', 'pose_vis_{:03d}.png'.format(p_idx)) print(pose_vis_path) fig.savefig(pose_vis_path) # else: fig = plt.figure() fig.patch.set_facecolor('white') # 把背景设置为纯白色 traj_by_label = { # "estimate (not aligned)": traj_est, "Ours (aligned)": traj_est_aligned, # "NoPe-NeRF (aligned)": traj_est_aligned, # "CF-3DGS (aligned)": traj_est_aligned, # "NeRFmm (aligned)": traj_est_aligned, # args.method + " (aligned)": traj_est_aligned, "COLMAP (GT)": traj_ref # "Ground-truth": traj_ref } plot_mode = plot.PlotMode.xyz # ax = plot.prepare_axis(fig, plot_mode, 111) ax = fig.add_subplot(111, projection="3d") ax.set_facecolor('white') # 把子图设置为纯白色 ax.xaxis.set_tick_params(labelbottom=True) ax.yaxis.set_tick_params(labelleft=True) ax.zaxis.set_tick_params(labelleft=True) colors = ['#2c9e38', '#d12920'] # # colors = ['#2c9e38', '#a72126'] # # colors = ['r', 'b'] styles = ['-', '--'] for idx, (label, traj) in enumerate(traj_by_label.items()): plot.traj(ax, plot_mode, traj, styles[idx], colors[idx], label) # break # plot.trajectories(fig, traj_by_label, plot.PlotMode.xyz) ax.view_init(elev=30., azim=45) # ax.view_init(elev=10., azim=45) plt.tight_layout() pose_vis_path = output_path / f'pose_vis.png' # pose_vis_path = os.path.join(os.path.dirname(output_path), f'pose_vis_{args.method}_{args.scene}.png') fig.savefig(pose_vis_path) # path_parts = args.pose_path.split('/') # tmp_vis_path = '/'.join(path_parts[:-1]) + '/all_vis' # tmp_vis_path2 = os.path.join(tmp_vis_path, f'pose_vis_{args.method}_{args.scene}.png') # fig.savefig(tmp_vis_path2)