KyanChen's picture
init
f549064
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from mmengine.logging import MMLogger
from mmcls.registry import MODELS
from .base_backbone import BaseBackbone
def print_timm_feature_info(feature_info):
"""Print feature_info of timm backbone to help development and debug.
Args:
feature_info (list[dict] | timm.models.features.FeatureInfo | None):
feature_info of timm backbone.
"""
logger = MMLogger.get_current_instance()
if feature_info is None:
logger.warning('This backbone does not have feature_info')
elif isinstance(feature_info, list):
for feat_idx, each_info in enumerate(feature_info):
logger.info(f'backbone feature_info[{feat_idx}]: {each_info}')
else:
try:
logger.info(f'backbone out_indices: {feature_info.out_indices}')
logger.info(f'backbone out_channels: {feature_info.channels()}')
logger.info(f'backbone out_strides: {feature_info.reduction()}')
except AttributeError:
logger.warning('Unexpected format of backbone feature_info')
@MODELS.register_module()
class TIMMBackbone(BaseBackbone):
"""Wrapper to use backbones from timm library.
More details can be found in
`timm <https://github.com/rwightman/pytorch-image-models>`_.
See especially the document for `feature extraction
<https://rwightman.github.io/pytorch-image-models/feature_extraction/>`_.
Args:
model_name (str): Name of timm model to instantiate.
features_only (bool): Whether to extract feature pyramid (multi-scale
feature maps from the deepest layer at each stride). For Vision
Transformer models that do not support this argument,
set this False. Defaults to False.
pretrained (bool): Whether to load pretrained weights.
Defaults to False.
checkpoint_path (str): Path of checkpoint to load at the last of
``timm.create_model``. Defaults to empty string, which means
not loading.
in_channels (int): Number of input image channels. Defaults to 3.
init_cfg (dict or list[dict], optional): Initialization config dict of
OpenMMLab projects. Defaults to None.
**kwargs: Other timm & model specific arguments.
"""
def __init__(self,
model_name,
features_only=False,
pretrained=False,
checkpoint_path='',
in_channels=3,
init_cfg=None,
**kwargs):
try:
import timm
except ImportError:
raise ImportError(
'Failed to import timm. Please run "pip install timm". '
'"pip install dataclasses" may also be needed for Python 3.6.')
if not isinstance(pretrained, bool):
raise TypeError('pretrained must be bool, not str for model path')
if features_only and checkpoint_path:
warnings.warn(
'Using both features_only and checkpoint_path will cause error'
' in timm. See '
'https://github.com/rwightman/pytorch-image-models/issues/488')
super(TIMMBackbone, self).__init__(init_cfg)
if 'norm_layer' in kwargs:
kwargs['norm_layer'] = MODELS.get(kwargs['norm_layer'])
self.timm_model = timm.create_model(
model_name=model_name,
features_only=features_only,
pretrained=pretrained,
in_chans=in_channels,
checkpoint_path=checkpoint_path,
**kwargs)
# reset classifier
if hasattr(self.timm_model, 'reset_classifier'):
self.timm_model.reset_classifier(0, '')
# Hack to use pretrained weights from timm
if pretrained or checkpoint_path:
self._is_init = True
feature_info = getattr(self.timm_model, 'feature_info', None)
print_timm_feature_info(feature_info)
def forward(self, x):
features = self.timm_model(x)
if isinstance(features, (list, tuple)):
features = tuple(features)
else:
features = (features, )
return features