jadechoghari
commited on
update renderer,
Browse fileswe include math utils functions here to avoid import issues
- renderer.py +103 -2
renderer.py
CHANGED
@@ -29,6 +29,104 @@ import torch.nn.functional as F
|
|
29 |
from .ray_marcher import MipRayMarcher2
|
30 |
from . import math_utils
|
31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
def generate_planes():
|
33 |
"""
|
34 |
Defines planes by the three vectors that form the "axes" of the
|
@@ -47,6 +145,7 @@ def generate_planes():
|
|
47 |
[0, 1, 0],
|
48 |
[1, 0, 0]]], dtype=torch.float32)
|
49 |
|
|
|
50 |
def project_onto_planes(planes, coordinates):
|
51 |
"""
|
52 |
Does a projection of a 3D point onto a batch of 2D planes,
|
@@ -64,6 +163,7 @@ def project_onto_planes(planes, coordinates):
|
|
64 |
projections = torch.bmm(coordinates, inv_planes)
|
65 |
return projections[..., :2]
|
66 |
|
|
|
67 |
def sample_from_planes(plane_axes, plane_features, coordinates, mode='bilinear', padding_mode='zeros', box_warp=None):
|
68 |
assert padding_mode == 'zeros'
|
69 |
N, n_planes, C, H, W = plane_features.shape
|
@@ -77,6 +177,7 @@ def sample_from_planes(plane_axes, plane_features, coordinates, mode='bilinear',
|
|
77 |
output_features = torch.nn.functional.grid_sample(plane_features.float(), projected_coordinates.float(), mode=mode, padding_mode=padding_mode, align_corners=False).permute(0, 3, 2, 1).reshape(N, n_planes, M, C)
|
78 |
return output_features
|
79 |
|
|
|
80 |
def sample_from_3dgrid(grid, coordinates):
|
81 |
"""
|
82 |
Expects coordinates in shape (batch_size, num_points_per_batch, 3)
|
@@ -156,7 +257,7 @@ class ImportanceRenderer(torch.nn.Module):
|
|
156 |
# self.plane_axes = self.plane_axes.to(ray_origins.device)
|
157 |
|
158 |
if rendering_options['ray_start'] == rendering_options['ray_end'] == 'auto':
|
159 |
-
ray_start, ray_end =
|
160 |
is_ray_valid = ray_end > ray_start
|
161 |
if torch.any(is_ray_valid).item():
|
162 |
ray_start[~is_ray_valid] = ray_start[is_ray_valid].min()
|
@@ -242,7 +343,7 @@ class ImportanceRenderer(torch.nn.Module):
|
|
242 |
depths_coarse = 1./(1./ray_start * (1. - depths_coarse) + 1./ray_end * depths_coarse)
|
243 |
else:
|
244 |
if type(ray_start) == torch.Tensor:
|
245 |
-
depths_coarse =
|
246 |
depth_delta = (ray_end - ray_start) / (depth_resolution - 1)
|
247 |
depths_coarse += torch.rand_like(depths_coarse) * depth_delta[..., None]
|
248 |
else:
|
|
|
29 |
from .ray_marcher import MipRayMarcher2
|
30 |
from . import math_utils
|
31 |
|
32 |
+
# Copied from .math_utils.transform_vectors
|
33 |
+
def transform_vectors(matrix: torch.Tensor, vectors4: torch.Tensor) -> torch.Tensor:
|
34 |
+
"""
|
35 |
+
Left-multiplies MxM @ NxM. Returns NxM.
|
36 |
+
"""
|
37 |
+
res = torch.matmul(vectors4, matrix.T)
|
38 |
+
return res
|
39 |
+
|
40 |
+
# Copied from .math_utils.normalize_vecs
|
41 |
+
def normalize_vecs(vectors: torch.Tensor) -> torch.Tensor:
|
42 |
+
"""
|
43 |
+
Normalize vector lengths.
|
44 |
+
"""
|
45 |
+
return vectors / (torch.norm(vectors, dim=-1, keepdim=True))
|
46 |
+
|
47 |
+
# Copied from .math_utils.torch_dot
|
48 |
+
def torch_dot(x: torch.Tensor, y: torch.Tensor):
|
49 |
+
"""
|
50 |
+
Dot product of two tensors.
|
51 |
+
"""
|
52 |
+
return (x * y).sum(-1)
|
53 |
+
|
54 |
+
# Copied from .math_utils.get_ray_limits_box
|
55 |
+
def get_ray_limits_box(rays_o: torch.Tensor, rays_d: torch.Tensor, box_side_length):
|
56 |
+
"""
|
57 |
+
Author: Petr Kellnhofer
|
58 |
+
Intersects rays with the [-1, 1] NDC volume.
|
59 |
+
Returns min and max distance of entry.
|
60 |
+
Returns -1 for no intersection.
|
61 |
+
https://www.scratchapixel.com/lessons/3d-basic-rendering/minimal-ray-tracer-rendering-simple-shapes/ray-box-intersection
|
62 |
+
"""
|
63 |
+
o_shape = rays_o.shape
|
64 |
+
rays_o = rays_o.detach().reshape(-1, 3)
|
65 |
+
rays_d = rays_d.detach().reshape(-1, 3)
|
66 |
+
|
67 |
+
|
68 |
+
bb_min = [-1*(box_side_length/2), -1*(box_side_length/2), -1*(box_side_length/2)]
|
69 |
+
bb_max = [1*(box_side_length/2), 1*(box_side_length/2), 1*(box_side_length/2)]
|
70 |
+
bounds = torch.tensor([bb_min, bb_max], dtype=rays_o.dtype, device=rays_o.device)
|
71 |
+
is_valid = torch.ones(rays_o.shape[:-1], dtype=bool, device=rays_o.device)
|
72 |
+
|
73 |
+
# Precompute inverse for stability.
|
74 |
+
invdir = 1 / rays_d
|
75 |
+
sign = (invdir < 0).long()
|
76 |
+
|
77 |
+
# Intersect with YZ plane.
|
78 |
+
tmin = (bounds.index_select(0, sign[..., 0])[..., 0] - rays_o[..., 0]) * invdir[..., 0]
|
79 |
+
tmax = (bounds.index_select(0, 1 - sign[..., 0])[..., 0] - rays_o[..., 0]) * invdir[..., 0]
|
80 |
+
|
81 |
+
# Intersect with XZ plane.
|
82 |
+
tymin = (bounds.index_select(0, sign[..., 1])[..., 1] - rays_o[..., 1]) * invdir[..., 1]
|
83 |
+
tymax = (bounds.index_select(0, 1 - sign[..., 1])[..., 1] - rays_o[..., 1]) * invdir[..., 1]
|
84 |
+
|
85 |
+
# Resolve parallel rays.
|
86 |
+
is_valid[torch.logical_or(tmin > tymax, tymin > tmax)] = False
|
87 |
+
|
88 |
+
# Use the shortest intersection.
|
89 |
+
tmin = torch.max(tmin, tymin)
|
90 |
+
tmax = torch.min(tmax, tymax)
|
91 |
+
|
92 |
+
# Intersect with XY plane.
|
93 |
+
tzmin = (bounds.index_select(0, sign[..., 2])[..., 2] - rays_o[..., 2]) * invdir[..., 2]
|
94 |
+
tzmax = (bounds.index_select(0, 1 - sign[..., 2])[..., 2] - rays_o[..., 2]) * invdir[..., 2]
|
95 |
+
|
96 |
+
# Resolve parallel rays.
|
97 |
+
is_valid[torch.logical_or(tmin > tzmax, tzmin > tmax)] = False
|
98 |
+
|
99 |
+
# Use the shortest intersection.
|
100 |
+
tmin = torch.max(tmin, tzmin)
|
101 |
+
tmax = torch.min(tmax, tzmax)
|
102 |
+
|
103 |
+
# Mark invalid.
|
104 |
+
tmin[torch.logical_not(is_valid)] = -1
|
105 |
+
tmax[torch.logical_not(is_valid)] = -2
|
106 |
+
|
107 |
+
return tmin.reshape(*o_shape[:-1], 1), tmax.reshape(*o_shape[:-1], 1)
|
108 |
+
|
109 |
+
# Copied from .math_utils.linspace
|
110 |
+
def linspace(start: torch.Tensor, stop: torch.Tensor, num: int):
|
111 |
+
"""
|
112 |
+
Creates a tensor of shape [num, *start.shape] whose values are evenly spaced from start to end, inclusive.
|
113 |
+
Replicates but the multi-dimensional bahaviour of numpy.linspace in PyTorch.
|
114 |
+
"""
|
115 |
+
# create a tensor of 'num' steps from 0 to 1
|
116 |
+
steps = torch.arange(num, dtype=torch.float32, device=start.device) / (num - 1)
|
117 |
+
|
118 |
+
# reshape the 'steps' tensor to [-1, *([1]*start.ndim)] to allow for broadcastings
|
119 |
+
# - using 'steps.reshape([-1, *([1]*start.ndim)])' would be nice here but torchscript
|
120 |
+
# "cannot statically infer the expected size of a list in this contex", hence the code below
|
121 |
+
for i in range(start.ndim):
|
122 |
+
steps = steps.unsqueeze(-1)
|
123 |
+
|
124 |
+
# the output starts at 'start' and increments until 'stop' in each dimension
|
125 |
+
out = start[None] + steps * (stop - start)[None]
|
126 |
+
|
127 |
+
return out
|
128 |
+
|
129 |
+
# Copied from .math_utils.generate_planes
|
130 |
def generate_planes():
|
131 |
"""
|
132 |
Defines planes by the three vectors that form the "axes" of the
|
|
|
145 |
[0, 1, 0],
|
146 |
[1, 0, 0]]], dtype=torch.float32)
|
147 |
|
148 |
+
# Copied from .math_utils.project_onto_planes
|
149 |
def project_onto_planes(planes, coordinates):
|
150 |
"""
|
151 |
Does a projection of a 3D point onto a batch of 2D planes,
|
|
|
163 |
projections = torch.bmm(coordinates, inv_planes)
|
164 |
return projections[..., :2]
|
165 |
|
166 |
+
# Copied from .math_utils.sample_from_planes
|
167 |
def sample_from_planes(plane_axes, plane_features, coordinates, mode='bilinear', padding_mode='zeros', box_warp=None):
|
168 |
assert padding_mode == 'zeros'
|
169 |
N, n_planes, C, H, W = plane_features.shape
|
|
|
177 |
output_features = torch.nn.functional.grid_sample(plane_features.float(), projected_coordinates.float(), mode=mode, padding_mode=padding_mode, align_corners=False).permute(0, 3, 2, 1).reshape(N, n_planes, M, C)
|
178 |
return output_features
|
179 |
|
180 |
+
# Copied from .math_utils.sample_from_3dgrid
|
181 |
def sample_from_3dgrid(grid, coordinates):
|
182 |
"""
|
183 |
Expects coordinates in shape (batch_size, num_points_per_batch, 3)
|
|
|
257 |
# self.plane_axes = self.plane_axes.to(ray_origins.device)
|
258 |
|
259 |
if rendering_options['ray_start'] == rendering_options['ray_end'] == 'auto':
|
260 |
+
ray_start, ray_end = get_ray_limits_box(ray_origins, ray_directions, box_side_length=rendering_options['box_warp'])
|
261 |
is_ray_valid = ray_end > ray_start
|
262 |
if torch.any(is_ray_valid).item():
|
263 |
ray_start[~is_ray_valid] = ray_start[is_ray_valid].min()
|
|
|
343 |
depths_coarse = 1./(1./ray_start * (1. - depths_coarse) + 1./ray_end * depths_coarse)
|
344 |
else:
|
345 |
if type(ray_start) == torch.Tensor:
|
346 |
+
depths_coarse = linspace(ray_start, ray_end, depth_resolution).permute(1,2,0,3)
|
347 |
depth_delta = (ray_end - ray_start) / (depth_resolution - 1)
|
348 |
depths_coarse += torch.rand_like(depths_coarse) * depth_delta[..., None]
|
349 |
else:
|