# Copyright (c) OpenMMLab. All rights reserved. import copy import fnmatch import os.path as osp import warnings from os import PathLike from pathlib import Path from typing import List, Union from mmengine.config import Config from modelindex.load_model_index import load from modelindex.models.Model import Model class ModelHub: """A hub to host the meta information of all pre-defined models.""" _models_dict = {} __mmcls_registered = False @classmethod def register_model_index(cls, model_index_path: Union[str, PathLike], config_prefix: Union[str, PathLike, None] = None): """Parse the model-index file and register all models. Args: model_index_path (str | PathLike): The path of the model-index file. config_prefix (str | PathLike | None): The prefix of all config file paths in the model-index file. """ model_index = load(str(model_index_path)) model_index.build_models_with_collections() for metainfo in model_index.models: model_name = metainfo.name.lower() if metainfo.name in cls._models_dict: raise ValueError( 'The model name {} is conflict in {} and {}.'.format( model_name, osp.abspath(metainfo.filepath), osp.abspath(cls._models_dict[model_name].filepath))) metainfo.config = cls._expand_config_path(metainfo, config_prefix) cls._models_dict[model_name] = metainfo @classmethod def get(cls, model_name): """Get the model's metainfo by the model name. Args: model_name (str): The name of model. Returns: modelindex.models.Model: The metainfo of the specified model. """ cls._register_mmcls_models() # lazy load config metainfo = copy.deepcopy(cls._models_dict.get(model_name.lower())) if metainfo is None: raise ValueError(f'Failed to find model {model_name}.') if isinstance(metainfo.config, str): metainfo.config = Config.fromfile(metainfo.config) return metainfo @staticmethod def _expand_config_path(metainfo: Model, config_prefix: Union[str, PathLike] = None): if config_prefix is None: config_prefix = osp.dirname(metainfo.filepath) if metainfo.config is None or osp.isabs(metainfo.config): config_path: str = metainfo.config else: config_path = osp.abspath(osp.join(config_prefix, metainfo.config)) return config_path @classmethod def _register_mmcls_models(cls): # register models in mmcls if not cls.__mmcls_registered: from mmengine.utils import get_installed_path mmcls_root = Path(get_installed_path('mmcls')) model_index_path = mmcls_root / '.mim' / 'model-index.yml' ModelHub.register_model_index( model_index_path, config_prefix=mmcls_root / '.mim') cls.__mmcls_registered = True def init_model(config, checkpoint=None, device=None, **kwargs): """Initialize a classifier from config file. Args: config (str | :obj:`mmengine.Config`): Config file path or the config object. checkpoint (str, optional): Checkpoint path. If left as None, the model will not load any weights. device (str | torch.device | None): Transfer the model to the target device. Defaults to None. **kwargs: Other keyword arguments of the model config. Returns: nn.Module: The constructed model. """ if isinstance(config, (str, PathLike)): config = Config.fromfile(config) elif not isinstance(config, Config): raise TypeError('config must be a filename or Config object, ' f'but got {type(config)}') if kwargs: config.merge_from_dict({'model': kwargs}) config.model.setdefault('data_preprocessor', config.get('data_preprocessor', None)) import mmcls.models # noqa: F401 from mmcls.registry import MODELS config.model._scope_ = 'mmcls' model = MODELS.build(config.model) if checkpoint is not None: # Mapping the weights to GPU may cause unexpected video memory leak # which refers to https://github.com/open-mmlab/mmdetection/pull/6405 from mmengine.runner import load_checkpoint checkpoint = load_checkpoint(model, checkpoint, map_location='cpu') if not model.with_head: # Don't set CLASSES if the model is headless. pass elif 'dataset_meta' in checkpoint.get('meta', {}): # mmcls 1.x model.CLASSES = checkpoint['meta']['dataset_meta']['classes'] elif 'CLASSES' in checkpoint.get('meta', {}): # mmcls < 1.x model.CLASSES = checkpoint['meta']['CLASSES'] else: from mmcls.datasets.categories import IMAGENET_CATEGORIES warnings.simplefilter('once') warnings.warn('Class names are not saved in the checkpoint\'s ' 'meta data, use imagenet by default.') model.CLASSES = IMAGENET_CATEGORIES model.cfg = config # save the config in the model for convenience model.to(device) model.eval() return model def get_model(model_name, pretrained=False, device=None, **kwargs): """Get a pre-defined model by the name of model. Args: model_name (str): The name of model. pretrained (bool | str): If True, load the pre-defined pretrained weights. If a string, load the weights from it. Defaults to False. device (str | torch.device | None): Transfer the model to the target device. Defaults to None. **kwargs: Other keyword arguments of the model config. Returns: mmengine.model.BaseModel: The result model. Examples: Get a ResNet-50 model and extract images feature: >>> import torch >>> from mmcls import get_model >>> inputs = torch.rand(16, 3, 224, 224) >>> model = get_model('resnet50_8xb32_in1k', pretrained=True, backbone=dict(out_indices=(0, 1, 2, 3))) >>> feats = model.extract_feat(inputs) >>> for feat in feats: ... print(feat.shape) torch.Size([16, 256]) torch.Size([16, 512]) torch.Size([16, 1024]) torch.Size([16, 2048]) Get Swin-Transformer model with pre-trained weights and inference: >>> from mmcls import get_model, inference_model >>> model = get_model('swin-base_16xb64_in1k', pretrained=True) >>> result = inference_model(model, 'demo/demo.JPEG') >>> print(result['pred_class']) 'sea snake' """ # noqa: E501 metainfo = ModelHub.get(model_name) if isinstance(pretrained, str): ckpt = pretrained elif pretrained: if metainfo.weights is None: raise ValueError( f"The model {model_name} doesn't have pretrained weights.") ckpt = metainfo.weights else: ckpt = None if metainfo.config is None: raise ValueError( f"The model {model_name} doesn't support building by now.") model = init_model(metainfo.config, ckpt, device=device, **kwargs) return model def list_models(pattern=None) -> List[str]: """List all models available in MMClassification. Args: pattern (str | None): A wildcard pattern to match model names. Returns: List[str]: a list of model names. Examples: List all models: >>> from mmcls import list_models >>> print(list_models()) List ResNet-50 models on ImageNet-1k dataset: >>> from mmcls import list_models >>> print(list_models('resnet*in1k')) ['resnet50_8xb32_in1k', 'resnet50_8xb32-fp16_in1k', 'resnet50_8xb256-rsb-a1-600e_in1k', 'resnet50_8xb256-rsb-a2-300e_in1k', 'resnet50_8xb256-rsb-a3-100e_in1k'] """ ModelHub._register_mmcls_models() if pattern is None: return sorted(list(ModelHub._models_dict.keys())) # Always match keys with any postfix. matches = fnmatch.filter(ModelHub._models_dict.keys(), pattern + '*') return matches