Spaces:
Build error
Build error
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | |
from bdb import set_trace | |
import copy | |
from email import generator | |
import imp | |
import math | |
from platform import architecture | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.autograd import grad | |
from training.networks import * | |
from dnnlib.camera import * | |
from dnnlib.geometry import ( | |
positional_encoding, upsample, downsample | |
) | |
from dnnlib.util import dividable, hash_func, EasyDict | |
from torch_utils.ops.hash_sample import hash_sample | |
from torch_utils.ops.grid_sample_gradfix import grid_sample | |
from torch_utils.ops.nerf_utils import topp_masking | |
from einops import repeat, rearrange | |
# --------------------------------- basic modules ------------------------------------------- # | |
class Style2Layer(nn.Module): | |
def __init__(self, | |
in_channels, | |
out_channels, | |
w_dim, | |
activation='lrelu', | |
resample_filter=[1,3,3,1], | |
magnitude_ema_beta = -1, # -1 means not using magnitude ema | |
**unused_kwargs): | |
# simplified version of SynthesisLayer | |
# no noise, kernel size forced to be 1x1, used in NeRF block | |
super().__init__() | |
self.activation = activation | |
self.conv_clamp = None | |
self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter)) | |
self.padding = 0 | |
self.act_gain = bias_act.activation_funcs[activation].def_gain | |
self.w_dim = w_dim | |
self.in_features = in_channels | |
self.out_features = out_channels | |
memory_format = torch.contiguous_format | |
if w_dim > 0: | |
self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1) | |
self.weight = torch.nn.Parameter( | |
torch.randn([out_channels, in_channels, 1, 1]).to(memory_format=memory_format)) | |
self.bias = torch.nn.Parameter(torch.zeros([out_channels])) | |
else: | |
self.weight = torch.nn.Parameter(torch.Tensor(out_channels, in_channels)) | |
self.bias = torch.nn.Parameter(torch.Tensor(out_channels)) | |
self.weight_gain = 1. | |
# initialization | |
torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) | |
fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight) | |
bound = 1 / math.sqrt(fan_in) | |
torch.nn.init.uniform_(self.bias, -bound, bound) | |
self.magnitude_ema_beta = magnitude_ema_beta | |
if magnitude_ema_beta > 0: | |
self.register_buffer('w_avg', torch.ones([])) | |
def extra_repr(self) -> str: | |
return 'in_features={}, out_features={}, style={}'.format( | |
self.in_features, self.out_features, self.w_dim | |
) | |
def forward(self, x, w=None, fused_modconv=None, gain=1, up=1, **unused_kwargs): | |
flip_weight = True # (up == 1) # slightly faster HACK | |
act = self.activation | |
if (self.magnitude_ema_beta > 0): | |
if self.training: # updating EMA. | |
with torch.autograd.profiler.record_function('update_magnitude_ema'): | |
magnitude_cur = x.detach().to(torch.float32).square().mean() | |
self.w_avg.copy_(magnitude_cur.lerp(self.w_avg, self.magnitude_ema_beta)) | |
input_gain = self.w_avg.rsqrt() | |
x = x * input_gain | |
if fused_modconv is None: | |
with misc.suppress_tracer_warnings(): # this value will be treated as a constant | |
fused_modconv = not self.training | |
if self.w_dim > 0: # modulated convolution | |
assert x.ndim == 4, "currently not support modulated MLP" | |
styles = self.affine(w) # Batch x style_dim | |
if x.size(0) > styles.size(0): | |
styles = repeat(styles, 'b c -> (b s) c', s=x.size(0) // styles.size(0)) | |
x = modulated_conv2d(x=x, weight=self.weight, styles=styles, noise=None, up=up, | |
padding=self.padding, resample_filter=self.resample_filter, | |
flip_weight=flip_weight, fused_modconv=fused_modconv) | |
act_gain = self.act_gain * gain | |
act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None | |
x = bias_act.bias_act(x, self.bias.to(x.dtype), act=act, gain=act_gain, clamp=act_clamp) | |
else: | |
if x.ndim == 2: # MLP mode | |
x = F.relu(F.linear(x, self.weight, self.bias.to(x.dtype))) | |
else: | |
x = F.relu(F.conv2d(x, self.weight[:,:,None, None], self.bias)) | |
# x = bias_act.bias_act(x, self.bias.to(x.dtype), act='relu') | |
return x | |
class SDFDensityLaplace(nn.Module): # alpha * Laplace(loc=0, scale=beta).cdf(-sdf) | |
def __init__(self, params_init={}, noise_std=0.0, beta_min=0.001, exp_beta=False): | |
super().__init__() | |
self.noise_std = noise_std | |
for p in params_init: | |
param = nn.Parameter(torch.tensor(params_init[p])) | |
setattr(self, p, param) | |
self.beta_min = beta_min | |
self.exp_beta = exp_beta | |
if (exp_beta == 'upper') or exp_beta: | |
self.register_buffer("steps", torch.scalar_tensor(0).float()) | |
def density_func(self, sdf, beta=None): | |
if beta is None: | |
beta = self.get_beta() | |
alpha = 1 / beta | |
return alpha * (0.5 + 0.5 * sdf.sign() * torch.expm1(-sdf.abs() / beta)) # TODO: need abs maybe, not sure | |
def get_beta(self): | |
if self.exp_beta == 'upper': | |
beta_upper = 0.12 * torch.exp(-0.003 * (self.steps / 1e3)) | |
beta = min(self.beta.abs(), beta_upper) + self.beta_min | |
elif self.exp_beta: | |
if self.steps < 500000: | |
beta = self.beta.abs() + self.beta_min | |
else: | |
beta = self.beta.abs().detach() + self.beta_min | |
else: | |
beta = self.beta.abs() + self.beta_min | |
return beta | |
def set_steps(self, steps): | |
if hasattr(self, "steps"): | |
self.steps = self.steps * 0 + steps | |
# ------------------------------------------------------------------------------------------- # | |
class NeRFBlock(nn.Module): | |
''' | |
Predicts volume density and color from 3D location, viewing | |
direction, and latent code z. | |
''' | |
# dimensions | |
input_dim = 3 | |
w_dim = 512 # style latent | |
z_dim = 0 # input latent | |
rgb_out_dim = 128 | |
hidden_size = 128 | |
n_blocks = 8 | |
img_channels = 3 | |
magnitude_ema_beta = -1 | |
disable_latents = False | |
max_batch_size = 2 ** 18 | |
shuffle_factor = 1 | |
implementation = 'batch_reshape' # option: [flatten_2d, batch_reshape] | |
# architecture settings | |
activation = 'lrelu' | |
use_skip = False | |
use_viewdirs = False | |
add_rgb = False | |
predict_rgb = False | |
inverse_sphere = False | |
merge_sigma_feat = False # use one MLP for sigma and features | |
no_sigma = False # do not predict sigma, only output features | |
tcnn_backend = False | |
use_style = None | |
use_normal = False | |
use_sdf = None | |
volsdf_exp_beta = False | |
normalized_feat = False | |
final_sigmoid_act = False | |
# positional encoding inpuut | |
use_pos = False | |
n_freq_posenc = 10 | |
n_freq_posenc_views = 4 | |
downscale_p_by = 1 | |
gauss_dim_pos = 20 | |
gauss_dim_view = 4 | |
gauss_std = 10. | |
positional_encoding = "normal" | |
def __init__(self, nerf_kwargs): | |
super().__init__() | |
for key in nerf_kwargs: | |
if hasattr(self, key): | |
setattr(self, key, nerf_kwargs[key]) | |
self.sdf_mode = self.use_sdf | |
self.use_sdf = self.use_sdf is not None | |
if self.use_sdf == 'volsdf': | |
self.density_transform = SDFDensityLaplace( | |
params_init={'beta': 0.1}, | |
beta_min=0.0001, | |
exp_beta=self.volsdf_exp_beta) | |
# ----------- input module ------------------------- | |
D = self.input_dim if not self.inverse_sphere else self.input_dim + 1 | |
if self.positional_encoding == 'gauss': | |
rng = np.random.RandomState(2021) | |
B_pos = self.gauss_std * torch.from_numpy(rng.randn(D, self.gauss_dim_pos * D)).float() | |
B_view = self.gauss_std * torch.from_numpy(rng.randn(3, self.gauss_dim_view * 3)).float() | |
self.register_buffer("B_pos", B_pos) | |
self.register_buffer("B_view", B_view) | |
dim_embed = D * self.gauss_dim_pos * 2 | |
dim_embed_view = 3 * self.gauss_dim_view * 2 | |
elif self.positional_encoding == 'normal': | |
dim_embed = D * self.n_freq_posenc * 2 | |
dim_embed_view = 3 * self.n_freq_posenc_views * 2 | |
else: # not using positional encoding | |
dim_embed, dim_embed_view = D, 3 | |
if self.use_pos: | |
dim_embed, dim_embed_view = dim_embed + D, dim_embed_view + 3 | |
self.dim_embed = dim_embed | |
self.dim_embed_view = dim_embed_view | |
# ------------ Layers -------------------------- | |
assert not (self.add_rgb and self.predict_rgb), "only one could be achieved" | |
assert not ((self.use_viewdirs or self.use_normal) and (self.merge_sigma_feat or self.no_sigma)), \ | |
"merged MLP does not support." | |
if self.disable_latents: | |
w_dim = 0 | |
elif self.z_dim > 0: # if input global latents, disable using style vectors | |
w_dim, dim_embed, dim_embed_view = 0, dim_embed + self.z_dim, dim_embed_view + self.z_dim | |
else: | |
w_dim = self.w_dim | |
final_in_dim = self.hidden_size | |
if self.use_normal: | |
final_in_dim += D | |
final_out_dim = self.rgb_out_dim * self.shuffle_factor | |
if self.merge_sigma_feat: | |
final_out_dim += self.shuffle_factor # predicting sigma | |
if self.add_rgb: | |
final_out_dim += self.img_channels | |
# start building the model | |
if self.tcnn_backend: | |
try: | |
import tinycudann as tcnn | |
except ImportError: | |
raise ImportError("This sample requires the tiny-cuda-nn extension for PyTorch.") | |
assert self.merge_sigma_feat and (not self.predict_rgb) and (not self.add_rgb) | |
assert w_dim == 0, "do not use any modulating inputs" | |
tcnn_config = {"otype": "FullyFusedMLP", "activation": "ReLU", "output_activation": "None", "n_neurons": 64, "n_hidden_layers": 1} | |
self.network = tcnn.Network(dim_embed, final_out_dim, tcnn_config) | |
self.num_ws = 0 | |
else: | |
self.fc_in = Style2Layer(dim_embed, self.hidden_size, w_dim, activation=self.activation) | |
self.num_ws = 1 | |
self.skip_layer = self.n_blocks // 2 - 1 if self.use_skip else None | |
if self.n_blocks > 1: | |
self.blocks = nn.ModuleList([ | |
Style2Layer( | |
self.hidden_size if i != self.skip_layer else self.hidden_size + dim_embed, | |
self.hidden_size, | |
w_dim, activation=self.activation, | |
magnitude_ema_beta=self.magnitude_ema_beta) | |
for i in range(self.n_blocks - 1)]) | |
self.num_ws += (self.n_blocks - 1) | |
if not (self.merge_sigma_feat or self.no_sigma): | |
self.sigma_out = ToRGBLayer(self.hidden_size, self.shuffle_factor, w_dim, kernel_size=1) | |
self.num_ws += 1 | |
self.feat_out = ToRGBLayer(final_in_dim, final_out_dim, w_dim, kernel_size=1) | |
if (self.z_dim == 0 and (not self.disable_latents)): | |
self.num_ws += 1 | |
else: | |
self.num_ws = 0 | |
if self.use_viewdirs: | |
assert self.predict_rgb, "only works when predicting RGB" | |
self.from_ray = Conv2dLayer(dim_embed_view, final_out_dim, kernel_size=1, activation='linear') | |
if self.predict_rgb: # predict RGB over features | |
self.to_rgb = Conv2dLayer(final_out_dim, self.img_channels * self.shuffle_factor, kernel_size=1, activation='linear') | |
def set_steps(self, steps): | |
if hasattr(self, "steps"): | |
self.steps.fill_(steps) | |
def transform_points(self, p, views=False): | |
p = p / self.downscale_p_by | |
if self.positional_encoding == 'gauss': | |
B = self.B_view if views else self.B_pos | |
p_transformed = positional_encoding(p, B, 'gauss', self.use_pos) | |
elif self.positional_encoding == 'normal': | |
L = self.n_freq_posenc_views if views else self.n_freq_posenc | |
p_transformed = positional_encoding(p, L, 'normal', self.use_pos) | |
else: | |
p_transformed = p | |
return p_transformed | |
def forward(self, p_in, ray_d, z_shape=None, z_app=None, ws=None, shape=None, requires_grad=False, impl=None): | |
with torch.set_grad_enabled(self.training or self.use_sdf or requires_grad): | |
impl = 'mlp' if self.tcnn_backend else impl | |
option, p_in = self.forward_inputs(p_in, shape=shape, impl=impl) | |
if self.tcnn_backend: | |
with torch.cuda.amp.autocast(): | |
p = p_in.squeeze(-1).squeeze(-1) | |
o = self.network(p) | |
sigma_raw, feat = o[:, :self.shuffle_factor], o[:, self.shuffle_factor:] | |
sigma_raw = rearrange(sigma_raw, '(b s) d -> b s d', s=option[2]).to(p_in.dtype) | |
feat = rearrange(feat, '(b s) d -> b s d', s=option[2]).to(p_in.dtype) | |
else: | |
feat, sigma_raw = self.forward_nerf(option, p_in, ray_d, ws=ws, z_shape=z_shape, z_app=z_app) | |
return feat, sigma_raw | |
def forward_inputs(self, p_in, shape=None, impl=None): | |
# prepare the inputs | |
impl = impl if impl is not None else self.implementation | |
if (shape is not None) and (impl == 'batch_reshape'): | |
height, width, n_steps = shape[1:] | |
elif impl == 'flatten_2d': | |
(height, width), n_steps = dividable(p_in.shape[1]), 1 | |
elif impl == 'mlp': | |
height, width, n_steps = 1, 1, p_in.shape[1] | |
else: | |
raise NotImplementedError("looking for more efficient implementation.") | |
p_in = rearrange(p_in, 'b (h w s) d -> (b s) d h w', h=height, w=width, s=n_steps) | |
use_normal = self.use_normal or self.use_sdf | |
if use_normal: | |
p_in.requires_grad_(True) | |
return (height, width, n_steps, use_normal), p_in | |
def forward_nerf(self, option, p_in, ray_d=None, ws=None, z_shape=None, z_app=None): | |
height, width, n_steps, use_normal = option | |
# forward nerf feature networks | |
p = self.transform_points(p_in.permute(0,2,3,1)) | |
if (self.z_dim > 0) and (not self.disable_latents): | |
assert (z_shape is not None) and (ws is None) | |
z_shape = repeat(z_shape, 'b c -> (b s) h w c', h=height, w=width, s=n_steps) | |
p = torch.cat([p, z_shape], -1) | |
p = p.permute(0,3,1,2) # BS x C x H x W | |
if height == width == 1: # MLP | |
p = p.squeeze(-1).squeeze(-1) | |
net = self.fc_in(p, ws[:, 0] if ws is not None else None) | |
if self.n_blocks > 1: | |
for idx, layer in enumerate(self.blocks): | |
ws_i = ws[:, idx + 1] if ws is not None else None | |
if (self.skip_layer is not None) and (idx == self.skip_layer): | |
net = torch.cat([net, p], 1) | |
net = layer(net, ws_i, up=1) | |
# forward to get the final results | |
w_idx = self.n_blocks # fc_in, self.blocks | |
feat_inputs = [net] | |
if not (self.merge_sigma_feat or self.no_sigma): | |
ws_i = ws[:, w_idx] if ws is not None else None | |
sigma_out = self.sigma_out(net, ws_i) | |
if use_normal: | |
gradients, = grad( | |
outputs=sigma_out, inputs=p_in, | |
grad_outputs=torch.ones_like(sigma_out, requires_grad=False), | |
retain_graph=True, create_graph=True, only_inputs=True) | |
feat_inputs.append(gradients) | |
ws_i = ws[:, -1] if ws is not None else None | |
net = torch.cat(feat_inputs, 1) if len(feat_inputs) > 1 else net | |
feat_out = self.feat_out(net, ws_i) # this is used for lowres output | |
if self.merge_sigma_feat: # split sigma from the feature | |
sigma_out, feat_out = feat_out[:, :self.shuffle_factor], feat_out[:, self.shuffle_factor:] | |
elif self.no_sigma: | |
sigma_out = None | |
if self.predict_rgb: | |
if self.use_viewdirs and ray_d is not None: | |
ray_d = ray_d / torch.norm(ray_d, dim=-1, keepdim=True) | |
ray_d = self.transform_points(ray_d, views=True) | |
if self.z_dim > 0: | |
ray_d = torch.cat([ray_d, repeat(z_app, 'b c -> b (h w s) c', h=height, w=width, s=n_steps)], -1) | |
ray_d = rearrange(ray_d, 'b (h w s) d -> (b s) d h w', h=height, w=width, s=n_steps) | |
feat_ray = self.from_ray(ray_d) | |
rgb = self.to_rgb(F.leaky_relu(feat_out + feat_ray)) | |
else: | |
rgb = self.to_rgb(feat_out) | |
if self.final_sigmoid_act: | |
rgb = torch.sigmoid(rgb) | |
if self.normalized_feat: | |
feat_out = feat_out / (1e-7 + feat_out.norm(dim=-1, keepdim=True)) | |
feat_out = torch.cat([rgb, feat_out], 1) | |
# transform back | |
if feat_out.ndim == 2: # mlp mode | |
sigma_out = rearrange(sigma_out, '(b s) d -> b s d', s=n_steps) if sigma_out is not None else None | |
feat_out = rearrange(feat_out, '(b s) d -> b s d', s=n_steps) | |
else: | |
sigma_out = rearrange(sigma_out, '(b s) d h w -> b (h w s) d', s=n_steps) if sigma_out is not None else None | |
feat_out = rearrange(feat_out, '(b s) d h w -> b (h w s) d', s=n_steps) | |
return feat_out, sigma_out | |
class CameraGenerator(torch.nn.Module): | |
def __init__(self, in_dim=2, hi_dim=128, out_dim=2): | |
super().__init__() | |
self.affine1 = FullyConnectedLayer(in_dim, hi_dim, activation='lrelu') | |
self.affine2 = FullyConnectedLayer(hi_dim, hi_dim, activation='lrelu') | |
self.proj = FullyConnectedLayer(hi_dim, out_dim) | |
def forward(self, x): | |
cam = self.proj(self.affine2(self.affine1(x))) | |
return cam | |
class CameraRay(object): | |
range_u = (0, 0) | |
range_v = (0.25, 0.25) | |
range_radius = (2.732, 2.732) | |
depth_range = [0.5, 6.] | |
gaussian_camera = False | |
angular_camera = False | |
intersect_ball = False | |
fov = 49.13 | |
bg_start = 1.0 | |
depth_transform = None # "LogWarp" or "InverseWarp" | |
dists_normalized = False # use normalized interval instead of real dists | |
random_rotate = False | |
ray_align_corner = True | |
nonparam_cameras = None | |
def __init__(self, camera_kwargs, **other_kwargs): | |
if len(camera_kwargs) == 0: # for compitatbility of old checkpoints | |
camera_kwargs.update(other_kwargs) | |
for key in camera_kwargs: | |
if hasattr(self, key): | |
setattr(self, key, camera_kwargs[key]) | |
self.camera_matrix = get_camera_mat(fov=self.fov) | |
def prepare_pixels(self, img_res, tgt_res, vol_res, camera_matrices, theta, margin=0, **unused): | |
if self.ray_align_corner: | |
all_pixels = self.get_pixel_coords(img_res, camera_matrices, theta=theta) | |
all_pixels = rearrange(all_pixels, 'b (h w) c -> b c h w', h=img_res, w=img_res) | |
tgt_pixels = F.interpolate(all_pixels, size=(tgt_res, tgt_res), mode='nearest') if tgt_res < img_res else all_pixels.clone() | |
vol_pixels = F.interpolate(tgt_pixels, size=(vol_res, vol_res), mode='nearest') if tgt_res > vol_res else tgt_pixels.clone() | |
vol_pixels = rearrange(vol_pixels, 'b c h w -> b (h w) c') | |
else: # coordinates not aligned! | |
tgt_pixels = self.get_pixel_coords(tgt_res, camera_matrices, corner_aligned=False, theta=theta) | |
vol_pixels = self.get_pixel_coords(vol_res, camera_matrices, corner_aligned=False, theta=theta, margin=margin) \ | |
if (tgt_res > vol_res) or (margin > 0) else tgt_pixels.clone() | |
tgt_pixels = rearrange(tgt_pixels, 'b (h w) c -> b c h w', h=tgt_res, w=tgt_res) | |
return vol_pixels, tgt_pixels | |
def prepare_pixels_regularization(self, tgt_pixels, n_reg_samples): | |
# only apply when size is bigger than voxel resolution | |
pace = tgt_pixels.size(-1) // n_reg_samples | |
idxs = torch.arange(0, tgt_pixels.size(-1), pace, device=tgt_pixels.device) # n_reg_samples | |
u_xy = torch.rand(tgt_pixels.size(0), 2, device=tgt_pixels.device) | |
u_xy = (u_xy * pace).floor().long() # batch_size x 2 | |
x_idxs, y_idxs = idxs[None,:] + u_xy[:,:1], idxs[None,:] + u_xy[:,1:] | |
rand_indexs = (x_idxs[:,None,:] + y_idxs[:,:,None] * tgt_pixels.size(-1)).reshape(tgt_pixels.size(0), -1) | |
tgt_pixels = rearrange(tgt_pixels, 'b c h w -> b (h w) c') | |
rand_pixels = tgt_pixels.gather(1, rand_indexs.unsqueeze(-1).repeat(1,1,2)) | |
return rand_pixels, rand_indexs | |
def get_roll(self, ws, training=True, theta=None, **unused): | |
if (self.random_rotate is not None) and training: | |
theta = torch.randn(ws.size(0)).to(ws.device) * self.random_rotate / 2 | |
theta = theta / 180 * math.pi | |
else: | |
if theta is not None: | |
theta = torch.ones(ws.size(0)).to(ws.device) * theta | |
return theta | |
def get_camera(self, batch_size, device, mode='random', fov=None, force_uniform=False): | |
if fov is not None: | |
camera_matrix = get_camera_mat(fov) | |
else: | |
camera_matrix = self.camera_matrix | |
camera_mat = camera_matrix.repeat(batch_size, 1, 1).to(device) | |
reg_loss = None # TODO: useless | |
if isinstance(mode, list): | |
# default camera generator, we assume input mode is linear | |
if len(mode) == 3: | |
val_u, val_v, val_r = mode | |
r0 = self.range_radius[0] | |
r1 = self.range_radius[1] | |
else: | |
val_u, val_v, val_r, r_s = mode | |
r0 = self.range_radius[0] * r_s | |
r1 = self.range_radius[1] * r_s | |
world_mat = get_camera_pose( | |
self.range_u, self.range_v, [r0, r1], | |
val_u, val_v, val_r, | |
batch_size=batch_size, | |
gaussian=False, # input mode is by default uniform | |
angular=self.angular_camera).to(device) | |
elif isinstance(mode, torch.Tensor): | |
world_mat, mode = get_camera_pose_v2( | |
self.range_u, self.range_v, self.range_radius, mode, | |
gaussian=self.gaussian_camera and (not force_uniform), | |
angular=self.angular_camera) | |
world_mat = world_mat.to(device) | |
mode = torch.stack(mode, 1).to(device) | |
else: | |
world_mat, mode = get_random_pose( | |
self.range_u, self.range_v, | |
self.range_radius, batch_size, | |
gaussian=self.gaussian_camera, | |
angular=self.angular_camera) | |
world_mat = world_mat.to(device) | |
mode = torch.stack(mode, 1).to(device) | |
return camera_mat.float(), world_mat.float(), mode, reg_loss | |
def get_transformed_depth(self, di, reversed=False): | |
depth_range = self.depth_range | |
if (self.depth_transform is None) or (self.depth_transform == 'None'): | |
g_fwd, g_inv = lambda x: x, lambda x: x | |
elif self.depth_transform == 'LogWarp': | |
g_fwd, g_inv = math.log, torch.exp | |
elif self.depth_transform == 'InverseWarp': | |
g_fwd, g_inv = lambda x: 1/x, lambda x: 1/x | |
else: | |
raise NotImplementedError | |
if not reversed: | |
return g_inv(g_fwd(depth_range[1]) * di + g_fwd(depth_range[0]) * (1 - di)) | |
else: | |
d0 = (g_fwd(di) - g_fwd(depth_range[0])) / (g_fwd(depth_range[1]) - g_fwd(depth_range[0])) | |
return d0.clip(min=0, max=1) | |
def get_evaluation_points(self, pixels_world=None, camera_world=None, di=None, p_i=None, no_reshape=False, transform=None): | |
if p_i is None: | |
batch_size = pixels_world.shape[0] | |
n_steps = di.shape[-1] | |
ray_i = pixels_world - camera_world | |
p_i = camera_world.unsqueeze(-2).contiguous() + \ | |
di.unsqueeze(-1).contiguous() * ray_i.unsqueeze(-2).contiguous() | |
ray_i = ray_i.unsqueeze(-2).repeat(1, 1, n_steps, 1) | |
else: | |
assert no_reshape, "only used to transform points to a warped space" | |
if transform is None: | |
transform = self.depth_transform | |
if transform == 'LogWarp': | |
c = torch.tensor([1., 0., 0.]).to(p_i.device) | |
p_i = normalization_inverse_sqrt_dist_centered( | |
p_i, c[None, None, None, :], self.depth_range[1]) | |
elif transform == 'InverseWarp': | |
# https://arxiv.org/pdf/2111.12077.pdf | |
p_n = p_i.norm(p=2, dim=-1, keepdim=True).clamp(min=1e-7) | |
con = p_n.ge(1).type_as(p_n) | |
p_i = p_i * (1 -con) + (2 - 1 / p_n) * (p_i / p_n) * con | |
if no_reshape: | |
return p_i | |
assert(p_i.shape == ray_i.shape) | |
p_i = p_i.reshape(batch_size, -1, 3) | |
ray_i = ray_i.reshape(batch_size, -1, 3) | |
return p_i, ray_i | |
def get_evaluation_points_bg(self, pixels_world, camera_world, di): | |
batch_size = pixels_world.shape[0] | |
n_steps = di.shape[-1] | |
n_pixels = pixels_world.shape[1] | |
ray_world = pixels_world - camera_world | |
ray_world = ray_world / ray_world.norm(dim=-1, keepdim=True) # normalize | |
camera_world = camera_world.unsqueeze(-2).expand(batch_size, n_pixels, n_steps, 3) | |
ray_world = ray_world.unsqueeze(-2).expand(batch_size, n_pixels, n_steps, 3) | |
bg_pts, _ = depth2pts_outside(camera_world, ray_world, di) # di: 1 ---> 0 | |
bg_pts = bg_pts.reshape(batch_size, -1, 4) | |
ray_world = ray_world.reshape(batch_size, -1, 3) | |
return bg_pts, ray_world | |
def add_noise_to_interval(self, di): | |
di_mid = .5 * (di[..., 1:] + di[..., :-1]) | |
di_high = torch.cat([di_mid, di[..., -1:]], dim=-1) | |
di_low = torch.cat([di[..., :1], di_mid], dim=-1) | |
noise = torch.rand_like(di_low) | |
ti = di_low + (di_high - di_low) * noise | |
return ti | |
def calc_volume_weights(self, sigma, z_vals=None, ray_vector=None, dists=None, last_dist=1e10): | |
if dists is None: | |
dists = z_vals[..., 1:] - z_vals[..., :-1] | |
if ray_vector is not None: | |
dists = dists * torch.norm(ray_vector, dim=-1, keepdim=True) | |
dists = torch.cat([dists, torch.ones_like(dists[..., :1]) * last_dist], dim=-1) | |
alpha = 1.-torch.exp(-F.relu(sigma)*dists) | |
if last_dist > 0: | |
alpha[..., -1] = 1 | |
# alpha = 1.-torch.exp(-sigma * dists) | |
T = torch.cumprod(torch.cat([ | |
torch.ones_like(alpha[:, :, :1]), | |
(1. - alpha + 1e-10), ], dim=-1), dim=-1)[..., :-1] | |
weights = alpha * T | |
return weights, T[..., -1], dists | |
def get_pixel_coords(self, tgt_res, camera_matrices, corner_aligned=True, margin=0, theta=None, invert_y=True): | |
device = camera_matrices[0].device | |
batch_size = camera_matrices[0].shape[0] | |
# margin = self.margin if margin is None else margin | |
full_pixels = arange_pixels((tgt_res, tgt_res), | |
batch_size, invert_y_axis=invert_y, margin=margin, | |
corner_aligned=corner_aligned).to(device) | |
if (theta is not None): | |
theta = theta.unsqueeze(-1) | |
x = full_pixels[..., 0] * torch.cos(theta) - full_pixels[..., 1] * torch.sin(theta) | |
y = full_pixels[..., 0] * torch.sin(theta) + full_pixels[..., 1] * torch.cos(theta) | |
full_pixels = torch.stack([x, y], -1) | |
return full_pixels | |
def get_origin_direction(self, pixels, camera_matrices): | |
camera_mat, world_mat = camera_matrices[:2] | |
if camera_mat.size(0) < pixels.size(0): | |
camera_mat = repeat(camera_mat, 'b c d -> (b s) c d', s=pixels.size(0)//camera_mat.size(0)) | |
if world_mat.size(0) < pixels.size(0): | |
world_mat = repeat(world_mat, 'b c d -> (b s) c d', s=pixels.size(0)//world_mat.size(0)) | |
pixels_world = image_points_to_world(pixels, camera_mat=camera_mat, world_mat=world_mat) | |
camera_world = origin_to_world(pixels.size(1), camera_mat=camera_mat, world_mat=world_mat) | |
ray_vector = pixels_world - camera_world | |
return pixels_world, camera_world, ray_vector | |
def set_camera_prior(self, dataset_cams): | |
self.nonparam_cameras = dataset_cams | |
class VolumeRenderer(object): | |
n_ray_samples = 14 | |
n_bg_samples = 4 | |
n_final_samples = None # final nerf steps after upsampling (optional) | |
sigma_type = 'relu' # other allowed options including, "abs", "shiftedsoftplus", "exp" | |
hierarchical = True | |
fine_only = False | |
no_background = False | |
white_background = False | |
mask_background = False | |
pre_volume_size = None | |
bound = None | |
density_p_target = 1.0 | |
tv_loss_weight = 0.0 # for now only works for density-based voxels | |
def __init__(self, renderer_kwargs, camera_ray, input_encoding=None, **other_kwargs): | |
if len(renderer_kwargs) == 0: # for compitatbility of old checkpoints | |
renderer_kwargs.update(other_kwargs) | |
for key in renderer_kwargs: | |
if hasattr(self, key): | |
setattr(self, key, renderer_kwargs[key]) | |
self.C = camera_ray | |
self.I = input_encoding | |
def split_feat(self, x, img_channels, white_color=None, split_rgb=True): | |
img = x[:, :img_channels] | |
if split_rgb: | |
x = x[:, img_channels:] | |
if (white_color is not None) and self.white_background: | |
img = img + white_color | |
return x, img | |
def get_bound(self): | |
if self.bound is not None: | |
return self.bound | |
# when applying normalization, the points are restricted inside R=2 ball | |
if self.C.depth_transform == 'InverseWarp': | |
bound = 2 | |
else: # TODO: this is a bit hacky as we assume object at origin | |
bound = (self.C.depth_range[1] - self.C.depth_range[0]) | |
return bound | |
def get_density(self, sigma_raw, fg_nerf, no_noise=False, training=False): | |
if fg_nerf.use_sdf: | |
sigma = fg_nerf.density_transform.density_func(sigma_raw) | |
elif self.sigma_type == 'relu': | |
if training and (not no_noise): # adding noise to pass gradient? | |
sigma_raw = sigma_raw + torch.randn_like(sigma_raw) | |
sigma = F.relu(sigma_raw) | |
elif self.sigma_type == 'shiftedsoftplus': # https://arxiv.org/pdf/2111.11215.pdf | |
sigma = F.softplus(sigma_raw - 1) # 1 is the shifted bias. | |
elif self.sigma_type == 'exp_truncated': # density in the log-space | |
sigma = torch.exp(5 - F.relu(5 - (sigma_raw - 1))) # up-bound = 5, also shifted by 1 | |
else: | |
sigma = sigma_raw | |
return sigma | |
def forward_hierarchical_sampling(self, di, weights, n_steps, det=False): | |
di_mid = 0.5 * (di[..., :-1] + di[..., 1:]) | |
n_bins = di_mid.size(-1) | |
batch_size = di.size(0) | |
di_fine = sample_pdf( | |
di_mid.reshape(-1, n_bins), | |
weights.reshape(-1, n_bins+1)[:, 1:-1], | |
n_steps, det=det).reshape(batch_size, -1, n_steps) | |
return di_fine | |
def forward_rendering_with_pre_density(self, H, output, fg_nerf, nerf_input_cams, nerf_input_feats, latent_codes, styles): | |
pixels_world, camera_world, ray_vector = nerf_input_cams | |
z_shape_obj, z_app_obj = latent_codes[:2] | |
height, width = dividable(H.n_points) | |
fg_shape = [H.batch_size, height, width, H.n_steps] | |
bound = self.get_bound() | |
# sample points | |
di = torch.linspace(0., 1., steps=H.n_steps).to(H.device) | |
di = repeat(di, 's -> b n s', b=H.batch_size, n=H.n_points) | |
if (H.training and (not H.get('disable_noise', False))) or H.get('force_noise', False): | |
di = self.C.add_noise_to_interval(di) | |
di_trs = self.C.get_transformed_depth(di) | |
p_i, r_i = self.C.get_evaluation_points(pixels_world, camera_world, di_trs) | |
p_i = self.I.query_input_features(p_i, nerf_input_feats, fg_shape, bound) | |
pre_sigma_raw, p_i = p_i[...,:self.I.sigma_dim].sum(dim=-1, keepdim=True), p_i[..., self.I.sigma_dim:] | |
pre_sigma = self.get_density(rearrange(pre_sigma_raw, 'b (n s) () -> b n s', s=H.n_steps), | |
fg_nerf, training=H.training) | |
pre_weights = self.C.calc_volume_weights( | |
pre_sigma, di if self.C.dists_normalized else di_trs, ray_vector, last_dist=1e10)[0] | |
feat, _ = fg_nerf(p_i, r_i, z_shape_obj, z_app_obj, ws=styles, shape=fg_shape) | |
feat = rearrange(feat, 'b (n s) d -> b n s d', s=H.n_steps) | |
feat = torch.sum(pre_weights.unsqueeze(-1) * feat, dim=-2) | |
output.feat += [feat] | |
output.fg_weights = pre_weights | |
output.fg_depths = (di, di_trs) | |
return output | |
def forward_sampling(self, H, output, fg_nerf, nerf_input_cams, nerf_input_feats, latent_codes, styles): | |
# TODO: experimental research code. Not functional yet. | |
pixels_world, camera_world, ray_vector = nerf_input_cams | |
z_shape_obj, z_app_obj = latent_codes[:2] | |
height, width = dividable(H.n_points) | |
bound = self.get_bound() | |
# just to simulate | |
H.n_steps = 64 | |
di = torch.linspace(0., 1., steps=H.n_steps).to(H.device) | |
di = repeat(di, 's -> b n s', b=H.batch_size, n=H.n_points) | |
if (H.training and (not H.get('disable_noise', False))) or H.get('force_noise', False): | |
di = self.C.add_noise_to_interval(di) | |
di_trs = self.C.get_transformed_depth(di) | |
fg_shape = [H.batch_size, height, width, 1] | |
# iteration in the loop (?) | |
feats, sigmas = [], [] | |
with torch.enable_grad(): | |
di_trs.requires_grad_(True) | |
for s in range(di_trs.shape[-1]): | |
di_s = di_trs[..., s:s+1] | |
p_i, r_i = self.C.get_evaluation_points(pixels_world, camera_world, di_s) | |
if nerf_input_feats is not None: | |
p_i = self.I.query_input_features(p_i, nerf_input_feats, fg_shape, bound) | |
feat, sigma_raw = fg_nerf(p_i, r_i, z_shape_obj, z_app_obj, ws=styles, shape=fg_shape, requires_grad=True) | |
sigma = self.get_density(sigma_raw, fg_nerf, training=H.training) | |
feats += [feat] | |
sigmas += [sigma] | |
feat, sigma = torch.stack(feats, 2), torch.cat(sigmas, 2) | |
fg_weights, bg_lambda = self.C.calc_volume_weights( | |
sigma, di if self.C.dists_normalized else di_trs, # use real dists for computing weights | |
ray_vector, last_dist=0 if not H.fg_inf_depth else 1e10)[:2] | |
fg_feat = torch.sum(fg_weights.unsqueeze(-1) * feat, dim=-2) | |
output.feat += [fg_feat] | |
output.full_out += [feat] | |
output.fg_weights = fg_weights | |
output.bg_lambda = bg_lambda | |
output.fg_depths = (di, di_trs) | |
return output | |
def forward_rendering(self, H, output, fg_nerf, nerf_input_cams, nerf_input_feats, latent_codes, styles): | |
pixels_world, camera_world, ray_vector = nerf_input_cams | |
z_shape_obj, z_app_obj = latent_codes[:2] | |
height, width = dividable(H.n_points) | |
fg_shape = [H.batch_size, height, width, H.n_steps] | |
bound = self.get_bound() | |
# sample points | |
di = torch.linspace(0., 1., steps=H.n_steps).to(H.device) | |
di = repeat(di, 's -> b n s', b=H.batch_size, n=H.n_points) | |
if (H.training and (not H.get('disable_noise', False))) or H.get('force_noise', False): | |
di = self.C.add_noise_to_interval(di) | |
di_trs = self.C.get_transformed_depth(di) | |
p_i, r_i = self.C.get_evaluation_points(pixels_world, camera_world, di_trs) | |
if nerf_input_feats is not None: | |
p_i = self.I.query_input_features(p_i, nerf_input_feats, fg_shape, bound) | |
feat, sigma_raw = fg_nerf(p_i, r_i, z_shape_obj, z_app_obj, ws=styles, shape=fg_shape) | |
feat = rearrange(feat, 'b (n s) d -> b n s d', s=H.n_steps) | |
sigma_raw = rearrange(sigma_raw.squeeze(-1), 'b (n s) -> b n s', s=H.n_steps) | |
sigma = self.get_density(sigma_raw, fg_nerf, training=H.training) | |
fg_weights, bg_lambda = self.C.calc_volume_weights( | |
sigma, di if self.C.dists_normalized else di_trs, # use real dists for computing weights | |
ray_vector, last_dist=0 if not H.fg_inf_depth else 1e10)[:2] | |
if self.hierarchical and (not H.get('disable_hierarchical', False)): | |
with torch.no_grad(): | |
di_fine = self.forward_hierarchical_sampling(di, fg_weights, H.n_steps, det=(not H.training)) | |
di_trs_fine = self.C.get_transformed_depth(di_fine) | |
p_f, r_f = self.C.get_evaluation_points(pixels_world, camera_world, di_trs_fine) | |
if nerf_input_feats is not None: | |
p_f = self.I.query_input_features(p_f, nerf_input_feats, fg_shape, bound) | |
feat_f, sigma_raw_f = fg_nerf(p_f, r_f, z_shape_obj, z_app_obj, ws=styles, shape=fg_shape) | |
feat_f = rearrange(feat_f, 'b (n s) d -> b n s d', s=H.n_steps) | |
sigma_raw_f = rearrange(sigma_raw_f.squeeze(-1), 'b (n s) -> b n s', s=H.n_steps) | |
sigma_f = self.get_density(sigma_raw_f, fg_nerf, training=H.training) | |
feat = torch.cat([feat_f, feat], 2) | |
sigma = torch.cat([sigma_f, sigma], 2) | |
sigma_raw = torch.cat([sigma_raw_f, sigma_raw], 2) | |
di = torch.cat([di_fine, di], 2) | |
di_trs = torch.cat([di_trs_fine, di_trs], 2) | |
di, indices = torch.sort(di, dim=2) | |
di_trs = torch.gather(di_trs, 2, indices) | |
sigma = torch.gather(sigma, 2, indices) | |
sigma_raw = torch.gather(sigma_raw, 2, indices) | |
feat = torch.gather(feat, 2, repeat(indices, 'b n s -> b n s d', d=feat.size(-1))) | |
fg_weights, bg_lambda = self.C.calc_volume_weights( | |
sigma, di if self.C.dists_normalized else di_trs, # use real dists for computing weights, | |
ray_vector, last_dist=0 if not H.fg_inf_depth else 1e10)[:2] | |
fg_feat = torch.sum(fg_weights.unsqueeze(-1) * feat, dim=-2) | |
output.feat += [fg_feat] | |
output.full_out += [feat] | |
output.fg_weights = fg_weights | |
output.bg_lambda = bg_lambda | |
output.fg_depths = (di, di_trs) | |
return output | |
def forward_rendering_background(self, H, output, bg_nerf, nerf_input_cams, latent_codes, styles_bg): | |
pixels_world, camera_world, _ = nerf_input_cams | |
z_shape_bg, z_app_bg = latent_codes[2:] | |
height, width = dividable(H.n_points) | |
bg_shape = [H.batch_size, height, width, H.n_bg_steps] | |
if H.fixed_input_cams is not None: | |
pixels_world, camera_world, _ = H.fixed_input_cams | |
# render background, use NeRF++ inverse sphere parameterization | |
di = torch.linspace(-1., 0., steps=H.n_bg_steps).to(H.device) | |
di = repeat(di, 's -> b n s', b=H.batch_size, n=H.n_points) * self.C.bg_start | |
if (H.training and (not H.get('disable_noise', False))) or H.get('force_noise', False): | |
di = self.C.add_noise_to_interval(di) | |
p_bg, r_bg = self.C.get_evaluation_points_bg(pixels_world, camera_world, -di) | |
feat, sigma_raw = bg_nerf(p_bg, r_bg, z_shape_bg, z_app_bg, ws=styles_bg, shape=bg_shape) | |
feat = rearrange(feat, 'b (n s) d -> b n s d', s=H.n_bg_steps) | |
sigma_raw = rearrange(sigma_raw.squeeze(-1), 'b (n s) -> b n s', s=H.n_bg_steps) | |
sigma = self.get_density(sigma_raw, bg_nerf, training=H.training) | |
bg_weights = self.C.calc_volume_weights(sigma, di, None)[0] | |
bg_feat = torch.sum(bg_weights.unsqueeze(-1) * feat, dim=-2) | |
if output.get('bg_lambda', None) is not None: | |
bg_feat = output.bg_lambda.unsqueeze(-1) * bg_feat | |
output.feat += [bg_feat] | |
output.full_out += [feat] | |
output.bg_weights = bg_weights | |
output.bg_depths = di | |
return output | |
def forward_volume_rendering( | |
self, | |
nerf_modules, # (fg_nerf, bg_nerf) | |
camera_matrices, # camera (K, RT) | |
vol_pixels, | |
nerf_input_feats = None, | |
latent_codes = None, | |
styles = None, | |
styles_bg = None, | |
not_render_background = False, | |
only_render_background = False, | |
render_option = None, | |
return_full = False, | |
alpha = 0, | |
**unused): | |
assert (latent_codes is not None) or (styles is not None) | |
assert self.no_background or (nerf_input_feats is None), "input features do not support background field" | |
# hyper-parameters for rendering | |
H = EasyDict(**unused) | |
output = EasyDict() | |
output.reg_loss = EasyDict() | |
output.feat = [] | |
output.full_out = [] | |
if render_option is None: | |
render_option = "" | |
H.render_option = render_option | |
H.alpha = alpha | |
# prepare for rendering (parameters) | |
fg_nerf, bg_nerf = nerf_modules | |
H.training = fg_nerf.training | |
H.device = camera_matrices[0].device | |
H.batch_size = camera_matrices[0].shape[0] | |
H.img_channels = fg_nerf.img_channels | |
H.n_steps = self.n_ray_samples | |
H.n_bg_steps = self.n_bg_samples | |
if alpha == -1: | |
H.n_steps = 20 # just for memory safe. | |
if "steps" in render_option: | |
H.n_steps = [int(r.split(':')[1]) for r in H.render_option.split(',') if r[:5] == 'steps'][0] | |
# prepare for pixels for generating images | |
if isinstance(vol_pixels, tuple): | |
vol_pixels, rand_pixels = vol_pixels | |
pixels = torch.cat([vol_pixels, rand_pixels], 1) | |
H.rnd_res = int(math.sqrt(rand_pixels.size(1))) | |
else: | |
pixels, rand_pixels, H.rnd_res = vol_pixels, None, None | |
H.tgt_res, H.n_points = int(math.sqrt(vol_pixels.size(1))), pixels.size(1) | |
nerf_input_cams = self.C.get_origin_direction(pixels, camera_matrices) | |
# set up an frozen camera for background if necessary | |
if ('freeze_bg' in H.render_option) and (bg_nerf is not None): | |
pitch, yaw = 0.2 + np.pi/2, 0 | |
range_u, range_v = self.C.range_u, self.C.range_v | |
u = (yaw - range_u[0]) / (range_u[1] - range_u[0]) | |
v = (pitch - range_v[0]) / (range_v[1] - range_v[0]) | |
fixed_camera = self.C.get_camera( | |
batch_size=H.batch_size, mode=[u, v, 0.5], device=H.device) | |
H.fixed_input_cams = self.C.get_origin_direction(pixels, fixed_camera) | |
else: | |
H.fixed_input_cams = None | |
H.fg_inf_depth = (self.no_background or not_render_background) and (not self.white_background) | |
assert(not (not_render_background and only_render_background)) | |
# volume rendering options: bg_weights, bg_lambda = None, None | |
if (nerf_input_feats is not None) and \ | |
len(nerf_input_feats) == 4 and \ | |
nerf_input_feats[2] == 'tri_vector' and \ | |
self.I.sigma_dim > 0 and H.fg_inf_depth: | |
# volume rendering with pre-computed density similar to tensor-decomposition | |
output = self.forward_rendering_with_pre_density( | |
H, output, fg_nerf, nerf_input_cams, nerf_input_feats, latent_codes, styles) | |
else: | |
# standard volume rendering | |
if not only_render_background: | |
output = self.forward_rendering( | |
H, output, fg_nerf, nerf_input_cams, nerf_input_feats, latent_codes, styles) | |
# background rendering (NeRF++) | |
if (not not_render_background) and (not self.no_background): | |
output = self.forward_rendering_background( | |
H, output, bg_nerf, nerf_input_cams, latent_codes, styles_bg) | |
if ('early' in render_option) and ('value' not in render_option): | |
return self.gen_optional_output( | |
H, fg_nerf, nerf_input_cams, nerf_input_feats, latent_codes, styles, output) | |
# ------------------------------------------- PREPARE FULL OUTPUT (NO 2D aggregation) -------------------------------------------- # | |
vol_len = vol_pixels.size(1) | |
feat_map = sum(output.feat) | |
full_x = rearrange(feat_map[:, :vol_len], 'b (h w) d -> b d h w', h=H.tgt_res) | |
split_rgb = fg_nerf.add_rgb or fg_nerf.predict_rgb | |
full_out = self.split_feat(full_x, H.img_channels, None, split_rgb=split_rgb) | |
if rand_pixels is not None: # used in full supervision (debug later) | |
if return_full: | |
assert (fg_nerf.predict_rgb or fg_nerf.add_rgb) | |
rand_outputs = [f[:,vol_pixels.size(1):] for f in output.full_out] | |
full_weights = torch.cat([output.fg_weights, output.bg_weights * output.bg_lambda.unsqueeze(-1)], -1) \ | |
if output.get('bg_weights', None) is not None else output.fg_weights | |
full_weights = full_weights[:,vol_pixels.size(1):] | |
full_weights = rearrange(full_weights, 'b (h w) s -> b s h w', h=H.rnd_res, w=H.rnd_res) | |
lh, lw = dividable(full_weights.size(1)) | |
full_x = rearrange(torch.cat(rand_outputs, 2), 'b (h w) (l m) d -> b d (l h) (m w)', | |
h=H.rnd_res, w=H.rnd_res, l=lh, m=lw) | |
full_x, full_img = self.split_feat(full_x, H.img_channels, split_rgb=split_rgb) | |
output.rand_out = (full_x, full_img, full_weights) | |
else: | |
rand_x = rearrange(feat_map[:, vol_len:], 'b (h w) d -> b d h w', h=H.rnd_res) | |
output.rand_out = self.split_feat(rand_x, H.img_channels, split_rgb=split_rgb) | |
output.full_out = full_out | |
return output | |
def post_process_outputs(self, outputs, freeze_nerf=False): | |
if freeze_nerf: | |
outputs = [x.detach() if isinstance(x, torch.Tensor) else x for x in outputs] | |
x, img = outputs[0], outputs[1] | |
probs = outputs[2] if len(outputs) == 3 else None | |
return x, img, probs | |
def gen_optional_output(self, H, fg_nerf, nerf_input_cams, nerf_input_feats, latent_codes, styles, output): | |
_, camera_world, ray_vector = nerf_input_cams | |
z_shape_obj, z_app_obj = latent_codes[:2] | |
fg_depth_map = torch.sum(output.fg_weights * output.fg_depths[1], dim=-1, keepdim=True) | |
img = camera_world[:, :1] + fg_depth_map * ray_vector | |
img = img.permute(0,2,1).reshape(-1, 3, H.tgt_res, H.tgt_res) | |
if 'input_feats' in H.render_option: | |
a, b = [r.split(':')[1:] for r in H.render_option.split(',') if r.startswith('input_feats')][0] | |
a, b = int(a), int(b) | |
if nerf_input_feats[0] == 'volume': | |
img = nerf_input_feats[1][:,a:a+3,b,:,:] | |
elif nerf_input_feats[0] == 'tri_plane': | |
img = nerf_input_feats[1][:,b,a:a+3,:,:] | |
elif nerf_input_feats[0] == 'hash_table': | |
assert self.I.hash_mode == 'grid_hash' | |
img = nerf_input_feats[1][:,self.I.offsets[b]:self.I.offsets[b+1], :] | |
siz = int(np.ceil(img.size(1)**(1/3))) | |
img = rearrange(img, 'b (d h w) c -> b (d c) h w', h=siz, w=siz, d=siz) | |
img = img[:, a:a+3] | |
else: | |
raise NotImplementedError | |
if 'normal' in H.render_option.split(','): | |
shift_l, shift_r = img[:,:,2:,:], img[:,:,:-2,:] | |
shift_u, shift_d = img[:,:,:,2:], img[:,:,:,:-2] | |
diff_hor = normalize(shift_r - shift_l, axis=1)[0][:, :, :, 1:-1] | |
diff_ver = normalize(shift_u - shift_d, axis=1)[0][:, :, 1:-1, :] | |
normal = torch.cross(diff_hor, diff_ver, dim=1) | |
img = normalize(normal, axis=1)[0] | |
if 'gradient' in H.render_option.split(','): | |
points, _ = self.C.get_evaluation_points(camera_world + ray_vector, camera_world, output.fg_depths[1]) | |
fg_shape = [H.batch_size, H.tgt_res, H.tgt_res, output.fg_depths[1].size(-1)] | |
with torch.enable_grad(): | |
points.requires_grad_(True) | |
inputs = self.I.query_input_features(points, nerf_input_feats, fg_shape, self.get_bound(), True) \ | |
if nerf_input_feats is not None else points | |
if (nerf_input_feats is not None) and len(nerf_input_feats) == 4 and nerf_input_feats[2] == 'tri_vector' and (self.I.sigma_dim > 0): | |
sigma_out = inputs[..., :8].sum(dim=-1, keepdim=True) | |
else: | |
_, sigma_out = fg_nerf(inputs, None, ws=styles, shape=fg_shape, z_shape=z_shape_obj, z_app=z_app_obj, requires_grad=True) | |
gradients, = grad( | |
outputs=sigma_out, inputs=points, | |
grad_outputs=torch.ones_like(sigma_out, requires_grad=False), | |
retain_graph=True, create_graph=True, only_inputs=True) | |
gradients = rearrange(gradients, 'b (n s) d -> b n s d', s=output.fg_depths[1].size(-1)) | |
avg_grads = (gradients * output.fg_weights.unsqueeze(-1)).sum(-2) | |
avg_grads = F.normalize(avg_grads, p=2, dim=-1) | |
normal = rearrange(avg_grads, 'b (h w) s -> b s h w', h=H.tgt_res, w=H.tgt_res) | |
img = -normal | |
return {'full_out': (None, img)} | |
class Upsampler(object): | |
no_2d_renderer = False | |
no_residual_img = False | |
block_reses = None | |
shared_rgb_style = False | |
upsample_type = 'default' | |
img_channels = 3 | |
in_res = 32 | |
out_res = 512 | |
channel_base = 1 | |
channel_base_sz = None | |
channel_max = 512 | |
channel_dict = None | |
out_channel_dict = None | |
def __init__(self, upsampler_kwargs, **other_kwargs): | |
# for compitatbility of old checkpoints | |
for key in other_kwargs: | |
if hasattr(self, key) and (key not in upsampler_kwargs): | |
upsampler_kwargs[key] = other_kwargs[key] | |
for key in upsampler_kwargs: | |
if hasattr(self, key): | |
setattr(self, key, upsampler_kwargs[key]) | |
self.out_res_log2 = int(np.log2(self.out_res)) | |
# set up upsamplers | |
if self.block_reses is None: | |
self.block_resolutions = [2 ** i for i in range(2, self.out_res_log2 + 1)] | |
self.block_resolutions = [b for b in self.block_resolutions if b > self.in_res] | |
else: | |
self.block_resolutions = self.block_reses | |
if self.no_2d_renderer: | |
self.block_resolutions = [] | |
def build_network(self, w_dim, input_dim, **block_kwargs): | |
upsamplers = [] | |
if len(self.block_resolutions) > 0: # nerf resolution smaller than image | |
channel_base = int(self.channel_base * 32768) if self.channel_base_sz is None else self.channel_base_sz | |
fp16_resolution = self.block_resolutions[0] * 2 # do not use fp16 for the first block | |
if self.channel_dict is None: | |
channels_dict = {res: min(channel_base // res, self.channel_max) for res in self.block_resolutions} | |
else: | |
channels_dict = self.channel_dict | |
if self.out_channel_dict is not None: | |
img_channels = self.out_channel_dict | |
else: | |
img_channels = {res: self.img_channels for res in self.block_resolutions} | |
for ir, res in enumerate(self.block_resolutions): | |
res_before = self.block_resolutions[ir-1] if ir > 0 else self.in_res | |
in_channels = channels_dict[res_before] if ir > 0 else input_dim | |
out_channels = channels_dict[res] | |
use_fp16 = (res >= fp16_resolution) # TRY False | |
is_last = (ir == (len(self.block_resolutions) - 1)) | |
no_upsample = (res == res_before) | |
block = util.construct_class_by_name( | |
class_name=block_kwargs.get('block_name', "training.networks.SynthesisBlock"), | |
in_channels=in_channels, | |
out_channels=out_channels, | |
w_dim=w_dim, | |
resolution=res, | |
img_channels=img_channels[res], | |
is_last=is_last, | |
use_fp16=use_fp16, | |
disable_upsample=no_upsample, | |
block_id=ir, | |
**block_kwargs) | |
upsamplers += [{ | |
'block': block, | |
'num_ws': block.num_conv if not is_last else block.num_conv + block.num_torgb, | |
'name': f'b{res}' if res_before != res else f'b{res}_l{ir}' | |
}] | |
self.num_ws = sum([u['num_ws'] for u in upsamplers]) | |
return upsamplers | |
def forward_ws_split(self, ws, blocks): | |
block_ws, w_idx = [], 0 | |
for ir, res in enumerate(self.block_resolutions): | |
block = blocks[ir] | |
if self.shared_rgb_style: | |
w = ws.narrow(1, w_idx, block.num_conv) | |
w_img = ws.narrow(1, -block.num_torgb, block.num_torgb) # TODO: tRGB to use the same style (?) | |
block_ws.append(torch.cat([w, w_img], 1)) | |
else: | |
block_ws.append(ws.narrow(1, w_idx, block.num_conv + block.num_torgb)) | |
w_idx += block.num_conv | |
return block_ws | |
def forward_network(self, blocks, block_ws, x, img, target_res, alpha, skip_up=False, **block_kwargs): | |
imgs = [] | |
for index_l, (res, cur_ws) in enumerate(zip(self.block_resolutions, block_ws)): | |
if res > target_res: | |
break | |
block = blocks[index_l] | |
block_noise = block_kwargs['voxel_noise'][index_l] if "voxel_noise" in block_kwargs else None | |
x, img = block( | |
x, | |
img if not self.no_residual_img else None, | |
cur_ws, | |
block_noise=block_noise, | |
skip_up=skip_up, | |
**block_kwargs) | |
imgs += [img] | |
return imgs | |
class NeRFInput(Upsampler): | |
""" Instead of positional encoding, it learns additional features for each points. | |
However, it is important to normalize the input points | |
""" | |
output_mode = 'none' | |
input_mode = 'random' # coordinates | |
architecture = 'skip' | |
# only useful for triplane/volume inputs | |
in_res = 4 | |
out_res = 256 | |
out_dim = 32 | |
sigma_dim = 8 | |
split_size = 64 | |
# only useful for hashtable inputs | |
hash_n_min = 16 | |
hash_n_max = 512 | |
hash_size = 16 | |
hash_level = 16 | |
hash_dim_in = 32 | |
hash_dim_mid = None | |
hash_dim_out = 2 | |
hash_n_layer = 4 | |
hash_mode = 'fast_hash' # grid_hash (like volumes) | |
keep_posenc = -1 | |
keep_nerf_latents = False | |
def build_network(self, w_dim, **block_kwargs): | |
# change global settings for input field. | |
kwargs_copy = copy.deepcopy(block_kwargs) | |
kwargs_copy['kernel_size'] = 3 | |
kwargs_copy['upsample_mode'] = 'default' | |
kwargs_copy['use_noise'] = True | |
kwargs_copy['architecture'] = self.architecture | |
self._flag = 0 | |
assert self.input_mode == 'random', \ | |
"currently only support normal StyleGAN2. in the future we may work on other inputs." | |
# plane-based inputs with modulated 2D convolutions | |
if self.output_mode == 'tri_plane_reshape': | |
self.img_channels, in_channels, const = 3 * self.out_dim, 0, None | |
elif self.output_mode == 'tri_plane_product': #TODO: sigma_dim is for density | |
self.img_channels, in_channels = 3 * (self.out_dim + self.sigma_dim), 0 | |
const = torch.nn.Parameter(0.1 * torch.randn([self.img_channels, self.out_res])) | |
elif self.output_mode == 'multi_planes': | |
self.img_channels, in_channels, const = self.out_dim * self.split_size, 0, None | |
kwargs_copy['architecture'] = 'orig' | |
# volume-based inputs with modulated 3D convolutions | |
elif self.output_mode == '3d_volume': # use 3D convolution to generate | |
kwargs_copy['architecture'] = 'orig' | |
kwargs_copy['mode'] = '3d' | |
self.img_channels, in_channels, const = self.out_dim, 0, None | |
elif self.output_mode == 'ms_volume': # multi-resolution voulume, between hashtable and volumes | |
kwargs_copy['architecture'] = 'orig' | |
kwargs_copy['mode'] = '3d' | |
self.img_channels, in_channels, const = self.out_dim, 0, None | |
# embedding-based inputs with modulated MLPs | |
elif self.output_mode == 'hash_table': | |
if self.hash_mode == 'grid_hash': | |
assert self.hash_size % 3 == 0, "needs to be 3D" | |
kwargs_copy['hash_size'], self._flag = 2 ** self.hash_size, 1 | |
assert self.hash_dim_out * self.hash_level == self.out_dim, "size must matched" | |
return self.build_modulated_embedding(w_dim, **kwargs_copy) | |
elif self.output_mode == 'ms_nerf_hash': | |
self.hash_mode, self._flag = 'grid_hash', 2 | |
ms_nerf = NeRFBlock({ | |
'rgb_out_dim': self.hash_dim_out * self.hash_level, # HACK | |
'magnitude_ema_beta': block_kwargs['magnitude_ema_beta'], | |
'no_sigma': True, 'predict_rgb': False, 'add_rgb': False, | |
'n_freq_posenc': 5, | |
}) | |
self.num_ws = ms_nerf.num_ws | |
return [{'block': ms_nerf, 'num_ws': ms_nerf.num_ws, 'name': 'ms_nerf'}] | |
else: | |
raise NotImplementedError | |
networks = super().build_network(w_dim, in_channels, **kwargs_copy) | |
if const is not None: | |
networks.append({'block': const, 'num_ws': 0, 'name': 'const'}) | |
return networks | |
def forward_ws_split(self, ws, blocks): | |
if self._flag == 1: | |
return ws.split(1, dim=1)[:len(blocks)-1] | |
elif self._flag == 0: | |
return super().forward_ws_split(ws, blocks) | |
else: | |
return ws # do not split | |
def forward_network(self, blocks, block_ws, batch_size, **block_kwargs): | |
x, img, out = None, None, None | |
def _forward_conv_networks(x, img, blocks, block_ws): | |
for index_l, (res, cur_ws) in enumerate(zip(self.block_resolutions, block_ws)): | |
x, img = blocks[index_l](x, img, cur_ws, **block_kwargs) | |
return img | |
def _forward_ffn_networks(x, blocks, block_ws): | |
#TODO: FFN is implemented as 1x1 conv for now # | |
h, w = dividable(x.size(0)) | |
x = repeat(x, 'n d -> b n d', b=batch_size) | |
x = rearrange(x, 'b (h w) d -> b d h w', h=h, w=w) | |
for index_l, cur_ws in enumerate(block_ws): | |
block, cur_ws = blocks[index_l], cur_ws[:, 0] | |
x = block(x, cur_ws) | |
return x | |
# tri-plane outputs | |
if 'tri_plane' in self.output_mode: | |
img = _forward_conv_networks(x, img, blocks, block_ws) | |
if self.output_mode == 'tri_plane_reshape': | |
out = ('tri_plane', rearrange(img, 'b (s c) h w -> b s c h w', s=3)) | |
elif self.output_mode == 'tri_plane_product': | |
out = ('tri_plane', rearrange(img, 'b (s c) h w -> b s c h w', s=3), | |
'tri_vector', repeat(rearrange(blocks[-1], '(s c) d -> s c d', s=3), 's c d -> b s c d', b=img.size(0))) | |
else: | |
raise NotImplementedError("remove support for other types of tri-plane implementation.") | |
# volume/3d voxel outputs | |
elif self.output_mode == 'multi_planes': | |
img = _forward_conv_networks(x, img, blocks, block_ws) | |
out = ('volume', rearrange(img, 'b (s c) h w -> b s c h w', s=self.out_dim)) | |
elif self.output_mode == '3d_volume': | |
img = _forward_conv_networks(x, img, blocks, block_ws) | |
out = ('volume', img) | |
# multi-resolution 3d volume outputs (similar to hash-table) | |
elif self.output_mode == 'ms_volume': | |
img = _forward_conv_networks(x, img, blocks, block_ws) | |
out = ('ms_volume', rearrange(img, 'b (l m) d h w -> b l m d h w', l=self.hash_level)) | |
# hash-table outputs (need hash sample implemented #TODO# | |
elif self.output_mode == 'hash_table': | |
x, blocks = blocks[-1], blocks[:-1] | |
if len(blocks) > 0: | |
x = _forward_ffn_networks(x, blocks, block_ws) | |
out = ('hash_table', rearrange(x, 'b d h w -> b (h w) d')) | |
else: | |
out = ('hash_table', repeat(x, 'n d -> b n d', b=batch_size)) | |
elif self.output_mode == 'ms_nerf_hash': | |
# prepare inputs for nerf | |
x = torch.linspace(-1, 1, steps=self.out_res, device=block_ws.device) | |
x = torch.stack(torch.meshgrid(x,x,x), -1).reshape(-1, 3) | |
x = repeat(x, 'n s -> b n s', b=block_ws.size(0)) | |
x = blocks[0](x, None, ws=block_ws, shape=[block_ws.size(0), 32, 32, 32])[0] | |
x = rearrange(x, 'b (d h w) (l m) -> b l m d h w', l=self.hash_level, d=32, h=32, w=32) | |
out = ('ms_volume', x) | |
else: | |
raise NotImplementedError | |
return out | |
def query_input_features(self, p_i, input_feats, p_shape, bound, grad_inputs=False): | |
batch_size, height, width, n_steps = p_shape | |
p_i = p_i / bound | |
if input_feats[0] == 'tri_plane': | |
# TODO!! Our world space, x->depth, y->width, z->height | |
lh, lw = dividable(n_steps) | |
p_ds = rearrange(p_i, 'b (h w l m) d -> b (l h) (m w) d', | |
b=batch_size, h=height, w=width, l=lh, m=lw).split(1, dim=-1) | |
px, py, pz = p_ds[0], p_ds[1], p_ds[2] | |
# project points onto three planes | |
p_xy = torch.cat([px, py], -1) | |
p_xz = torch.cat([px, pz], -1) | |
p_yz = torch.cat([py, pz], -1) | |
p_gs = torch.cat([p_xy, p_xz, p_yz], 0) | |
f_in = torch.cat([input_feats[1][:, i] for i in range(3)], 0) | |
p_f = grid_sample(f_in, p_gs) # gradient-fix bilinear interpolation | |
p_f = [p_f[i * batch_size: (i+1) * batch_size] for i in range(3)] | |
# project points to three vectors (optional) | |
if len(input_feats) == 4 and input_feats[2] == 'tri_vector': | |
# TODO: PyTorch did not support grid_sample for 1D data. Maybe need custom code. | |
p_gs_vec = torch.cat([pz, py, px], 0) | |
f_in_vec = torch.cat([input_feats[3][:, i] for i in range(3)], 0) | |
p_f_vec = grid_sample(f_in_vec.unsqueeze(-1), torch.cat([torch.zeros_like(p_gs_vec), p_gs_vec], -1)) | |
p_f_vec = [p_f_vec[i * batch_size: (i+1) * batch_size] for i in range(3)] | |
# multiply on the triplane features | |
p_f = [m * v for m, v in zip(p_f, p_f_vec)] | |
p_f = sum(p_f) | |
p_f = rearrange(p_f, 'b d (l h) (m w) -> b (h w l m) d', l=lh, m=lw) | |
elif input_feats[0] == 'volume': | |
# TODO!! Our world space, x->depth, y->width, z->height | |
# (width-c, height-c, depth-c), volume (B x N x D x H x W) | |
p_ds = rearrange(p_i, 'b (h w s) d -> b s h w d', | |
b=batch_size, h=height, w=width, s=n_steps).split(1, dim=-1) | |
px, py, pz = p_ds[0], p_ds[1], p_ds[2] | |
p_yzx = torch.cat([py, -pz, px], -1) | |
p_f = F.grid_sample(input_feats[1], p_yzx, mode='bilinear', align_corners=False) | |
p_f = rearrange(p_f, 'b c s h w -> b (h w s) c') | |
elif input_feats[0] == 'ms_volume': | |
# TODO!! Multi-resolution volumes (experimental) | |
# for smoothness, maybe we should expand the volume? (TODO) | |
# print(p_i.shape) | |
ms_v = input_feats[1].new_zeros( | |
batch_size, self.hash_level, self.hash_dim_out, self.out_res+1, self.out_res+1, self.out_res+1) | |
ms_v[..., 1:, 1:, 1:] = input_feats[1].flip([3,4,5]) | |
ms_v[..., :self.out_res, :self.out_res, :self.out_res] = input_feats[1] | |
v_size = ms_v.size(-1) | |
# multi-resolutions | |
b = math.exp((math.log(self.hash_n_max) - math.log(self.hash_n_min))/(self.hash_level-1)) | |
hash_res_ls = [round(self.hash_n_min * b ** l) for l in range(self.hash_level)] | |
# prepare interpolate grids | |
p_ds = rearrange(p_i, 'b (h w s) d -> b s h w d', | |
b=batch_size, h=height, w=width, s=n_steps).split(1, dim=-1) | |
px, py, pz = p_ds[0], p_ds[1], p_ds[2] | |
p_yzx = torch.cat([py, -pz, px], -1) | |
p_yzx = ((p_yzx + 1) / 2).clamp(min=0, max=1) # normalize to 0~1 (just for safe) | |
p_yzx = torch.stack([p_yzx if n < v_size else torch.fmod(p_yzx * n, v_size) / v_size for n in hash_res_ls], 1) | |
p_yzx = (p_yzx * 2 - 1).view(-1, n_steps, height, width, 3) | |
ms_v = ms_v.view(-1, self.hash_dim_out, v_size, v_size, v_size) # back to -1~1 | |
p_f = F.grid_sample(ms_v, p_yzx, mode='bilinear', align_corners=False) | |
p_f = rearrange(p_f, '(b l) c s h w -> b (h w s) (l c)', l=self.hash_level) | |
elif input_feats[0] == 'hash_table': | |
# TODO:!! Experimental code trying to learn hashtable used in (maybe buggy) | |
# https://nvlabs.github.io/instant-ngp/assets/mueller2022instant.pdf | |
p_xyz = ((p_i + 1) / 2).clamp(min=0, max=1) # normalize to 0~1 | |
p_f = hash_sample( | |
p_xyz, input_feats[1], self.offsets.to(p_xyz.device), | |
self.beta, self.hash_n_min, grad_inputs, mode=self.hash_mode) | |
else: | |
raise NotImplementedError | |
if self.keep_posenc > -1: | |
if self.keep_posenc > 0: | |
p_f = torch.cat([p_f, positional_encoding(p_i, self.keep_posenc, use_pos=True)], -1) | |
else: | |
p_f = torch.cat([p_f, p_i], -1) | |
return p_f | |
def build_hashtable_info(self, hash_size): | |
self.beta = math.exp((math.log(self.hash_n_max) - math.log(self.hash_n_min)) / (self.hash_level-1)) | |
self.hash_res_ls = [round(self.hash_n_min * self.beta ** l) for l in range(self.hash_level)] | |
offsets, offset = [], 0 | |
for i in range(self.hash_level): | |
resolution = self.hash_res_ls[i] | |
params_in_level = min(hash_size, (resolution + 1) ** 3) | |
offsets.append(offset) | |
offset += params_in_level | |
offsets.append(offset) | |
self.offsets = torch.from_numpy(np.array(offsets, dtype=np.int32)) | |
return offset | |
def build_modulated_embedding(self, w_dim, hash_size, **block_kwargs): | |
# allocate parameters | |
offset = self.build_hashtable_info(hash_size) | |
hash_const = torch.nn.Parameter(torch.zeros( | |
[offset, self.hash_dim_in if self.hash_n_layer > -1 else self.hash_dim_out])) | |
hash_const.data.uniform_(-1e-4, 1e-4) | |
hash_networks = [] | |
if self.hash_n_layer > -1: | |
input_dim = self.hash_dim_in | |
for l in range(self.hash_n_layer): | |
output_dim = self.hash_dim_mid if self.hash_dim_mid is not None else self.hash_dim_in | |
hash_networks.append({ | |
'block': Style2Layer(input_dim, output_dim, w_dim), | |
'num_ws': 1, 'name': f'hmlp{l}' | |
}) | |
input_dim = output_dim | |
hash_networks.append({ | |
'block': ToRGBLayer(input_dim, self.hash_dim_out, w_dim, kernel_size=1), | |
'num_ws': 1, 'name': 'hmlpout'}) | |
hash_networks.append({'block': hash_const, 'num_ws': 0, 'name': 'hash_const'}) | |
self.num_ws = sum([h['num_ws'] for h in hash_networks]) | |
return hash_networks | |
class NeRFSynthesisNetwork(torch.nn.Module): | |
def __init__(self, | |
w_dim, # Intermediate latent (W) dimensionality. | |
img_resolution, # Output image resolution. | |
img_channels, # Number of color channels. | |
channel_base = 1, | |
channel_max = 1024, | |
# module settings | |
camera_kwargs = {}, | |
renderer_kwargs = {}, | |
upsampler_kwargs = {}, | |
input_kwargs = {}, | |
foreground_kwargs = {}, | |
background_kwargs = {}, | |
# nerf space settings | |
z_dim = 256, | |
z_dim_bg = 128, | |
rgb_out_dim = 256, | |
rgb_out_dim_bg = None, | |
resolution_vol = 32, | |
resolution_start = None, | |
progressive = True, | |
prog_nerf_only = False, | |
interp_steps = None, # (optional) "start_step:final_step" | |
# others (regularization) | |
regularization = [], # nv_beta, nv_vol | |
predict_camera = False, | |
camera_condition = None, | |
n_reg_samples = 0, | |
reg_full = False, | |
cam_based_sampler = False, | |
rectangular = None, | |
freeze_nerf = False, | |
**block_kwargs, # Other arguments for SynthesisBlock. | |
): | |
assert img_resolution >= 4 and img_resolution & (img_resolution - 1) == 0 | |
super().__init__() | |
# dimensions | |
self.w_dim = w_dim | |
self.z_dim = z_dim | |
self.z_dim_bg = z_dim_bg | |
self.num_ws = 0 | |
self.rgb_out_dim = rgb_out_dim | |
self.rgb_out_dim_bg = rgb_out_dim_bg if rgb_out_dim_bg is not None else rgb_out_dim | |
self.img_resolution = img_resolution | |
self.resolution_vol = resolution_vol if resolution_vol < img_resolution else img_resolution | |
self.resolution_start = resolution_start if resolution_start is not None else resolution_vol | |
self.img_resolution_log2 = int(np.log2(img_resolution)) | |
self.img_channels = img_channels | |
# number of samples | |
self.n_reg_samples = n_reg_samples | |
self.reg_full = reg_full | |
self.use_noise = block_kwargs.get('use_noise', False) | |
# ---------------------------------- Initialize Modules ---------------------------------------- -# | |
# camera module | |
self.C = CameraRay(camera_kwargs, **block_kwargs) | |
# input encoding module | |
if (len(input_kwargs) > 0) and (input_kwargs['output_mode'] != 'none'): # using synthezied inputs | |
input_kwargs['channel_base'] = input_kwargs.get('channel_base', channel_base) | |
input_kwargs['channel_max'] = input_kwargs.get('channel_max', channel_max) | |
self.I = NeRFInput(input_kwargs, **block_kwargs) | |
else: | |
self.I = None | |
# volume renderer module | |
self.V = VolumeRenderer(renderer_kwargs, camera_ray=self.C, input_encoding=self.I, **block_kwargs) | |
# upsampler module | |
upsampler_kwargs.update(dict( | |
img_channels=img_channels, | |
in_res=resolution_vol, | |
out_res=img_resolution, | |
channel_max=channel_max, | |
channel_base=channel_base)) | |
self.U = Upsampler(upsampler_kwargs, **block_kwargs) | |
# full model resolutions | |
self.block_resolutions = copy.deepcopy(self.U.block_resolutions) | |
if self.resolution_start < self.resolution_vol: | |
r = self.resolution_vol | |
while r > self.resolution_start: | |
self.block_resolutions.insert(0, r) | |
r = r // 2 | |
self.predict_camera = predict_camera | |
if predict_camera: # encoder side camera predictor (not very useful) | |
self.camera_generator = CameraGenerator() | |
self.camera_condition = camera_condition | |
if self.camera_condition is not None: # style vector modulated by the camera poses (uv) | |
self.camera_map = MappingNetwork(z_dim=0, c_dim=16, w_dim=self.w_dim, num_ws=None, w_avg_beta=None, num_layers=2) | |
# ray level choices | |
self.regularization = regularization | |
self.margin = block_kwargs.get('margin', 0) | |
self.activation = block_kwargs.get('activation', 'lrelu') | |
self.rectangular_crop = rectangular # [384, 512] ?? | |
# nerf (foregournd/background) | |
foreground_kwargs.update(dict( | |
z_dim=self.z_dim, | |
w_dim=w_dim, | |
rgb_out_dim=self.rgb_out_dim, | |
activation=self.activation)) | |
# disable positional encoding if input encoding is given | |
if self.I is not None: | |
foreground_kwargs.update(dict( | |
disable_latents=(not self.I.keep_nerf_latents), | |
input_dim=self.I.out_dim + 3 * (2 * self.I.keep_posenc + 1) | |
if self.I.keep_posenc > -1 else self.I.out_dim, | |
positional_encoding='none')) | |
self.fg_nerf = NeRFBlock(foreground_kwargs) | |
self.num_ws += self.fg_nerf.num_ws | |
if not self.V.no_background: | |
background_kwargs.update(dict( | |
z_dim=self.z_dim_bg, w_dim=w_dim, | |
rgb_out_dim=self.rgb_out_dim_bg, | |
activation=self.activation)) | |
self.bg_nerf = NeRFBlock(background_kwargs) | |
self.num_ws += self.bg_nerf.num_ws | |
else: | |
self.bg_nerf = None | |
# ---------------------------------- Build Networks ---------------------------------------- -# | |
# input encoding (optional) | |
if self.I is not None: | |
assert self.V.no_background, "does not support background field" | |
nerf_inputs = self.I.build_network(w_dim, **block_kwargs) | |
self.input_block_names = ['in_' + i['name'] for i in nerf_inputs] | |
self.num_ws += sum([i['num_ws'] for i in nerf_inputs]) | |
for i in nerf_inputs: | |
setattr(self, 'in_' + i['name'], i['block']) | |
# upsampler | |
upsamplers = self.U.build_network(w_dim, self.fg_nerf.rgb_out_dim, **block_kwargs) | |
if len(upsamplers) > 0: | |
self.block_names = [u['name'] for u in upsamplers] | |
self.num_ws += sum([u['num_ws'] for u in upsamplers]) | |
for u in upsamplers: | |
setattr(self, u['name'], u['block']) | |
# data-sampler | |
if cam_based_sampler: | |
self.sampler = (CameraQueriedSampler, {'camera_module': self.C}) | |
# other hyperameters | |
self.progressive_growing = progressive | |
self.progressive_nerf_only = prog_nerf_only | |
assert not (self.progressive_growing and self.progressive_nerf_only) | |
if prog_nerf_only: | |
assert (self.n_reg_samples == 0) and (not reg_full), "does not support regularization" | |
self.register_buffer("alpha", torch.scalar_tensor(-1)) | |
if predict_camera: | |
self.num_ws += 1 # additional w for camera | |
self.freeze_nerf = freeze_nerf | |
self.steps = None | |
self.interp_steps = [int(a) for a in interp_steps.split(':')] \ | |
if interp_steps is not None else None #TODO two-stage training trick (from EG3d paper, not working so far) | |
def set_alpha(self, alpha): | |
if alpha is not None: | |
self.alpha.fill_(alpha) | |
def set_steps(self, steps): | |
if hasattr(self, "steps"): | |
if self.steps is not None: | |
self.steps = self.steps * 0 + steps / 1000.0 | |
else: | |
self.steps = steps / 1000.0 | |
def forward(self, ws, **block_kwargs): | |
block_ws, imgs, rand_imgs = [], [], [] | |
batch_size = block_kwargs['batch_size'] = ws.size(0) | |
n_levels, end_l, _, target_res = self.get_current_resolution() | |
# save ws for potential usage. | |
block_kwargs['ws_detach'] = ws.detach() | |
# cameras, background codes | |
if self.camera_condition is not None: | |
cam_cond = self.get_camera_samples(batch_size, ws, block_kwargs, gen_cond=True) | |
if "camera_matrices" not in block_kwargs: | |
block_kwargs['camera_matrices'] = self.get_camera_samples(batch_size, ws, block_kwargs) | |
if (self.camera_condition is not None) and (cam_cond is None): | |
cam_cond = block_kwargs['camera_matrices'] | |
block_kwargs['theta'] = self.C.get_roll(ws, self.training, **block_kwargs) | |
# get latent codes instead of style vectors (used in GRAF & GIRAFFE) | |
if "latent_codes" not in block_kwargs: | |
block_kwargs["latent_codes"] = self.get_latent_codes(batch_size, device=ws.device) | |
if (self.camera_condition is not None) and (self.camera_condition == 'full'): | |
cam_cond = normalize_2nd_moment(self.camera_map(None, cam_cond[1].reshape(-1, 16))) | |
ws = ws * cam_cond[:, None, :] | |
# generate features for input points (Optional, default not use) | |
with torch.autograd.profiler.record_function('nerf_input_feats'): | |
if self.I is not None: | |
ws = ws.to(torch.float32) | |
blocks = [getattr(self, name) for name in self.input_block_names] | |
block_ws = self.I.forward_ws_split(ws, blocks) | |
nerf_input_feats = self.I.forward_network(blocks, block_ws, **block_kwargs) | |
ws = ws[:, self.I.num_ws:] | |
else: | |
nerf_input_feats = None | |
# prepare for NeRF part | |
with torch.autograd.profiler.record_function('prepare_nerf_path'): | |
if self.progressive_nerf_only and (self.alpha > -1): | |
cur_resolution = int(self.resolution_start * (1 - self.alpha) + self.resolution_vol * self.alpha) | |
elif (end_l == 0) or len(self.block_resolutions) == 0: | |
cur_resolution = self.resolution_start | |
else: | |
cur_resolution = self.block_resolutions[end_l-1] | |
vol_resolution = self.resolution_vol if self.resolution_vol < cur_resolution else cur_resolution | |
nerf_resolution = vol_resolution | |
if (self.interp_steps is not None) and (self.steps is not None) and (self.alpha > 0): # interpolation trick (maybe work??) | |
if self.steps < self.interp_steps[0]: | |
nerf_resolution = vol_resolution // 2 | |
elif self.steps < self.interp_steps[1]: | |
nerf_resolution = (self.steps - self.interp_steps[0]) / (self.interp_steps[1] - self.interp_steps[0]) | |
nerf_resolution = int(nerf_resolution * (vol_resolution / 2) + vol_resolution / 2) | |
vol_pixels, tgt_pixels = self.C.prepare_pixels(self.img_resolution, cur_resolution, nerf_resolution, **block_kwargs) | |
if (end_l > 0) and (self.n_reg_samples > 0) and self.training: | |
rand_pixels, rand_indexs = self.C.prepare_pixels_regularization(tgt_pixels, self.n_reg_samples) | |
else: | |
rand_pixels, rand_indexs = None, None | |
if self.fg_nerf.num_ws > 0: # use style vector instead of latent codes? | |
block_kwargs["styles"] = ws[:, :self.fg_nerf.num_ws] | |
ws = ws[:, self.fg_nerf.num_ws:] | |
if (self.bg_nerf is not None) and self.bg_nerf.num_ws > 0: | |
block_kwargs["styles_bg"] = ws[:, :self.bg_nerf.num_ws] | |
ws = ws[:, self.bg_nerf.num_ws:] | |
# volume rendering | |
with torch.autograd.profiler.record_function('nerf'): | |
if (rand_pixels is not None) and self.training: | |
vol_pixels = (vol_pixels, rand_pixels) | |
outputs = self.V.forward_volume_rendering( | |
nerf_modules=(self.fg_nerf, self.bg_nerf), | |
vol_pixels=vol_pixels, | |
nerf_input_feats=nerf_input_feats, | |
return_full=self.reg_full, | |
alpha=self.alpha, | |
**block_kwargs) | |
reg_loss = outputs.get('reg_loss', {}) | |
x, img, _ = self.V.post_process_outputs(outputs['full_out'], self.freeze_nerf) | |
if nerf_resolution < vol_resolution: | |
x = F.interpolate(x, vol_resolution, mode='bilinear', align_corners=False) | |
img = F.interpolate(img, vol_resolution, mode='bilinear', align_corners=False) | |
# early output from the network (used for visualization) | |
if 'meshes' in block_kwargs: | |
from dnnlib.geometry import render_mesh | |
block_kwargs['voxel_noise'] = render_mesh(block_kwargs['meshes'], block_kwargs["camera_matrices"]) | |
if (len(self.U.block_resolutions) == 0) or \ | |
(x is None) or \ | |
(block_kwargs.get("render_option", None) is not None and | |
'early' in block_kwargs['render_option']): | |
if 'value' in block_kwargs['render_option']: | |
img = x[:,:3] | |
img = img / img.norm(dim=1, keepdim=True) | |
assert img is not None, "need to add RGB" | |
return img | |
if 'rand_out' in outputs: | |
x_rand, img_rand, rand_probs = self.V.post_process_outputs(outputs['rand_out'], self.freeze_nerf) | |
lh, lw = dividable(rand_probs.size(1)) | |
rand_imgs += [img_rand] | |
# append low-resolution image | |
if img is not None: | |
if self.progressive_nerf_only and (img.size(-1) < self.resolution_vol): | |
x = upsample(x, self.resolution_vol) | |
img = upsample(img, self.resolution_vol) | |
block_kwargs['img_nerf'] = img | |
# Use 2D upsampler | |
if (cur_resolution > self.resolution_vol) or self.progressive_nerf_only: | |
imgs += [img] | |
if (self.camera_condition is not None) and (self.camera_condition != 'full'): | |
cam_cond = normalize_2nd_moment(self.camera_map(None, cam_cond[1].reshape(-1, 16))) | |
ws = ws * cam_cond[:, None, :] | |
# 2D feature map upsampling | |
with torch.autograd.profiler.record_function('upsampling'): | |
ws = ws.to(torch.float32) | |
blocks = [getattr(self, name) for name in self.block_names] | |
block_ws = self.U.forward_ws_split(ws, blocks) | |
imgs += self.U.forward_network(blocks, block_ws, x, img, target_res, self.alpha, **block_kwargs) | |
img = imgs[-1] | |
if len(rand_imgs) > 0: # nerf path regularization | |
rand_imgs += self.U.forward_network( | |
blocks, block_ws, x_rand, img_rand, target_res, self.alpha, skip_up=True, **block_kwargs) | |
img_rand = rand_imgs[-1] | |
with torch.autograd.profiler.record_function('rgb_interp'): | |
if (self.alpha > -1) and (not self.progressive_nerf_only) and self.progressive_growing: | |
if (self.alpha < 1) and (self.alpha > 0): | |
alpha, _ = math.modf(self.alpha * n_levels) | |
img_nerf = imgs[-2] | |
if img_nerf.size(-1) < img.size(-1): # need upsample image | |
img_nerf = upsample(img_nerf, 2 * img_nerf.size(-1)) | |
img = img_nerf * (1 - alpha) + img * alpha | |
if len(rand_imgs) > 0: | |
img_rand = rand_imgs[-2] * (1 - alpha) + img_rand * alpha | |
with torch.autograd.profiler.record_function('nerf_path_reg_loss'): | |
if len(rand_imgs) > 0: # and self.training: # random pixel regularization?? | |
assert self.progressive_growing | |
if self.reg_full: # aggregate RGB in the end. | |
lh, lw = img_rand.size(2) // self.n_reg_samples, img_rand.size(3) // self.n_reg_samples | |
img_rand = rearrange(img_rand, 'b d (l h) (m w) -> b d (l m) h w', l=lh, m=lw) | |
img_rand = (img_rand * rand_probs[:, None]).sum(2) | |
if self.V.white_background: | |
img_rand = img_rand + (1 - rand_probs.sum(1, keepdim=True)) | |
rand_indexs = repeat(rand_indexs, 'b n -> b d n', d=img_rand.size(1)) | |
img_ff = rearrange(rearrange(img, 'b d l h -> b d (l h)').gather(2, rand_indexs), 'b d (l h) -> b d l h', l=self.n_reg_samples) | |
def l2(img_ff, img_nf): | |
batch_size = img_nf.size(0) | |
return ((img_ff - img_nf) ** 2).sum(1).reshape(batch_size, -1).mean(-1, keepdim=True) | |
reg_loss['reg_loss'] = l2(img_ff, img_rand) * 2.0 | |
if len(reg_loss) > 0: | |
for key in reg_loss: | |
block_kwargs[key] = reg_loss[key] | |
if self.rectangular_crop is not None: # in case rectangular | |
h, w = self.rectangular_crop | |
c = int(img.size(-1) * (1 - h / w) / 2) | |
mask = torch.ones_like(img) | |
mask[:, :, c:-c, :] = 0 | |
img = img.masked_fill(mask > 0, -1) | |
block_kwargs['img'] = img | |
return block_kwargs | |
def get_current_resolution(self): | |
n_levels = len(self.block_resolutions) | |
if not self.progressive_growing: | |
end_l = n_levels | |
elif (self.alpha > -1) and (not self.progressive_nerf_only): | |
if self.alpha == 0: | |
end_l = 0 | |
elif self.alpha == 1: | |
end_l = n_levels | |
elif self.alpha < 1: | |
end_l = int(math.modf(self.alpha * n_levels)[1] + 1) | |
else: | |
end_l = n_levels | |
target_res = self.resolution_start if end_l <= 0 else self.block_resolutions[end_l-1] | |
before_res = self.resolution_start if end_l <= 1 else self.block_resolutions[end_l-2] | |
return n_levels, end_l, before_res, target_res | |
def get_latent_codes(self, batch_size=32, device="cpu", tmp=1.): | |
z_dim, z_dim_bg = self.z_dim, self.z_dim_bg | |
def sample_z(*size): | |
torch.randn(*size).to(device) | |
return torch.randn(*size).to(device) * tmp | |
z_shape_obj = sample_z(batch_size, z_dim) | |
z_app_obj = sample_z(batch_size, z_dim) | |
z_shape_bg = sample_z(batch_size, z_dim_bg) if not self.V.no_background else None | |
z_app_bg = sample_z(batch_size, z_dim_bg) if not self.V.no_background else None | |
return z_shape_obj, z_app_obj, z_shape_bg, z_app_bg | |
def get_camera(self, *args, **kwargs): # for compitability | |
return self.C.get_camera(*args, **kwargs) | |
def get_camera_samples(self, batch_size, ws, block_kwargs, gen_cond=False): | |
if gen_cond: # camera condition for generator (? a special variant) | |
if ('camera_matrices' in block_kwargs) and (not self.training): # this is for rendering | |
camera_matrices = self.get_camera(batch_size, device=ws.device, mode=[0.5, 0.5, 0.5]) | |
elif self.training and (np.random.rand() > 0.5): | |
camera_matrices = self.get_camera(batch_size, device=ws.device) | |
else: | |
camera_matrices = None | |
elif 'camera_mode' in block_kwargs: | |
camera_matrices = self.get_camera(batch_size, device=ws.device, mode=block_kwargs["camera_mode"]) | |
else: | |
if self.predict_camera: | |
rand_mode = ws.new_zeros(ws.size(0), 2) | |
if self.C.gaussian_camera: | |
rand_mode = rand_mode.normal_() | |
pred_mode = self.camera_generator(rand_mode) | |
else: | |
rand_mode = rand_mode.uniform_() | |
pred_mode = self.camera_generator(rand_mode - 0.5) | |
mode = rand_mode if self.alpha <= 0 else rand_mode + pred_mode * 0.1 | |
camera_matrices = self.get_camera(batch_size, device=ws.device, mode=mode) | |
else: | |
camera_matrices = self.get_camera(batch_size, device=ws.device) | |
if ('camera_RT' in block_kwargs) or ('camera_UV' in block_kwargs): | |
camera_matrices = list(camera_matrices) | |
camera_mask = torch.rand(batch_size).type_as(camera_matrices[1]).lt(self.alpha) | |
if 'camera_RT' in block_kwargs: | |
image_RT = block_kwargs['camera_RT'].reshape(-1, 4, 4) | |
camera_matrices[1][camera_mask] = image_RT[camera_mask] # replacing with inferred cameras | |
else: # sample uv instead of sampling the extrinsic matrix | |
image_UV = block_kwargs['camera_UV'] | |
image_RT = self.get_camera(batch_size, device=ws.device, mode=image_UV, force_uniform=True)[1] | |
camera_matrices[1][camera_mask] = image_RT[camera_mask] # replacing with inferred cameras | |
camera_matrices[2][camera_mask] = image_UV[camera_mask] # replacing with inferred uvs | |
camera_matrices = tuple(camera_matrices) | |
return camera_matrices | |
class Discriminator(torch.nn.Module): | |
def __init__(self, | |
c_dim, # Conditioning label (C) dimensionality. | |
img_resolution, # Input resolution. | |
img_channels, # Number of input color channels. | |
architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'. | |
channel_base = 1, # Overall multiplier for the number of channels. | |
channel_max = 512, # Maximum number of channels in any layer. | |
num_fp16_res = 0, # Use FP16 for the N highest resolutions. | |
conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping. | |
cmap_dim = None, # Dimensionality of mapped conditioning label, None = default. | |
lowres_head = None, # add a low-resolution discriminator head | |
dual_discriminator = False, # add low-resolution (NeRF) image | |
block_kwargs = {}, # Arguments for DiscriminatorBlock. | |
mapping_kwargs = {}, # Arguments for MappingNetwork. | |
epilogue_kwargs = {}, # Arguments for DiscriminatorEpilogue. | |
camera_kwargs = {}, # Arguments for Camera predictor and condition (optional, refactoring) | |
upsample_type = 'default', | |
progressive = False, | |
resize_real_early = False, # Peform resizing before the training loop | |
enable_ema = False, # Additionally save an EMA checkpoint | |
**unused | |
): | |
super().__init__() | |
# setup parameters | |
self.img_resolution = img_resolution | |
self.img_resolution_log2 = int(np.log2(img_resolution)) | |
self.img_channels = img_channels | |
self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, 2, -1)] | |
self.architecture = architecture | |
self.lowres_head = lowres_head | |
self.dual_discriminator = dual_discriminator | |
self.upsample_type = upsample_type | |
self.progressive = progressive | |
self.resize_real_early = resize_real_early | |
self.enable_ema = enable_ema | |
if self.progressive: | |
assert self.architecture == 'skip', "not supporting other types for now." | |
channel_base = int(channel_base * 32768) | |
channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]} | |
fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8) | |
# camera prediction module | |
self.camera_kwargs = EasyDict( | |
predict_camera=False, | |
predict_styles=False, | |
camera_type='3d', | |
camera_encoder=True, | |
camera_encoder_progressive=False, | |
camera_disc=True) | |
## ------ for compitibility ------- # | |
self.camera_kwargs.predict_camera = unused.get('predict_camera', False) | |
self.camera_kwargs.camera_type = '9d' if unused.get('predict_9d_camera', False) else '3d' | |
self.camera_kwargs.camera_disc = not unused.get('no_camera_condition', False) | |
self.camera_kwargs.camera_encoder = unused.get('saperate_camera', False) | |
self.camera_kwargs.update(camera_kwargs) | |
## ------ for compitibility ------- # | |
self.c_dim = c_dim | |
if self.camera_kwargs.predict_camera: | |
if self.camera_kwargs.camera_type == '3d': | |
self.c_dim = out_dim = 3 # (u, v) on the sphere | |
elif self.camera_kwargs.camera_type == '9d': | |
self.c_dim, out_dim = 16, 9 | |
elif self.camera_kwargs.camera_type == '16d': | |
self.c_dim = out_dim = 16 | |
else: | |
raise NotImplementedError('Wrong camera type') | |
if not self.camera_kwargs.camera_disc: | |
self.c_dim = c_dim | |
self.projector = EqualConv2d(channels_dict[4], out_dim, 4, padding=0, bias=False) | |
if cmap_dim is None: | |
cmap_dim = channels_dict[4] | |
if self.c_dim == 0: | |
cmap_dim = 0 | |
if self.c_dim > 0: | |
self.mapping = MappingNetwork(z_dim=0, c_dim=self.c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None, **mapping_kwargs) | |
if self.camera_kwargs.predict_styles: | |
self.w_dim, self.num_ws = self.camera_kwargs.w_dim, self.camera_kwargs.num_ws | |
self.projector_styles = EqualConv2d(channels_dict[4], self.w_dim * self.num_ws, 4, padding=0, bias=False) | |
self.mapping_styles = MappingNetwork(z_dim=0, c_dim=self.w_dim * self.num_ws, w_dim=cmap_dim, num_ws=None, w_avg_beta=None, **mapping_kwargs) | |
# main discriminator blocks | |
common_kwargs = dict(img_channels=self.img_channels, architecture=architecture, conv_clamp=conv_clamp) | |
def build_blocks(layer_name='b', low_resolution=False): | |
cur_layer_idx = 0 | |
block_resolutions = self.block_resolutions | |
if low_resolution: | |
block_resolutions = [r for r in self.block_resolutions if r <= self.lowres_head] | |
for res in block_resolutions: | |
in_channels = channels_dict[res] if res < img_resolution else 0 | |
tmp_channels = channels_dict[res] | |
out_channels = channels_dict[res // 2] | |
use_fp16 = (res >= fp16_resolution) | |
block = DiscriminatorBlock(in_channels, tmp_channels, out_channels, resolution=res, | |
first_layer_idx=cur_layer_idx, use_fp16=use_fp16, **block_kwargs, **common_kwargs) | |
setattr(self, f'{layer_name}{res}', block) | |
cur_layer_idx += block.num_layers | |
build_blocks(layer_name='b') # main blocks | |
if self.dual_discriminator: | |
build_blocks(layer_name='dual', low_resolution=True) | |
if self.camera_kwargs.camera_encoder: | |
build_blocks(layer_name='c', low_resolution=(not self.camera_kwargs.camera_encoder_progressive)) | |
# final output module | |
self.b4 = DiscriminatorEpilogue(channels_dict[4], cmap_dim=cmap_dim, resolution=4, **epilogue_kwargs, **common_kwargs) | |
self.register_buffer("alpha", torch.scalar_tensor(-1)) | |
def set_alpha(self, alpha): | |
if alpha is not None: | |
self.alpha = self.alpha * 0 + alpha | |
def set_resolution(self, res): | |
self.curr_status = res | |
def forward_blocks_progressive(self, img, mode="disc", **block_kwargs): | |
# mode from ['disc', 'dual_disc', 'cam_enc'] | |
if isinstance(img, dict): | |
img = img['img'] | |
block_resolutions, alpha, lowres_head = self.get_block_resolutions(img) | |
layer_name, progressive = 'b', self.progressive | |
if mode == "cam_enc": | |
assert self.camera_kwargs.predict_camera and self.camera_kwargs.camera_encoder | |
layer_name = 'c' | |
if not self.camera_kwargs.camera_encoder_progressive: | |
block_resolutions, progressive = [r for r in self.block_resolutions if r <= self.lowres_head], False | |
img = downsample(img, self.lowres_head) | |
elif mode == 'dual_disc': | |
layer_name = 'dual' | |
block_resolutions, progressive = [r for r in self.block_resolutions if r <= self.lowres_head], False | |
img0 = downsample(img, img.size(-1) // 2) if \ | |
progressive and (self.lowres_head is not None) and (self.alpha > -1) and (self.alpha < 1) and (alpha > 0) \ | |
else None | |
x = None if (not progressive) or (block_resolutions[0] == self.img_resolution) \ | |
else getattr(self, f'{layer_name}{block_resolutions[0]}').fromrgb(img) | |
for res in block_resolutions: | |
block = getattr(self, f'{layer_name}{res}') | |
if (lowres_head == res) and (self.alpha > -1) and (self.alpha < 1) and (alpha > 0): | |
if progressive: | |
if self.architecture == 'skip': | |
img = img * alpha + img0 * (1 - alpha) | |
x = x * alpha + block.fromrgb(img0) * (1 - alpha) | |
x, img = block(x, img, **block_kwargs) | |
output = {} | |
if (mode == 'cam_enc') or \ | |
(mode == 'disc' and self.camera_kwargs.predict_camera and (not self.camera_kwargs.camera_encoder)): | |
c = self.projector(x)[:,:,0,0] | |
if self.camera_kwargs.camera_type == '9d': | |
c = camera_9d_to_16d(c) | |
output['cam'] = c | |
if self.camera_kwargs.predict_styles: | |
w = self.projector_styles(x)[:,:,0,0] | |
output['styles'] = w | |
return output, x, img | |
def get_camera_loss(self, RT=None, UV=None, c=None): | |
if (RT is None) or (UV is None): | |
return None | |
if self.camera_kwargs.camera_type == '3d': # UV has higher priority? | |
return F.mse_loss(UV, c) | |
else: | |
return F.smooth_l1_loss(RT.reshape(RT.size(0), -1), c) * 10 | |
def get_styles_loss(self, WS=None, w=None): | |
if WS is None: | |
return None | |
return F.mse_loss(WS, w) * 0.1 | |
def get_block_resolutions(self, input_img): | |
block_resolutions = self.block_resolutions | |
lowres_head = self.lowres_head | |
alpha = self.alpha | |
img_res = input_img.size(-1) | |
if self.progressive and (self.lowres_head is not None) and (self.alpha > -1): | |
if (self.alpha < 1) and (self.alpha > 0): | |
try: | |
n_levels, _, before_res, target_res = self.curr_status | |
alpha, index = math.modf(self.alpha * n_levels) | |
index = int(index) | |
except Exception as e: # TODO: this is a hack, better to save status as buffers. | |
before_res = target_res = img_res | |
if before_res == target_res: # no upsampling was used in generator, do not increase the discriminator | |
alpha = 0 | |
block_resolutions = [res for res in self.block_resolutions if res <= target_res] | |
lowres_head = before_res | |
elif self.alpha == 0: | |
block_resolutions = [res for res in self.block_resolutions if res <= lowres_head] | |
return block_resolutions, alpha, lowres_head | |
def forward(self, inputs, c=None, aug_pipe=None, return_camera=False, **block_kwargs): | |
if not isinstance(inputs, dict): | |
inputs = {'img': inputs} | |
img = inputs['img'] | |
# this is to handle real images | |
block_resolutions, alpha, _ = self.get_block_resolutions(img) | |
if img.size(-1) > block_resolutions[0]: | |
img = downsample(img, block_resolutions[0]) | |
if self.dual_discriminator and ('img_nerf' not in inputs): | |
inputs['img_nerf'] = downsample(img, self.lowres_head) | |
RT = inputs['camera_matrices'][1].detach() if 'camera_matrices' in inputs else None | |
UV = inputs['camera_matrices'][2].detach() if 'camera_matrices' in inputs else None | |
WS = inputs['ws_detach'].reshape(inputs['batch_size'], -1) if 'ws_detach' in inputs else None | |
no_condition = (c.size(-1) == 0) | |
# forward separate camera encoder, which can also be progressive... | |
if self.camera_kwargs.camera_encoder: | |
out_camenc, _, _ = self.forward_blocks_progressive(img, mode='cam_enc', **block_kwargs) | |
if no_condition and ('cam' in out_camenc): | |
c, camera_loss = out_camenc['cam'], self.get_camera_loss(RT, UV, out_camenc['cam']) | |
if 'styles' in out_camenc: | |
w, styles_loss = out_camenc['styles'], self.get_styles_loss(WS, out_camenc['styles']) | |
no_condition = False | |
# forward another dual discriminator only for low resolution images | |
if self.dual_discriminator: | |
_, x_nerf, img_nerf = self.forward_blocks_progressive(inputs['img_nerf'], mode='dual_disc', **block_kwargs) | |
# if applied data augmentation for discriminator | |
if aug_pipe is not None: | |
img = aug_pipe(img) | |
# perform main discriminator block | |
out_disc, x, img = self.forward_blocks_progressive(img, mode='disc', **block_kwargs) | |
if no_condition and ('cam' in out_disc): | |
c, camera_loss = out_disc['cam'], self.get_camera_loss(RT, UV, out_disc['cam']) | |
if 'styles' in out_disc: | |
w, styles_loss = out_disc['styles'], self.get_styles_loss(WS, out_disc['styles']) | |
no_condition = False | |
# camera conditional discriminator | |
cmap = None | |
if self.c_dim > 0: | |
cc = c.clone().detach() | |
cmap = self.mapping(None, cc) | |
if self.camera_kwargs.predict_styles: | |
ww = w.clone().detach() | |
cmap = [cmap] + [self.mapping_styles(None, ww)] | |
logits = self.b4(x, img, cmap) | |
if self.dual_discriminator: | |
logits = torch.cat([logits, self.b4(x_nerf, img_nerf, cmap)], 0) | |
outputs = {'logits': logits} | |
if self.camera_kwargs.predict_camera and (camera_loss is not None): | |
outputs['camera_loss'] = camera_loss | |
if self.camera_kwargs.predict_styles and (styles_loss is not None): | |
outputs['styles_loss'] = styles_loss | |
if return_camera: | |
outputs['camera'] = c | |
return outputs | |
class Encoder(torch.nn.Module): | |
def __init__(self, | |
img_resolution, # Input resolution. | |
img_channels, # Number of input color channels. | |
bottleneck_factor = 2, # By default, the same as discriminator we use 4x4 features | |
architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'. | |
channel_base = 1, # Overall multiplier for the number of channels. | |
channel_max = 512, # Maximum number of channels in any layer. | |
num_fp16_res = 0, # Use FP16 for the N highest resolutions. | |
conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping | |
lowres_head = None, # add a low-resolution discriminator head | |
block_kwargs = {}, # Arguments for DiscriminatorBlock. | |
model_kwargs = {}, | |
upsample_type = 'default', | |
progressive = False, | |
**unused | |
): | |
super().__init__() | |
self.img_resolution = img_resolution | |
self.img_resolution_log2 = int(np.log2(img_resolution)) | |
self.img_channels = img_channels | |
self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, bottleneck_factor, -1)] | |
self.architecture = architecture | |
self.lowres_head = lowres_head | |
self.upsample_type = upsample_type | |
self.progressive = progressive | |
self.model_kwargs = model_kwargs | |
self.output_mode = model_kwargs.get('output_mode', 'styles') | |
if self.progressive: | |
assert self.architecture == 'skip', "not supporting other types for now." | |
self.predict_camera = model_kwargs.get('predict_camera', False) | |
channel_base = int(channel_base * 32768) | |
channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]} | |
fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8) | |
common_kwargs = dict(img_channels=self.img_channels, architecture=architecture, conv_clamp=conv_clamp) | |
cur_layer_idx = 0 | |
for res in self.block_resolutions: | |
in_channels = channels_dict[res] if res < img_resolution else 0 | |
tmp_channels = channels_dict[res] | |
out_channels = channels_dict[res // 2] | |
use_fp16 = (res >= fp16_resolution) | |
block = DiscriminatorBlock(in_channels, tmp_channels, out_channels, resolution=res, | |
first_layer_idx=cur_layer_idx, use_fp16=use_fp16, **block_kwargs, **common_kwargs) | |
setattr(self, f'b{res}', block) | |
cur_layer_idx += block.num_layers | |
# this is an encoder | |
if self.output_mode in ['W', 'W+', 'None']: | |
self.num_ws = self.model_kwargs.get('num_ws', 0) | |
self.n_latents = self.num_ws if self.output_mode == 'W+' else (0 if self.output_mode == 'None' else 1) | |
self.w_dim = self.model_kwargs.get('w_dim', 512) | |
self.add_dim = self.model_kwargs.get('add_dim', 0) if not self.predict_camera else 9 | |
self.out_dim = self.w_dim * self.n_latents + self.add_dim | |
assert self.out_dim > 0, 'output dimenstion has to be larger than 0' | |
assert self.block_resolutions[-1] // 2 == 4, "make sure the last resolution is 4x4" | |
self.projector = EqualConv2d(channels_dict[4], self.out_dim, 4, padding=0, bias=False) | |
else: | |
raise NotImplementedError | |
self.register_buffer("alpha", torch.scalar_tensor(-1)) | |
def set_alpha(self, alpha): | |
if alpha is not None: | |
self.alpha.fill_(alpha) | |
def set_resolution(self, res): | |
self.curr_status = res | |
def get_block_resolutions(self, input_img): | |
block_resolutions = self.block_resolutions | |
lowres_head = self.lowres_head | |
alpha = self.alpha | |
img_res = input_img.size(-1) | |
if self.progressive and (self.lowres_head is not None) and (self.alpha > -1): | |
if (self.alpha < 1) and (self.alpha > 0): | |
try: | |
n_levels, _, before_res, target_res = self.curr_status | |
alpha, index = math.modf(self.alpha * n_levels) | |
index = int(index) | |
except Exception as e: # TODO: this is a hack, better to save status as buffers. | |
before_res = target_res = img_res | |
if before_res == target_res: | |
# no upsampling was used in generator, do not increase the discriminator | |
alpha = 0 | |
block_resolutions = [res for res in self.block_resolutions if res <= target_res] | |
lowres_head = before_res | |
elif self.alpha == 0: | |
block_resolutions = [res for res in self.block_resolutions if res <= lowres_head] | |
return block_resolutions, alpha, lowres_head | |
def forward(self, inputs, **block_kwargs): | |
if isinstance(inputs, dict): | |
img = inputs['img'] | |
else: | |
img = inputs | |
block_resolutions, alpha, lowres_head = self.get_block_resolutions(img) | |
if img.size(-1) > block_resolutions[0]: | |
img = downsample(img, block_resolutions[0]) | |
if self.progressive and (self.lowres_head is not None) and (self.alpha > -1) and (self.alpha < 1) and (alpha > 0): | |
img0 = downsample(img, img.size(-1) // 2) | |
x = None if (not self.progressive) or (block_resolutions[0] == self.img_resolution) \ | |
else getattr(self, f'b{block_resolutions[0]}').fromrgb(img) | |
for res in block_resolutions: | |
block = getattr(self, f'b{res}') | |
if (lowres_head == res) and (self.alpha > -1) and (self.alpha < 1) and (alpha > 0): | |
if self.architecture == 'skip': | |
img = img * alpha + img0 * (1 - alpha) | |
if self.progressive: | |
x = x * alpha + block.fromrgb(img0) * (1 - alpha) # combine from img0 | |
x, img = block(x, img, **block_kwargs) | |
outputs = {} | |
if self.output_mode in ['W', 'W+', 'None']: | |
out = self.projector(x)[:,:,0,0] | |
if self.predict_camera: | |
out, out_cam_9d = out[:, 9:], out[:, :9] | |
outputs['camera'] = camera_9d_to_16d(out_cam_9d) | |
if self.output_mode == 'W+': | |
out = rearrange(out, 'b (n s) -> b n s', n=self.num_ws, s=self.w_dim) | |
elif self.output_mode == 'W': | |
out = repeat(out, 'b s -> b n s', n=self.num_ws) | |
else: | |
out = None | |
outputs['ws'] = out | |
return outputs | |
# ------------------------------------------------------------------------------------------- # | |
class CameraQueriedSampler(torch.utils.data.Sampler): | |
def __init__(self, dataset, camera_module, nearest_neighbors=400, rank=0, num_replicas=1, device='cpu', seed=0): | |
assert len(dataset) > 0 | |
super().__init__(dataset) | |
self.dataset = dataset | |
self.dataset_cameras = None | |
self.seed = seed | |
self.rank = rank | |
self.device = device | |
self.num_replicas = num_replicas | |
self.C = camera_module | |
self.K = nearest_neighbors | |
self.B = 1000 | |
def update_dataset_cameras(self, estimator): | |
import tqdm | |
from torch_utils.distributed_utils import gather_list_and_concat | |
output = torch.ones(len(self.dataset), 16).to(self.device) | |
with torch.no_grad(): | |
predicted_cameras, image_indices, bsz = [], [], 64 | |
item_subset = [(i * self.num_replicas + self.rank) % len(self.dataset) for i in range((len(self.dataset) - 1) // self.num_replicas + 1)] | |
for _, (images, _, indices) in tqdm.tqdm(enumerate(torch.utils.data.DataLoader( | |
dataset=copy.deepcopy(self.dataset), sampler=item_subset, batch_size=bsz)), | |
total=len(item_subset)//bsz+1, colour='red', desc=f'Estimating camera poses for the training set at'): | |
predicted_cameras += [estimator(images.to(self.device).to(torch.float32) / 127.5 - 1)] | |
image_indices += [indices.to(self.device).long()] | |
predicted_cameras = torch.cat(predicted_cameras, 0) | |
image_indices = torch.cat(image_indices, 0) | |
if self.num_replicas > 1: | |
predicted_cameras = gather_list_and_concat(predicted_cameras) | |
image_indices = gather_list_and_concat(image_indices) | |
output[image_indices] = predicted_cameras | |
self.dataset_cameras = output | |
def get_knn_cameras(self): | |
return torch.norm( | |
self.dataset_cameras.unsqueeze(1) - | |
self.C.get_camera(self.B, self.device)[0].reshape(1,self.B,16), dim=2, p=None | |
).topk(self.K, largest=False, dim=0)[1] # K x B | |
def __iter__(self): | |
order = np.arange(len(self.dataset)) | |
rnd = np.random.RandomState(self.seed+self.rank) | |
while True: | |
if self.dataset_cameras is None: | |
rand_idx = rnd.randint(order.size) | |
yield rand_idx | |
else: | |
knn_idxs = self.get_knn_cameras() | |
for i in range(self.B): | |
rand_idx = rnd.randint(self.K) | |
yield knn_idxs[rand_idx, i].item() | |