|
import os.path as osp |
|
import warnings |
|
warnings.filterwarnings('ignore') |
|
from typing import Optional |
|
from pathlib import Path |
|
from models.maplocnet import MapLocNet |
|
import hydra |
|
import pytorch_lightning as pl |
|
import torch |
|
from omegaconf import DictConfig, OmegaConf |
|
from pytorch_lightning.utilities import rank_zero_only |
|
from module import GenericModule |
|
from logger import logger, pl_logger, EXPERIMENTS_PATH |
|
from module import GenericModule |
|
from dataset import UavMapDatasetModule |
|
from pytorch_lightning.callbacks.early_stopping import EarlyStopping |
|
|
|
|
|
|
|
class CleanProgressBar(pl.callbacks.TQDMProgressBar): |
|
def get_metrics(self, trainer, model): |
|
items = super().get_metrics(trainer, model) |
|
items.pop("v_num", None) |
|
items.pop("loss", None) |
|
return items |
|
|
|
|
|
class SeedingCallback(pl.callbacks.Callback): |
|
def on_epoch_start_(self, trainer, module): |
|
seed = module.cfg.experiment.seed |
|
is_overfit = module.cfg.training.trainer.get("overfit_batches", 0) > 0 |
|
if trainer.training and not is_overfit: |
|
seed = seed + trainer.current_epoch |
|
|
|
|
|
pl_logger.disabled = True |
|
try: |
|
pl.seed_everything(seed, workers=True) |
|
finally: |
|
pl_logger.disabled = False |
|
|
|
def on_train_epoch_start(self, *args, **kwargs): |
|
self.on_epoch_start_(*args, **kwargs) |
|
|
|
def on_validation_epoch_start(self, *args, **kwargs): |
|
self.on_epoch_start_(*args, **kwargs) |
|
|
|
def on_test_epoch_start(self, *args, **kwargs): |
|
self.on_epoch_start_(*args, **kwargs) |
|
|
|
|
|
class ConsoleLogger(pl.callbacks.Callback): |
|
@rank_zero_only |
|
def on_train_epoch_start(self, trainer, module): |
|
logger.info( |
|
"New training epoch %d for experiment '%s'.", |
|
module.current_epoch, |
|
module.cfg.experiment.name, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def find_last_checkpoint_path(experiment_dir): |
|
cls = pl.callbacks.ModelCheckpoint |
|
path = osp.join(experiment_dir, cls.CHECKPOINT_NAME_LAST + cls.FILE_EXTENSION) |
|
if osp.exists(path): |
|
return path |
|
else: |
|
return None |
|
|
|
|
|
def prepare_experiment_dir(experiment_dir, cfg, rank): |
|
config_path = osp.join(experiment_dir, "config.yaml") |
|
last_checkpoint_path = find_last_checkpoint_path(experiment_dir) |
|
if last_checkpoint_path is not None: |
|
if rank == 0: |
|
logger.info( |
|
"Resuming the training from checkpoint %s", last_checkpoint_path |
|
) |
|
if osp.exists(config_path): |
|
with open(config_path, "r") as fp: |
|
cfg_prev = OmegaConf.create(fp.read()) |
|
compare_keys = ["experiment", "data", "model", "training"] |
|
if OmegaConf.masked_copy(cfg, compare_keys) != OmegaConf.masked_copy( |
|
cfg_prev, compare_keys |
|
): |
|
raise ValueError( |
|
"Attempting to resume training with a different config: " |
|
f"{OmegaConf.masked_copy(cfg, compare_keys)} vs " |
|
f"{OmegaConf.masked_copy(cfg_prev, compare_keys)}" |
|
) |
|
if rank == 0: |
|
Path(experiment_dir).mkdir(exist_ok=True, parents=True) |
|
with open(config_path, "w") as fp: |
|
OmegaConf.save(cfg, fp) |
|
return last_checkpoint_path |
|
|
|
|
|
def train(cfg: DictConfig) -> None: |
|
torch.set_float32_matmul_precision("medium") |
|
OmegaConf.resolve(cfg) |
|
rank = rank_zero_only.rank |
|
|
|
if rank == 0: |
|
logger.info("Starting training with config:\n%s", OmegaConf.to_yaml(cfg)) |
|
if cfg.experiment.gpus in (None, 0): |
|
logger.warning("Will train on CPU...") |
|
cfg.experiment.gpus = 0 |
|
elif not torch.cuda.is_available(): |
|
raise ValueError("Requested GPU but no NVIDIA drivers found.") |
|
pl.seed_everything(cfg.experiment.seed, workers=True) |
|
|
|
init_checkpoint_path = cfg.training.get("finetune_from_checkpoint") |
|
if init_checkpoint_path is not None: |
|
logger.info("Initializing the model from checkpoint %s.", init_checkpoint_path) |
|
model = GenericModule.load_from_checkpoint( |
|
init_checkpoint_path, strict=True, find_best=False, cfg=cfg |
|
) |
|
else: |
|
model = GenericModule(cfg) |
|
if rank == 0: |
|
logger.info("Network:\n%s", model.model) |
|
|
|
experiment_dir = osp.join(EXPERIMENTS_PATH, cfg.experiment.name) |
|
last_checkpoint_path = prepare_experiment_dir(experiment_dir, cfg, rank) |
|
checkpointing_epoch = pl.callbacks.ModelCheckpoint( |
|
dirpath=experiment_dir, |
|
filename="checkpoint-epoch-{epoch:02d}-loss-{loss/total/val:02f}", |
|
auto_insert_metric_name=False, |
|
save_last=True, |
|
every_n_epochs=1, |
|
save_on_train_epoch_end=True, |
|
verbose=True, |
|
**cfg.training.checkpointing, |
|
) |
|
checkpointing_step = pl.callbacks.ModelCheckpoint( |
|
dirpath=experiment_dir, |
|
filename="checkpoint-step-{step}-{loss/total/val:02f}", |
|
auto_insert_metric_name=False, |
|
save_last=True, |
|
every_n_train_steps=1000, |
|
verbose=True, |
|
**cfg.training.checkpointing, |
|
) |
|
checkpointing_step.CHECKPOINT_NAME_LAST = "last-step-checkpointing" |
|
|
|
|
|
early_stopping_callback = EarlyStopping(monitor=cfg.training.checkpointing.monitor, patience=5) |
|
|
|
strategy = None |
|
if cfg.experiment.gpus > 1: |
|
strategy = pl.strategies.DDPStrategy(find_unused_parameters=False) |
|
for split in ["train", "val"]: |
|
cfg.data[split].batch_size = ( |
|
cfg.data[split].batch_size // cfg.experiment.gpus |
|
) |
|
cfg.data[split].num_workers = int( |
|
(cfg.data[split].num_workers + cfg.experiment.gpus - 1) |
|
/ cfg.experiment.gpus |
|
) |
|
|
|
|
|
|
|
datamodule =UavMapDatasetModule(cfg.data) |
|
|
|
tb_args = {"name": cfg.experiment.name, "version": ""} |
|
tb = pl.loggers.TensorBoardLogger(EXPERIMENTS_PATH, **tb_args) |
|
|
|
callbacks = [ |
|
checkpointing_epoch, |
|
checkpointing_step, |
|
|
|
pl.callbacks.LearningRateMonitor(), |
|
SeedingCallback(), |
|
CleanProgressBar(), |
|
ConsoleLogger(), |
|
] |
|
if cfg.experiment.gpus > 0: |
|
callbacks.append(pl.callbacks.DeviceStatsMonitor()) |
|
|
|
trainer = pl.Trainer( |
|
default_root_dir=experiment_dir, |
|
detect_anomaly=False, |
|
|
|
enable_model_summary=True, |
|
sync_batchnorm=True, |
|
enable_checkpointing=True, |
|
logger=tb, |
|
callbacks=callbacks, |
|
strategy=strategy, |
|
check_val_every_n_epoch=1, |
|
accelerator="gpu", |
|
num_nodes=1, |
|
**cfg.training.trainer, |
|
) |
|
trainer.fit(model=model, datamodule=datamodule, ckpt_path=last_checkpoint_path) |
|
|
|
|
|
@hydra.main( |
|
config_path=osp.join(osp.dirname(__file__), "conf"), config_name="maplocnet.yaml" |
|
) |
|
def main(cfg: DictConfig) -> None: |
|
OmegaConf.save(config=cfg, f='maplocnet.yaml') |
|
train(cfg) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|
|
|