# Copyright 2022 Dakewe Biotech Corporation. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== import os import shutil from enum import Enum from typing import Any import torch from torch import nn from torch.nn import Module from torch.optim import Optimizer __all__ = [ "load_state_dict", "make_directory", "save_checkpoint", "Summary", "AverageMeter", "ProgressMeter" ] def load_state_dict( model: nn.Module, model_weights_path: str, ema_model: nn.Module = None, optimizer: torch.optim.Optimizer = None, scheduler: torch.optim.lr_scheduler = None, load_mode: str = None, ) -> tuple[Module, Module, Any, Any, Any, Optimizer | None, Any] | tuple[Module, Any, Any, Any, Optimizer | None, Any] | Module: # Load model weights checkpoint = torch.load(model_weights_path, map_location=lambda storage, loc: storage) if load_mode == "resume": # Restore the parameters in the training node to this point start_epoch = checkpoint["epoch"] best_psnr = checkpoint["best_psnr"] best_ssim = checkpoint["best_ssim"] # Load model state dict. Extract the fitted model weights model_state_dict = model.state_dict() state_dict = {k: v for k, v in checkpoint["state_dict"].items() if k in model_state_dict.keys()} # Overwrite the model weights to the current model (base model) model_state_dict.update(state_dict) model.load_state_dict(model_state_dict) # Load the optimizer model optimizer.load_state_dict(checkpoint["optimizer"]) if scheduler is not None: # Load the scheduler model scheduler.load_state_dict(checkpoint["scheduler"]) if ema_model is not None: # Load ema model state dict. Extract the fitted model weights ema_model_state_dict = ema_model.state_dict() ema_state_dict = {k: v for k, v in checkpoint["ema_state_dict"].items() if k in ema_model_state_dict.keys()} # Overwrite the model weights to the current model (ema model) ema_model_state_dict.update(ema_state_dict) ema_model.load_state_dict(ema_model_state_dict) return model, ema_model, start_epoch, best_psnr, best_ssim, optimizer, scheduler else: # Load model state dict. Extract the fitted model weights model_state_dict = model.state_dict() state_dict = {k: v for k, v in checkpoint["state_dict"].items() if k in model_state_dict.keys() and v.size() == model_state_dict[k].size()} # Overwrite the model weights to the current model model_state_dict.update(state_dict) model.load_state_dict(model_state_dict) return model def make_directory(dir_path: str) -> None: if not os.path.exists(dir_path): os.makedirs(dir_path) def save_checkpoint( state_dict: dict, file_name: str, samples_dir: str, results_dir: str, best_file_name: str, last_file_name: str, is_best: bool = False, is_last: bool = False, ) -> None: checkpoint_path = os.path.join(samples_dir, file_name) torch.save(state_dict, checkpoint_path) if is_best: shutil.copyfile(checkpoint_path, os.path.join(results_dir, best_file_name)) if is_last: shutil.copyfile(checkpoint_path, os.path.join(results_dir, last_file_name)) class Summary(Enum): NONE = 0 AVERAGE = 1 SUM = 2 COUNT = 3 class AverageMeter(object): def __init__(self, name, fmt=":f", summary_type=Summary.AVERAGE): self.name = name self.fmt = fmt self.summary_type = summary_type self.reset() def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count def __str__(self): fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})" return fmtstr.format(**self.__dict__) def summary(self): if self.summary_type is Summary.NONE: fmtstr = "" elif self.summary_type is Summary.AVERAGE: fmtstr = "{name} {avg:.2f}" elif self.summary_type is Summary.SUM: fmtstr = "{name} {sum:.2f}" elif self.summary_type is Summary.COUNT: fmtstr = "{name} {count:.2f}" else: raise ValueError(f"Invalid summary type {self.summary_type}") return fmtstr.format(**self.__dict__) class ProgressMeter(object): def __init__(self, num_batches, meters, prefix=""): self.batch_fmtstr = self._get_batch_fmtstr(num_batches) self.meters = meters self.prefix = prefix def display(self, batch): entries = [self.prefix + self.batch_fmtstr.format(batch)] entries += [str(meter) for meter in self.meters] print("\t".join(entries)) def display_summary(self): entries = [" *"] entries += [meter.summary() for meter in self.meters] print(" ".join(entries)) def _get_batch_fmtstr(self, num_batches): num_digits = len(str(num_batches // 1)) fmt = "{:" + str(num_digits) + "d}" return "[" + fmt + "/" + fmt.format(num_batches) + "]"