|
|
|
import io |
|
import os |
|
import os.path as osp |
|
import pkgutil |
|
import time |
|
import warnings |
|
from collections import OrderedDict |
|
from importlib import import_module |
|
from tempfile import TemporaryDirectory |
|
|
|
import torch |
|
import torchvision |
|
from torch.optim import Optimizer |
|
from torch.utils import model_zoo |
|
from torch.nn import functional as F |
|
|
|
import annotator.uniformer.mmcv as mmcv |
|
from annotator.uniformer.mmcv.fileio import FileClient |
|
from annotator.uniformer.mmcv.fileio import load as load_file |
|
from annotator.uniformer.mmcv.parallel import is_module_wrapper |
|
from annotator.uniformer.mmcv.utils import mkdir_or_exist |
|
from annotator.uniformer.mmcv.runner import get_dist_info |
|
|
|
ENV_MMCV_HOME = 'MMCV_HOME' |
|
ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME' |
|
DEFAULT_CACHE_DIR = '~/.cache' |
|
|
|
|
|
def _get_mmcv_home(): |
|
mmcv_home = os.path.expanduser( |
|
os.getenv( |
|
ENV_MMCV_HOME, |
|
os.path.join( |
|
os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'mmcv'))) |
|
|
|
mkdir_or_exist(mmcv_home) |
|
return mmcv_home |
|
|
|
|
|
def load_state_dict(module, state_dict, strict=False, logger=None): |
|
"""Load state_dict to a module. |
|
|
|
This method is modified from :meth:`torch.nn.Module.load_state_dict`. |
|
Default value for ``strict`` is set to ``False`` and the message for |
|
param mismatch will be shown even if strict is False. |
|
|
|
Args: |
|
module (Module): Module that receives the state_dict. |
|
state_dict (OrderedDict): Weights. |
|
strict (bool): whether to strictly enforce that the keys |
|
in :attr:`state_dict` match the keys returned by this module's |
|
:meth:`~torch.nn.Module.state_dict` function. Default: ``False``. |
|
logger (:obj:`logging.Logger`, optional): Logger to log the error |
|
message. If not specified, print function will be used. |
|
""" |
|
unexpected_keys = [] |
|
all_missing_keys = [] |
|
err_msg = [] |
|
|
|
metadata = getattr(state_dict, '_metadata', None) |
|
state_dict = state_dict.copy() |
|
if metadata is not None: |
|
state_dict._metadata = metadata |
|
|
|
|
|
def load(module, prefix=''): |
|
|
|
|
|
if is_module_wrapper(module): |
|
module = module.module |
|
local_metadata = {} if metadata is None else metadata.get( |
|
prefix[:-1], {}) |
|
module._load_from_state_dict(state_dict, prefix, local_metadata, True, |
|
all_missing_keys, unexpected_keys, |
|
err_msg) |
|
for name, child in module._modules.items(): |
|
if child is not None: |
|
load(child, prefix + name + '.') |
|
|
|
load(module) |
|
load = None |
|
|
|
|
|
missing_keys = [ |
|
key for key in all_missing_keys if 'num_batches_tracked' not in key |
|
] |
|
|
|
if unexpected_keys: |
|
err_msg.append('unexpected key in source ' |
|
f'state_dict: {", ".join(unexpected_keys)}\n') |
|
if missing_keys: |
|
err_msg.append( |
|
f'missing keys in source state_dict: {", ".join(missing_keys)}\n') |
|
|
|
rank, _ = get_dist_info() |
|
if len(err_msg) > 0 and rank == 0: |
|
err_msg.insert( |
|
0, 'The model and loaded state dict do not match exactly\n') |
|
err_msg = '\n'.join(err_msg) |
|
if strict: |
|
raise RuntimeError(err_msg) |
|
elif logger is not None: |
|
logger.warning(err_msg) |
|
else: |
|
print(err_msg) |
|
|
|
|
|
def load_url_dist(url, model_dir=None): |
|
"""In distributed setting, this function only download checkpoint at local |
|
rank 0.""" |
|
rank, world_size = get_dist_info() |
|
rank = int(os.environ.get('LOCAL_RANK', rank)) |
|
if rank == 0: |
|
checkpoint = model_zoo.load_url(url, model_dir=model_dir) |
|
if world_size > 1: |
|
torch.distributed.barrier() |
|
if rank > 0: |
|
checkpoint = model_zoo.load_url(url, model_dir=model_dir) |
|
return checkpoint |
|
|
|
|
|
def load_pavimodel_dist(model_path, map_location=None): |
|
"""In distributed setting, this function only download checkpoint at local |
|
rank 0.""" |
|
try: |
|
from pavi import modelcloud |
|
except ImportError: |
|
raise ImportError( |
|
'Please install pavi to load checkpoint from modelcloud.') |
|
rank, world_size = get_dist_info() |
|
rank = int(os.environ.get('LOCAL_RANK', rank)) |
|
if rank == 0: |
|
model = modelcloud.get(model_path) |
|
with TemporaryDirectory() as tmp_dir: |
|
downloaded_file = osp.join(tmp_dir, model.name) |
|
model.download(downloaded_file) |
|
checkpoint = torch.load(downloaded_file, map_location=map_location) |
|
if world_size > 1: |
|
torch.distributed.barrier() |
|
if rank > 0: |
|
model = modelcloud.get(model_path) |
|
with TemporaryDirectory() as tmp_dir: |
|
downloaded_file = osp.join(tmp_dir, model.name) |
|
model.download(downloaded_file) |
|
checkpoint = torch.load( |
|
downloaded_file, map_location=map_location) |
|
return checkpoint |
|
|
|
|
|
def load_fileclient_dist(filename, backend, map_location): |
|
"""In distributed setting, this function only download checkpoint at local |
|
rank 0.""" |
|
rank, world_size = get_dist_info() |
|
rank = int(os.environ.get('LOCAL_RANK', rank)) |
|
allowed_backends = ['ceph'] |
|
if backend not in allowed_backends: |
|
raise ValueError(f'Load from Backend {backend} is not supported.') |
|
if rank == 0: |
|
fileclient = FileClient(backend=backend) |
|
buffer = io.BytesIO(fileclient.get(filename)) |
|
checkpoint = torch.load(buffer, map_location=map_location) |
|
if world_size > 1: |
|
torch.distributed.barrier() |
|
if rank > 0: |
|
fileclient = FileClient(backend=backend) |
|
buffer = io.BytesIO(fileclient.get(filename)) |
|
checkpoint = torch.load(buffer, map_location=map_location) |
|
return checkpoint |
|
|
|
|
|
def get_torchvision_models(): |
|
model_urls = dict() |
|
for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__): |
|
if ispkg: |
|
continue |
|
_zoo = import_module(f'torchvision.models.{name}') |
|
if hasattr(_zoo, 'model_urls'): |
|
_urls = getattr(_zoo, 'model_urls') |
|
model_urls.update(_urls) |
|
return model_urls |
|
|
|
|
|
def get_external_models(): |
|
mmcv_home = _get_mmcv_home() |
|
default_json_path = osp.join(mmcv.__path__[0], 'model_zoo/open_mmlab.json') |
|
default_urls = load_file(default_json_path) |
|
assert isinstance(default_urls, dict) |
|
external_json_path = osp.join(mmcv_home, 'open_mmlab.json') |
|
if osp.exists(external_json_path): |
|
external_urls = load_file(external_json_path) |
|
assert isinstance(external_urls, dict) |
|
default_urls.update(external_urls) |
|
|
|
return default_urls |
|
|
|
|
|
def get_mmcls_models(): |
|
mmcls_json_path = osp.join(mmcv.__path__[0], 'model_zoo/mmcls.json') |
|
mmcls_urls = load_file(mmcls_json_path) |
|
|
|
return mmcls_urls |
|
|
|
|
|
def get_deprecated_model_names(): |
|
deprecate_json_path = osp.join(mmcv.__path__[0], |
|
'model_zoo/deprecated.json') |
|
deprecate_urls = load_file(deprecate_json_path) |
|
assert isinstance(deprecate_urls, dict) |
|
|
|
return deprecate_urls |
|
|
|
|
|
def _process_mmcls_checkpoint(checkpoint): |
|
state_dict = checkpoint['state_dict'] |
|
new_state_dict = OrderedDict() |
|
for k, v in state_dict.items(): |
|
if k.startswith('backbone.'): |
|
new_state_dict[k[9:]] = v |
|
new_checkpoint = dict(state_dict=new_state_dict) |
|
|
|
return new_checkpoint |
|
|
|
|
|
def _load_checkpoint(filename, map_location=None): |
|
"""Load checkpoint from somewhere (modelzoo, file, url). |
|
|
|
Args: |
|
filename (str): Accept local filepath, URL, ``torchvision://xxx``, |
|
``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for |
|
details. |
|
map_location (str | None): Same as :func:`torch.load`. Default: None. |
|
|
|
Returns: |
|
dict | OrderedDict: The loaded checkpoint. It can be either an |
|
OrderedDict storing model weights or a dict containing other |
|
information, which depends on the checkpoint. |
|
""" |
|
if filename.startswith('modelzoo://'): |
|
warnings.warn('The URL scheme of "modelzoo://" is deprecated, please ' |
|
'use "torchvision://" instead') |
|
model_urls = get_torchvision_models() |
|
model_name = filename[11:] |
|
checkpoint = load_url_dist(model_urls[model_name]) |
|
elif filename.startswith('torchvision://'): |
|
model_urls = get_torchvision_models() |
|
model_name = filename[14:] |
|
checkpoint = load_url_dist(model_urls[model_name]) |
|
elif filename.startswith('open-mmlab://'): |
|
model_urls = get_external_models() |
|
model_name = filename[13:] |
|
deprecated_urls = get_deprecated_model_names() |
|
if model_name in deprecated_urls: |
|
warnings.warn(f'open-mmlab://{model_name} is deprecated in favor ' |
|
f'of open-mmlab://{deprecated_urls[model_name]}') |
|
model_name = deprecated_urls[model_name] |
|
model_url = model_urls[model_name] |
|
|
|
if model_url.startswith(('http://', 'https://')): |
|
checkpoint = load_url_dist(model_url) |
|
else: |
|
filename = osp.join(_get_mmcv_home(), model_url) |
|
if not osp.isfile(filename): |
|
raise IOError(f'{filename} is not a checkpoint file') |
|
checkpoint = torch.load(filename, map_location=map_location) |
|
elif filename.startswith('mmcls://'): |
|
model_urls = get_mmcls_models() |
|
model_name = filename[8:] |
|
checkpoint = load_url_dist(model_urls[model_name]) |
|
checkpoint = _process_mmcls_checkpoint(checkpoint) |
|
elif filename.startswith(('http://', 'https://')): |
|
checkpoint = load_url_dist(filename) |
|
elif filename.startswith('pavi://'): |
|
model_path = filename[7:] |
|
checkpoint = load_pavimodel_dist(model_path, map_location=map_location) |
|
elif filename.startswith('s3://'): |
|
checkpoint = load_fileclient_dist( |
|
filename, backend='ceph', map_location=map_location) |
|
else: |
|
if not osp.isfile(filename): |
|
raise IOError(f'{filename} is not a checkpoint file') |
|
checkpoint = torch.load(filename, map_location=map_location) |
|
return checkpoint |
|
|
|
|
|
def load_checkpoint(model, |
|
filename, |
|
map_location='cpu', |
|
strict=False, |
|
logger=None): |
|
"""Load checkpoint from a file or URI. |
|
|
|
Args: |
|
model (Module): Module to load checkpoint. |
|
filename (str): Accept local filepath, URL, ``torchvision://xxx``, |
|
``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for |
|
details. |
|
map_location (str): Same as :func:`torch.load`. |
|
strict (bool): Whether to allow different params for the model and |
|
checkpoint. |
|
logger (:mod:`logging.Logger` or None): The logger for error message. |
|
|
|
Returns: |
|
dict or OrderedDict: The loaded checkpoint. |
|
""" |
|
checkpoint = _load_checkpoint(filename, map_location) |
|
|
|
if not isinstance(checkpoint, dict): |
|
raise RuntimeError( |
|
f'No state_dict found in checkpoint file {filename}') |
|
|
|
if 'state_dict' in checkpoint: |
|
state_dict = checkpoint['state_dict'] |
|
elif 'model' in checkpoint: |
|
state_dict = checkpoint['model'] |
|
else: |
|
state_dict = checkpoint |
|
|
|
if list(state_dict.keys())[0].startswith('module.'): |
|
state_dict = {k[7:]: v for k, v in state_dict.items()} |
|
|
|
|
|
if sorted(list(state_dict.keys()))[0].startswith('encoder'): |
|
state_dict = {k.replace('encoder.', ''): v for k, v in state_dict.items() if k.startswith('encoder.')} |
|
|
|
|
|
if state_dict.get('absolute_pos_embed') is not None: |
|
absolute_pos_embed = state_dict['absolute_pos_embed'] |
|
N1, L, C1 = absolute_pos_embed.size() |
|
N2, C2, H, W = model.absolute_pos_embed.size() |
|
if N1 != N2 or C1 != C2 or L != H*W: |
|
logger.warning("Error in loading absolute_pos_embed, pass") |
|
else: |
|
state_dict['absolute_pos_embed'] = absolute_pos_embed.view(N2, H, W, C2).permute(0, 3, 1, 2) |
|
|
|
|
|
relative_position_bias_table_keys = [k for k in state_dict.keys() if "relative_position_bias_table" in k] |
|
for table_key in relative_position_bias_table_keys: |
|
table_pretrained = state_dict[table_key] |
|
table_current = model.state_dict()[table_key] |
|
L1, nH1 = table_pretrained.size() |
|
L2, nH2 = table_current.size() |
|
if nH1 != nH2: |
|
logger.warning(f"Error in loading {table_key}, pass") |
|
else: |
|
if L1 != L2: |
|
S1 = int(L1 ** 0.5) |
|
S2 = int(L2 ** 0.5) |
|
table_pretrained_resized = F.interpolate( |
|
table_pretrained.permute(1, 0).view(1, nH1, S1, S1), |
|
size=(S2, S2), mode='bicubic') |
|
state_dict[table_key] = table_pretrained_resized.view(nH2, L2).permute(1, 0) |
|
|
|
|
|
load_state_dict(model, state_dict, strict, logger) |
|
return checkpoint |
|
|
|
|
|
def weights_to_cpu(state_dict): |
|
"""Copy a model state_dict to cpu. |
|
|
|
Args: |
|
state_dict (OrderedDict): Model weights on GPU. |
|
|
|
Returns: |
|
OrderedDict: Model weights on GPU. |
|
""" |
|
state_dict_cpu = OrderedDict() |
|
for key, val in state_dict.items(): |
|
state_dict_cpu[key] = val.cpu() |
|
return state_dict_cpu |
|
|
|
|
|
def _save_to_state_dict(module, destination, prefix, keep_vars): |
|
"""Saves module state to `destination` dictionary. |
|
|
|
This method is modified from :meth:`torch.nn.Module._save_to_state_dict`. |
|
|
|
Args: |
|
module (nn.Module): The module to generate state_dict. |
|
destination (dict): A dict where state will be stored. |
|
prefix (str): The prefix for parameters and buffers used in this |
|
module. |
|
""" |
|
for name, param in module._parameters.items(): |
|
if param is not None: |
|
destination[prefix + name] = param if keep_vars else param.detach() |
|
for name, buf in module._buffers.items(): |
|
|
|
if buf is not None: |
|
destination[prefix + name] = buf if keep_vars else buf.detach() |
|
|
|
|
|
def get_state_dict(module, destination=None, prefix='', keep_vars=False): |
|
"""Returns a dictionary containing a whole state of the module. |
|
|
|
Both parameters and persistent buffers (e.g. running averages) are |
|
included. Keys are corresponding parameter and buffer names. |
|
|
|
This method is modified from :meth:`torch.nn.Module.state_dict` to |
|
recursively check parallel module in case that the model has a complicated |
|
structure, e.g., nn.Module(nn.Module(DDP)). |
|
|
|
Args: |
|
module (nn.Module): The module to generate state_dict. |
|
destination (OrderedDict): Returned dict for the state of the |
|
module. |
|
prefix (str): Prefix of the key. |
|
keep_vars (bool): Whether to keep the variable property of the |
|
parameters. Default: False. |
|
|
|
Returns: |
|
dict: A dictionary containing a whole state of the module. |
|
""" |
|
|
|
|
|
if is_module_wrapper(module): |
|
module = module.module |
|
|
|
|
|
if destination is None: |
|
destination = OrderedDict() |
|
destination._metadata = OrderedDict() |
|
destination._metadata[prefix[:-1]] = local_metadata = dict( |
|
version=module._version) |
|
_save_to_state_dict(module, destination, prefix, keep_vars) |
|
for name, child in module._modules.items(): |
|
if child is not None: |
|
get_state_dict( |
|
child, destination, prefix + name + '.', keep_vars=keep_vars) |
|
for hook in module._state_dict_hooks.values(): |
|
hook_result = hook(module, destination, prefix, local_metadata) |
|
if hook_result is not None: |
|
destination = hook_result |
|
return destination |
|
|
|
|
|
def save_checkpoint(model, filename, optimizer=None, meta=None): |
|
"""Save checkpoint to file. |
|
|
|
The checkpoint will have 3 fields: ``meta``, ``state_dict`` and |
|
``optimizer``. By default ``meta`` will contain version and time info. |
|
|
|
Args: |
|
model (Module): Module whose params are to be saved. |
|
filename (str): Checkpoint filename. |
|
optimizer (:obj:`Optimizer`, optional): Optimizer to be saved. |
|
meta (dict, optional): Metadata to be saved in checkpoint. |
|
""" |
|
if meta is None: |
|
meta = {} |
|
elif not isinstance(meta, dict): |
|
raise TypeError(f'meta must be a dict or None, but got {type(meta)}') |
|
meta.update(mmcv_version=mmcv.__version__, time=time.asctime()) |
|
|
|
if is_module_wrapper(model): |
|
model = model.module |
|
|
|
if hasattr(model, 'CLASSES') and model.CLASSES is not None: |
|
|
|
meta.update(CLASSES=model.CLASSES) |
|
|
|
checkpoint = { |
|
'meta': meta, |
|
'state_dict': weights_to_cpu(get_state_dict(model)) |
|
} |
|
|
|
if isinstance(optimizer, Optimizer): |
|
checkpoint['optimizer'] = optimizer.state_dict() |
|
elif isinstance(optimizer, dict): |
|
checkpoint['optimizer'] = {} |
|
for name, optim in optimizer.items(): |
|
checkpoint['optimizer'][name] = optim.state_dict() |
|
|
|
if filename.startswith('pavi://'): |
|
try: |
|
from pavi import modelcloud |
|
from pavi.exception import NodeNotFoundError |
|
except ImportError: |
|
raise ImportError( |
|
'Please install pavi to load checkpoint from modelcloud.') |
|
model_path = filename[7:] |
|
root = modelcloud.Folder() |
|
model_dir, model_name = osp.split(model_path) |
|
try: |
|
model = modelcloud.get(model_dir) |
|
except NodeNotFoundError: |
|
model = root.create_training_model(model_dir) |
|
with TemporaryDirectory() as tmp_dir: |
|
checkpoint_file = osp.join(tmp_dir, model_name) |
|
with open(checkpoint_file, 'wb') as f: |
|
torch.save(checkpoint, f) |
|
f.flush() |
|
model.create_file(checkpoint_file, name=model_name) |
|
else: |
|
mmcv.mkdir_or_exist(osp.dirname(filename)) |
|
|
|
with open(filename, 'wb') as f: |
|
torch.save(checkpoint, f) |
|
f.flush() |