Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the BSD-style license found in the | |
# LICENSE file in the root directory of this source tree. | |
# pyre-unsafe | |
from typing import Optional | |
import torch | |
class HarmonicEmbedding(torch.nn.Module): | |
def __init__( | |
self, n_harmonic_functions: int = 6, omega_0: float = 1.0, logspace: bool = True, append_input: bool = True | |
) -> None: | |
""" | |
The harmonic embedding layer supports the classical | |
Nerf positional encoding described in | |
`NeRF <https://arxiv.org/abs/2003.08934>`_ | |
and the integrated position encoding in | |
`MIP-NeRF <https://arxiv.org/abs/2103.13415>`_. | |
During the inference you can provide the extra argument `diag_cov`. | |
If `diag_cov is None`, it converts | |
rays parametrized with a `ray_bundle` to 3D points by | |
extending each ray according to the corresponding length. | |
Then it converts each feature | |
(i.e. vector along the last dimension) in `x` | |
into a series of harmonic features `embedding`, | |
where for each i in range(dim) the following are present | |
in embedding[...]:: | |
[ | |
sin(f_1*x[..., i]), | |
sin(f_2*x[..., i]), | |
... | |
sin(f_N * x[..., i]), | |
cos(f_1*x[..., i]), | |
cos(f_2*x[..., i]), | |
... | |
cos(f_N * x[..., i]), | |
x[..., i], # only present if append_input is True. | |
] | |
where N corresponds to `n_harmonic_functions-1`, and f_i is a scalar | |
denoting the i-th frequency of the harmonic embedding. | |
If `diag_cov is not None`, it approximates | |
conical frustums following a ray bundle as gaussians, | |
defined by x, the means of the gaussians and diag_cov, | |
the diagonal covariances. | |
Then it converts each gaussian | |
into a series of harmonic features `embedding`, | |
where for each i in range(dim) the following are present | |
in embedding[...]:: | |
[ | |
sin(f_1*x[..., i]) * exp(0.5 * f_1**2 * diag_cov[..., i,]), | |
sin(f_2*x[..., i]) * exp(0.5 * f_2**2 * diag_cov[..., i,]), | |
... | |
sin(f_N * x[..., i]) * exp(0.5 * f_N**2 * diag_cov[..., i,]), | |
cos(f_1*x[..., i]) * exp(0.5 * f_1**2 * diag_cov[..., i,]), | |
cos(f_2*x[..., i]) * exp(0.5 * f_2**2 * diag_cov[..., i,]),, | |
... | |
cos(f_N * x[..., i]) * exp(0.5 * f_N**2 * diag_cov[..., i,]), | |
x[..., i], # only present if append_input is True. | |
] | |
where N equals `n_harmonic_functions-1`, and f_i is a scalar | |
denoting the i-th frequency of the harmonic embedding. | |
If `logspace==True`, the frequencies `[f_1, ..., f_N]` are | |
powers of 2: | |
`f_1, ..., f_N = 2**torch.arange(n_harmonic_functions)` | |
If `logspace==False`, frequencies are linearly spaced between | |
`1.0` and `2**(n_harmonic_functions-1)`: | |
`f_1, ..., f_N = torch.linspace( | |
1.0, 2**(n_harmonic_functions-1), n_harmonic_functions | |
)` | |
Note that `x` is also premultiplied by the base frequency `omega_0` | |
before evaluating the harmonic functions. | |
Args: | |
n_harmonic_functions: int, number of harmonic | |
features | |
omega_0: float, base frequency | |
logspace: bool, Whether to space the frequencies in | |
logspace or linear space | |
append_input: bool, whether to concat the original | |
input to the harmonic embedding. If true the | |
output is of the form (embed.sin(), embed.cos(), x) | |
""" | |
super().__init__() | |
if logspace: | |
frequencies = 2.0 ** torch.arange(n_harmonic_functions, dtype=torch.float32) | |
else: | |
frequencies = torch.linspace( | |
1.0, 2.0 ** (n_harmonic_functions - 1), n_harmonic_functions, dtype=torch.float32 | |
) | |
self.register_buffer("_frequencies", frequencies * omega_0, persistent=False) | |
self.register_buffer("_zero_half_pi", torch.tensor([0.0, 0.5 * torch.pi]), persistent=False) | |
self.append_input = append_input | |
def forward(self, x: torch.Tensor, diag_cov: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: | |
""" | |
Args: | |
x: tensor of shape [..., dim] | |
diag_cov: An optional tensor of shape `(..., dim)` | |
representing the diagonal covariance matrices of our Gaussians, joined with x | |
as means of the Gaussians. | |
Returns: | |
embedding: a harmonic embedding of `x` of shape | |
[..., (n_harmonic_functions * 2 + int(append_input)) * num_points_per_ray] | |
""" | |
# [..., dim, n_harmonic_functions] | |
embed = x[..., None] * self._frequencies | |
# [..., 1, dim, n_harmonic_functions] + [2, 1, 1] => [..., 2, dim, n_harmonic_functions] | |
embed = embed[..., None, :, :] + self._zero_half_pi[..., None, None] | |
# Use the trig identity cos(x) = sin(x + pi/2) | |
# and do one vectorized call to sin([x, x+pi/2]) instead of (sin(x), cos(x)). | |
embed = embed.sin() | |
if diag_cov is not None: | |
x_var = diag_cov[..., None] * torch.pow(self._frequencies, 2) | |
exp_var = torch.exp(-0.5 * x_var) | |
# [..., 2, dim, n_harmonic_functions] | |
embed = embed * exp_var[..., None, :, :] | |
embed = embed.reshape(*x.shape[:-1], -1) | |
if self.append_input: | |
return torch.cat([embed, x], dim=-1) | |
return embed | |
def get_output_dim_static(input_dims: int, n_harmonic_functions: int, append_input: bool) -> int: | |
""" | |
Utility to help predict the shape of the output of `forward`. | |
Args: | |
input_dims: length of the last dimension of the input tensor | |
n_harmonic_functions: number of embedding frequencies | |
append_input: whether or not to concat the original | |
input to the harmonic embedding | |
Returns: | |
int: the length of the last dimension of the output tensor | |
""" | |
return input_dims * (2 * n_harmonic_functions + int(append_input)) | |
def get_output_dim(self, input_dims: int = 3) -> int: | |
""" | |
Same as above. The default for input_dims is 3 for 3D applications | |
which use harmonic embedding for positional encoding, | |
so the input might be xyz. | |
""" | |
return self.get_output_dim_static(input_dims, len(self._frequencies), self.append_input) | |