import os.path as osp |
import warnings |
warnings.filterwarnings('ignore') |
from typing import Optional |
from pathlib import Path |
from models.maplocnet import MapLocNet |
import hydra |
import pytorch_lightning as pl |
import torch |
from omegaconf import DictConfig, OmegaConf |
from pytorch_lightning.utilities import rank_zero_only |
from module import GenericModule |
from logger import logger, pl_logger, EXPERIMENTS_PATH |
from module import GenericModule |
from dataset import UavMapDatasetModule |
from pytorch_lightning.callbacks.early_stopping import EarlyStopping |
class CleanProgressBar(pl.callbacks.TQDMProgressBar): |
def get_metrics(self, trainer, model): |
items = super().get_metrics(trainer, model) |
items.pop("v_num", None) |
items.pop("loss", None) |
return items |
class SeedingCallback(pl.callbacks.Callback): |
def on_epoch_start_(self, trainer, module): |
seed = module.cfg.experiment.seed |
is_overfit = module.cfg.training.trainer.get("overfit_batches", 0) > 0 |
if trainer.training and not is_overfit: |
seed = seed + trainer.current_epoch |
pl_logger.disabled = True |
try: |
pl.seed_everything(seed, workers=True) |
finally: |
pl_logger.disabled = False |
def on_train_epoch_start(self, *args, **kwargs): |
self.on_epoch_start_(*args, **kwargs) |
def on_validation_epoch_start(self, *args, **kwargs): |
self.on_epoch_start_(*args, **kwargs) |
def on_test_epoch_start(self, *args, **kwargs): |
self.on_epoch_start_(*args, **kwargs) |
class ConsoleLogger(pl.callbacks.Callback): |
@rank_zero_only |
def on_train_epoch_start(self, trainer, module): |
logger.info( |
"New training epoch %d for experiment '%s'.", |
module.current_epoch, |
module.cfg.experiment.name, |
) |
def find_last_checkpoint_path(experiment_dir): |
cls = pl.callbacks.ModelCheckpoint |
path = osp.join(experiment_dir, cls.CHECKPOINT_NAME_LAST + cls.FILE_EXTENSION) |
if osp.exists(path): |
return path |
else: |
return None |
def prepare_experiment_dir(experiment_dir, cfg, rank): |
config_path = osp.join(experiment_dir, "config.yaml") |
last_checkpoint_path = find_last_checkpoint_path(experiment_dir) |
if last_checkpoint_path is not None: |
if rank == 0: |
logger.info( |
"Resuming the training from checkpoint %s", last_checkpoint_path |
) |
if osp.exists(config_path): |
with open(config_path, "r") as fp: |
cfg_prev = OmegaConf.create(fp.read()) |
compare_keys = ["experiment", "data", "model", "training"] |
if OmegaConf.masked_copy(cfg, compare_keys) != OmegaConf.masked_copy( |
cfg_prev, compare_keys |
): |
raise ValueError( |
"Attempting to resume training with a different config: " |
f"{OmegaConf.masked_copy(cfg, compare_keys)} vs " |
f"{OmegaConf.masked_copy(cfg_prev, compare_keys)}" |
) |
if rank == 0: |
Path(experiment_dir).mkdir(exist_ok=True, parents=True) |
with open(config_path, "w") as fp: |
OmegaConf.save(cfg, fp) |
return last_checkpoint_path |
def train(cfg: DictConfig) -> None: |
torch.set_float32_matmul_precision("medium") |
OmegaConf.resolve(cfg) |
rank = rank_zero_only.rank |
if rank == 0: |
logger.info("Starting training with config:\n%s", OmegaConf.to_yaml(cfg)) |
if cfg.experiment.gpus in (None, 0): |
logger.warning("Will train on CPU...") |
cfg.experiment.gpus = 0 |
elif not torch.cuda.is_available(): |
raise ValueError("Requested GPU but no NVIDIA drivers found.") |
pl.seed_everything(cfg.experiment.seed, workers=True) |
init_checkpoint_path = cfg.training.get("finetune_from_checkpoint") |
if init_checkpoint_path is not None: |
logger.info("Initializing the model from checkpoint %s.", init_checkpoint_path) |
model = GenericModule.load_from_checkpoint( |
init_checkpoint_path, strict=True, find_best=False, cfg=cfg |
) |
else: |
model = GenericModule(cfg) |
if rank == 0: |
logger.info("Network:\n%s", model.model) |
experiment_dir = osp.join(EXPERIMENTS_PATH, cfg.experiment.name) |
last_checkpoint_path = prepare_experiment_dir(experiment_dir, cfg, rank) |
checkpointing_epoch = pl.callbacks.ModelCheckpoint( |
dirpath=experiment_dir, |
filename="checkpoint-epoch-{epoch:02d}-loss-{loss/total/val:02f}", |
auto_insert_metric_name=False, |
save_last=True, |
every_n_epochs=1, |
save_on_train_epoch_end=True, |
verbose=True, |
**cfg.training.checkpointing, |
) |
checkpointing_step = pl.callbacks.ModelCheckpoint( |
dirpath=experiment_dir, |
filename="checkpoint-step-{step}-{loss/total/val:02f}", |
auto_insert_metric_name=False, |
save_last=True, |
every_n_train_steps=1000, |
verbose=True, |
**cfg.training.checkpointing, |
) |
checkpointing_step.CHECKPOINT_NAME_LAST = "last-step-checkpointing" |
early_stopping_callback = EarlyStopping(monitor=cfg.training.checkpointing.monitor, patience=5) |
strategy = None |
if cfg.experiment.gpus > 1: |
strategy = pl.strategies.DDPStrategy(find_unused_parameters=False) |
for split in ["train", "val"]: |
cfg.data[split].batch_size = ( |
cfg.data[split].batch_size // cfg.experiment.gpus |
) |
cfg.data[split].num_workers = int( |
(cfg.data[split].num_workers + cfg.experiment.gpus - 1) |
/ cfg.experiment.gpus |
) |
datamodule =UavMapDatasetModule(cfg.data) |
tb_args = {"name": cfg.experiment.name, "version": ""} |
tb = pl.loggers.TensorBoardLogger(EXPERIMENTS_PATH, **tb_args) |
callbacks = [ |
checkpointing_epoch, |
checkpointing_step, |
pl.callbacks.LearningRateMonitor(), |
SeedingCallback(), |
CleanProgressBar(), |
ConsoleLogger(), |
] |
if cfg.experiment.gpus > 0: |
callbacks.append(pl.callbacks.DeviceStatsMonitor()) |
trainer = pl.Trainer( |
default_root_dir=experiment_dir, |
detect_anomaly=False, |
enable_model_summary=True, |
sync_batchnorm=True, |
enable_checkpointing=True, |
logger=tb, |
callbacks=callbacks, |
strategy=strategy, |
check_val_every_n_epoch=1, |
accelerator="gpu", |
num_nodes=1, |
**cfg.training.trainer, |
) |
trainer.fit(model=model, datamodule=datamodule, ckpt_path=last_checkpoint_path) |
@hydra.main( |
config_path=osp.join(osp.dirname(__file__), "conf"), config_name="maplocnet.yaml" |
) |
def main(cfg: DictConfig) -> None: |
OmegaConf.save(config=cfg, f='maplocnet.yaml') |
train(cfg) |
if __name__ == "__main__": |
main() |