|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import logging |
|
import os |
|
import time |
|
import weakref |
|
from collections import OrderedDict |
|
from typing import Any, Dict, List |
|
import pytorch_lightning as pl |
|
from pytorch_lightning import LightningDataModule, LightningModule |
|
|
|
import detectron2.utils.comm as comm |
|
from detectron2.checkpoint import DetectionCheckpointer |
|
from detectron2.config import get_cfg |
|
from detectron2.data import build_detection_test_loader, build_detection_train_loader |
|
from detectron2.engine import ( |
|
DefaultTrainer, |
|
SimpleTrainer, |
|
default_argument_parser, |
|
default_setup, |
|
default_writers, |
|
hooks, |
|
) |
|
from detectron2.evaluation import print_csv_format |
|
from detectron2.evaluation.testing import flatten_results_dict |
|
from detectron2.modeling import build_model |
|
from detectron2.solver import build_lr_scheduler, build_optimizer |
|
from detectron2.utils.events import EventStorage |
|
from detectron2.utils.logger import setup_logger |
|
|
|
from train_net import build_evaluator |
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger("detectron2") |
|
|
|
|
|
class TrainingModule(LightningModule): |
|
def __init__(self, cfg): |
|
super().__init__() |
|
if not logger.isEnabledFor(logging.INFO): |
|
setup_logger() |
|
self.cfg = DefaultTrainer.auto_scale_workers(cfg, comm.get_world_size()) |
|
self.storage: EventStorage = None |
|
self.model = build_model(self.cfg) |
|
|
|
self.start_iter = 0 |
|
self.max_iter = cfg.SOLVER.MAX_ITER |
|
|
|
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: |
|
checkpoint["iteration"] = self.storage.iter |
|
|
|
def on_load_checkpoint(self, checkpointed_state: Dict[str, Any]) -> None: |
|
self.start_iter = checkpointed_state["iteration"] |
|
self.storage.iter = self.start_iter |
|
|
|
def setup(self, stage: str): |
|
if self.cfg.MODEL.WEIGHTS: |
|
self.checkpointer = DetectionCheckpointer( |
|
|
|
self.model, |
|
self.cfg.OUTPUT_DIR, |
|
) |
|
logger.info(f"Load model weights from checkpoint: {self.cfg.MODEL.WEIGHTS}.") |
|
|
|
self.checkpointer.load(self.cfg.MODEL.WEIGHTS) |
|
|
|
self.iteration_timer = hooks.IterationTimer() |
|
self.iteration_timer.before_train() |
|
self.data_start = time.perf_counter() |
|
self.writers = None |
|
|
|
def training_step(self, batch, batch_idx): |
|
data_time = time.perf_counter() - self.data_start |
|
|
|
|
|
if self.storage is None: |
|
self.storage = EventStorage(0) |
|
self.storage.__enter__() |
|
self.iteration_timer.trainer = weakref.proxy(self) |
|
self.iteration_timer.before_step() |
|
self.writers = ( |
|
default_writers(self.cfg.OUTPUT_DIR, self.max_iter) |
|
if comm.is_main_process() |
|
else {} |
|
) |
|
|
|
loss_dict = self.model(batch) |
|
SimpleTrainer.write_metrics(loss_dict, data_time) |
|
|
|
opt = self.optimizers() |
|
self.storage.put_scalar( |
|
"lr", opt.param_groups[self._best_param_group_id]["lr"], smoothing_hint=False |
|
) |
|
self.iteration_timer.after_step() |
|
self.storage.step() |
|
|
|
self.iteration_timer.before_step() |
|
|
|
if self.storage.iter % 20 == 0: |
|
for writer in self.writers: |
|
writer.write() |
|
return sum(loss_dict.values()) |
|
|
|
def training_step_end(self, training_step_outpus): |
|
self.data_start = time.perf_counter() |
|
return training_step_outpus |
|
|
|
def training_epoch_end(self, training_step_outputs): |
|
self.iteration_timer.after_train() |
|
if comm.is_main_process(): |
|
self.checkpointer.save("model_final") |
|
for writer in self.writers: |
|
writer.write() |
|
writer.close() |
|
self.storage.__exit__(None, None, None) |
|
|
|
def _process_dataset_evaluation_results(self) -> OrderedDict: |
|
results = OrderedDict() |
|
for idx, dataset_name in enumerate(self.cfg.DATASETS.TEST): |
|
results[dataset_name] = self._evaluators[idx].evaluate() |
|
if comm.is_main_process(): |
|
print_csv_format(results[dataset_name]) |
|
|
|
if len(results) == 1: |
|
results = list(results.values())[0] |
|
return results |
|
|
|
def _reset_dataset_evaluators(self): |
|
self._evaluators = [] |
|
for dataset_name in self.cfg.DATASETS.TEST: |
|
evaluator = build_evaluator(self.cfg, dataset_name) |
|
evaluator.reset() |
|
self._evaluators.append(evaluator) |
|
|
|
def on_validation_epoch_start(self, _outputs): |
|
self._reset_dataset_evaluators() |
|
|
|
def validation_epoch_end(self, _outputs): |
|
results = self._process_dataset_evaluation_results(_outputs) |
|
|
|
flattened_results = flatten_results_dict(results) |
|
for k, v in flattened_results.items(): |
|
try: |
|
v = float(v) |
|
except Exception as e: |
|
raise ValueError( |
|
"[EvalHook] eval_function should return a nested dict of float. " |
|
"Got '{}: {}' instead.".format(k, v) |
|
) from e |
|
self.storage.put_scalars(**flattened_results, smoothing_hint=False) |
|
|
|
def validation_step(self, batch, batch_idx: int, dataloader_idx: int = 0) -> None: |
|
if not isinstance(batch, List): |
|
batch = [batch] |
|
outputs = self.model(batch) |
|
self._evaluators[dataloader_idx].process(batch, outputs) |
|
|
|
def configure_optimizers(self): |
|
optimizer = build_optimizer(self.cfg, self.model) |
|
self._best_param_group_id = hooks.LRScheduler.get_best_param_group_id(optimizer) |
|
scheduler = build_lr_scheduler(self.cfg, optimizer) |
|
return [optimizer], [{"scheduler": scheduler, "interval": "step"}] |
|
|
|
|
|
class DataModule(LightningDataModule): |
|
def __init__(self, cfg): |
|
super().__init__() |
|
self.cfg = DefaultTrainer.auto_scale_workers(cfg, comm.get_world_size()) |
|
|
|
def train_dataloader(self): |
|
return build_detection_train_loader(self.cfg) |
|
|
|
def val_dataloader(self): |
|
dataloaders = [] |
|
for dataset_name in self.cfg.DATASETS.TEST: |
|
dataloaders.append(build_detection_test_loader(self.cfg, dataset_name)) |
|
return dataloaders |
|
|
|
|
|
def main(args): |
|
cfg = setup(args) |
|
train(cfg, args) |
|
|
|
|
|
def train(cfg, args): |
|
trainer_params = { |
|
|
|
|
|
"max_epochs": 10**8, |
|
"max_steps": cfg.SOLVER.MAX_ITER, |
|
"val_check_interval": cfg.TEST.EVAL_PERIOD if cfg.TEST.EVAL_PERIOD > 0 else 10**8, |
|
"num_nodes": args.num_machines, |
|
"gpus": args.num_gpus, |
|
"num_sanity_val_steps": 0, |
|
} |
|
if cfg.SOLVER.AMP.ENABLED: |
|
trainer_params["precision"] = 16 |
|
|
|
last_checkpoint = os.path.join(cfg.OUTPUT_DIR, "last.ckpt") |
|
if args.resume: |
|
|
|
trainer_params["resume_from_checkpoint"] = last_checkpoint |
|
logger.info(f"Resuming training from checkpoint: {last_checkpoint}.") |
|
|
|
trainer = pl.Trainer(**trainer_params) |
|
logger.info(f"start to train with {args.num_machines} nodes and {args.num_gpus} GPUs") |
|
|
|
module = TrainingModule(cfg) |
|
data_module = DataModule(cfg) |
|
if args.eval_only: |
|
logger.info("Running inference") |
|
trainer.validate(module, data_module) |
|
else: |
|
logger.info("Running training") |
|
trainer.fit(module, data_module) |
|
|
|
|
|
def setup(args): |
|
""" |
|
Create configs and perform basic setups. |
|
""" |
|
cfg = get_cfg() |
|
cfg.merge_from_file(args.config_file) |
|
cfg.merge_from_list(args.opts) |
|
cfg.freeze() |
|
default_setup(cfg, args) |
|
return cfg |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = default_argument_parser() |
|
args = parser.parse_args() |
|
logger.info("Command Line Args:", args) |
|
main(args) |
|
|