jadechoghari commited on
Commit
7b3f3a7
·
verified ·
1 Parent(s): a99ee1d

Create ray_sampler_part.py

Browse files
Files changed (1) hide show
  1. ray_sampler_part.py +94 -0
ray_sampler_part.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
7
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
8
+ #
9
+ # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
10
+ # property and proprietary rights in and to this material, related
11
+ # documentation and any modifications thereto. Any use, reproduction,
12
+ # disclosure or distribution of this material and related documentation
13
+ # without an express license agreement from NVIDIA CORPORATION or
14
+ # its affiliates is strictly prohibited.
15
+ #
16
+ # Modified by Zexin He
17
+ # The modifications are subject to the same license as the original.
18
+
19
+
20
+ """
21
+ The ray sampler is a module that takes in camera matrices and resolution and batches of rays.
22
+ Expects cam2world matrices that use the OpenCV camera coordinate system conventions.
23
+ """
24
+
25
+ import torch
26
+
27
+ class RaySampler(torch.nn.Module):
28
+ def __init__(self):
29
+ super().__init__()
30
+ self.ray_origins_h, self.ray_directions, self.depths, self.image_coords, self.rendering_options = None, None, None, None, None
31
+
32
+
33
+ def forward(self, cam2world_matrix, intrinsics, render_size, crop_size, start_x, start_y):
34
+ """
35
+ Create batches of rays and return origins and directions.
36
+
37
+ cam2world_matrix: (N, 4, 4)
38
+ intrinsics: (N, 3, 3)
39
+ render_size: int
40
+
41
+ ray_origins: (N, M, 3)
42
+ ray_dirs: (N, M, 2)
43
+ """
44
+
45
+ N, M = cam2world_matrix.shape[0], crop_size**2
46
+ cam_locs_world = cam2world_matrix[:, :3, 3]
47
+ fx = intrinsics[:, 0, 0]
48
+ fy = intrinsics[:, 1, 1]
49
+ cx = intrinsics[:, 0, 2]
50
+ cy = intrinsics[:, 1, 2]
51
+ sk = intrinsics[:, 0, 1]
52
+
53
+ uv = torch.stack(torch.meshgrid(
54
+ torch.arange(render_size, dtype=torch.float32, device=cam2world_matrix.device),
55
+ torch.arange(render_size, dtype=torch.float32, device=cam2world_matrix.device),
56
+ indexing='ij',
57
+ ))
58
+ if crop_size < render_size:
59
+ patch_uv = []
60
+ for i in range(cam2world_matrix.shape[0]):
61
+ patch_uv.append(uv.clone()[None, :, start_y:start_y+crop_size, start_x:start_x+crop_size])
62
+ uv = torch.cat(patch_uv, 0)
63
+ uv = uv.flip(1).reshape(cam2world_matrix.shape[0], 2, -1).transpose(2, 1)
64
+ else:
65
+ uv = uv.flip(0).reshape(2, -1).transpose(1, 0)
66
+ uv = uv.unsqueeze(0).repeat(cam2world_matrix.shape[0], 1, 1)
67
+ # uv = uv.unsqueeze(0).repeat(cam2world_matrix.shape[0], 1, 1)
68
+ # uv = uv.flip(1).reshape(cam2world_matrix.shape[0], 2, -1).transpose(2, 1)
69
+ x_cam = uv[:, :, 0].view(N, -1) * (1./render_size) + (0.5/render_size)
70
+ y_cam = uv[:, :, 1].view(N, -1) * (1./render_size) + (0.5/render_size)
71
+ z_cam = torch.ones((N, M), device=cam2world_matrix.device)
72
+
73
+ x_lift = (x_cam - cx.unsqueeze(-1) + cy.unsqueeze(-1)*sk.unsqueeze(-1)/fy.unsqueeze(-1) - sk.unsqueeze(-1)*y_cam/fy.unsqueeze(-1)) / fx.unsqueeze(-1) * z_cam
74
+ y_lift = (y_cam - cy.unsqueeze(-1)) / fy.unsqueeze(-1) * z_cam
75
+
76
+ cam_rel_points = torch.stack((x_lift, y_lift, z_cam, torch.ones_like(z_cam)), dim=-1).float()
77
+
78
+ _opencv2blender = torch.tensor([
79
+ [1, 0, 0, 0],
80
+ [0, -1, 0, 0],
81
+ [0, 0, -1, 0],
82
+ [0, 0, 0, 1],
83
+ ], dtype=torch.float32, device=cam2world_matrix.device).unsqueeze(0).repeat(N, 1, 1)
84
+
85
+ # added float here
86
+ cam2world_matrix = torch.bmm(cam2world_matrix.float(), _opencv2blender.float())
87
+
88
+ world_rel_points = torch.bmm(cam2world_matrix.float(), cam_rel_points.permute(0, 2, 1)).permute(0, 2, 1)[:, :, :3]
89
+
90
+ ray_dirs = world_rel_points - cam_locs_world[:, None, :]
91
+ ray_dirs = torch.nn.functional.normalize(ray_dirs, dim=2)
92
+
93
+ ray_origins = cam_locs_world.unsqueeze(1).repeat(1, ray_dirs.shape[1], 1)
94
+ return ray_origins, ray_dirs