AiOS / pytorch3d /projects /nerf /train_nerf.py
ttxskk
update
d7e58f0
raw
history blame
9.5 kB
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import collections
import os
import pickle
import warnings
import hydra
import numpy as np
import torch
from nerf.dataset import get_nerf_datasets, trivial_collate
from nerf.nerf_renderer import RadianceFieldRenderer, visualize_nerf_outputs
from nerf.stats import Stats
from omegaconf import DictConfig
from visdom import Visdom
CONFIG_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs")
@hydra.main(config_path=CONFIG_DIR, config_name="lego")
def main(cfg: DictConfig):
# Set the relevant seeds for reproducibility.
np.random.seed(cfg.seed)
torch.manual_seed(cfg.seed)
# Device on which to run.
if torch.cuda.is_available():
device = "cuda"
else:
warnings.warn(
"Please note that although executing on CPU is supported,"
+ "the training is unlikely to finish in reasonable time."
)
device = "cpu"
# Initialize the Radiance Field model.
model = RadianceFieldRenderer(
image_size=cfg.data.image_size,
n_pts_per_ray=cfg.raysampler.n_pts_per_ray,
n_pts_per_ray_fine=cfg.raysampler.n_pts_per_ray,
n_rays_per_image=cfg.raysampler.n_rays_per_image,
min_depth=cfg.raysampler.min_depth,
max_depth=cfg.raysampler.max_depth,
stratified=cfg.raysampler.stratified,
stratified_test=cfg.raysampler.stratified_test,
chunk_size_test=cfg.raysampler.chunk_size_test,
n_harmonic_functions_xyz=cfg.implicit_function.n_harmonic_functions_xyz,
n_harmonic_functions_dir=cfg.implicit_function.n_harmonic_functions_dir,
n_hidden_neurons_xyz=cfg.implicit_function.n_hidden_neurons_xyz,
n_hidden_neurons_dir=cfg.implicit_function.n_hidden_neurons_dir,
n_layers_xyz=cfg.implicit_function.n_layers_xyz,
density_noise_std=cfg.implicit_function.density_noise_std,
visualization=cfg.visualization.visdom,
)
# Move the model to the relevant device.
model.to(device)
# Init stats to None before loading.
stats = None
optimizer_state_dict = None
start_epoch = 0
checkpoint_path = os.path.join(hydra.utils.get_original_cwd(), cfg.checkpoint_path)
if len(cfg.checkpoint_path) > 0:
# Make the root of the experiment directory.
checkpoint_dir = os.path.split(checkpoint_path)[0]
os.makedirs(checkpoint_dir, exist_ok=True)
# Resume training if requested.
if cfg.resume and os.path.isfile(checkpoint_path):
print(f"Resuming from checkpoint {checkpoint_path}.")
loaded_data = torch.load(checkpoint_path)
model.load_state_dict(loaded_data["model"])
stats = pickle.loads(loaded_data["stats"])
print(f" => resuming from epoch {stats.epoch}.")
optimizer_state_dict = loaded_data["optimizer"]
start_epoch = stats.epoch
# Initialize the optimizer.
optimizer = torch.optim.Adam(
model.parameters(),
lr=cfg.optimizer.lr,
)
# Load the optimizer state dict in case we are resuming.
if optimizer_state_dict is not None:
optimizer.load_state_dict(optimizer_state_dict)
optimizer.last_epoch = start_epoch
# Init the stats object.
if stats is None:
stats = Stats(
["loss", "mse_coarse", "mse_fine", "psnr_coarse", "psnr_fine", "sec/it"],
)
# Learning rate scheduler setup.
# Following the original code, we use exponential decay of the
# learning rate: current_lr = base_lr * gamma ** (epoch / step_size)
def lr_lambda(epoch):
return cfg.optimizer.lr_scheduler_gamma ** (
epoch / cfg.optimizer.lr_scheduler_step_size
)
# The learning rate scheduling is implemented with LambdaLR PyTorch scheduler.
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
optimizer, lr_lambda, last_epoch=start_epoch - 1, verbose=False
)
# Initialize the cache for storing variables needed for visualization.
visuals_cache = collections.deque(maxlen=cfg.visualization.history_size)
# Init the visualization visdom env.
if cfg.visualization.visdom:
viz = Visdom(
server=cfg.visualization.visdom_server,
port=cfg.visualization.visdom_port,
use_incoming_socket=False,
)
else:
viz = None
# Load the training/validation data.
train_dataset, val_dataset, _ = get_nerf_datasets(
dataset_name=cfg.data.dataset_name,
image_size=cfg.data.image_size,
)
if cfg.data.precache_rays:
# Precache the projection rays.
model.eval()
with torch.no_grad():
for dataset in (train_dataset, val_dataset):
cache_cameras = [e["camera"].to(device) for e in dataset]
cache_camera_hashes = [e["camera_idx"] for e in dataset]
model.precache_rays(cache_cameras, cache_camera_hashes)
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
batch_size=1,
shuffle=True,
num_workers=0,
collate_fn=trivial_collate,
)
# The validation dataloader is just an endless stream of random samples.
val_dataloader = torch.utils.data.DataLoader(
val_dataset,
batch_size=1,
num_workers=0,
collate_fn=trivial_collate,
sampler=torch.utils.data.RandomSampler(
val_dataset,
replacement=True,
num_samples=cfg.optimizer.max_epochs,
),
)
# Set the model to the training mode.
model.train()
# Run the main training loop.
for epoch in range(start_epoch, cfg.optimizer.max_epochs):
stats.new_epoch() # Init a new epoch.
for iteration, batch in enumerate(train_dataloader):
image, camera, camera_idx = batch[0].values()
image = image.to(device)
camera = camera.to(device)
optimizer.zero_grad()
# Run the forward pass of the model.
nerf_out, metrics = model(
camera_idx if cfg.data.precache_rays else None,
camera,
image,
)
# The loss is a sum of coarse and fine MSEs
loss = metrics["mse_coarse"] + metrics["mse_fine"]
# Take the training step.
loss.backward()
optimizer.step()
# Update stats with the current metrics.
stats.update(
{"loss": float(loss), **metrics},
stat_set="train",
)
if iteration % cfg.stats_print_interval == 0:
stats.print(stat_set="train")
# Update the visualization cache.
if viz is not None:
visuals_cache.append(
{
"camera": camera.cpu(),
"camera_idx": camera_idx,
"image": image.cpu().detach(),
"rgb_fine": nerf_out["rgb_fine"].cpu().detach(),
"rgb_coarse": nerf_out["rgb_coarse"].cpu().detach(),
"rgb_gt": nerf_out["rgb_gt"].cpu().detach(),
"coarse_ray_bundle": nerf_out["coarse_ray_bundle"],
}
)
# Adjust the learning rate.
lr_scheduler.step()
# Validation
if epoch % cfg.validation_epoch_interval == 0 and epoch > 0:
# Sample a validation camera/image.
val_batch = next(val_dataloader.__iter__())
val_image, val_camera, camera_idx = val_batch[0].values()
val_image = val_image.to(device)
val_camera = val_camera.to(device)
# Activate eval mode of the model (lets us do a full rendering pass).
model.eval()
with torch.no_grad():
val_nerf_out, val_metrics = model(
camera_idx if cfg.data.precache_rays else None,
val_camera,
val_image,
)
# Update stats with the validation metrics.
stats.update(val_metrics, stat_set="val")
stats.print(stat_set="val")
if viz is not None:
# Plot that loss curves into visdom.
stats.plot_stats(
viz=viz,
visdom_env=cfg.visualization.visdom_env,
plot_file=None,
)
# Visualize the intermediate results.
visualize_nerf_outputs(
val_nerf_out, visuals_cache, viz, cfg.visualization.visdom_env
)
# Set the model back to train mode.
model.train()
# Checkpoint.
if (
epoch % cfg.checkpoint_epoch_interval == 0
and len(cfg.checkpoint_path) > 0
and epoch > 0
):
print(f"Storing checkpoint {checkpoint_path}.")
data_to_store = {
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
"stats": pickle.dumps(stats),
}
torch.save(data_to_store, checkpoint_path)
if __name__ == "__main__":
main()