jadechoghari commited on
Commit
cd769b1
·
verified ·
1 Parent(s): 82e91bf

Create synthesizer_part.py

Browse files
Files changed (1) hide show
  1. synthesizer_part.py +194 -0
synthesizer_part.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+
8
+ import itertools
9
+ import torch
10
+ import torch.nn as nn
11
+
12
+ from .renderer import ImportanceRenderer
13
+ from .ray_sampler_part import RaySampler
14
+
15
+
16
+ class OSGDecoder(nn.Module):
17
+ """
18
+ Triplane decoder that gives RGB and sigma values from sampled features.
19
+ Using ReLU here instead of Softplus in the original implementation.
20
+
21
+ Reference:
22
+ EG3D: https://github.com/NVlabs/eg3d/blob/main/eg3d/training/triplane.py#L112
23
+ """
24
+ def __init__(self, n_features: int,
25
+ hidden_dim: int = 64, num_layers: int = 4, activation: nn.Module = nn.ReLU):
26
+ super().__init__()
27
+ self.net = nn.Sequential(
28
+ nn.Linear(3 * n_features, hidden_dim),
29
+ activation(),
30
+ *itertools.chain(*[[
31
+ nn.Linear(hidden_dim, hidden_dim),
32
+ activation(),
33
+ ] for _ in range(num_layers - 2)]),
34
+ nn.Linear(hidden_dim, 1 + 3),
35
+ )
36
+ # init all bias to zero
37
+ for m in self.modules():
38
+ if isinstance(m, nn.Linear):
39
+ nn.init.zeros_(m.bias)
40
+
41
+ def forward(self, sampled_features, ray_directions):
42
+ # Aggregate features by mean
43
+ # sampled_features = sampled_features.mean(1)
44
+ # Aggregate features by concatenation
45
+ _N, n_planes, _M, _C = sampled_features.shape
46
+ sampled_features = sampled_features.permute(0, 2, 1, 3).reshape(_N, _M, n_planes*_C)
47
+ x = sampled_features
48
+
49
+ N, M, C = x.shape
50
+ x = x.contiguous().view(N*M, C)
51
+
52
+ x = self.net(x)
53
+ x = x.view(N, M, -1)
54
+ rgb = torch.sigmoid(x[..., 1:])*(1 + 2*0.001) - 0.001 # Uses sigmoid clamping from MipNeRF
55
+ sigma = x[..., 0:1]
56
+
57
+ return {'rgb': rgb, 'sigma': sigma}
58
+
59
+
60
+ class TriplaneSynthesizer(nn.Module):
61
+ """
62
+ Synthesizer that renders a triplane volume with planes and a camera.
63
+
64
+ Reference:
65
+ EG3D: https://github.com/NVlabs/eg3d/blob/main/eg3d/training/triplane.py#L19
66
+ """
67
+
68
+ DEFAULT_RENDERING_KWARGS = {
69
+ 'ray_start': 'auto',
70
+ 'ray_end': 'auto',
71
+ 'box_warp': 2.,
72
+ 'white_back': True,
73
+ 'disparity_space_sampling': False,
74
+ 'clamp_mode': 'softplus',
75
+ 'sampler_bbox_min': -1.,
76
+ 'sampler_bbox_max': 1.,
77
+ }
78
+
79
+ def __init__(self, triplane_dim: int, samples_per_ray: int):
80
+ super().__init__()
81
+
82
+ # attributes
83
+ self.triplane_dim = triplane_dim
84
+ self.rendering_kwargs = {
85
+ **self.DEFAULT_RENDERING_KWARGS,
86
+ 'depth_resolution': samples_per_ray // 2,
87
+ 'depth_resolution_importance': samples_per_ray // 2,
88
+ }
89
+
90
+ # renderings
91
+ self.renderer = ImportanceRenderer()
92
+ self.ray_sampler = RaySampler()
93
+
94
+ # modules
95
+ self.decoder = OSGDecoder(n_features=triplane_dim)
96
+
97
+ def forward(self, planes, cameras, render_size: int, crop_size: int, start_x: int, start_y:int):
98
+ # planes: (N, 3, D', H', W')
99
+ # cameras: (N, M, D_cam)
100
+ # render_size: int
101
+ assert planes.shape[0] == cameras.shape[0], "Batch size mismatch for planes and cameras"
102
+ N, M = cameras.shape[:2]
103
+ cam2world_matrix = cameras[..., :16].view(N, M, 4, 4)
104
+ intrinsics = cameras[..., 16:25].view(N, M, 3, 3)
105
+
106
+ # Create a batch of rays for volume rendering
107
+ ray_origins, ray_directions = self.ray_sampler(
108
+ cam2world_matrix=cam2world_matrix.reshape(-1, 4, 4),
109
+ intrinsics=intrinsics.reshape(-1, 3, 3),
110
+ render_size=render_size,
111
+ crop_size = crop_size,
112
+ start_x = start_x,
113
+ start_y = start_y
114
+ )
115
+ assert N*M == ray_origins.shape[0], "Batch size mismatch for ray_origins"
116
+ assert ray_origins.dim() == 3, "ray_origins should be 3-dimensional"
117
+ # Perform volume rendering
118
+ rgb_samples, depth_samples, weights_samples = self.renderer(
119
+ planes.repeat_interleave(M, dim=0), self.decoder, ray_origins, ray_directions, self.rendering_kwargs,
120
+ )
121
+
122
+ # Reshape into 'raw' neural-rendered image
123
+ Himg = Wimg = crop_size
124
+ rgb_images = rgb_samples.permute(0, 2, 1).reshape(N, M, rgb_samples.shape[-1], Himg, Wimg).contiguous()
125
+ depth_images = depth_samples.permute(0, 2, 1).reshape(N, M, 1, Himg, Wimg)
126
+ weight_images = weights_samples.permute(0, 2, 1).reshape(N, M, 1, Himg, Wimg)
127
+
128
+ return {
129
+ 'images_rgb': rgb_images,
130
+ 'images_depth': depth_images,
131
+ 'images_weight': weight_images,
132
+ }
133
+
134
+ def forward_grid(self, planes, grid_size: int, aabb: torch.Tensor = None):
135
+ # planes: (N, 3, D', H', W')
136
+ # grid_size: int
137
+ # aabb: (N, 2, 3)
138
+ if aabb is None:
139
+ aabb = torch.tensor([
140
+ [self.rendering_kwargs['sampler_bbox_min']] * 3,
141
+ [self.rendering_kwargs['sampler_bbox_max']] * 3,
142
+ ], device=planes.device, dtype=planes.dtype).unsqueeze(0).repeat(planes.shape[0], 1, 1)
143
+ assert planes.shape[0] == aabb.shape[0], "Batch size mismatch for planes and aabb"
144
+ N = planes.shape[0]
145
+
146
+ # create grid points for triplane query
147
+ grid_points = []
148
+ for i in range(N):
149
+ grid_points.append(torch.stack(torch.meshgrid(
150
+ torch.linspace(aabb[i, 0, 0], aabb[i, 1, 0], grid_size, device=planes.device),
151
+ torch.linspace(aabb[i, 0, 1], aabb[i, 1, 1], grid_size, device=planes.device),
152
+ torch.linspace(aabb[i, 0, 2], aabb[i, 1, 2], grid_size, device=planes.device),
153
+ indexing='ij',
154
+ ), dim=-1).reshape(-1, 3))
155
+ cube_grid = torch.stack(grid_points, dim=0).to(planes.device)
156
+
157
+ features = self.forward_points(planes, cube_grid)
158
+
159
+ # reshape into grid
160
+ features = {
161
+ k: v.reshape(N, grid_size, grid_size, grid_size, -1)
162
+ for k, v in features.items()
163
+ }
164
+ return features
165
+
166
+ def forward_points(self, planes, points: torch.Tensor, chunk_size: int = 2**20):
167
+ # planes: (N, 3, D', H', W')
168
+ # points: (N, P, 3)
169
+ N, P = points.shape[:2]
170
+
171
+ # query triplane in chunks
172
+ outs = []
173
+ for i in range(0, points.shape[1], chunk_size):
174
+ chunk_points = points[:, i:i+chunk_size]
175
+
176
+ # query triplane
177
+ chunk_out = self.renderer.run_model_activated(
178
+ planes=planes,
179
+ decoder=self.decoder,
180
+ sample_coordinates=chunk_points,
181
+ sample_directions=torch.zeros_like(chunk_points),
182
+ options=self.rendering_kwargs,
183
+ )
184
+ outs.append(chunk_out)
185
+
186
+ # concatenate the outputs
187
+ point_features = {
188
+ k: torch.cat([out[k] for out in outs], dim=1)
189
+ for k in outs[0].keys()
190
+ }
191
+
192
+ sig = point_features['sigma']
193
+ print(sig.mean(), sig.max(), sig.min())
194
+ return point_features