# Copyright (c) OpenMMLab. All rights reserved. import torch.nn as nn from mmcv.cnn import ConvModule from mmengine.model import BaseModule from mmengine.utils import is_tuple_of from .make_divisible import make_divisible class SELayer(BaseModule): """Squeeze-and-Excitation Module. Args: channels (int): The input (and output) channels of the SE layer. squeeze_channels (None or int): The intermediate channel number of SElayer. Default: None, means the value of ``squeeze_channels`` is ``make_divisible(channels // ratio, divisor)``. ratio (int): Squeeze ratio in SELayer, the intermediate channel will be ``make_divisible(channels // ratio, divisor)``. Only used when ``squeeze_channels`` is None. Default: 16. divisor(int): The divisor to true divide the channel number. Only used when ``squeeze_channels`` is None. Default: 8. conv_cfg (None or dict): Config dict for convolution layer. Default: None, which means using conv2d. return_weight(bool): Whether to return the weight. Default: False. act_cfg (dict or Sequence[dict]): Config dict for activation layer. If act_cfg is a dict, two activation layers will be configurated by this dict. If act_cfg is a sequence of dicts, the first activation layer will be configurated by the first dict and the second activation layer will be configurated by the second dict. Default: (dict(type='ReLU'), dict(type='Sigmoid')) """ def __init__(self, channels, squeeze_channels=None, ratio=16, divisor=8, bias='auto', conv_cfg=None, act_cfg=(dict(type='ReLU'), dict(type='Sigmoid')), return_weight=False, init_cfg=None): super(SELayer, self).__init__(init_cfg) if isinstance(act_cfg, dict): act_cfg = (act_cfg, act_cfg) assert len(act_cfg) == 2 assert is_tuple_of(act_cfg, dict) self.global_avgpool = nn.AdaptiveAvgPool2d(1) if squeeze_channels is None: squeeze_channels = make_divisible(channels // ratio, divisor) assert isinstance(squeeze_channels, int) and squeeze_channels > 0, \ '"squeeze_channels" should be a positive integer, but get ' + \ f'{squeeze_channels} instead.' self.return_weight = return_weight self.conv1 = ConvModule( in_channels=channels, out_channels=squeeze_channels, kernel_size=1, stride=1, bias=bias, conv_cfg=conv_cfg, act_cfg=act_cfg[0]) self.conv2 = ConvModule( in_channels=squeeze_channels, out_channels=channels, kernel_size=1, stride=1, bias=bias, conv_cfg=conv_cfg, act_cfg=act_cfg[1]) def forward(self, x): out = self.global_avgpool(x) out = self.conv1(out) out = self.conv2(out) if self.return_weight: return out else: return x * out