Spaces:
Running
on
L40S
Running
on
L40S
#!/usr/bin/env python3 | |
# Copyright (c) Facebook, Inc. and its affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the BSD-style license found in the | |
# LICENSE file in the root directory of this source tree. | |
import collections | |
import os | |
import pickle | |
import warnings | |
import hydra | |
import numpy as np | |
import torch | |
from nerf.dataset import get_nerf_datasets, trivial_collate | |
from nerf.nerf_renderer import RadianceFieldRenderer, visualize_nerf_outputs | |
from nerf.stats import Stats | |
from omegaconf import DictConfig | |
from visdom import Visdom | |
CONFIG_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs") | |
def main(cfg: DictConfig): | |
# Set the relevant seeds for reproducibility. | |
np.random.seed(cfg.seed) | |
torch.manual_seed(cfg.seed) | |
# Device on which to run. | |
if torch.cuda.is_available(): | |
device = "cuda" | |
else: | |
warnings.warn( | |
"Please note that although executing on CPU is supported," | |
+ "the training is unlikely to finish in reasonable time." | |
) | |
device = "cpu" | |
# Initialize the Radiance Field model. | |
model = RadianceFieldRenderer( | |
image_size=cfg.data.image_size, | |
n_pts_per_ray=cfg.raysampler.n_pts_per_ray, | |
n_pts_per_ray_fine=cfg.raysampler.n_pts_per_ray, | |
n_rays_per_image=cfg.raysampler.n_rays_per_image, | |
min_depth=cfg.raysampler.min_depth, | |
max_depth=cfg.raysampler.max_depth, | |
stratified=cfg.raysampler.stratified, | |
stratified_test=cfg.raysampler.stratified_test, | |
chunk_size_test=cfg.raysampler.chunk_size_test, | |
n_harmonic_functions_xyz=cfg.implicit_function.n_harmonic_functions_xyz, | |
n_harmonic_functions_dir=cfg.implicit_function.n_harmonic_functions_dir, | |
n_hidden_neurons_xyz=cfg.implicit_function.n_hidden_neurons_xyz, | |
n_hidden_neurons_dir=cfg.implicit_function.n_hidden_neurons_dir, | |
n_layers_xyz=cfg.implicit_function.n_layers_xyz, | |
density_noise_std=cfg.implicit_function.density_noise_std, | |
visualization=cfg.visualization.visdom, | |
) | |
# Move the model to the relevant device. | |
model.to(device) | |
# Init stats to None before loading. | |
stats = None | |
optimizer_state_dict = None | |
start_epoch = 0 | |
checkpoint_path = os.path.join(hydra.utils.get_original_cwd(), cfg.checkpoint_path) | |
if len(cfg.checkpoint_path) > 0: | |
# Make the root of the experiment directory. | |
checkpoint_dir = os.path.split(checkpoint_path)[0] | |
os.makedirs(checkpoint_dir, exist_ok=True) | |
# Resume training if requested. | |
if cfg.resume and os.path.isfile(checkpoint_path): | |
print(f"Resuming from checkpoint {checkpoint_path}.") | |
loaded_data = torch.load(checkpoint_path) | |
model.load_state_dict(loaded_data["model"]) | |
stats = pickle.loads(loaded_data["stats"]) | |
print(f" => resuming from epoch {stats.epoch}.") | |
optimizer_state_dict = loaded_data["optimizer"] | |
start_epoch = stats.epoch | |
# Initialize the optimizer. | |
optimizer = torch.optim.Adam( | |
model.parameters(), | |
lr=cfg.optimizer.lr, | |
) | |
# Load the optimizer state dict in case we are resuming. | |
if optimizer_state_dict is not None: | |
optimizer.load_state_dict(optimizer_state_dict) | |
optimizer.last_epoch = start_epoch | |
# Init the stats object. | |
if stats is None: | |
stats = Stats( | |
["loss", "mse_coarse", "mse_fine", "psnr_coarse", "psnr_fine", "sec/it"], | |
) | |
# Learning rate scheduler setup. | |
# Following the original code, we use exponential decay of the | |
# learning rate: current_lr = base_lr * gamma ** (epoch / step_size) | |
def lr_lambda(epoch): | |
return cfg.optimizer.lr_scheduler_gamma ** ( | |
epoch / cfg.optimizer.lr_scheduler_step_size | |
) | |
# The learning rate scheduling is implemented with LambdaLR PyTorch scheduler. | |
lr_scheduler = torch.optim.lr_scheduler.LambdaLR( | |
optimizer, lr_lambda, last_epoch=start_epoch - 1, verbose=False | |
) | |
# Initialize the cache for storing variables needed for visualization. | |
visuals_cache = collections.deque(maxlen=cfg.visualization.history_size) | |
# Init the visualization visdom env. | |
if cfg.visualization.visdom: | |
viz = Visdom( | |
server=cfg.visualization.visdom_server, | |
port=cfg.visualization.visdom_port, | |
use_incoming_socket=False, | |
) | |
else: | |
viz = None | |
# Load the training/validation data. | |
train_dataset, val_dataset, _ = get_nerf_datasets( | |
dataset_name=cfg.data.dataset_name, | |
image_size=cfg.data.image_size, | |
) | |
if cfg.data.precache_rays: | |
# Precache the projection rays. | |
model.eval() | |
with torch.no_grad(): | |
for dataset in (train_dataset, val_dataset): | |
cache_cameras = [e["camera"].to(device) for e in dataset] | |
cache_camera_hashes = [e["camera_idx"] for e in dataset] | |
model.precache_rays(cache_cameras, cache_camera_hashes) | |
train_dataloader = torch.utils.data.DataLoader( | |
train_dataset, | |
batch_size=1, | |
shuffle=True, | |
num_workers=0, | |
collate_fn=trivial_collate, | |
) | |
# The validation dataloader is just an endless stream of random samples. | |
val_dataloader = torch.utils.data.DataLoader( | |
val_dataset, | |
batch_size=1, | |
num_workers=0, | |
collate_fn=trivial_collate, | |
sampler=torch.utils.data.RandomSampler( | |
val_dataset, | |
replacement=True, | |
num_samples=cfg.optimizer.max_epochs, | |
), | |
) | |
# Set the model to the training mode. | |
model.train() | |
# Run the main training loop. | |
for epoch in range(start_epoch, cfg.optimizer.max_epochs): | |
stats.new_epoch() # Init a new epoch. | |
for iteration, batch in enumerate(train_dataloader): | |
image, camera, camera_idx = batch[0].values() | |
image = image.to(device) | |
camera = camera.to(device) | |
optimizer.zero_grad() | |
# Run the forward pass of the model. | |
nerf_out, metrics = model( | |
camera_idx if cfg.data.precache_rays else None, | |
camera, | |
image, | |
) | |
# The loss is a sum of coarse and fine MSEs | |
loss = metrics["mse_coarse"] + metrics["mse_fine"] | |
# Take the training step. | |
loss.backward() | |
optimizer.step() | |
# Update stats with the current metrics. | |
stats.update( | |
{"loss": float(loss), **metrics}, | |
stat_set="train", | |
) | |
if iteration % cfg.stats_print_interval == 0: | |
stats.print(stat_set="train") | |
# Update the visualization cache. | |
if viz is not None: | |
visuals_cache.append( | |
{ | |
"camera": camera.cpu(), | |
"camera_idx": camera_idx, | |
"image": image.cpu().detach(), | |
"rgb_fine": nerf_out["rgb_fine"].cpu().detach(), | |
"rgb_coarse": nerf_out["rgb_coarse"].cpu().detach(), | |
"rgb_gt": nerf_out["rgb_gt"].cpu().detach(), | |
"coarse_ray_bundle": nerf_out["coarse_ray_bundle"], | |
} | |
) | |
# Adjust the learning rate. | |
lr_scheduler.step() | |
# Validation | |
if epoch % cfg.validation_epoch_interval == 0 and epoch > 0: | |
# Sample a validation camera/image. | |
val_batch = next(val_dataloader.__iter__()) | |
val_image, val_camera, camera_idx = val_batch[0].values() | |
val_image = val_image.to(device) | |
val_camera = val_camera.to(device) | |
# Activate eval mode of the model (lets us do a full rendering pass). | |
model.eval() | |
with torch.no_grad(): | |
val_nerf_out, val_metrics = model( | |
camera_idx if cfg.data.precache_rays else None, | |
val_camera, | |
val_image, | |
) | |
# Update stats with the validation metrics. | |
stats.update(val_metrics, stat_set="val") | |
stats.print(stat_set="val") | |
if viz is not None: | |
# Plot that loss curves into visdom. | |
stats.plot_stats( | |
viz=viz, | |
visdom_env=cfg.visualization.visdom_env, | |
plot_file=None, | |
) | |
# Visualize the intermediate results. | |
visualize_nerf_outputs( | |
val_nerf_out, visuals_cache, viz, cfg.visualization.visdom_env | |
) | |
# Set the model back to train mode. | |
model.train() | |
# Checkpoint. | |
if ( | |
epoch % cfg.checkpoint_epoch_interval == 0 | |
and len(cfg.checkpoint_path) > 0 | |
and epoch > 0 | |
): | |
print(f"Storing checkpoint {checkpoint_path}.") | |
data_to_store = { | |
"model": model.state_dict(), | |
"optimizer": optimizer.state_dict(), | |
"stats": pickle.dumps(stats), | |
} | |
torch.save(data_to_store, checkpoint_path) | |
if __name__ == "__main__": | |
main() | |