|
|
|
import inspect |
|
import warnings |
|
from functools import partial |
|
|
|
from .misc import is_seq_of |
|
|
|
|
|
def build_from_cfg(cfg, registry, default_args=None): |
|
"""Build a module from config dict. |
|
|
|
Args: |
|
cfg (dict): Config dict. It should at least contain the key "type". |
|
registry (:obj:`Registry`): The registry to search the type from. |
|
default_args (dict, optional): Default initialization arguments. |
|
|
|
Returns: |
|
object: The constructed object. |
|
""" |
|
if not isinstance(cfg, dict): |
|
raise TypeError(f'cfg must be a dict, but got {type(cfg)}') |
|
if 'type' not in cfg: |
|
if default_args is None or 'type' not in default_args: |
|
raise KeyError( |
|
'`cfg` or `default_args` must contain the key "type", ' |
|
f'but got {cfg}\n{default_args}') |
|
if not isinstance(registry, Registry): |
|
raise TypeError('registry must be an mmcv.Registry object, ' |
|
f'but got {type(registry)}') |
|
if not (isinstance(default_args, dict) or default_args is None): |
|
raise TypeError('default_args must be a dict or None, ' |
|
f'but got {type(default_args)}') |
|
|
|
args = cfg.copy() |
|
|
|
if default_args is not None: |
|
for name, value in default_args.items(): |
|
args.setdefault(name, value) |
|
|
|
obj_type = args.pop('type') |
|
if isinstance(obj_type, str): |
|
obj_cls = registry.get(obj_type) |
|
if obj_cls is None: |
|
raise KeyError( |
|
f'{obj_type} is not in the {registry.name} registry') |
|
elif inspect.isclass(obj_type): |
|
obj_cls = obj_type |
|
else: |
|
raise TypeError( |
|
f'type must be a str or valid type, but got {type(obj_type)}') |
|
try: |
|
return obj_cls(**args) |
|
except Exception as e: |
|
|
|
raise type(e)(f'{obj_cls.__name__}: {e}') |
|
|
|
|
|
class Registry: |
|
"""A registry to map strings to classes. |
|
|
|
Registered object could be built from registry. |
|
Example: |
|
>>> MODELS = Registry('models') |
|
>>> @MODELS.register_module() |
|
>>> class ResNet: |
|
>>> pass |
|
>>> resnet = MODELS.build(dict(type='ResNet')) |
|
|
|
Please refer to |
|
https://mmcv.readthedocs.io/en/latest/understand_mmcv/registry.html for |
|
advanced usage. |
|
|
|
Args: |
|
name (str): Registry name. |
|
build_func(func, optional): Build function to construct instance from |
|
Registry, func:`build_from_cfg` is used if neither ``parent`` or |
|
``build_func`` is specified. If ``parent`` is specified and |
|
``build_func`` is not given, ``build_func`` will be inherited |
|
from ``parent``. Default: None. |
|
parent (Registry, optional): Parent registry. The class registered in |
|
children registry could be built from parent. Default: None. |
|
scope (str, optional): The scope of registry. It is the key to search |
|
for children registry. If not specified, scope will be the name of |
|
the package where class is defined, e.g. mmdet, mmcls, mmseg. |
|
Default: None. |
|
""" |
|
|
|
def __init__(self, name, build_func=None, parent=None, scope=None): |
|
self._name = name |
|
self._module_dict = dict() |
|
self._children = dict() |
|
self._scope = self.infer_scope() if scope is None else scope |
|
|
|
|
|
|
|
|
|
|
|
if build_func is None: |
|
if parent is not None: |
|
self.build_func = parent.build_func |
|
else: |
|
self.build_func = build_from_cfg |
|
else: |
|
self.build_func = build_func |
|
if parent is not None: |
|
assert isinstance(parent, Registry) |
|
parent._add_children(self) |
|
self.parent = parent |
|
else: |
|
self.parent = None |
|
|
|
def __len__(self): |
|
return len(self._module_dict) |
|
|
|
def __contains__(self, key): |
|
return self.get(key) is not None |
|
|
|
def __repr__(self): |
|
format_str = self.__class__.__name__ + \ |
|
f'(name={self._name}, ' \ |
|
f'items={self._module_dict})' |
|
return format_str |
|
|
|
@staticmethod |
|
def infer_scope(): |
|
"""Infer the scope of registry. |
|
|
|
The name of the package where registry is defined will be returned. |
|
|
|
Example: |
|
# in mmdet/models/backbone/resnet.py |
|
>>> MODELS = Registry('models') |
|
>>> @MODELS.register_module() |
|
>>> class ResNet: |
|
>>> pass |
|
The scope of ``ResNet`` will be ``mmdet``. |
|
|
|
|
|
Returns: |
|
scope (str): The inferred scope name. |
|
""" |
|
|
|
|
|
filename = inspect.getmodule(inspect.stack()[2][0]).__name__ |
|
split_filename = filename.split('.') |
|
return split_filename[0] |
|
|
|
@staticmethod |
|
def split_scope_key(key): |
|
"""Split scope and key. |
|
|
|
The first scope will be split from key. |
|
|
|
Examples: |
|
>>> Registry.split_scope_key('mmdet.ResNet') |
|
'mmdet', 'ResNet' |
|
>>> Registry.split_scope_key('ResNet') |
|
None, 'ResNet' |
|
|
|
Return: |
|
scope (str, None): The first scope. |
|
key (str): The remaining key. |
|
""" |
|
split_index = key.find('.') |
|
if split_index != -1: |
|
return key[:split_index], key[split_index + 1:] |
|
else: |
|
return None, key |
|
|
|
@property |
|
def name(self): |
|
return self._name |
|
|
|
@property |
|
def scope(self): |
|
return self._scope |
|
|
|
@property |
|
def module_dict(self): |
|
return self._module_dict |
|
|
|
@property |
|
def children(self): |
|
return self._children |
|
|
|
def get(self, key): |
|
"""Get the registry record. |
|
|
|
Args: |
|
key (str): The class name in string format. |
|
|
|
Returns: |
|
class: The corresponding class. |
|
""" |
|
scope, real_key = self.split_scope_key(key) |
|
if scope is None or scope == self._scope: |
|
|
|
if real_key in self._module_dict: |
|
return self._module_dict[real_key] |
|
else: |
|
|
|
if scope in self._children: |
|
return self._children[scope].get(real_key) |
|
else: |
|
|
|
parent = self.parent |
|
while parent.parent is not None: |
|
parent = parent.parent |
|
return parent.get(key) |
|
|
|
def build(self, *args, **kwargs): |
|
return self.build_func(*args, **kwargs, registry=self) |
|
|
|
def _add_children(self, registry): |
|
"""Add children for a registry. |
|
|
|
The ``registry`` will be added as children based on its scope. |
|
The parent registry could build objects from children registry. |
|
|
|
Example: |
|
>>> models = Registry('models') |
|
>>> mmdet_models = Registry('models', parent=models) |
|
>>> @mmdet_models.register_module() |
|
>>> class ResNet: |
|
>>> pass |
|
>>> resnet = models.build(dict(type='mmdet.ResNet')) |
|
""" |
|
|
|
assert isinstance(registry, Registry) |
|
assert registry.scope is not None |
|
assert registry.scope not in self.children, \ |
|
f'scope {registry.scope} exists in {self.name} registry' |
|
self.children[registry.scope] = registry |
|
|
|
def _register_module(self, module_class, module_name=None, force=False): |
|
if not inspect.isclass(module_class): |
|
raise TypeError('module must be a class, ' |
|
f'but got {type(module_class)}') |
|
|
|
if module_name is None: |
|
module_name = module_class.__name__ |
|
if isinstance(module_name, str): |
|
module_name = [module_name] |
|
for name in module_name: |
|
if not force and name in self._module_dict: |
|
raise KeyError(f'{name} is already registered ' |
|
f'in {self.name}') |
|
self._module_dict[name] = module_class |
|
|
|
def deprecated_register_module(self, cls=None, force=False): |
|
warnings.warn( |
|
'The old API of register_module(module, force=False) ' |
|
'is deprecated and will be removed, please use the new API ' |
|
'register_module(name=None, force=False, module=None) instead.') |
|
if cls is None: |
|
return partial(self.deprecated_register_module, force=force) |
|
self._register_module(cls, force=force) |
|
return cls |
|
|
|
def register_module(self, name=None, force=False, module=None): |
|
"""Register a module. |
|
|
|
A record will be added to `self._module_dict`, whose key is the class |
|
name or the specified name, and value is the class itself. |
|
It can be used as a decorator or a normal function. |
|
|
|
Example: |
|
>>> backbones = Registry('backbone') |
|
>>> @backbones.register_module() |
|
>>> class ResNet: |
|
>>> pass |
|
|
|
>>> backbones = Registry('backbone') |
|
>>> @backbones.register_module(name='mnet') |
|
>>> class MobileNet: |
|
>>> pass |
|
|
|
>>> backbones = Registry('backbone') |
|
>>> class ResNet: |
|
>>> pass |
|
>>> backbones.register_module(ResNet) |
|
|
|
Args: |
|
name (str | None): The module name to be registered. If not |
|
specified, the class name will be used. |
|
force (bool, optional): Whether to override an existing class with |
|
the same name. Default: False. |
|
module (type): Module class to be registered. |
|
""" |
|
if not isinstance(force, bool): |
|
raise TypeError(f'force must be a boolean, but got {type(force)}') |
|
|
|
|
|
if isinstance(name, type): |
|
return self.deprecated_register_module(name, force=force) |
|
|
|
|
|
if not (name is None or isinstance(name, str) or is_seq_of(name, str)): |
|
raise TypeError( |
|
'name must be either of None, an instance of str or a sequence' |
|
f' of str, but got {type(name)}') |
|
|
|
|
|
if module is not None: |
|
self._register_module( |
|
module_class=module, module_name=name, force=force) |
|
return module |
|
|
|
|
|
def _register(cls): |
|
self._register_module( |
|
module_class=cls, module_name=name, force=force) |
|
return cls |
|
|
|
return _register |
|
|