zdou0830's picture
desco
749745d
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import logging
import os
import torch
from maskrcnn_benchmark.utils.model_serialization import load_state_dict
from maskrcnn_benchmark.utils.c2_model_loading import load_c2_format
from maskrcnn_benchmark.utils.big_model_loading import load_big_format
from maskrcnn_benchmark.utils.pretrain_model_loading import load_pretrain_format
from maskrcnn_benchmark.utils.imports import import_file
from maskrcnn_benchmark.utils.model_zoo import cache_url
class Checkpointer(object):
def __init__(
self,
model,
optimizer=None,
scheduler=None,
save_dir="",
save_to_disk=None,
logger=None,
):
self.model = model
self.optimizer = optimizer
self.scheduler = scheduler
self.save_dir = save_dir
self.save_to_disk = save_to_disk
if logger is None:
logger = logging.getLogger(__name__)
self.logger = logger
def save(self, name, **kwargs):
if not self.save_dir:
return
if not self.save_to_disk:
return
data = {}
data["model"] = self.model.state_dict()
if self.optimizer is not None:
data["optimizer"] = self.optimizer.state_dict()
if self.scheduler is not None:
if isinstance(self.scheduler, list):
data["scheduler"] = [scheduler.state_dict() for scheduler in self.scheduler]
else:
data["scheduler"] = self.scheduler.state_dict()
data.update(kwargs)
save_file = os.path.join(self.save_dir, "{}.pth".format(name))
self.logger.info("Saving checkpoint to {}".format(save_file))
torch.save(data, save_file)
# self.tag_last_checkpoint(save_file)
# use relative path name to save the checkpoint
self.tag_last_checkpoint("{}.pth".format(name))
def load(self, f=None, force=False, keyword="model", skip_optimizer=False, skip_scheduler=False):
resume = False
if self.has_checkpoint() and not force:
# override argument with existing checkpoint
f = self.get_checkpoint_file()
# get the absolute path
f = os.path.join(self.save_dir, f)
resume = True
if not f:
# no checkpoint could be found
self.logger.info("No checkpoint found. Initializing model from scratch")
return {}
self.logger.info("Loading checkpoint from {}".format(f))
checkpoint = self._load_file(f)
self._load_model(checkpoint, keyword=keyword)
# if resume training, load optimizer and scheduler,
# otherwise use the specified LR in config yaml for fine-tuning
if resume and not skip_optimizer:
if "optimizer" in checkpoint and self.optimizer:
self.logger.info("Loading optimizer from {}".format(f))
self.optimizer.load_state_dict(checkpoint.pop("optimizer"))
if "scheduler" in checkpoint and self.scheduler and not skip_scheduler:
self.logger.info("Loading scheduler from {}".format(f))
if isinstance(self.scheduler, list):
for scheduler, state_dict in zip(self.scheduler, checkpoint.pop("scheduler")):
scheduler.load_state_dict(state_dict)
else:
self.scheduler.load_state_dict(checkpoint.pop("scheduler"))
# print("Scheduler", {k:v for k,v in self.scheduler.state_dict() if k != "base_lrs"})
# return any further checkpoint data
return checkpoint
else:
return {}
def has_checkpoint(self):
save_file = os.path.join(self.save_dir, "last_checkpoint")
return os.path.exists(save_file)
def get_checkpoint_file(self):
save_file = os.path.join(self.save_dir, "last_checkpoint")
try:
with open(save_file, "r") as f:
last_saved = f.read()
last_saved = last_saved.strip()
except IOError:
# if file doesn't exist, maybe because it has just been
# deleted by a separate process
last_saved = ""
return last_saved
def tag_last_checkpoint(self, last_filename):
save_file = os.path.join(self.save_dir, "last_checkpoint")
with open(save_file, "w") as f:
f.write(last_filename)
def _load_file(self, f):
return torch.load(f, map_location=torch.device("cpu"))
def _load_model(self, checkpoint, keyword="model"):
load_state_dict(self.model, checkpoint.pop(keyword))
class DetectronCheckpointer(Checkpointer):
def __init__(
self,
cfg,
model,
optimizer=None,
scheduler=None,
save_dir="",
save_to_disk=None,
logger=None,
):
super(DetectronCheckpointer, self).__init__(model, optimizer, scheduler, save_dir, save_to_disk, logger)
self.cfg = cfg.clone()
def _load_file(self, f):
# catalog lookup
if f.startswith("catalog://"):
paths_catalog = import_file("maskrcnn_benchmark.config.paths_catalog", self.cfg.PATHS_CATALOG, True)
catalog_f = paths_catalog.ModelCatalog.get(f[len("catalog://") :])
self.logger.info("{} points to {}".format(f, catalog_f))
f = catalog_f
# download url files
if f.startswith("http"):
# if the file is a url path, download it and cache it
cached_f = cache_url(f)
self.logger.info("url {} cached in {}".format(f, cached_f))
f = cached_f
# convert Caffe2 checkpoint from pkl
if f.endswith(".pkl"):
return load_c2_format(self.cfg, f)
if f.endswith(".big"):
return load_big_format(self.cfg, f)
if f.endswith(".pretrain"):
return load_pretrain_format(self.cfg, f)
# load native detectron.pytorch checkpoint
loaded = super(DetectronCheckpointer, self)._load_file(f)
if "model" not in loaded:
loaded = dict(model=loaded)
return loaded