splatt3r / data /scannetpp /scannetpp.py
brandonsmart's picture
Initial commit
5ed9923
raw
history blame
No virus
6.9 kB
import json
import logging
import os
import sys
import cv2
import numpy as np
# Add dust3r to the sys.path
sys.path.append('src/dust3r_src')
from data.data import crop_resize_if_necessary, DUST3RSplattingDataset, DUST3RSplattingTestDataset
from src.mast3r_src.dust3r.dust3r.utils.image import imread_cv2
logger = logging.getLogger(__name__)
class ScanNetPPData():
def __init__(self, root, stage):
self.root = root
self.stage = stage
self.png_depth_scale = 1000.0
# Dictionaries to store the data for each scene
self.color_paths = {}
self.depth_paths = {}
self.intrinsics = {}
self.c2ws = {}
# Fetch the sequences to use
if stage == "train":
sequence_file = os.path.join(self.root, "raw", "splits", "nvs_sem_train.txt")
bad_scenes = ['303745abc7']
elif stage == "val" or stage == "test":
sequence_file = os.path.join(self.root, "raw", "splits", "nvs_sem_val.txt")
bad_scenes = ['cc5237fd77']
with open(sequence_file, "r") as f:
self.sequences = f.read().splitlines()
# Remove scenes that have frames with no valid depths
logger.info(f"Removing scenes that have frames with no valid depths: {bad_scenes}")
self.sequences = [s for s in self.sequences if s not in bad_scenes]
P = np.array([
[1, 0, 0, 0],
[0, -1, 0, 0],
[0, 0, -1, 0],
[0, 0, 0, 1]]
).astype(np.float32)
# Collect information for every sequence
scenes_with_no_good_frames = []
for sequence in self.sequences:
input_raw_folder = os.path.join(self.root, 'raw', 'data', sequence)
input_processed_folder = os.path.join(self.root, 'processed', sequence)
# Load Train & Test Splits
frame_file = os.path.join(input_raw_folder, "dslr", "train_test_lists.json")
with open(frame_file, "r") as f:
train_test_list = json.load(f)
# Camera Metadata
cams_metadata_path = f"{input_processed_folder}/dslr/nerfstudio/transforms_undistorted.json"
with open(cams_metadata_path, "r") as f:
cams_metadata = json.load(f)
# Load the nerfstudio/transforms.json file to check whether each image is blurry
nerfstudio_transforms_path = f"{input_raw_folder}/dslr/nerfstudio/transforms.json"
with open(nerfstudio_transforms_path, "r") as f:
nerfstudio_transforms = json.load(f)
# Create a reverse mapping from image name to the frame information and nerfstudio transform
# (as transforms_undistorted.json does not store the frames in the same order as train_test_lists.json)
file_path_to_frame_metadata = {}
file_path_to_nerfstudio_transform = {}
for frame in cams_metadata["frames"]:
file_path_to_frame_metadata[frame["file_path"]] = frame
for frame in nerfstudio_transforms["frames"]:
file_path_to_nerfstudio_transform[frame["file_path"]] = frame
# Fetch the pose for every frame
sequence_color_paths = []
sequence_depth_paths = []
sequence_c2ws = []
for train_file_name in train_test_list["train"]:
is_bad = file_path_to_nerfstudio_transform[train_file_name]["is_bad"]
if is_bad:
continue
sequence_color_paths.append(f"{input_processed_folder}/dslr/undistorted_images/{train_file_name}")
sequence_depth_paths.append(f"{input_processed_folder}/dslr/undistorted_depths/{train_file_name.replace('.JPG', '.png')}")
frame_metadata = file_path_to_frame_metadata[train_file_name]
c2w = np.array(frame_metadata["transform_matrix"], dtype=np.float32)
c2w = P @ c2w @ P.T
sequence_c2ws.append(c2w)
if len(sequence_color_paths) == 0:
logger.info(f"No good frames for sequence: {sequence}")
scenes_with_no_good_frames.append(sequence)
continue
# Get the intrinsics data for the frame
K = np.eye(4, dtype=np.float32)
K[0, 0] = cams_metadata["fl_x"]
K[1, 1] = cams_metadata["fl_y"]
K[0, 2] = cams_metadata["cx"]
K[1, 2] = cams_metadata["cy"]
self.color_paths[sequence] = sequence_color_paths
self.depth_paths[sequence] = sequence_depth_paths
self.c2ws[sequence] = sequence_c2ws
self.intrinsics[sequence] = K
# Remove scenes with no good frames
self.sequences = [s for s in self.sequences if s not in scenes_with_no_good_frames]
def get_view(self, sequence, view_idx, resolution):
# RGB Image
rgb_path = self.color_paths[sequence][view_idx]
rgb_image = imread_cv2(rgb_path)
# Depthmap
depth_path = self.depth_paths[sequence][view_idx]
depthmap = imread_cv2(depth_path, cv2.IMREAD_UNCHANGED)
depthmap = depthmap.astype(np.float32)
depthmap = depthmap / self.png_depth_scale
# C2W Pose
c2w = self.c2ws[sequence][view_idx]
# Camera Intrinsics
intrinsics = self.intrinsics[sequence]
# Resize
rgb_image, depthmap, intrinsics = crop_resize_if_necessary(
rgb_image, depthmap, intrinsics, resolution
)
view = {
'original_img': rgb_image,
'depthmap': depthmap,
'camera_pose': c2w,
'camera_intrinsics': intrinsics,
'dataset': 'scannet++',
'label': f"scannet++/{sequence}",
'instance': f'{view_idx}',
'is_metric_scale': True,
'sky_mask': depthmap <= 0.0,
}
return view
def get_scannet_dataset(root, stage, resolution, num_epochs_per_epoch=1):
data = ScanNetPPData(root, stage)
coverage = {}
for sequence in data.sequences:
with open(f'./data/scannetpp/coverage/{sequence}.json', 'r') as f:
sequence_coverage = json.load(f)
coverage[sequence] = sequence_coverage[sequence]
dataset = DUST3RSplattingDataset(
data,
coverage,
resolution,
num_epochs_per_epoch=num_epochs_per_epoch,
)
return dataset
def get_scannet_test_dataset(root, alpha, beta, resolution, use_every_n_sample=100):
data = ScanNetPPData(root, 'val')
samples_file = f'data/scannetpp/test_set_{alpha}_{beta}.json'
print(f"Loading samples from: {samples_file}")
with open(samples_file, 'r') as f:
samples = json.load(f)
samples = samples[::use_every_n_sample]
dataset = DUST3RSplattingTestDataset(data, samples, resolution)
return dataset