gokaygokay's picture
Upload 93 files
0a88b62 verified
# ORIGINAL LICENSE
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Modified by Zexin He
# The modifications are subject to the same license as the original.
import itertools
import torch
import torch.nn as nn
import torch.nn.functional as F
from .utils.renderer import ImportanceRenderer, sample_from_planes
from .utils.ray_sampler import RaySampler
from ...utils.ops import get_rank
class OSGDecoder(nn.Module):
"""
Triplane decoder that gives RGB and sigma values from sampled features.
Using ReLU here instead of Softplus in the original implementation.
Reference:
EG3D: https://github.com/NVlabs/eg3d/blob/main/eg3d/training/triplane.py#L112
"""
def __init__(self, n_features: int,
hidden_dim: int = 64,
num_layers: int = 2,
activation: nn.Module = nn.ReLU,
sdf_bias='sphere',
sdf_bias_params=0.5,
output_normal=True,
normal_type='finite_difference'):
super().__init__()
self.sdf_bias = sdf_bias
self.sdf_bias_params = sdf_bias_params
self.output_normal = output_normal
self.normal_type = normal_type
self.net = nn.Sequential(
nn.Linear(3 * n_features, hidden_dim),
activation(),
*itertools.chain(*[[
nn.Linear(hidden_dim, hidden_dim),
activation(),
] for _ in range(num_layers - 2)]),
nn.Linear(hidden_dim, 1 + 3),
)
# init all bias to zero
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.zeros_(m.bias)
def forward(self, ray_directions, sample_coordinates, plane_axes, planes, options):
# Aggregate features by mean
# sampled_features = sampled_features.mean(1)
# Aggregate features by concatenation
# torch.set_grad_enabled(True)
# sample_coordinates.requires_grad_(True)
sampled_features = sample_from_planes(plane_axes, planes, sample_coordinates, padding_mode='zeros', box_warp=options['box_warp'])
_N, n_planes, _M, _C = sampled_features.shape
sampled_features = sampled_features.permute(0, 2, 1, 3).reshape(_N, _M, n_planes*_C)
x = sampled_features
N, M, C = x.shape
# x = x.contiguous().view(N*M, C)
x = self.net(x)
x = x.view(N, M, -1)
rgb = torch.sigmoid(x[..., 1:])*(1 + 2*0.001) - 0.001 # Uses sigmoid clamping from MipNeRF
sdf = x[..., 0:1]
# import ipdb; ipdb.set_trace()
# print(f'sample_coordinates shape: {sample_coordinates.shape}')
# sdf = self.get_shifted_sdf(sample_coordinates, sdf)
# calculate normal
eps = 0.01
offsets = torch.as_tensor(
[[eps, 0.0, 0.0], [0.0, eps, 0.0], [0.0, 0.0, eps]]
).to(sample_coordinates)
points_offset = (
sample_coordinates[..., None, :] + offsets # Float[Tensor, "... 3 3"]
).clamp(options['sampler_bbox_min'], options['sampler_bbox_max'])
sdf_offset_list = [self.forward_sdf(
plane_axes,
planes,
points_offset[:,:,i,:],
options
).unsqueeze(-2) for i in range(points_offset.shape[-2])] # Float[Tensor, "... 3 1"]
# import ipdb; ipdb.set_trace()
sdf_offset = torch.cat(sdf_offset_list, -2)
sdf_grad = (sdf_offset[..., 0::1, 0] - sdf) / eps
normal = F.normalize(sdf_grad, dim=-1).to(sdf.dtype)
return {'rgb': rgb, 'sdf': sdf, 'normal': normal, 'sdf_grad': sdf_grad}
def forward_sdf(self, plane_axes, planes, points_offset, options):
sampled_features = sample_from_planes(plane_axes, planes, points_offset, padding_mode='zeros', box_warp=options['box_warp'])
_N, n_planes, _M, _C = sampled_features.shape
sampled_features = sampled_features.permute(0, 2, 1, 3).reshape(_N, _M, n_planes*_C)
x = sampled_features
N, M, C = x.shape
# x = x.contiguous().view(N*M, C)
x = self.net(x)
x = x.view(N, M, -1)
sdf = x[..., 0:1]
# sdf = self.get_shifted_sdf(points_offset, sdf)
return sdf
def get_shifted_sdf(
self, points, sdf
):
if self.sdf_bias == "sphere":
assert isinstance(self.sdf_bias_params, float)
radius = self.sdf_bias_params
sdf_bias = (points**2).sum(dim=-1, keepdim=True).sqrt() - radius
else:
raise ValueError(f"Unknown sdf bias {self.cfg.sdf_bias}")
return sdf + sdf_bias.to(sdf.dtype)
class TriplaneSynthesizer(nn.Module):
"""
Synthesizer that renders a triplane volume with planes and a camera.
Reference:
EG3D: https://github.com/NVlabs/eg3d/blob/main/eg3d/training/triplane.py#L19
"""
DEFAULT_RENDERING_KWARGS = {
'ray_start': 'auto',
'ray_end': 'auto',
'box_warp': 1.2,
# 'box_warp': 1.,
'white_back': True,
'disparity_space_sampling': False,
'clamp_mode': 'softplus',
# 'sampler_bbox_min': -1,
# 'sampler_bbox_max': 1.,
'sampler_bbox_min': -0.6,
'sampler_bbox_max': 0.6,
}
print('DEFAULT_RENDERING_KWARGS')
print(DEFAULT_RENDERING_KWARGS)
def __init__(self, triplane_dim: int, samples_per_ray: int, osg_decoder='default'):
super().__init__()
# attributes
self.triplane_dim = triplane_dim
self.rendering_kwargs = {
**self.DEFAULT_RENDERING_KWARGS,
'depth_resolution': samples_per_ray,
'depth_resolution_importance': 0
# 'depth_resolution': samples_per_ray // 2,
# 'depth_resolution_importance': samples_per_ray // 2,
}
# renderings
self.renderer = ImportanceRenderer()
self.ray_sampler = RaySampler()
# modules
if osg_decoder == 'default':
self.decoder = OSGDecoder(n_features=triplane_dim)
else:
raise NotImplementedError
def forward(self, planes, ray_origins, ray_directions, render_size, bgcolor=None):
# planes: (N, 3, D', H', W')
# render_size: int
assert ray_origins.dim() == 3, "ray_origins should be 3-dimensional"
# Perform volume rendering
rgb_samples, depth_samples, weights_samples, sdf_grad, normal_samples = self.renderer(
planes, self.decoder, ray_origins, ray_directions, self.rendering_kwargs, bgcolor
)
N = planes.shape[0]
# zhaohx : add for normals
normal_samples = F.normalize(normal_samples, dim=-1)
normal_samples = (normal_samples + 1.0) / 2.0 # for visualization
normal_samples = torch.lerp(torch.zeros_like(normal_samples), normal_samples, weights_samples)
# Reshape into 'raw' neural-rendered image
Himg = Wimg = render_size
rgb_images = rgb_samples.permute(0, 2, 1).reshape(N, rgb_samples.shape[-1], Himg, Wimg).contiguous()
depth_images = depth_samples.permute(0, 2, 1).reshape(N, 1, Himg, Wimg)
weight_images = weights_samples.permute(0, 2, 1).reshape(N, 1, Himg, Wimg)
# zhaohx : add for normals
normal_images = normal_samples.permute(0, 2, 1).reshape(N, normal_samples.shape[-1], Himg, Wimg).contiguous()
# return {
# 'images_rgb': rgb_images,
# 'images_depth': depth_images,
# 'images_weight': weight_images,
# }
return {
'comp_rgb': rgb_images,
'comp_depth': depth_images,
'opacity': weight_images,
'sdf_grad': sdf_grad,
'comp_normal': normal_images
}
# 输出normal的话在这个return里加
def forward_grid(self, planes, grid_size: int, aabb: torch.Tensor = None):
# planes: (N, 3, D', H', W')
# grid_size: int
# aabb: (N, 2, 3)
if aabb is None:
aabb = torch.tensor([
[self.rendering_kwargs['sampler_bbox_min']] * 3,
[self.rendering_kwargs['sampler_bbox_max']] * 3,
], device=planes.device, dtype=planes.dtype).unsqueeze(0).repeat(planes.shape[0], 1, 1)
assert planes.shape[0] == aabb.shape[0], "Batch size mismatch for planes and aabb"
N = planes.shape[0]
# create grid points for triplane query
grid_points = []
for i in range(N):
grid_points.append(torch.stack(torch.meshgrid(
torch.linspace(aabb[i, 0, 0], aabb[i, 1, 0], grid_size, device=planes.device),
torch.linspace(aabb[i, 0, 1], aabb[i, 1, 1], grid_size, device=planes.device),
torch.linspace(aabb[i, 0, 2], aabb[i, 1, 2], grid_size, device=planes.device),
indexing='ij',
), dim=-1).reshape(-1, 3))
cube_grid = torch.stack(grid_points, dim=0).to(planes.device)
features = self.forward_points(planes, cube_grid)
# reshape into grid
features = {
k: v.reshape(N, grid_size, grid_size, grid_size, -1)
for k, v in features.items()
}
return features
def forward_points(self, planes, points: torch.Tensor, chunk_size: int = 2**20):
# planes: (N, 3, D', H', W')
# points: (N, P, 3)
N, P = points.shape[:2]
# query triplane in chunks
outs = []
for i in range(0, points.shape[1], chunk_size):
chunk_points = points[:, i:i+chunk_size]
# query triplane
# chunk_out = self.renderer.run_model_activated(
chunk_out = self.renderer.run_model(
planes=planes,
decoder=self.decoder,
sample_coordinates=chunk_points,
sample_directions=torch.zeros_like(chunk_points),
options=self.rendering_kwargs,
)
outs.append(chunk_out)
# concatenate the outputs
point_features = {
k: torch.cat([out[k] for out in outs], dim=1)
for k in outs[0].keys()
}
return point_features