# Copyright (c) OpenMMLab. All rights reserved. import warnings from mmengine.logging import MMLogger from mmcls.registry import MODELS from .base_backbone import BaseBackbone def print_timm_feature_info(feature_info): """Print feature_info of timm backbone to help development and debug. Args: feature_info (list[dict] | timm.models.features.FeatureInfo | None): feature_info of timm backbone. """ logger = MMLogger.get_current_instance() if feature_info is None: logger.warning('This backbone does not have feature_info') elif isinstance(feature_info, list): for feat_idx, each_info in enumerate(feature_info): logger.info(f'backbone feature_info[{feat_idx}]: {each_info}') else: try: logger.info(f'backbone out_indices: {feature_info.out_indices}') logger.info(f'backbone out_channels: {feature_info.channels()}') logger.info(f'backbone out_strides: {feature_info.reduction()}') except AttributeError: logger.warning('Unexpected format of backbone feature_info') @MODELS.register_module() class TIMMBackbone(BaseBackbone): """Wrapper to use backbones from timm library. More details can be found in `timm `_. See especially the document for `feature extraction `_. Args: model_name (str): Name of timm model to instantiate. features_only (bool): Whether to extract feature pyramid (multi-scale feature maps from the deepest layer at each stride). For Vision Transformer models that do not support this argument, set this False. Defaults to False. pretrained (bool): Whether to load pretrained weights. Defaults to False. checkpoint_path (str): Path of checkpoint to load at the last of ``timm.create_model``. Defaults to empty string, which means not loading. in_channels (int): Number of input image channels. Defaults to 3. init_cfg (dict or list[dict], optional): Initialization config dict of OpenMMLab projects. Defaults to None. **kwargs: Other timm & model specific arguments. """ def __init__(self, model_name, features_only=False, pretrained=False, checkpoint_path='', in_channels=3, init_cfg=None, **kwargs): try: import timm except ImportError: raise ImportError( 'Failed to import timm. Please run "pip install timm". ' '"pip install dataclasses" may also be needed for Python 3.6.') if not isinstance(pretrained, bool): raise TypeError('pretrained must be bool, not str for model path') if features_only and checkpoint_path: warnings.warn( 'Using both features_only and checkpoint_path will cause error' ' in timm. See ' 'https://github.com/rwightman/pytorch-image-models/issues/488') super(TIMMBackbone, self).__init__(init_cfg) if 'norm_layer' in kwargs: kwargs['norm_layer'] = MODELS.get(kwargs['norm_layer']) self.timm_model = timm.create_model( model_name=model_name, features_only=features_only, pretrained=pretrained, in_chans=in_channels, checkpoint_path=checkpoint_path, **kwargs) # reset classifier if hasattr(self.timm_model, 'reset_classifier'): self.timm_model.reset_classifier(0, '') # Hack to use pretrained weights from timm if pretrained or checkpoint_path: self._is_init = True feature_info = getattr(self.timm_model, 'feature_info', None) print_timm_feature_info(feature_info) def forward(self, x): features = self.timm_model(x) if isinstance(features, (list, tuple)): features = tuple(features) else: features = (features, ) return features