File size: 1,446 Bytes
45b4aa7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 |
# Copyright (c) OpenMMLab. All rights reserved.
from torch import nn
from .registry import CONV_LAYERS
CONV_LAYERS.register_module('Conv1d', module=nn.Conv1d)
CONV_LAYERS.register_module('Conv2d', module=nn.Conv2d)
CONV_LAYERS.register_module('Conv3d', module=nn.Conv3d)
CONV_LAYERS.register_module('Conv', module=nn.Conv2d)
def build_conv_layer(cfg, *args, **kwargs):
"""Build convolution layer.
Args:
cfg (None or dict): The conv layer config, which should contain:
- type (str): Layer type.
- layer args: Args needed to instantiate an conv layer.
args (argument list): Arguments passed to the `__init__`
method of the corresponding conv layer.
kwargs (keyword arguments): Keyword arguments passed to the `__init__`
method of the corresponding conv layer.
Returns:
nn.Module: Created conv layer.
"""
if cfg is None:
cfg_ = dict(type='Conv2d')
else:
if not isinstance(cfg, dict):
raise TypeError('cfg must be a dict')
if 'type' not in cfg:
raise KeyError('the cfg dict must contain the key "type"')
cfg_ = cfg.copy()
layer_type = cfg_.pop('type')
if layer_type not in CONV_LAYERS:
raise KeyError(f'Unrecognized norm type {layer_type}')
else:
conv_layer = CONV_LAYERS.get(layer_type)
layer = conv_layer(*args, **kwargs, **cfg_)
return layer
|