jadechoghari
commited on
Create ray_sampler_part.py
Browse files- ray_sampler_part.py +94 -0
ray_sampler_part.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
7 |
+
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
8 |
+
#
|
9 |
+
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
10 |
+
# property and proprietary rights in and to this material, related
|
11 |
+
# documentation and any modifications thereto. Any use, reproduction,
|
12 |
+
# disclosure or distribution of this material and related documentation
|
13 |
+
# without an express license agreement from NVIDIA CORPORATION or
|
14 |
+
# its affiliates is strictly prohibited.
|
15 |
+
#
|
16 |
+
# Modified by Zexin He
|
17 |
+
# The modifications are subject to the same license as the original.
|
18 |
+
|
19 |
+
|
20 |
+
"""
|
21 |
+
The ray sampler is a module that takes in camera matrices and resolution and batches of rays.
|
22 |
+
Expects cam2world matrices that use the OpenCV camera coordinate system conventions.
|
23 |
+
"""
|
24 |
+
|
25 |
+
import torch
|
26 |
+
|
27 |
+
class RaySampler(torch.nn.Module):
|
28 |
+
def __init__(self):
|
29 |
+
super().__init__()
|
30 |
+
self.ray_origins_h, self.ray_directions, self.depths, self.image_coords, self.rendering_options = None, None, None, None, None
|
31 |
+
|
32 |
+
|
33 |
+
def forward(self, cam2world_matrix, intrinsics, render_size, crop_size, start_x, start_y):
|
34 |
+
"""
|
35 |
+
Create batches of rays and return origins and directions.
|
36 |
+
|
37 |
+
cam2world_matrix: (N, 4, 4)
|
38 |
+
intrinsics: (N, 3, 3)
|
39 |
+
render_size: int
|
40 |
+
|
41 |
+
ray_origins: (N, M, 3)
|
42 |
+
ray_dirs: (N, M, 2)
|
43 |
+
"""
|
44 |
+
|
45 |
+
N, M = cam2world_matrix.shape[0], crop_size**2
|
46 |
+
cam_locs_world = cam2world_matrix[:, :3, 3]
|
47 |
+
fx = intrinsics[:, 0, 0]
|
48 |
+
fy = intrinsics[:, 1, 1]
|
49 |
+
cx = intrinsics[:, 0, 2]
|
50 |
+
cy = intrinsics[:, 1, 2]
|
51 |
+
sk = intrinsics[:, 0, 1]
|
52 |
+
|
53 |
+
uv = torch.stack(torch.meshgrid(
|
54 |
+
torch.arange(render_size, dtype=torch.float32, device=cam2world_matrix.device),
|
55 |
+
torch.arange(render_size, dtype=torch.float32, device=cam2world_matrix.device),
|
56 |
+
indexing='ij',
|
57 |
+
))
|
58 |
+
if crop_size < render_size:
|
59 |
+
patch_uv = []
|
60 |
+
for i in range(cam2world_matrix.shape[0]):
|
61 |
+
patch_uv.append(uv.clone()[None, :, start_y:start_y+crop_size, start_x:start_x+crop_size])
|
62 |
+
uv = torch.cat(patch_uv, 0)
|
63 |
+
uv = uv.flip(1).reshape(cam2world_matrix.shape[0], 2, -1).transpose(2, 1)
|
64 |
+
else:
|
65 |
+
uv = uv.flip(0).reshape(2, -1).transpose(1, 0)
|
66 |
+
uv = uv.unsqueeze(0).repeat(cam2world_matrix.shape[0], 1, 1)
|
67 |
+
# uv = uv.unsqueeze(0).repeat(cam2world_matrix.shape[0], 1, 1)
|
68 |
+
# uv = uv.flip(1).reshape(cam2world_matrix.shape[0], 2, -1).transpose(2, 1)
|
69 |
+
x_cam = uv[:, :, 0].view(N, -1) * (1./render_size) + (0.5/render_size)
|
70 |
+
y_cam = uv[:, :, 1].view(N, -1) * (1./render_size) + (0.5/render_size)
|
71 |
+
z_cam = torch.ones((N, M), device=cam2world_matrix.device)
|
72 |
+
|
73 |
+
x_lift = (x_cam - cx.unsqueeze(-1) + cy.unsqueeze(-1)*sk.unsqueeze(-1)/fy.unsqueeze(-1) - sk.unsqueeze(-1)*y_cam/fy.unsqueeze(-1)) / fx.unsqueeze(-1) * z_cam
|
74 |
+
y_lift = (y_cam - cy.unsqueeze(-1)) / fy.unsqueeze(-1) * z_cam
|
75 |
+
|
76 |
+
cam_rel_points = torch.stack((x_lift, y_lift, z_cam, torch.ones_like(z_cam)), dim=-1).float()
|
77 |
+
|
78 |
+
_opencv2blender = torch.tensor([
|
79 |
+
[1, 0, 0, 0],
|
80 |
+
[0, -1, 0, 0],
|
81 |
+
[0, 0, -1, 0],
|
82 |
+
[0, 0, 0, 1],
|
83 |
+
], dtype=torch.float32, device=cam2world_matrix.device).unsqueeze(0).repeat(N, 1, 1)
|
84 |
+
|
85 |
+
# added float here
|
86 |
+
cam2world_matrix = torch.bmm(cam2world_matrix.float(), _opencv2blender.float())
|
87 |
+
|
88 |
+
world_rel_points = torch.bmm(cam2world_matrix.float(), cam_rel_points.permute(0, 2, 1)).permute(0, 2, 1)[:, :, :3]
|
89 |
+
|
90 |
+
ray_dirs = world_rel_points - cam_locs_world[:, None, :]
|
91 |
+
ray_dirs = torch.nn.functional.normalize(ray_dirs, dim=2)
|
92 |
+
|
93 |
+
ray_origins = cam_locs_world.unsqueeze(1).repeat(1, ray_dirs.shape[1], 1)
|
94 |
+
return ray_origins, ray_dirs
|