Spaces:
Running
on
Zero
Running
on
Zero
import os | |
from plyfile import PlyData, PlyElement | |
from scipy.spatial.transform import Rotation | |
import einops | |
import numpy as np | |
import torch | |
import torchvision | |
import trimesh | |
import lightning as L | |
import utils.loss_mask as loss_mask | |
from src.mast3r_src.dust3r.dust3r.viz import OPENGL, pts3d_to_trimesh, cat_meshes | |
class SaveBatchData(L.Callback): | |
'''A Lightning callback that occasionally saves batch inputs and outputs to disk. | |
It is not critical to the training process, and can be disabled if unwanted.''' | |
def __init__(self, save_dir, train_save_interval=100, val_save_interval=100, test_save_interval=100): | |
self.save_dir = save_dir | |
self.train_save_interval = train_save_interval | |
self.val_save_interval = val_save_interval | |
self.test_save_interval = test_save_interval | |
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): | |
if batch_idx % self.train_save_interval == 0 and trainer.global_rank == 0: | |
self.save_batch_data('train', trainer, pl_module, batch, batch_idx) | |
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): | |
if batch_idx % self.val_save_interval == 0 and trainer.global_rank == 0: | |
self.save_batch_data('val', trainer, pl_module, batch, batch_idx) | |
def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): | |
if batch_idx % self.test_save_interval == 0 and trainer.global_rank == 0: | |
self.save_batch_data('test', trainer, pl_module, batch, batch_idx) | |
def save_batch_data(self, prefix, trainer, pl_module, batch, batch_idx): | |
print(f'Saving {prefix} data at epoch {trainer.current_epoch} and batch {batch_idx}') | |
# Run the batch through the model again | |
_, _, h, w = batch["context"][0]["img"].shape | |
view1, view2 = batch['context'] | |
pred1, pred2 = pl_module.forward(view1, view2) | |
color, depth = pl_module.decoder(batch, pred1, pred2, (h, w)) | |
mask = loss_mask.calculate_loss_mask(batch) | |
# Save the data | |
save_dir = os.path.join( | |
self.save_dir, | |
f"{prefix}_epoch_{trainer.current_epoch}_batch_{batch_idx}" | |
) | |
log_batch_files(batch, color, depth, mask, view1, view2, pred1, pred2, save_dir) | |
def save_as_ply(pred1, pred2, save_path): | |
"""Save the 3D Gaussians as a point cloud in the PLY format. | |
Adapted loosely from PixelSplat""" | |
def construct_list_of_attributes(num_rest: int) -> list[str]: | |
'''Construct a list of attributes for the PLY file format. This | |
corresponds to the attributes used by online readers, such as | |
https://niujinshuchong.github.io/mip-splatting-demo/index.html''' | |
attributes = ["x", "y", "z", "nx", "ny", "nz"] | |
for i in range(3): | |
attributes.append(f"f_dc_{i}") | |
for i in range(num_rest): | |
attributes.append(f"f_rest_{i}") | |
attributes.append("opacity") | |
for i in range(3): | |
attributes.append(f"scale_{i}") | |
for i in range(4): | |
attributes.append(f"rot_{i}") | |
return attributes | |
def covariance_to_quaternion_and_scale(covariance): | |
'''Convert the covariance matrix to a four dimensional quaternion and | |
a three dimensional scale vector''' | |
# Perform singular value decomposition | |
U, S, V = torch.linalg.svd(covariance) | |
# The scale factors are the square roots of the eigenvalues | |
scale = torch.sqrt(S) | |
scale = scale.detach().cpu().numpy() | |
# The rotation matrix is U*Vt | |
rotation_matrix = torch.bmm(U, V.transpose(-2, -1)) | |
rotation_matrix_np = rotation_matrix.detach().cpu().numpy() | |
# Use scipy to convert the rotation matrix to a quaternion | |
rotation = Rotation.from_matrix(rotation_matrix_np) | |
quaternion = rotation.as_quat() | |
return quaternion, scale | |
# Collect the Gaussian parameters | |
means = torch.stack([pred1["pts3d"], pred2["pts3d_in_other_view"]], dim=1) | |
covariances = torch.stack([pred1["covariances"], pred2["covariances"]], dim=1) | |
harmonics = torch.stack([pred1["sh"], pred2["sh"]], dim=1)[..., 0] # Only use the first harmonic | |
opacities = torch.stack([pred1["opacities"], pred2["opacities"]], dim=1) | |
# Rearrange the tensors to the correct shape | |
means = einops.rearrange(means[0], "view h w xyz -> (view h w) xyz").detach().cpu().numpy() | |
covariances = einops.rearrange(covariances[0], "v h w i j -> (v h w) i j") | |
harmonics = einops.rearrange(harmonics[0], "view h w xyz -> (view h w) xyz").detach().cpu().numpy() | |
opacities = einops.rearrange(opacities[0], "view h w xyz -> (view h w) xyz").detach().cpu().numpy() | |
# Convert the covariance matrices to quaternions and scales | |
rotations, scales = covariance_to_quaternion_and_scale(covariances) | |
# Construct the attributes | |
rest = np.zeros_like(means) | |
attributes = np.concatenate((means, rest, harmonics, opacities, np.log(scales), rotations), axis=-1) | |
dtype_full = [(attribute, "f4") for attribute in construct_list_of_attributes(0)] | |
elements = np.empty(attributes.shape[0], dtype=dtype_full) | |
elements[:] = list(map(tuple, attributes)) | |
# Save the point cloud | |
point_cloud = PlyElement.describe(elements, "vertex") | |
scene = PlyData([point_cloud]) | |
scene.write(save_path) | |
def save_3d(view1, view2, pred1, pred2, save_dir, as_pointcloud=True, all_points=True): | |
"""Save the 3D points as a point cloud or as a mesh. Adapted from DUSt3R""" | |
os.makedirs(save_dir, exist_ok=True) | |
batch_size = pred1["pts3d"].shape[0] | |
views = [view1, view2] | |
for b in range(batch_size): | |
pts3d = [pred1["pts3d"][b].cpu().numpy()] + [pred2["pts3d_in_other_view"][b].cpu().numpy()] | |
imgs = [einops.rearrange(view["original_img"][b], "c h w -> h w c").cpu().numpy() for view in views] | |
mask = [view["valid_mask"][b].cpu().numpy() for view in views] | |
# Treat all pixels as valid, because we want to render the entire viewpoint | |
if all_points: | |
mask = [np.ones_like(m) for m in mask] | |
# Construct the scene from the 3D points as a point cloud or as a mesh | |
scene = trimesh.Scene() | |
if as_pointcloud: | |
pts = np.concatenate([p[m] for p, m in zip(pts3d, mask)]) | |
col = np.concatenate([p[m] for p, m in zip(imgs, mask)]) | |
pct = trimesh.PointCloud(pts.reshape(-1, 3), colors=col.reshape(-1, 3)) | |
scene.add_geometry(pct) | |
save_path = os.path.join(save_dir, f"{b}.ply") | |
else: | |
meshes = [] | |
for i in range(len(imgs)): | |
meshes.append(pts3d_to_trimesh(imgs[i], pts3d[i], mask[i])) | |
mesh = trimesh.Trimesh(**cat_meshes(meshes)) | |
scene.add_geometry(mesh) | |
save_path = os.path.join(save_dir, f"{b}.glb") | |
# Save the scene | |
scene.export(file_obj=save_path) | |
def log_batch_files(batch, color, depth, mask, view1, view2, pred1, pred2, save_dir, should_save_3d=False): | |
'''Save all the relevant debug files for a batch''' | |
os.makedirs(save_dir, exist_ok=True) | |
# Save the 3D Gaussians as a .ply file | |
save_as_ply(pred1, pred2, os.path.join(save_dir, f"gaussians.ply")) | |
# Save the 3D points as a point cloud and as a mesh (disabled) | |
if should_save_3d: | |
save_3d(view1, view2, pred1, pred2, os.path.join(save_dir, "3d_mesh"), as_pointcloud=False) | |
save_3d(view1, view2, pred1, pred2, os.path.join(save_dir, "3d_pointcloud"), as_pointcloud=True) | |
# Save the color, depth and valid masks for the input context images | |
context_images = torch.stack([view["img"] for view in batch["context"]], dim=1) | |
context_original_images = torch.stack([view["original_img"] for view in batch["context"]], dim=1) | |
context_depthmaps = torch.stack([view["depthmap"] for view in batch["context"]], dim=1) | |
context_valid_masks = torch.stack([view["valid_mask"] for view in batch["context"]], dim=1) | |
for b in range(min(context_images.shape[0], 4)): | |
torchvision.utils.save_image(context_images[b], os.path.join(save_dir, f"sample_{b}_img_context.jpg")) | |
torchvision.utils.save_image(context_original_images[b], os.path.join(save_dir, f"sample_{b}_original_img_context.jpg")) | |
torchvision.utils.save_image(context_depthmaps[b, :, None, ...], os.path.join(save_dir, f"sample_{b}_depthmap.jpg"), normalize=True) | |
torchvision.utils.save_image(context_valid_masks[b, :, None, ...].float(), os.path.join(save_dir, f"sample_{b}_valid_mask_context.jpg"), normalize=True) | |
# Save the color and depth images for the target images | |
target_original_images = torch.stack([view["original_img"] for view in batch["target"]], dim=1) | |
target_depthmaps = torch.stack([view["depthmap"] for view in batch["target"]], dim=1) | |
context_valid_masks = torch.stack([view["valid_mask"] for view in batch["context"]], dim=1) | |
for b in range(min(target_original_images.shape[0], 4)): | |
torchvision.utils.save_image(target_original_images[b], os.path.join(save_dir, f"sample_{b}_original_img_target.jpg")) | |
torchvision.utils.save_image(target_depthmaps[b, :, None, ...], os.path.join(save_dir, f"sample_{b}_depthmap_target.jpg"), normalize=True) | |
# Save the rendered images and depths | |
for b in range(min(color.shape[0], 4)): | |
torchvision.utils.save_image(color[b, ...], os.path.join(save_dir, f"sample_{b}_rendered_color.jpg")) | |
if depth is not None: | |
for b in range(min(color.shape[0], 4)): | |
torchvision.utils.save_image(depth[b, :, None, ...], os.path.join(save_dir, f"sample_{b}_rendered_depth.jpg"), normalize=True) | |
# Save the loss masks | |
for b in range(min(mask.shape[0], 4)): | |
torchvision.utils.save_image(mask[b, :, None, ...].float(), os.path.join(save_dir, f"sample_{b}_loss_mask.jpg"), normalize=True) | |
# Save the masked target and rendered images | |
target_original_images = torch.stack([view["original_img"] for view in batch["target"]], dim=1) | |
masked_target_original_images = target_original_images * mask[..., None, :, :] | |
masked_predictions = color * mask[..., None, :, :] | |
for b in range(min(target_original_images.shape[0], 4)): | |
torchvision.utils.save_image(masked_target_original_images[b], os.path.join(save_dir, f"sample_{b}_masked_original_img_target.jpg")) | |
torchvision.utils.save_image(masked_predictions[b], os.path.join(save_dir, f"sample_{b}_masked_rendered_color.jpg")) |