|
import inspect |
|
import platform |
|
|
|
from .registry import PLUGIN_LAYERS |
|
|
|
if platform.system() == 'Windows': |
|
import regex as re |
|
else: |
|
import re |
|
|
|
|
|
def infer_abbr(class_type): |
|
"""Infer abbreviation from the class name. |
|
|
|
This method will infer the abbreviation to map class types to |
|
abbreviations. |
|
|
|
Rule 1: If the class has the property "abbr", return the property. |
|
Rule 2: Otherwise, the abbreviation falls back to snake case of class |
|
name, e.g. the abbreviation of ``FancyBlock`` will be ``fancy_block``. |
|
|
|
Args: |
|
class_type (type): The norm layer type. |
|
|
|
Returns: |
|
str: The inferred abbreviation. |
|
""" |
|
|
|
def camel2snack(word): |
|
"""Convert camel case word into snack case. |
|
|
|
Modified from `inflection lib |
|
<https://inflection.readthedocs.io/en/latest/#inflection.underscore>`_. |
|
|
|
Example:: |
|
|
|
>>> camel2snack("FancyBlock") |
|
'fancy_block' |
|
""" |
|
|
|
word = re.sub(r'([A-Z]+)([A-Z][a-z])', r'\1_\2', word) |
|
word = re.sub(r'([a-z\d])([A-Z])', r'\1_\2', word) |
|
word = word.replace('-', '_') |
|
return word.lower() |
|
|
|
if not inspect.isclass(class_type): |
|
raise TypeError( |
|
f'class_type must be a type, but got {type(class_type)}') |
|
if hasattr(class_type, '_abbr_'): |
|
return class_type._abbr_ |
|
else: |
|
return camel2snack(class_type.__name__) |
|
|
|
|
|
def build_plugin_layer(cfg, postfix='', **kwargs): |
|
"""Build plugin layer. |
|
|
|
Args: |
|
cfg (None or dict): cfg should contain: |
|
type (str): identify plugin layer type. |
|
layer args: args needed to instantiate a plugin layer. |
|
postfix (int, str): appended into norm abbreviation to |
|
create named layer. Default: ''. |
|
|
|
Returns: |
|
tuple[str, nn.Module]: |
|
name (str): abbreviation + postfix |
|
layer (nn.Module): created plugin layer |
|
""" |
|
if not isinstance(cfg, dict): |
|
raise TypeError('cfg must be a dict') |
|
if 'type' not in cfg: |
|
raise KeyError('the cfg dict must contain the key "type"') |
|
cfg_ = cfg.copy() |
|
|
|
layer_type = cfg_.pop('type') |
|
if layer_type not in PLUGIN_LAYERS: |
|
raise KeyError(f'Unrecognized plugin type {layer_type}') |
|
|
|
plugin_layer = PLUGIN_LAYERS.get(layer_type) |
|
abbr = infer_abbr(plugin_layer) |
|
|
|
assert isinstance(postfix, (int, str)) |
|
name = abbr + str(postfix) |
|
|
|
layer = plugin_layer(**kwargs, **cfg_) |
|
|
|
return name, layer |
|
|