vggsfm / vggsfm_code /minipytorch3d /harmonic_embedding.py
JianyuanWang's picture
push
34d287c
raw
history blame
6.79 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-unsafe
from typing import Optional
import torch
class HarmonicEmbedding(torch.nn.Module):
def __init__(
self, n_harmonic_functions: int = 6, omega_0: float = 1.0, logspace: bool = True, append_input: bool = True
) -> None:
"""
The harmonic embedding layer supports the classical
Nerf positional encoding described in
`NeRF <https://arxiv.org/abs/2003.08934>`_
and the integrated position encoding in
`MIP-NeRF <https://arxiv.org/abs/2103.13415>`_.
During the inference you can provide the extra argument `diag_cov`.
If `diag_cov is None`, it converts
rays parametrized with a `ray_bundle` to 3D points by
extending each ray according to the corresponding length.
Then it converts each feature
(i.e. vector along the last dimension) in `x`
into a series of harmonic features `embedding`,
where for each i in range(dim) the following are present
in embedding[...]::
[
sin(f_1*x[..., i]),
sin(f_2*x[..., i]),
...
sin(f_N * x[..., i]),
cos(f_1*x[..., i]),
cos(f_2*x[..., i]),
...
cos(f_N * x[..., i]),
x[..., i], # only present if append_input is True.
]
where N corresponds to `n_harmonic_functions-1`, and f_i is a scalar
denoting the i-th frequency of the harmonic embedding.
If `diag_cov is not None`, it approximates
conical frustums following a ray bundle as gaussians,
defined by x, the means of the gaussians and diag_cov,
the diagonal covariances.
Then it converts each gaussian
into a series of harmonic features `embedding`,
where for each i in range(dim) the following are present
in embedding[...]::
[
sin(f_1*x[..., i]) * exp(0.5 * f_1**2 * diag_cov[..., i,]),
sin(f_2*x[..., i]) * exp(0.5 * f_2**2 * diag_cov[..., i,]),
...
sin(f_N * x[..., i]) * exp(0.5 * f_N**2 * diag_cov[..., i,]),
cos(f_1*x[..., i]) * exp(0.5 * f_1**2 * diag_cov[..., i,]),
cos(f_2*x[..., i]) * exp(0.5 * f_2**2 * diag_cov[..., i,]),,
...
cos(f_N * x[..., i]) * exp(0.5 * f_N**2 * diag_cov[..., i,]),
x[..., i], # only present if append_input is True.
]
where N equals `n_harmonic_functions-1`, and f_i is a scalar
denoting the i-th frequency of the harmonic embedding.
If `logspace==True`, the frequencies `[f_1, ..., f_N]` are
powers of 2:
`f_1, ..., f_N = 2**torch.arange(n_harmonic_functions)`
If `logspace==False`, frequencies are linearly spaced between
`1.0` and `2**(n_harmonic_functions-1)`:
`f_1, ..., f_N = torch.linspace(
1.0, 2**(n_harmonic_functions-1), n_harmonic_functions
)`
Note that `x` is also premultiplied by the base frequency `omega_0`
before evaluating the harmonic functions.
Args:
n_harmonic_functions: int, number of harmonic
features
omega_0: float, base frequency
logspace: bool, Whether to space the frequencies in
logspace or linear space
append_input: bool, whether to concat the original
input to the harmonic embedding. If true the
output is of the form (embed.sin(), embed.cos(), x)
"""
super().__init__()
if logspace:
frequencies = 2.0 ** torch.arange(n_harmonic_functions, dtype=torch.float32)
else:
frequencies = torch.linspace(
1.0, 2.0 ** (n_harmonic_functions - 1), n_harmonic_functions, dtype=torch.float32
)
self.register_buffer("_frequencies", frequencies * omega_0, persistent=False)
self.register_buffer("_zero_half_pi", torch.tensor([0.0, 0.5 * torch.pi]), persistent=False)
self.append_input = append_input
def forward(self, x: torch.Tensor, diag_cov: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
"""
Args:
x: tensor of shape [..., dim]
diag_cov: An optional tensor of shape `(..., dim)`
representing the diagonal covariance matrices of our Gaussians, joined with x
as means of the Gaussians.
Returns:
embedding: a harmonic embedding of `x` of shape
[..., (n_harmonic_functions * 2 + int(append_input)) * num_points_per_ray]
"""
# [..., dim, n_harmonic_functions]
embed = x[..., None] * self._frequencies
# [..., 1, dim, n_harmonic_functions] + [2, 1, 1] => [..., 2, dim, n_harmonic_functions]
embed = embed[..., None, :, :] + self._zero_half_pi[..., None, None]
# Use the trig identity cos(x) = sin(x + pi/2)
# and do one vectorized call to sin([x, x+pi/2]) instead of (sin(x), cos(x)).
embed = embed.sin()
if diag_cov is not None:
x_var = diag_cov[..., None] * torch.pow(self._frequencies, 2)
exp_var = torch.exp(-0.5 * x_var)
# [..., 2, dim, n_harmonic_functions]
embed = embed * exp_var[..., None, :, :]
embed = embed.reshape(*x.shape[:-1], -1)
if self.append_input:
return torch.cat([embed, x], dim=-1)
return embed
@staticmethod
def get_output_dim_static(input_dims: int, n_harmonic_functions: int, append_input: bool) -> int:
"""
Utility to help predict the shape of the output of `forward`.
Args:
input_dims: length of the last dimension of the input tensor
n_harmonic_functions: number of embedding frequencies
append_input: whether or not to concat the original
input to the harmonic embedding
Returns:
int: the length of the last dimension of the output tensor
"""
return input_dims * (2 * n_harmonic_functions + int(append_input))
def get_output_dim(self, input_dims: int = 3) -> int:
"""
Same as above. The default for input_dims is 3 for 3D applications
which use harmonic embedding for positional encoding,
so the input might be xyz.
"""
return self.get_output_dim_static(input_dims, len(self._frequencies), self.append_input)