Spaces:
Sleeping
Sleeping
import inspect | |
from functools import partial | |
class Registry(object): | |
def __init__(self, name): | |
self._name = name | |
self._module_dict = dict() | |
def __repr__(self): | |
format_str = self.__class__.__name__ + '(name={}, items={})'.format( | |
self._name, list(self._module_dict.keys())) | |
return format_str | |
def __len__(self): | |
return len(self._module_dict) | |
def name(self): | |
return self._name | |
def module_dict(self): | |
return self._module_dict | |
def get(self, key): | |
return self._module_dict.get(key, None) | |
def registe_with_name(self, module_name=None, force=False): | |
return partial(self.register, module_name=module_name, force=force) | |
def register(self, module_build_function, module_name=None, force=False): | |
"""Register a module build function. | |
Args: | |
module (:obj:`nn.Module`): Module to be registered. | |
""" | |
if not inspect.isfunction(module_build_function): | |
raise TypeError( | |
'module_build_function must be a function, but got {}'.format( | |
type(module_build_function))) | |
if module_name is None: | |
module_name = module_build_function.__name__ | |
if not force and module_name in self._module_dict: | |
raise KeyError('{} is already registered in {}'.format( | |
module_name, self.name)) | |
self._module_dict[module_name] = module_build_function | |
return module_build_function | |
MODULE_BUILD_FUNCS = Registry('model build functions') | |