KyanChen's picture
init
f549064
raw
history blame
No virus
8.47 kB
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import fnmatch
import os.path as osp
import warnings
from os import PathLike
from pathlib import Path
from typing import List, Union
from mmengine.config import Config
from modelindex.load_model_index import load
from modelindex.models.Model import Model
class ModelHub:
"""A hub to host the meta information of all pre-defined models."""
_models_dict = {}
__mmcls_registered = False
@classmethod
def register_model_index(cls,
model_index_path: Union[str, PathLike],
config_prefix: Union[str, PathLike, None] = None):
"""Parse the model-index file and register all models.
Args:
model_index_path (str | PathLike): The path of the model-index
file.
config_prefix (str | PathLike | None): The prefix of all config
file paths in the model-index file.
"""
model_index = load(str(model_index_path))
model_index.build_models_with_collections()
for metainfo in model_index.models:
model_name = metainfo.name.lower()
if metainfo.name in cls._models_dict:
raise ValueError(
'The model name {} is conflict in {} and {}.'.format(
model_name, osp.abspath(metainfo.filepath),
osp.abspath(cls._models_dict[model_name].filepath)))
metainfo.config = cls._expand_config_path(metainfo, config_prefix)
cls._models_dict[model_name] = metainfo
@classmethod
def get(cls, model_name):
"""Get the model's metainfo by the model name.
Args:
model_name (str): The name of model.
Returns:
modelindex.models.Model: The metainfo of the specified model.
"""
cls._register_mmcls_models()
# lazy load config
metainfo = copy.deepcopy(cls._models_dict.get(model_name.lower()))
if metainfo is None:
raise ValueError(f'Failed to find model {model_name}.')
if isinstance(metainfo.config, str):
metainfo.config = Config.fromfile(metainfo.config)
return metainfo
@staticmethod
def _expand_config_path(metainfo: Model,
config_prefix: Union[str, PathLike] = None):
if config_prefix is None:
config_prefix = osp.dirname(metainfo.filepath)
if metainfo.config is None or osp.isabs(metainfo.config):
config_path: str = metainfo.config
else:
config_path = osp.abspath(osp.join(config_prefix, metainfo.config))
return config_path
@classmethod
def _register_mmcls_models(cls):
# register models in mmcls
if not cls.__mmcls_registered:
from mmengine.utils import get_installed_path
mmcls_root = Path(get_installed_path('mmcls'))
model_index_path = mmcls_root / '.mim' / 'model-index.yml'
ModelHub.register_model_index(
model_index_path, config_prefix=mmcls_root / '.mim')
cls.__mmcls_registered = True
def init_model(config, checkpoint=None, device=None, **kwargs):
"""Initialize a classifier from config file.
Args:
config (str | :obj:`mmengine.Config`): Config file path or the config
object.
checkpoint (str, optional): Checkpoint path. If left as None, the model
will not load any weights.
device (str | torch.device | None): Transfer the model to the target
device. Defaults to None.
**kwargs: Other keyword arguments of the model config.
Returns:
nn.Module: The constructed model.
"""
if isinstance(config, (str, PathLike)):
config = Config.fromfile(config)
elif not isinstance(config, Config):
raise TypeError('config must be a filename or Config object, '
f'but got {type(config)}')
if kwargs:
config.merge_from_dict({'model': kwargs})
config.model.setdefault('data_preprocessor',
config.get('data_preprocessor', None))
import mmcls.models # noqa: F401
from mmcls.registry import MODELS
config.model._scope_ = 'mmcls'
model = MODELS.build(config.model)
if checkpoint is not None:
# Mapping the weights to GPU may cause unexpected video memory leak
# which refers to https://github.com/open-mmlab/mmdetection/pull/6405
from mmengine.runner import load_checkpoint
checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')
if not model.with_head:
# Don't set CLASSES if the model is headless.
pass
elif 'dataset_meta' in checkpoint.get('meta', {}):
# mmcls 1.x
model.CLASSES = checkpoint['meta']['dataset_meta']['classes']
elif 'CLASSES' in checkpoint.get('meta', {}):
# mmcls < 1.x
model.CLASSES = checkpoint['meta']['CLASSES']
else:
from mmcls.datasets.categories import IMAGENET_CATEGORIES
warnings.simplefilter('once')
warnings.warn('Class names are not saved in the checkpoint\'s '
'meta data, use imagenet by default.')
model.CLASSES = IMAGENET_CATEGORIES
model.cfg = config # save the config in the model for convenience
model.to(device)
model.eval()
return model
def get_model(model_name, pretrained=False, device=None, **kwargs):
"""Get a pre-defined model by the name of model.
Args:
model_name (str): The name of model.
pretrained (bool | str): If True, load the pre-defined pretrained
weights. If a string, load the weights from it. Defaults to False.
device (str | torch.device | None): Transfer the model to the target
device. Defaults to None.
**kwargs: Other keyword arguments of the model config.
Returns:
mmengine.model.BaseModel: The result model.
Examples:
Get a ResNet-50 model and extract images feature:
>>> import torch
>>> from mmcls import get_model
>>> inputs = torch.rand(16, 3, 224, 224)
>>> model = get_model('resnet50_8xb32_in1k', pretrained=True, backbone=dict(out_indices=(0, 1, 2, 3)))
>>> feats = model.extract_feat(inputs)
>>> for feat in feats:
... print(feat.shape)
torch.Size([16, 256])
torch.Size([16, 512])
torch.Size([16, 1024])
torch.Size([16, 2048])
Get Swin-Transformer model with pre-trained weights and inference:
>>> from mmcls import get_model, inference_model
>>> model = get_model('swin-base_16xb64_in1k', pretrained=True)
>>> result = inference_model(model, 'demo/demo.JPEG')
>>> print(result['pred_class'])
'sea snake'
""" # noqa: E501
metainfo = ModelHub.get(model_name)
if isinstance(pretrained, str):
ckpt = pretrained
elif pretrained:
if metainfo.weights is None:
raise ValueError(
f"The model {model_name} doesn't have pretrained weights.")
ckpt = metainfo.weights
else:
ckpt = None
if metainfo.config is None:
raise ValueError(
f"The model {model_name} doesn't support building by now.")
model = init_model(metainfo.config, ckpt, device=device, **kwargs)
return model
def list_models(pattern=None) -> List[str]:
"""List all models available in MMClassification.
Args:
pattern (str | None): A wildcard pattern to match model names.
Returns:
List[str]: a list of model names.
Examples:
List all models:
>>> from mmcls import list_models
>>> print(list_models())
List ResNet-50 models on ImageNet-1k dataset:
>>> from mmcls import list_models
>>> print(list_models('resnet*in1k'))
['resnet50_8xb32_in1k',
'resnet50_8xb32-fp16_in1k',
'resnet50_8xb256-rsb-a1-600e_in1k',
'resnet50_8xb256-rsb-a2-300e_in1k',
'resnet50_8xb256-rsb-a3-100e_in1k']
"""
ModelHub._register_mmcls_models()
if pattern is None:
return sorted(list(ModelHub._models_dict.keys()))
# Always match keys with any postfix.
matches = fnmatch.filter(ModelHub._models_dict.keys(), pattern + '*')
return matches