from __future__ import print_function, absolute_import import json import os import sys # import moxing as mox import os.path as osp import shutil import torch from torch.nn import Parameter from .osutils import mkdir_if_missing from config import get_args global_args = get_args(sys.argv[1:]) if global_args.run_on_remote: import moxing as mox def read_json(fpath): with open(fpath, 'r') as f: obj = json.load(f) return obj def write_json(obj, fpath): mkdir_if_missing(osp.dirname(fpath)) with open(fpath, 'w') as f: json.dump(obj, f, indent=4, separators=(',', ': ')) def save_checkpoint(state, is_best, fpath='checkpoint.pth.tar'): print('=> saving checkpoint ', fpath) if global_args.run_on_remote: dir_name = osp.dirname(fpath) if not mox.file.exists(dir_name): mox.file.make_dirs(dir_name) print('=> makding dir ', dir_name) local_path = "local_checkpoint.pth.tar" torch.save(state, local_path) mox.file.copy(local_path, fpath) if is_best: mox.file.copy(local_path, osp.join(dir_name, 'model_best.pth.tar')) else: mkdir_if_missing(osp.dirname(fpath)) torch.save(state, fpath) if is_best: shutil.copy(fpath, osp.join(osp.dirname(fpath), 'model_best.pth.tar')) def load_checkpoint(fpath): if global_args.run_on_remote: mox.file.shift('os', 'mox') checkpoint = torch.load(fpath) print("=> Loaded checkpoint '{}'".format(fpath)) return checkpoint else: load_path = fpath if osp.isfile(load_path): checkpoint = torch.load(load_path) print("=> Loaded checkpoint '{}'".format(load_path)) return checkpoint else: raise ValueError("=> No checkpoint found at '{}'".format(load_path)) def copy_state_dict(state_dict, model, strip=None): tgt_state = model.state_dict() copied_names = set() for name, param in state_dict.items(): if strip is not None and name.startswith(strip): name = name[len(strip):] if name not in tgt_state: continue if isinstance(param, Parameter): param = param.data if param.size() != tgt_state[name].size(): print('mismatch:', name, param.size(), tgt_state[name].size()) continue tgt_state[name].copy_(param) copied_names.add(name) missing = set(tgt_state.keys()) - copied_names if len(missing) > 0: print("missing keys in state_dict:", missing) return model