ttxskk
update
d7e58f0
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmcv.utils import Registry, build_from_cfg
from .positional_encoding import (
LearnedPositionalEncoding,
SinePositionalEncoding,
)
TRANSFORMER = Registry('Transformer')
LINEAR_LAYERS = Registry('linear layers')
POSITIONAL_ENCODING = Registry('position encoding')
LINEAR_LAYERS.register_module('Linear', module=nn.Linear)
POSITIONAL_ENCODING.register_module('SinePositionalEncoding',
module=SinePositionalEncoding)
POSITIONAL_ENCODING.register_module('LearnedPositionalEncoding',
module=LearnedPositionalEncoding)
def build_transformer(cfg, default_args=None):
"""Builder for Transformer."""
return build_from_cfg(cfg, TRANSFORMER, default_args)
def build_linear_layer(cfg, *args, **kwargs):
"""Build linear layer.
Args:
cfg (None or dict): The linear layer config, which should contain:
- type (str): Layer type.
- layer args: Args needed to instantiate an linear layer.
args (argument list): Arguments passed to the `__init__`
method of the corresponding linear layer.
kwargs (keyword arguments): Keyword arguments passed to the `__init__`
method of the corresponding linear layer.
Returns:
nn.Module: Created linear layer.
"""
if cfg is None:
cfg_ = dict(type='Linear')
else:
if not isinstance(cfg, dict):
raise TypeError('cfg must be a dict')
if 'type' not in cfg:
raise KeyError('the cfg dict must contain the key "type"')
cfg_ = cfg.copy()
layer_type = cfg_.pop('type')
if layer_type not in LINEAR_LAYERS:
raise KeyError(f'Unrecognized linear type {layer_type}')
else:
linear_layer = LINEAR_LAYERS.get(layer_type)
layer = linear_layer(*args, **kwargs, **cfg_)
return layer
def build_positional_encoding(cfg, default_args=None):
"""Builder for Position Encoding."""
return build_from_cfg(cfg, POSITIONAL_ENCODING, default_args)