jadechoghari
commited on
Create synthesizer_part.py
Browse files- synthesizer_part.py +194 -0
synthesizer_part.py
ADDED
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
|
8 |
+
import itertools
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
|
12 |
+
from .renderer import ImportanceRenderer
|
13 |
+
from .ray_sampler_part import RaySampler
|
14 |
+
|
15 |
+
|
16 |
+
class OSGDecoder(nn.Module):
|
17 |
+
"""
|
18 |
+
Triplane decoder that gives RGB and sigma values from sampled features.
|
19 |
+
Using ReLU here instead of Softplus in the original implementation.
|
20 |
+
|
21 |
+
Reference:
|
22 |
+
EG3D: https://github.com/NVlabs/eg3d/blob/main/eg3d/training/triplane.py#L112
|
23 |
+
"""
|
24 |
+
def __init__(self, n_features: int,
|
25 |
+
hidden_dim: int = 64, num_layers: int = 4, activation: nn.Module = nn.ReLU):
|
26 |
+
super().__init__()
|
27 |
+
self.net = nn.Sequential(
|
28 |
+
nn.Linear(3 * n_features, hidden_dim),
|
29 |
+
activation(),
|
30 |
+
*itertools.chain(*[[
|
31 |
+
nn.Linear(hidden_dim, hidden_dim),
|
32 |
+
activation(),
|
33 |
+
] for _ in range(num_layers - 2)]),
|
34 |
+
nn.Linear(hidden_dim, 1 + 3),
|
35 |
+
)
|
36 |
+
# init all bias to zero
|
37 |
+
for m in self.modules():
|
38 |
+
if isinstance(m, nn.Linear):
|
39 |
+
nn.init.zeros_(m.bias)
|
40 |
+
|
41 |
+
def forward(self, sampled_features, ray_directions):
|
42 |
+
# Aggregate features by mean
|
43 |
+
# sampled_features = sampled_features.mean(1)
|
44 |
+
# Aggregate features by concatenation
|
45 |
+
_N, n_planes, _M, _C = sampled_features.shape
|
46 |
+
sampled_features = sampled_features.permute(0, 2, 1, 3).reshape(_N, _M, n_planes*_C)
|
47 |
+
x = sampled_features
|
48 |
+
|
49 |
+
N, M, C = x.shape
|
50 |
+
x = x.contiguous().view(N*M, C)
|
51 |
+
|
52 |
+
x = self.net(x)
|
53 |
+
x = x.view(N, M, -1)
|
54 |
+
rgb = torch.sigmoid(x[..., 1:])*(1 + 2*0.001) - 0.001 # Uses sigmoid clamping from MipNeRF
|
55 |
+
sigma = x[..., 0:1]
|
56 |
+
|
57 |
+
return {'rgb': rgb, 'sigma': sigma}
|
58 |
+
|
59 |
+
|
60 |
+
class TriplaneSynthesizer(nn.Module):
|
61 |
+
"""
|
62 |
+
Synthesizer that renders a triplane volume with planes and a camera.
|
63 |
+
|
64 |
+
Reference:
|
65 |
+
EG3D: https://github.com/NVlabs/eg3d/blob/main/eg3d/training/triplane.py#L19
|
66 |
+
"""
|
67 |
+
|
68 |
+
DEFAULT_RENDERING_KWARGS = {
|
69 |
+
'ray_start': 'auto',
|
70 |
+
'ray_end': 'auto',
|
71 |
+
'box_warp': 2.,
|
72 |
+
'white_back': True,
|
73 |
+
'disparity_space_sampling': False,
|
74 |
+
'clamp_mode': 'softplus',
|
75 |
+
'sampler_bbox_min': -1.,
|
76 |
+
'sampler_bbox_max': 1.,
|
77 |
+
}
|
78 |
+
|
79 |
+
def __init__(self, triplane_dim: int, samples_per_ray: int):
|
80 |
+
super().__init__()
|
81 |
+
|
82 |
+
# attributes
|
83 |
+
self.triplane_dim = triplane_dim
|
84 |
+
self.rendering_kwargs = {
|
85 |
+
**self.DEFAULT_RENDERING_KWARGS,
|
86 |
+
'depth_resolution': samples_per_ray // 2,
|
87 |
+
'depth_resolution_importance': samples_per_ray // 2,
|
88 |
+
}
|
89 |
+
|
90 |
+
# renderings
|
91 |
+
self.renderer = ImportanceRenderer()
|
92 |
+
self.ray_sampler = RaySampler()
|
93 |
+
|
94 |
+
# modules
|
95 |
+
self.decoder = OSGDecoder(n_features=triplane_dim)
|
96 |
+
|
97 |
+
def forward(self, planes, cameras, render_size: int, crop_size: int, start_x: int, start_y:int):
|
98 |
+
# planes: (N, 3, D', H', W')
|
99 |
+
# cameras: (N, M, D_cam)
|
100 |
+
# render_size: int
|
101 |
+
assert planes.shape[0] == cameras.shape[0], "Batch size mismatch for planes and cameras"
|
102 |
+
N, M = cameras.shape[:2]
|
103 |
+
cam2world_matrix = cameras[..., :16].view(N, M, 4, 4)
|
104 |
+
intrinsics = cameras[..., 16:25].view(N, M, 3, 3)
|
105 |
+
|
106 |
+
# Create a batch of rays for volume rendering
|
107 |
+
ray_origins, ray_directions = self.ray_sampler(
|
108 |
+
cam2world_matrix=cam2world_matrix.reshape(-1, 4, 4),
|
109 |
+
intrinsics=intrinsics.reshape(-1, 3, 3),
|
110 |
+
render_size=render_size,
|
111 |
+
crop_size = crop_size,
|
112 |
+
start_x = start_x,
|
113 |
+
start_y = start_y
|
114 |
+
)
|
115 |
+
assert N*M == ray_origins.shape[0], "Batch size mismatch for ray_origins"
|
116 |
+
assert ray_origins.dim() == 3, "ray_origins should be 3-dimensional"
|
117 |
+
# Perform volume rendering
|
118 |
+
rgb_samples, depth_samples, weights_samples = self.renderer(
|
119 |
+
planes.repeat_interleave(M, dim=0), self.decoder, ray_origins, ray_directions, self.rendering_kwargs,
|
120 |
+
)
|
121 |
+
|
122 |
+
# Reshape into 'raw' neural-rendered image
|
123 |
+
Himg = Wimg = crop_size
|
124 |
+
rgb_images = rgb_samples.permute(0, 2, 1).reshape(N, M, rgb_samples.shape[-1], Himg, Wimg).contiguous()
|
125 |
+
depth_images = depth_samples.permute(0, 2, 1).reshape(N, M, 1, Himg, Wimg)
|
126 |
+
weight_images = weights_samples.permute(0, 2, 1).reshape(N, M, 1, Himg, Wimg)
|
127 |
+
|
128 |
+
return {
|
129 |
+
'images_rgb': rgb_images,
|
130 |
+
'images_depth': depth_images,
|
131 |
+
'images_weight': weight_images,
|
132 |
+
}
|
133 |
+
|
134 |
+
def forward_grid(self, planes, grid_size: int, aabb: torch.Tensor = None):
|
135 |
+
# planes: (N, 3, D', H', W')
|
136 |
+
# grid_size: int
|
137 |
+
# aabb: (N, 2, 3)
|
138 |
+
if aabb is None:
|
139 |
+
aabb = torch.tensor([
|
140 |
+
[self.rendering_kwargs['sampler_bbox_min']] * 3,
|
141 |
+
[self.rendering_kwargs['sampler_bbox_max']] * 3,
|
142 |
+
], device=planes.device, dtype=planes.dtype).unsqueeze(0).repeat(planes.shape[0], 1, 1)
|
143 |
+
assert planes.shape[0] == aabb.shape[0], "Batch size mismatch for planes and aabb"
|
144 |
+
N = planes.shape[0]
|
145 |
+
|
146 |
+
# create grid points for triplane query
|
147 |
+
grid_points = []
|
148 |
+
for i in range(N):
|
149 |
+
grid_points.append(torch.stack(torch.meshgrid(
|
150 |
+
torch.linspace(aabb[i, 0, 0], aabb[i, 1, 0], grid_size, device=planes.device),
|
151 |
+
torch.linspace(aabb[i, 0, 1], aabb[i, 1, 1], grid_size, device=planes.device),
|
152 |
+
torch.linspace(aabb[i, 0, 2], aabb[i, 1, 2], grid_size, device=planes.device),
|
153 |
+
indexing='ij',
|
154 |
+
), dim=-1).reshape(-1, 3))
|
155 |
+
cube_grid = torch.stack(grid_points, dim=0).to(planes.device)
|
156 |
+
|
157 |
+
features = self.forward_points(planes, cube_grid)
|
158 |
+
|
159 |
+
# reshape into grid
|
160 |
+
features = {
|
161 |
+
k: v.reshape(N, grid_size, grid_size, grid_size, -1)
|
162 |
+
for k, v in features.items()
|
163 |
+
}
|
164 |
+
return features
|
165 |
+
|
166 |
+
def forward_points(self, planes, points: torch.Tensor, chunk_size: int = 2**20):
|
167 |
+
# planes: (N, 3, D', H', W')
|
168 |
+
# points: (N, P, 3)
|
169 |
+
N, P = points.shape[:2]
|
170 |
+
|
171 |
+
# query triplane in chunks
|
172 |
+
outs = []
|
173 |
+
for i in range(0, points.shape[1], chunk_size):
|
174 |
+
chunk_points = points[:, i:i+chunk_size]
|
175 |
+
|
176 |
+
# query triplane
|
177 |
+
chunk_out = self.renderer.run_model_activated(
|
178 |
+
planes=planes,
|
179 |
+
decoder=self.decoder,
|
180 |
+
sample_coordinates=chunk_points,
|
181 |
+
sample_directions=torch.zeros_like(chunk_points),
|
182 |
+
options=self.rendering_kwargs,
|
183 |
+
)
|
184 |
+
outs.append(chunk_out)
|
185 |
+
|
186 |
+
# concatenate the outputs
|
187 |
+
point_features = {
|
188 |
+
k: torch.cat([out[k] for out in outs], dim=1)
|
189 |
+
for k in outs[0].keys()
|
190 |
+
}
|
191 |
+
|
192 |
+
sig = point_features['sigma']
|
193 |
+
print(sig.mean(), sig.max(), sig.min())
|
194 |
+
return point_features
|