# Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import unittest import torch from nerf.raysampler import NeRFRaysampler, ProbabilisticRaysampler from pytorch3d.renderer import PerspectiveCameras from pytorch3d.transforms.rotation_conversions import random_rotations class TestRaysampler(unittest.TestCase): def setUp(self) -> None: torch.manual_seed(42) def test_raysampler_caching(self, batch_size=10): """ Tests the consistency of the NeRF raysampler caching. """ raysampler = NeRFRaysampler( min_x=0.0, max_x=10.0, min_y=0.0, max_y=10.0, n_pts_per_ray=10, min_depth=0.1, max_depth=10.0, n_rays_per_image=12, image_width=10, image_height=10, stratified=False, stratified_test=False, invert_directions=True, ) raysampler.eval() cameras, rays = [], [] for _ in range(batch_size): R = random_rotations(1) T = torch.randn(1, 3) focal_length = torch.rand(1, 2) + 0.5 principal_point = torch.randn(1, 2) camera = PerspectiveCameras( focal_length=focal_length, principal_point=principal_point, R=R, T=T, ) cameras.append(camera) rays.append(raysampler(camera)) raysampler.precache_rays(cameras, list(range(batch_size))) for cam_index, rays_ in enumerate(rays): rays_cached_ = raysampler( cameras=cameras[cam_index], chunksize=None, chunk_idx=0, camera_hash=cam_index, caching=False, ) for v, v_cached in zip(rays_, rays_cached_): self.assertTrue(torch.allclose(v, v_cached)) def test_probabilistic_raysampler(self, batch_size=1, n_pts_per_ray=60): """ Check that the probabilistic ray sampler does not crash for various settings. """ raysampler_grid = NeRFRaysampler( min_x=0.0, max_x=10.0, min_y=0.0, max_y=10.0, n_pts_per_ray=n_pts_per_ray, min_depth=1.0, max_depth=10.0, n_rays_per_image=12, image_width=10, image_height=10, stratified=False, stratified_test=False, invert_directions=True, ) R = random_rotations(batch_size) T = torch.randn(batch_size, 3) focal_length = torch.rand(batch_size, 2) + 0.5 principal_point = torch.randn(batch_size, 2) camera = PerspectiveCameras( focal_length=focal_length, principal_point=principal_point, R=R, T=T, ) raysampler_grid.eval() ray_bundle = raysampler_grid(cameras=camera) ray_weights = torch.rand_like(ray_bundle.lengths) # Just check that we dont crash for all possible settings. for stratified_test in (True, False): for stratified in (True, False): raysampler_prob = ProbabilisticRaysampler( n_pts_per_ray=n_pts_per_ray, stratified=stratified, stratified_test=stratified_test, add_input_samples=True, ) for mode in ("train", "eval"): getattr(raysampler_prob, mode)() for _ in range(10): raysampler_prob(ray_bundle, ray_weights)