import os from plyfile import PlyData, PlyElement from scipy.spatial.transform import Rotation import einops import numpy as np import torch import torchvision import trimesh import lightning as L import utils.loss_mask as loss_mask from src.mast3r_src.dust3r.dust3r.viz import OPENGL, pts3d_to_trimesh, cat_meshes class SaveBatchData(L.Callback): '''A Lightning callback that occasionally saves batch inputs and outputs to disk. It is not critical to the training process, and can be disabled if unwanted.''' def __init__(self, save_dir, train_save_interval=100, val_save_interval=100, test_save_interval=100): self.save_dir = save_dir self.train_save_interval = train_save_interval self.val_save_interval = val_save_interval self.test_save_interval = test_save_interval def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): if batch_idx % self.train_save_interval == 0 and trainer.global_rank == 0: self.save_batch_data('train', trainer, pl_module, batch, batch_idx) def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): if batch_idx % self.val_save_interval == 0 and trainer.global_rank == 0: self.save_batch_data('val', trainer, pl_module, batch, batch_idx) def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): if batch_idx % self.test_save_interval == 0 and trainer.global_rank == 0: self.save_batch_data('test', trainer, pl_module, batch, batch_idx) def save_batch_data(self, prefix, trainer, pl_module, batch, batch_idx): print(f'Saving {prefix} data at epoch {trainer.current_epoch} and batch {batch_idx}') # Run the batch through the model again _, _, h, w = batch["context"][0]["img"].shape view1, view2 = batch['context'] pred1, pred2 = pl_module.forward(view1, view2) color, depth = pl_module.decoder(batch, pred1, pred2, (h, w)) mask = loss_mask.calculate_loss_mask(batch) # Save the data save_dir = os.path.join( self.save_dir, f"{prefix}_epoch_{trainer.current_epoch}_batch_{batch_idx}" ) log_batch_files(batch, color, depth, mask, view1, view2, pred1, pred2, save_dir) def save_as_ply(pred1, pred2, save_path): """Save the 3D Gaussians as a point cloud in the PLY format. Adapted loosely from PixelSplat""" def construct_list_of_attributes(num_rest: int) -> list[str]: '''Construct a list of attributes for the PLY file format. This corresponds to the attributes used by online readers, such as https://niujinshuchong.github.io/mip-splatting-demo/index.html''' attributes = ["x", "y", "z", "nx", "ny", "nz"] for i in range(3): attributes.append(f"f_dc_{i}") for i in range(num_rest): attributes.append(f"f_rest_{i}") attributes.append("opacity") for i in range(3): attributes.append(f"scale_{i}") for i in range(4): attributes.append(f"rot_{i}") return attributes def covariance_to_quaternion_and_scale(covariance): '''Convert the covariance matrix to a four dimensional quaternion and a three dimensional scale vector''' # Perform singular value decomposition U, S, V = torch.linalg.svd(covariance) # The scale factors are the square roots of the eigenvalues scale = torch.sqrt(S) scale = scale.detach().cpu().numpy() # The rotation matrix is U*Vt rotation_matrix = torch.bmm(U, V.transpose(-2, -1)) rotation_matrix_np = rotation_matrix.detach().cpu().numpy() # Use scipy to convert the rotation matrix to a quaternion rotation = Rotation.from_matrix(rotation_matrix_np) quaternion = rotation.as_quat() return quaternion, scale # Collect the Gaussian parameters means = torch.stack([pred1["pts3d"], pred2["pts3d_in_other_view"]], dim=1) covariances = torch.stack([pred1["covariances"], pred2["covariances"]], dim=1) harmonics = torch.stack([pred1["sh"], pred2["sh"]], dim=1)[..., 0] # Only use the first harmonic opacities = torch.stack([pred1["opacities"], pred2["opacities"]], dim=1) # Rearrange the tensors to the correct shape means = einops.rearrange(means[0], "view h w xyz -> (view h w) xyz").detach().cpu().numpy() covariances = einops.rearrange(covariances[0], "v h w i j -> (v h w) i j") harmonics = einops.rearrange(harmonics[0], "view h w xyz -> (view h w) xyz").detach().cpu().numpy() opacities = einops.rearrange(opacities[0], "view h w xyz -> (view h w) xyz").detach().cpu().numpy() # Convert the covariance matrices to quaternions and scales rotations, scales = covariance_to_quaternion_and_scale(covariances) # Construct the attributes rest = np.zeros_like(means) attributes = np.concatenate((means, rest, harmonics, opacities, np.log(scales), rotations), axis=-1) dtype_full = [(attribute, "f4") for attribute in construct_list_of_attributes(0)] elements = np.empty(attributes.shape[0], dtype=dtype_full) elements[:] = list(map(tuple, attributes)) # Save the point cloud point_cloud = PlyElement.describe(elements, "vertex") scene = PlyData([point_cloud]) scene.write(save_path) def save_3d(view1, view2, pred1, pred2, save_dir, as_pointcloud=True, all_points=True): """Save the 3D points as a point cloud or as a mesh. Adapted from DUSt3R""" os.makedirs(save_dir, exist_ok=True) batch_size = pred1["pts3d"].shape[0] views = [view1, view2] for b in range(batch_size): pts3d = [pred1["pts3d"][b].cpu().numpy()] + [pred2["pts3d_in_other_view"][b].cpu().numpy()] imgs = [einops.rearrange(view["original_img"][b], "c h w -> h w c").cpu().numpy() for view in views] mask = [view["valid_mask"][b].cpu().numpy() for view in views] # Treat all pixels as valid, because we want to render the entire viewpoint if all_points: mask = [np.ones_like(m) for m in mask] # Construct the scene from the 3D points as a point cloud or as a mesh scene = trimesh.Scene() if as_pointcloud: pts = np.concatenate([p[m] for p, m in zip(pts3d, mask)]) col = np.concatenate([p[m] for p, m in zip(imgs, mask)]) pct = trimesh.PointCloud(pts.reshape(-1, 3), colors=col.reshape(-1, 3)) scene.add_geometry(pct) save_path = os.path.join(save_dir, f"{b}.ply") else: meshes = [] for i in range(len(imgs)): meshes.append(pts3d_to_trimesh(imgs[i], pts3d[i], mask[i])) mesh = trimesh.Trimesh(**cat_meshes(meshes)) scene.add_geometry(mesh) save_path = os.path.join(save_dir, f"{b}.glb") # Save the scene scene.export(file_obj=save_path) @torch.no_grad() def log_batch_files(batch, color, depth, mask, view1, view2, pred1, pred2, save_dir, should_save_3d=False): '''Save all the relevant debug files for a batch''' os.makedirs(save_dir, exist_ok=True) # Save the 3D Gaussians as a .ply file save_as_ply(pred1, pred2, os.path.join(save_dir, f"gaussians.ply")) # Save the 3D points as a point cloud and as a mesh (disabled) if should_save_3d: save_3d(view1, view2, pred1, pred2, os.path.join(save_dir, "3d_mesh"), as_pointcloud=False) save_3d(view1, view2, pred1, pred2, os.path.join(save_dir, "3d_pointcloud"), as_pointcloud=True) # Save the color, depth and valid masks for the input context images context_images = torch.stack([view["img"] for view in batch["context"]], dim=1) context_original_images = torch.stack([view["original_img"] for view in batch["context"]], dim=1) context_depthmaps = torch.stack([view["depthmap"] for view in batch["context"]], dim=1) context_valid_masks = torch.stack([view["valid_mask"] for view in batch["context"]], dim=1) for b in range(min(context_images.shape[0], 4)): torchvision.utils.save_image(context_images[b], os.path.join(save_dir, f"sample_{b}_img_context.jpg")) torchvision.utils.save_image(context_original_images[b], os.path.join(save_dir, f"sample_{b}_original_img_context.jpg")) torchvision.utils.save_image(context_depthmaps[b, :, None, ...], os.path.join(save_dir, f"sample_{b}_depthmap.jpg"), normalize=True) torchvision.utils.save_image(context_valid_masks[b, :, None, ...].float(), os.path.join(save_dir, f"sample_{b}_valid_mask_context.jpg"), normalize=True) # Save the color and depth images for the target images target_original_images = torch.stack([view["original_img"] for view in batch["target"]], dim=1) target_depthmaps = torch.stack([view["depthmap"] for view in batch["target"]], dim=1) context_valid_masks = torch.stack([view["valid_mask"] for view in batch["context"]], dim=1) for b in range(min(target_original_images.shape[0], 4)): torchvision.utils.save_image(target_original_images[b], os.path.join(save_dir, f"sample_{b}_original_img_target.jpg")) torchvision.utils.save_image(target_depthmaps[b, :, None, ...], os.path.join(save_dir, f"sample_{b}_depthmap_target.jpg"), normalize=True) # Save the rendered images and depths for b in range(min(color.shape[0], 4)): torchvision.utils.save_image(color[b, ...], os.path.join(save_dir, f"sample_{b}_rendered_color.jpg")) if depth is not None: for b in range(min(color.shape[0], 4)): torchvision.utils.save_image(depth[b, :, None, ...], os.path.join(save_dir, f"sample_{b}_rendered_depth.jpg"), normalize=True) # Save the loss masks for b in range(min(mask.shape[0], 4)): torchvision.utils.save_image(mask[b, :, None, ...].float(), os.path.join(save_dir, f"sample_{b}_loss_mask.jpg"), normalize=True) # Save the masked target and rendered images target_original_images = torch.stack([view["original_img"] for view in batch["target"]], dim=1) masked_target_original_images = target_original_images * mask[..., None, :, :] masked_predictions = color * mask[..., None, :, :] for b in range(min(target_original_images.shape[0], 4)): torchvision.utils.save_image(masked_target_original_images[b], os.path.join(save_dir, f"sample_{b}_masked_original_img_target.jpg")) torchvision.utils.save_image(masked_predictions[b], os.path.join(save_dir, f"sample_{b}_masked_rendered_color.jpg"))