Spaces:
Running
Running
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the BSD-style license found in the | |
# LICENSE file in the root directory of this source tree. | |
# pyre-unsafe | |
import logging | |
import os | |
import time | |
from typing import Any, List, Optional | |
import torch | |
from accelerate import Accelerator | |
from pytorch3d.implicitron.evaluation.evaluator import EvaluatorBase | |
from pytorch3d.implicitron.models.base_model import ImplicitronModelBase | |
from pytorch3d.implicitron.models.generic_model import EvaluationMode | |
from pytorch3d.implicitron.tools import model_io, vis_utils | |
from pytorch3d.implicitron.tools.config import ( | |
registry, | |
ReplaceableBase, | |
run_auto_creation, | |
) | |
from pytorch3d.implicitron.tools.stats import Stats | |
from torch.utils.data import DataLoader, Dataset | |
from .utils import seed_all_random_engines | |
logger = logging.getLogger(__name__) | |
# pyre-fixme[13]: Attribute `evaluator` is never initialized. | |
class TrainingLoopBase(ReplaceableBase): | |
""" | |
Members: | |
evaluator: An EvaluatorBase instance, used to evaluate training results. | |
""" | |
evaluator: Optional[EvaluatorBase] | |
evaluator_class_type: Optional[str] = "ImplicitronEvaluator" | |
def run( | |
self, | |
train_loader: DataLoader, | |
val_loader: Optional[DataLoader], | |
test_loader: Optional[DataLoader], | |
train_dataset: Dataset, | |
model: ImplicitronModelBase, | |
optimizer: torch.optim.Optimizer, | |
scheduler: Any, | |
**kwargs, | |
) -> None: | |
raise NotImplementedError() | |
def load_stats( | |
self, | |
log_vars: List[str], | |
exp_dir: str, | |
resume: bool = True, | |
resume_epoch: int = -1, | |
**kwargs, | |
) -> Stats: | |
raise NotImplementedError() | |
class ImplicitronTrainingLoop(TrainingLoopBase): | |
""" | |
Members: | |
eval_only: If True, only run evaluation using the test dataloader. | |
max_epochs: Train for this many epochs. Note that if the model was | |
loaded from a checkpoint, we will restart training at the appropriate | |
epoch and run for (max_epochs - checkpoint_epoch) epochs. | |
store_checkpoints: If True, store model and optimizer state checkpoints. | |
store_checkpoints_purge: If >= 0, remove any checkpoints older or equal | |
to this many epochs. | |
test_interval: Evaluate on a test dataloader each `test_interval` epochs. | |
test_when_finished: If True, evaluate on a test dataloader when training | |
completes. | |
validation_interval: Validate each `validation_interval` epochs. | |
clip_grad: Optionally clip the gradient norms. | |
If set to a value <=0.0, no clipping | |
metric_print_interval: The batch interval at which the stats should be | |
logged. | |
visualize_interval: The batch interval at which the visualizations | |
should be plotted | |
visdom_env: The name of the Visdom environment to use for plotting. | |
visdom_port: The Visdom port. | |
visdom_server: Address of the Visdom server. | |
""" | |
# Parameters of the outer training loop. | |
eval_only: bool = False | |
max_epochs: int = 1000 | |
store_checkpoints: bool = True | |
store_checkpoints_purge: int = 1 | |
test_interval: int = -1 | |
test_when_finished: bool = False | |
validation_interval: int = 1 | |
# Gradient clipping. | |
clip_grad: float = 0.0 | |
# Visualization/logging parameters. | |
metric_print_interval: int = 5 | |
visualize_interval: int = 1000 | |
visdom_env: str = "" | |
visdom_port: int = int(os.environ.get("VISDOM_PORT", 8097)) | |
visdom_server: str = "http://127.0.0.1" | |
def __post_init__(self): | |
run_auto_creation(self) | |
# pyre-fixme[14]: `run` overrides method defined in `TrainingLoopBase` | |
# inconsistently. | |
def run( | |
self, | |
*, | |
train_loader: DataLoader, | |
val_loader: Optional[DataLoader], | |
test_loader: Optional[DataLoader], | |
train_dataset: Dataset, | |
model: ImplicitronModelBase, | |
optimizer: torch.optim.Optimizer, | |
scheduler: Any, | |
accelerator: Optional[Accelerator], | |
device: torch.device, | |
exp_dir: str, | |
stats: Stats, | |
seed: int, | |
**kwargs, | |
): | |
""" | |
Entry point to run the training and validation loops | |
based on the specified config file. | |
""" | |
start_epoch = stats.epoch + 1 | |
assert scheduler.last_epoch == stats.epoch + 1 | |
assert scheduler.last_epoch == start_epoch | |
# only run evaluation on the test dataloader | |
if self.eval_only: | |
if test_loader is not None: | |
# pyre-fixme[16]: `Optional` has no attribute `run`. | |
self.evaluator.run( | |
dataloader=test_loader, | |
device=device, | |
dump_to_json=True, | |
epoch=stats.epoch, | |
exp_dir=exp_dir, | |
model=model, | |
) | |
return | |
else: | |
raise ValueError( | |
"Cannot evaluate and dump results to json, no test data provided." | |
) | |
# loop through epochs | |
for epoch in range(start_epoch, self.max_epochs): | |
# automatic new_epoch and plotting of stats at every epoch start | |
with stats: | |
# Make sure to re-seed random generators to ensure reproducibility | |
# even after restart. | |
seed_all_random_engines(seed + epoch) | |
cur_lr = float(scheduler.get_last_lr()[-1]) | |
logger.debug(f"scheduler lr = {cur_lr:1.2e}") | |
# train loop | |
self._training_or_validation_epoch( | |
accelerator=accelerator, | |
device=device, | |
epoch=epoch, | |
loader=train_loader, | |
model=model, | |
optimizer=optimizer, | |
stats=stats, | |
validation=False, | |
) | |
# val loop (optional) | |
if val_loader is not None and epoch % self.validation_interval == 0: | |
self._training_or_validation_epoch( | |
accelerator=accelerator, | |
device=device, | |
epoch=epoch, | |
loader=val_loader, | |
model=model, | |
optimizer=optimizer, | |
stats=stats, | |
validation=True, | |
) | |
# eval loop (optional) | |
if ( | |
test_loader is not None | |
and self.test_interval > 0 | |
and epoch % self.test_interval == 0 | |
): | |
self.evaluator.run( | |
device=device, | |
dataloader=test_loader, | |
model=model, | |
) | |
assert stats.epoch == epoch, "inconsistent stats!" | |
self._checkpoint(accelerator, epoch, exp_dir, model, optimizer, stats) | |
scheduler.step() | |
new_lr = float(scheduler.get_last_lr()[-1]) | |
if new_lr != cur_lr: | |
logger.info(f"LR change! {cur_lr} -> {new_lr}") | |
if self.test_when_finished: | |
if test_loader is not None: | |
self.evaluator.run( | |
device=device, | |
dump_to_json=True, | |
epoch=stats.epoch, | |
exp_dir=exp_dir, | |
dataloader=test_loader, | |
model=model, | |
) | |
else: | |
raise ValueError( | |
"Cannot evaluate and dump results to json, no test data provided." | |
) | |
def load_stats( | |
self, | |
log_vars: List[str], | |
exp_dir: str, | |
resume: bool = True, | |
resume_epoch: int = -1, | |
**kwargs, | |
) -> Stats: | |
""" | |
Load Stats that correspond to the model's log_vars and resume_epoch. | |
Args: | |
log_vars: A list of variable names to log. Should be a subset of the | |
`preds` returned by the forward function of the corresponding | |
ImplicitronModelBase instance. | |
exp_dir: Root experiment directory. | |
resume: If False, do not load stats from the checkpoint speci- | |
fied by resume and resume_epoch; instead, create a fresh stats object. | |
stats: The stats structure (optionally loaded from checkpoint) | |
""" | |
# Init the stats struct | |
visdom_env_charts = ( | |
vis_utils.get_visdom_env(self.visdom_env, exp_dir) + "_charts" | |
) | |
stats = Stats( | |
# log_vars should be a list, but OmegaConf might load them as ListConfig | |
list(log_vars), | |
plot_file=os.path.join(exp_dir, "train_stats.pdf"), | |
visdom_env=visdom_env_charts, | |
visdom_server=self.visdom_server, | |
visdom_port=self.visdom_port, | |
) | |
model_path = None | |
if resume: | |
if resume_epoch > 0: | |
model_path = model_io.get_checkpoint(exp_dir, resume_epoch) | |
if not os.path.isfile(model_path): | |
raise FileNotFoundError( | |
f"Cannot find stats from epoch {resume_epoch}." | |
) | |
else: | |
model_path = model_io.find_last_checkpoint(exp_dir) | |
if model_path is not None: | |
stats_path = model_io.get_stats_path(model_path) | |
stats_load = model_io.load_stats(stats_path) | |
# Determine if stats should be reset | |
if resume: | |
if stats_load is None: | |
logger.warning("\n\n\n\nCORRUPT STATS -> clearing stats\n\n\n\n") | |
last_epoch = model_io.parse_epoch_from_model_path(model_path) | |
logger.info(f"Estimated resume epoch = {last_epoch}") | |
# Reset the stats struct | |
for _ in range(last_epoch + 1): | |
stats.new_epoch() | |
assert last_epoch == stats.epoch | |
else: | |
logger.info(f"Found previous stats in {stats_path} -> resuming.") | |
stats = stats_load | |
# Update stats properties incase it was reset on load | |
stats.visdom_env = visdom_env_charts | |
stats.visdom_server = self.visdom_server | |
stats.visdom_port = self.visdom_port | |
stats.plot_file = os.path.join(exp_dir, "train_stats.pdf") | |
stats.synchronize_logged_vars(log_vars) | |
else: | |
logger.info("Clearing stats") | |
return stats | |
def _training_or_validation_epoch( | |
self, | |
epoch: int, | |
loader: DataLoader, | |
model: ImplicitronModelBase, | |
optimizer: torch.optim.Optimizer, | |
stats: Stats, | |
validation: bool, | |
*, | |
accelerator: Optional[Accelerator], | |
bp_var: str = "objective", | |
device: torch.device, | |
**kwargs, | |
) -> None: | |
""" | |
This is the main loop for training and evaluation including: | |
model forward pass, loss computation, backward pass and visualization. | |
Args: | |
epoch: The index of the current epoch | |
loader: The dataloader to use for the loop | |
model: The model module optionally loaded from checkpoint | |
optimizer: The optimizer module optionally loaded from checkpoint | |
stats: The stats struct, also optionally loaded from checkpoint | |
validation: If true, run the loop with the model in eval mode | |
and skip the backward pass | |
accelerator: An optional Accelerator instance. | |
bp_var: The name of the key in the model output `preds` dict which | |
should be used as the loss for the backward pass. | |
device: The device on which to run the model. | |
""" | |
if validation: | |
model.eval() | |
trainmode = "val" | |
else: | |
model.train() | |
trainmode = "train" | |
t_start = time.time() | |
# get the visdom env name | |
visdom_env_imgs = stats.visdom_env + "_images_" + trainmode | |
viz = vis_utils.get_visdom_connection( | |
server=stats.visdom_server, | |
port=stats.visdom_port, | |
) | |
# Iterate through the batches | |
n_batches = len(loader) | |
for it, net_input in enumerate(loader): | |
last_iter = it == n_batches - 1 | |
# move to gpu where possible (in place) | |
net_input = net_input.to(device) | |
# run the forward pass | |
if not validation: | |
optimizer.zero_grad() | |
preds = model( | |
**{**net_input, "evaluation_mode": EvaluationMode.TRAINING} | |
) | |
else: | |
with torch.no_grad(): | |
preds = model( | |
**{**net_input, "evaluation_mode": EvaluationMode.EVALUATION} | |
) | |
# make sure we dont overwrite something | |
assert all(k not in preds for k in net_input.keys()) | |
# merge everything into one big dict | |
preds.update(net_input) | |
# update the stats logger | |
stats.update(preds, time_start=t_start, stat_set=trainmode) | |
# pyre-ignore [16] | |
assert stats.it[trainmode] == it, "inconsistent stat iteration number!" | |
# print textual status update | |
if it % self.metric_print_interval == 0 or last_iter: | |
std_out = stats.get_status_string(stat_set=trainmode, max_it=n_batches) | |
logger.info(std_out) | |
# visualize results | |
if ( | |
(accelerator is None or accelerator.is_local_main_process) | |
and self.visualize_interval > 0 | |
and it % self.visualize_interval == 0 | |
): | |
prefix = f"e{stats.epoch}_it{stats.it[trainmode]}" | |
if hasattr(model, "visualize"): | |
model.visualize( | |
viz, | |
visdom_env_imgs, | |
preds, | |
prefix, | |
) | |
# optimizer step | |
if not validation: | |
loss = preds[bp_var] | |
assert torch.isfinite(loss).all(), "Non-finite loss!" | |
# backprop | |
if accelerator is None: | |
loss.backward() | |
else: | |
accelerator.backward(loss) | |
if self.clip_grad > 0.0: | |
# Optionally clip the gradient norms. | |
total_norm = torch.nn.utils.clip_grad_norm( | |
model.parameters(), self.clip_grad | |
) | |
if total_norm > self.clip_grad: | |
logger.debug( | |
f"Clipping gradient: {total_norm}" | |
+ f" with coef {self.clip_grad / float(total_norm)}." | |
) | |
optimizer.step() | |
def _checkpoint( | |
self, | |
accelerator: Optional[Accelerator], | |
epoch: int, | |
exp_dir: str, | |
model: ImplicitronModelBase, | |
optimizer: torch.optim.Optimizer, | |
stats: Stats, | |
): | |
""" | |
Save a model and its corresponding Stats object to a file, if | |
`self.store_checkpoints` is True. In addition, if | |
`self.store_checkpoints_purge` is True, remove any checkpoints older | |
than `self.store_checkpoints_purge` epochs old. | |
""" | |
if self.store_checkpoints and ( | |
accelerator is None or accelerator.is_local_main_process | |
): | |
if self.store_checkpoints_purge > 0: | |
for prev_epoch in range(epoch - self.store_checkpoints_purge): | |
model_io.purge_epoch(exp_dir, prev_epoch) | |
outfile = model_io.get_checkpoint(exp_dir, epoch) | |
unwrapped_model = ( | |
model if accelerator is None else accelerator.unwrap_model(model) | |
) | |
model_io.safe_save_model( | |
unwrapped_model, stats, outfile, optimizer=optimizer | |
) | |