# Copyright (c) OpenMMLab. All rights reserved. import math from itertools import chain from typing import Sequence import torch import torch.nn as nn from mmcv.cnn.bricks import DropPath, build_activation_layer, build_norm_layer from mmengine.model import BaseModule, ModuleList, Sequential from mmengine.registry import MODELS from ..utils import ChannelMultiheadAttention, PositionEncodingFourier from .base_backbone import BaseBackbone from .convnext import ConvNeXtBlock class SDTAEncoder(BaseModule): """A PyTorch implementation of split depth-wise transpose attention (SDTA) encoder. Inspiration from https://github.com/mmaaz60/EdgeNeXt Args: in_channel (int): Number of input channels. drop_path_rate (float): Stochastic depth dropout rate. Defaults to 0. layer_scale_init_value (float): Initial value of layer scale. Defaults to 1e-6. mlp_ratio (int): Number of channels ratio in the MLP. Defaults to 4. use_pos_emb (bool): Whether to use position encoding. Defaults to True. num_heads (int): Number of heads in the multihead attention. Defaults to 8. qkv_bias (bool): Whether to use bias in the multihead attention. Defaults to True. attn_drop (float): Dropout rate of the attention. Defaults to 0. proj_drop (float): Dropout rate of the projection. Defaults to 0. layer_scale_init_value (float): Initial value of layer scale. Defaults to 1e-6. norm_cfg (dict): Dictionary to construct normalization layer. Defaults to ``dict(type='LN')``. act_cfg (dict): Dictionary to construct activation layer. Defaults to ``dict(type='GELU')``. scales (int): Number of scales. Default to 1. """ def __init__(self, in_channel, drop_path_rate=0., layer_scale_init_value=1e-6, mlp_ratio=4, use_pos_emb=True, num_heads=8, qkv_bias=True, attn_drop=0., proj_drop=0., norm_cfg=dict(type='LN'), act_cfg=dict(type='GELU'), scales=1, init_cfg=None): super(SDTAEncoder, self).__init__(init_cfg=init_cfg) conv_channels = max( int(math.ceil(in_channel / scales)), int(math.floor(in_channel // scales))) self.conv_channels = conv_channels self.num_convs = scales if scales == 1 else scales - 1 self.conv_modules = ModuleList() for i in range(self.num_convs): self.conv_modules.append( nn.Conv2d( conv_channels, conv_channels, kernel_size=3, padding=1, groups=conv_channels)) self.pos_embed = PositionEncodingFourier( embed_dims=in_channel) if use_pos_emb else None self.norm_csa = build_norm_layer(norm_cfg, in_channel)[1] self.gamma_csa = nn.Parameter( layer_scale_init_value * torch.ones(in_channel), requires_grad=True) if layer_scale_init_value > 0 else None self.csa = ChannelMultiheadAttention( embed_dims=in_channel, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=proj_drop) self.norm = build_norm_layer(norm_cfg, in_channel)[1] self.pointwise_conv1 = nn.Linear(in_channel, mlp_ratio * in_channel) self.act = build_activation_layer(act_cfg) self.pointwise_conv2 = nn.Linear(mlp_ratio * in_channel, in_channel) self.gamma = nn.Parameter( layer_scale_init_value * torch.ones(in_channel), requires_grad=True) if layer_scale_init_value > 0 else None self.drop_path = DropPath( drop_path_rate) if drop_path_rate > 0. else nn.Identity() def forward(self, x): shortcut = x spx = torch.split(x, self.conv_channels, dim=1) for i in range(self.num_convs): if i == 0: sp = spx[i] else: sp = sp + spx[i] sp = self.conv_modules[i](sp) if i == 0: out = sp else: out = torch.cat((out, sp), 1) x = torch.cat((out, spx[self.num_convs]), 1) # Channel Self-attention B, C, H, W = x.shape x = x.reshape(B, C, H * W).permute(0, 2, 1) if self.pos_embed: pos_encoding = self.pos_embed((B, H, W)) pos_encoding = pos_encoding.reshape(B, -1, x.shape[1]).permute(0, 2, 1) x += pos_encoding x = x + self.drop_path(self.gamma_csa * self.csa(self.norm_csa(x))) x = x.reshape(B, H, W, C) # Inverted Bottleneck x = self.norm(x) x = self.pointwise_conv1(x) x = self.act(x) x = self.pointwise_conv2(x) if self.gamma is not None: x = self.gamma * x x = x.permute(0, 3, 1, 2) # (B, H, W, C) -> (B, C, H, W) x = shortcut + self.drop_path(x) return x @MODELS.register_module() class EdgeNeXt(BaseBackbone): """EdgeNeXt. A PyTorch implementation of: `EdgeNeXt: Efficiently Amalgamated CNN-Transformer Architecture for Mobile Vision Applications `_ Inspiration from https://github.com/mmaaz60/EdgeNeXt Args: arch (str | dict): The model's architecture. If string, it should be one of architectures in ``EdgeNeXt.arch_settings``. And if dict, it should include the following keys: - channels (list[int]): The number of channels at each stage. - depths (list[int]): The number of blocks at each stage. - num_heads (list[int]): The number of heads at each stage. Defaults to 'xxsmall'. in_channels (int): The number of input channels. Defaults to 3. global_blocks (list[int]): The number of global blocks. Defaults to [0, 1, 1, 1]. global_block_type (list[str]): The type of global blocks. Defaults to ['None', 'SDTA', 'SDTA', 'SDTA']. drop_path_rate (float): Stochastic depth dropout rate. Defaults to 0. layer_scale_init_value (float): Initial value of layer scale. Defaults to 1e-6. linear_pw_conv (bool): Whether to use linear layer to do pointwise convolution. Defaults to False. mlp_ratio (int): The number of channel ratio in MLP layers. Defaults to 4. conv_kernel_size (list[int]): The kernel size of convolutional layers at each stage. Defaults to [3, 5, 7, 9]. use_pos_embd_csa (list[bool]): Whether to use positional embedding in Channel Self-Attention. Defaults to [False, True, False, False]. use_pos_emebd_global (bool): Whether to use positional embedding for whole network. Defaults to False. d2_scales (list[int]): The number of channel groups used for SDTA at each stage. Defaults to [2, 2, 3, 4]. norm_cfg (dict): The config of normalization layer. Defaults to ``dict(type='LN2d', eps=1e-6)``. out_indices (Sequence | int): Output from which stages. Defaults to -1, means the last stage. frozen_stages (int): Stages to be frozen (all param fixed). Defaults to 0, which means not freezing any parameters. gap_before_final_norm (bool): Whether to globally average the feature map before the final norm layer. Defaults to True. act_cfg (dict): The config of activation layer. Defaults to ``dict(type='GELU')``. init_cfg (dict, optional): Config for initialization. Defaults to None. """ arch_settings = { 'xxsmall': { # parameters: 1.3M 'channels': [24, 48, 88, 168], 'depths': [2, 2, 6, 2], 'num_heads': [4, 4, 4, 4] }, 'xsmall': { # parameters: 2.3M 'channels': [32, 64, 100, 192], 'depths': [3, 3, 9, 3], 'num_heads': [4, 4, 4, 4] }, 'small': { # parameters: 5.6M 'channels': [48, 96, 160, 304], 'depths': [3, 3, 9, 3], 'num_heads': [8, 8, 8, 8] }, 'base': { # parameters: 18.51M 'channels': [80, 160, 288, 584], 'depths': [3, 3, 9, 3], 'num_heads': [8, 8, 8, 8] }, } def __init__(self, arch='xxsmall', in_channels=3, global_blocks=[0, 1, 1, 1], global_block_type=['None', 'SDTA', 'SDTA', 'SDTA'], drop_path_rate=0., layer_scale_init_value=1e-6, linear_pw_conv=True, mlp_ratio=4, conv_kernel_sizes=[3, 5, 7, 9], use_pos_embd_csa=[False, True, False, False], use_pos_embd_global=False, d2_scales=[2, 2, 3, 4], norm_cfg=dict(type='LN2d', eps=1e-6), out_indices=-1, frozen_stages=0, gap_before_final_norm=True, act_cfg=dict(type='GELU'), init_cfg=None): super(EdgeNeXt, self).__init__(init_cfg=init_cfg) if isinstance(arch, str): arch = arch.lower() assert arch in self.arch_settings, \ f'Arch {arch} is not in default archs ' \ f'{set(self.arch_settings)}' self.arch_settings = self.arch_settings[arch] elif isinstance(arch, dict): essential_keys = {'channels', 'depths', 'num_heads'} assert isinstance(arch, dict) and set(arch) == essential_keys, \ f'Custom arch needs a dict with keys {essential_keys}' self.arch_settings = arch self.channels = self.arch_settings['channels'] self.depths = self.arch_settings['depths'] self.num_heads = self.arch_settings['num_heads'] self.num_layers = len(self.depths) self.use_pos_embd_global = use_pos_embd_global for g in global_block_type: assert g in ['None', 'SDTA'], f'Global block type {g} is not supported' self.num_stages = len(self.depths) if isinstance(out_indices, int): out_indices = [out_indices] assert isinstance(out_indices, Sequence), \ f'"out_indices" must by a sequence or int, ' \ f'get {type(out_indices)} instead.' for i, index in enumerate(out_indices): if index < 0: out_indices[i] = 4 + index assert out_indices[i] >= 0, f'Invalid out_indices {index}' self.out_indices = out_indices self.frozen_stages = frozen_stages self.gap_before_final_norm = gap_before_final_norm if self.use_pos_embd_global: self.pos_embed = PositionEncodingFourier( embed_dims=self.channels[0]) else: self.pos_embed = None # stochastic depth decay rule dpr = [ x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths)) ] self.downsample_layers = ModuleList() stem = nn.Sequential( nn.Conv2d(in_channels, self.channels[0], kernel_size=4, stride=4), build_norm_layer(norm_cfg, self.channels[0])[1], ) self.downsample_layers.append(stem) self.stages = ModuleList() block_idx = 0 for i in range(self.num_stages): depth = self.depths[i] channels = self.channels[i] if i >= 1: downsample_layer = nn.Sequential( build_norm_layer(norm_cfg, self.channels[i - 1])[1], nn.Conv2d( self.channels[i - 1], channels, kernel_size=2, stride=2, )) self.downsample_layers.append(downsample_layer) stage_blocks = [] for j in range(depth): if j > depth - global_blocks[i] - 1: stage_blocks.append( SDTAEncoder( in_channel=channels, drop_path_rate=dpr[block_idx + j], mlp_ratio=mlp_ratio, scales=d2_scales[i], use_pos_emb=use_pos_embd_csa[i], num_heads=self.num_heads[i], )) else: dw_conv_cfg = dict( kernel_size=conv_kernel_sizes[i], padding=conv_kernel_sizes[i] // 2, ) stage_blocks.append( ConvNeXtBlock( in_channels=channels, dw_conv_cfg=dw_conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg, linear_pw_conv=linear_pw_conv, drop_path_rate=dpr[block_idx + j], layer_scale_init_value=layer_scale_init_value, )) block_idx += depth stage_blocks = Sequential(*stage_blocks) self.stages.append(stage_blocks) if i in self.out_indices: out_norm_cfg = dict(type='LN') if self.gap_before_final_norm \ else norm_cfg norm_layer = build_norm_layer(out_norm_cfg, channels)[1] self.add_module(f'norm{i}', norm_layer) def init_weights(self) -> None: # TODO: need to be implemented in the future return super().init_weights() def forward(self, x): outs = [] for i, stage in enumerate(self.stages): x = self.downsample_layers[i](x) x = stage(x) if self.pos_embed and i == 0: B, _, H, W = x.shape x += self.pos_embed((B, H, W)) if i in self.out_indices: norm_layer = getattr(self, f'norm{i}') if self.gap_before_final_norm: gap = x.mean([-2, -1], keepdim=True) outs.append(norm_layer(gap.flatten(1))) else: # The output of LayerNorm2d may be discontiguous, which # may cause some problem in the downstream tasks outs.append(norm_layer(x).contiguous()) return tuple(outs) def _freeze_stages(self): for i in range(self.frozen_stages): downsample_layer = self.downsample_layers[i] stage = self.stages[i] downsample_layer.eval() stage.eval() for param in chain(downsample_layer.parameters(), stage.parameters()): param.requires_grad = False def train(self, mode=True): super(EdgeNeXt, self).train(mode) self._freeze_stages()