Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import warnings | |
from mmengine.logging import MMLogger | |
from mmcls.registry import MODELS | |
from .base_backbone import BaseBackbone | |
def print_timm_feature_info(feature_info): | |
"""Print feature_info of timm backbone to help development and debug. | |
Args: | |
feature_info (list[dict] | timm.models.features.FeatureInfo | None): | |
feature_info of timm backbone. | |
""" | |
logger = MMLogger.get_current_instance() | |
if feature_info is None: | |
logger.warning('This backbone does not have feature_info') | |
elif isinstance(feature_info, list): | |
for feat_idx, each_info in enumerate(feature_info): | |
logger.info(f'backbone feature_info[{feat_idx}]: {each_info}') | |
else: | |
try: | |
logger.info(f'backbone out_indices: {feature_info.out_indices}') | |
logger.info(f'backbone out_channels: {feature_info.channels()}') | |
logger.info(f'backbone out_strides: {feature_info.reduction()}') | |
except AttributeError: | |
logger.warning('Unexpected format of backbone feature_info') | |
class TIMMBackbone(BaseBackbone): | |
"""Wrapper to use backbones from timm library. | |
More details can be found in | |
`timm <https://github.com/rwightman/pytorch-image-models>`_. | |
See especially the document for `feature extraction | |
<https://rwightman.github.io/pytorch-image-models/feature_extraction/>`_. | |
Args: | |
model_name (str): Name of timm model to instantiate. | |
features_only (bool): Whether to extract feature pyramid (multi-scale | |
feature maps from the deepest layer at each stride). For Vision | |
Transformer models that do not support this argument, | |
set this False. Defaults to False. | |
pretrained (bool): Whether to load pretrained weights. | |
Defaults to False. | |
checkpoint_path (str): Path of checkpoint to load at the last of | |
``timm.create_model``. Defaults to empty string, which means | |
not loading. | |
in_channels (int): Number of input image channels. Defaults to 3. | |
init_cfg (dict or list[dict], optional): Initialization config dict of | |
OpenMMLab projects. Defaults to None. | |
**kwargs: Other timm & model specific arguments. | |
""" | |
def __init__(self, | |
model_name, | |
features_only=False, | |
pretrained=False, | |
checkpoint_path='', | |
in_channels=3, | |
init_cfg=None, | |
**kwargs): | |
try: | |
import timm | |
except ImportError: | |
raise ImportError( | |
'Failed to import timm. Please run "pip install timm". ' | |
'"pip install dataclasses" may also be needed for Python 3.6.') | |
if not isinstance(pretrained, bool): | |
raise TypeError('pretrained must be bool, not str for model path') | |
if features_only and checkpoint_path: | |
warnings.warn( | |
'Using both features_only and checkpoint_path will cause error' | |
' in timm. See ' | |
'https://github.com/rwightman/pytorch-image-models/issues/488') | |
super(TIMMBackbone, self).__init__(init_cfg) | |
if 'norm_layer' in kwargs: | |
kwargs['norm_layer'] = MODELS.get(kwargs['norm_layer']) | |
self.timm_model = timm.create_model( | |
model_name=model_name, | |
features_only=features_only, | |
pretrained=pretrained, | |
in_chans=in_channels, | |
checkpoint_path=checkpoint_path, | |
**kwargs) | |
# reset classifier | |
if hasattr(self.timm_model, 'reset_classifier'): | |
self.timm_model.reset_classifier(0, '') | |
# Hack to use pretrained weights from timm | |
if pretrained or checkpoint_path: | |
self._is_init = True | |
feature_info = getattr(self.timm_model, 'feature_info', None) | |
print_timm_feature_info(feature_info) | |
def forward(self, x): | |
features = self.timm_model(x) | |
if isinstance(features, (list, tuple)): | |
features = tuple(features) | |
else: | |
features = (features, ) | |
return features | |