# Copyright (c) OpenMMLab. All rights reserved. from typing import List, Tuple import torch import torch.nn.functional as F from mmcv.cnn import ConvModule from mmengine.model import BaseModule from mmdet.registry import MODELS from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig class SSHContextModule(BaseModule): """This is an implementation of `SSH context module` described in `SSH: Single Stage Headless Face Detector. `_. Args: in_channels (int): Number of input channels used at each scale. out_channels (int): Number of output channels used at each scale. conv_cfg (:obj:`ConfigDict` or dict, optional): Config dict for convolution layer. Defaults to None. norm_cfg (:obj:`ConfigDict` or dict): Config dict for normalization layer. Defaults to dict(type='BN'). init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or list[dict], optional): Initialization config dict. Defaults to None. """ def __init__(self, in_channels: int, out_channels: int, conv_cfg: OptConfigType = None, norm_cfg: ConfigType = dict(type='BN'), init_cfg: OptMultiConfig = None): super().__init__(init_cfg=init_cfg) assert out_channels % 4 == 0 self.in_channels = in_channels self.out_channels = out_channels self.conv5x5_1 = ConvModule( self.in_channels, self.out_channels // 4, 3, stride=1, padding=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, ) self.conv5x5_2 = ConvModule( self.out_channels // 4, self.out_channels // 4, 3, stride=1, padding=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=None) self.conv7x7_2 = ConvModule( self.out_channels // 4, self.out_channels // 4, 3, stride=1, padding=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, ) self.conv7x7_3 = ConvModule( self.out_channels // 4, self.out_channels // 4, 3, stride=1, padding=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=None, ) def forward(self, x: torch.Tensor) -> tuple: conv5x5_1 = self.conv5x5_1(x) conv5x5 = self.conv5x5_2(conv5x5_1) conv7x7_2 = self.conv7x7_2(conv5x5_1) conv7x7 = self.conv7x7_3(conv7x7_2) return (conv5x5, conv7x7) class SSHDetModule(BaseModule): """This is an implementation of `SSH detection module` described in `SSH: Single Stage Headless Face Detector. `_. Args: in_channels (int): Number of input channels used at each scale. out_channels (int): Number of output channels used at each scale. conv_cfg (:obj:`ConfigDict` or dict, optional): Config dict for convolution layer. Defaults to None. norm_cfg (:obj:`ConfigDict` or dict): Config dict for normalization layer. Defaults to dict(type='BN'). init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or list[dict], optional): Initialization config dict. Defaults to None. """ def __init__(self, in_channels: int, out_channels: int, conv_cfg: OptConfigType = None, norm_cfg: ConfigType = dict(type='BN'), init_cfg: OptMultiConfig = None): super().__init__(init_cfg=init_cfg) assert out_channels % 4 == 0 self.in_channels = in_channels self.out_channels = out_channels self.conv3x3 = ConvModule( self.in_channels, self.out_channels // 2, 3, stride=1, padding=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=None) self.context_module = SSHContextModule( in_channels=self.in_channels, out_channels=self.out_channels, conv_cfg=conv_cfg, norm_cfg=norm_cfg) def forward(self, x: torch.Tensor) -> torch.Tensor: conv3x3 = self.conv3x3(x) conv5x5, conv7x7 = self.context_module(x) out = torch.cat([conv3x3, conv5x5, conv7x7], dim=1) out = F.relu(out) return out @MODELS.register_module() class SSH(BaseModule): """`SSH Neck` used in `SSH: Single Stage Headless Face Detector. `_. Args: num_scales (int): The number of scales / stages. in_channels (list[int]): The number of input channels per scale. out_channels (list[int]): The number of output channels per scale. conv_cfg (:obj:`ConfigDict` or dict, optional): Config dict for convolution layer. Defaults to None. norm_cfg (:obj:`ConfigDict` or dict): Config dict for normalization layer. Defaults to dict(type='BN'). init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or list[dict], optional): Initialization config dict. Example: >>> import torch >>> in_channels = [8, 16, 32, 64] >>> out_channels = [16, 32, 64, 128] >>> scales = [340, 170, 84, 43] >>> inputs = [torch.rand(1, c, s, s) ... for c, s in zip(in_channels, scales)] >>> self = SSH(num_scales=4, in_channels=in_channels, ... out_channels=out_channels) >>> outputs = self.forward(inputs) >>> for i in range(len(outputs)): ... print(f'outputs[{i}].shape = {outputs[i].shape}') outputs[0].shape = torch.Size([1, 16, 340, 340]) outputs[1].shape = torch.Size([1, 32, 170, 170]) outputs[2].shape = torch.Size([1, 64, 84, 84]) outputs[3].shape = torch.Size([1, 128, 43, 43]) """ def __init__(self, num_scales: int, in_channels: List[int], out_channels: List[int], conv_cfg: OptConfigType = None, norm_cfg: ConfigType = dict(type='BN'), init_cfg: OptMultiConfig = dict( type='Xavier', layer='Conv2d', distribution='uniform')): super().__init__(init_cfg=init_cfg) assert (num_scales == len(in_channels) == len(out_channels)) self.num_scales = num_scales self.in_channels = in_channels self.out_channels = out_channels for idx in range(self.num_scales): in_c, out_c = self.in_channels[idx], self.out_channels[idx] self.add_module( f'ssh_module{idx}', SSHDetModule( in_channels=in_c, out_channels=out_c, conv_cfg=conv_cfg, norm_cfg=norm_cfg)) def forward(self, inputs: Tuple[torch.Tensor]) -> tuple: assert len(inputs) == self.num_scales outs = [] for idx, x in enumerate(inputs): ssh_module = getattr(self, f'ssh_module{idx}') out = ssh_module(x) outs.append(out) return tuple(outs)