# Copyright (c) OpenMMLab. All rights reserved. from functools import partial from itertools import chain from typing import Sequence import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint as cp from mmcv.cnn.bricks import DropPath, build_activation_layer, build_norm_layer from mmengine.model import BaseModule, ModuleList, Sequential from mmengine.registry import MODELS from .base_backbone import BaseBackbone @MODELS.register_module('LN2d') class LayerNorm2d(nn.LayerNorm): """LayerNorm on channels for 2d images. Args: num_channels (int): The number of channels of the input tensor. eps (float): a value added to the denominator for numerical stability. Defaults to 1e-5. elementwise_affine (bool): a boolean value that when set to ``True``, this module has learnable per-element affine parameters initialized to ones (for weights) and zeros (for biases). Defaults to True. """ def __init__(self, num_channels: int, **kwargs) -> None: super().__init__(num_channels, **kwargs) self.num_channels = self.normalized_shape[0] def forward(self, x, data_format='channel_first'): assert x.dim() == 4, 'LayerNorm2d only supports inputs with shape ' \ f'(N, C, H, W), but got tensor with shape {x.shape}' if data_format == 'channel_last': x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) elif data_format == 'channel_first': x = x.permute(0, 2, 3, 1) x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) # If the output is discontiguous, it may cause some unexpected # problem in the downstream tasks x = x.permute(0, 3, 1, 2).contiguous() return x class ConvNeXtBlock(BaseModule): """ConvNeXt Block. Args: in_channels (int): The number of input channels. dw_conv_cfg (dict): Config of depthwise convolution. Defaults to ``dict(kernel_size=7, padding=3)``. norm_cfg (dict): The config dict for norm layers. Defaults to ``dict(type='LN2d', eps=1e-6)``. act_cfg (dict): The config dict for activation between pointwise convolution. Defaults to ``dict(type='GELU')``. mlp_ratio (float): The expansion ratio in both pointwise convolution. Defaults to 4. linear_pw_conv (bool): Whether to use linear layer to do pointwise convolution. More details can be found in the note. Defaults to True. drop_path_rate (float): Stochastic depth rate. Defaults to 0. layer_scale_init_value (float): Init value for Layer Scale. Defaults to 1e-6. Note: There are two equivalent implementations: 1. DwConv -> LayerNorm -> 1x1 Conv -> GELU -> 1x1 Conv; all outputs are in (N, C, H, W). 2. DwConv -> LayerNorm -> Permute to (N, H, W, C) -> Linear -> GELU -> Linear; Permute back As default, we use the second to align with the official repository. And it may be slightly faster. """ def __init__(self, in_channels, dw_conv_cfg=dict(kernel_size=7, padding=3), norm_cfg=dict(type='LN2d', eps=1e-6), act_cfg=dict(type='GELU'), mlp_ratio=4., linear_pw_conv=True, drop_path_rate=0., layer_scale_init_value=1e-6, with_cp=False): super().__init__() self.with_cp = with_cp self.depthwise_conv = nn.Conv2d( in_channels, in_channels, groups=in_channels, **dw_conv_cfg) self.linear_pw_conv = linear_pw_conv self.norm = build_norm_layer(norm_cfg, in_channels)[1] mid_channels = int(mlp_ratio * in_channels) if self.linear_pw_conv: # Use linear layer to do pointwise conv. pw_conv = nn.Linear else: pw_conv = partial(nn.Conv2d, kernel_size=1) self.pointwise_conv1 = pw_conv(in_channels, mid_channels) self.act = build_activation_layer(act_cfg) self.pointwise_conv2 = pw_conv(mid_channels, in_channels) self.gamma = nn.Parameter( layer_scale_init_value * torch.ones((in_channels)), requires_grad=True) if layer_scale_init_value > 0 else None self.drop_path = DropPath( drop_path_rate) if drop_path_rate > 0. else nn.Identity() def forward(self, x): def _inner_forward(x): shortcut = x x = self.depthwise_conv(x) if self.linear_pw_conv: x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) x = self.norm(x, data_format='channel_last') x = self.pointwise_conv1(x) x = self.act(x) x = self.pointwise_conv2(x) if self.linear_pw_conv: x = x.permute(0, 3, 1, 2) # permute back if self.gamma is not None: x = x.mul(self.gamma.view(1, -1, 1, 1)) x = shortcut + self.drop_path(x) return x if self.with_cp and x.requires_grad: x = cp.checkpoint(_inner_forward, x) else: x = _inner_forward(x) return x @MODELS.register_module() class ConvNeXt(BaseBackbone): """ConvNeXt. A PyTorch implementation of : `A ConvNet for the 2020s `_ Modified from the `official repo `_ and `timm `_. Args: arch (str | dict): The model's architecture. If string, it should be one of architecture in ``ConvNeXt.arch_settings``. And if dict, it should include the following two keys: - depths (list[int]): Number of blocks at each stage. - channels (list[int]): The number of channels at each stage. Defaults to 'tiny'. in_channels (int): Number of input image channels. Defaults to 3. stem_patch_size (int): The size of one patch in the stem layer. Defaults to 4. norm_cfg (dict): The config dict for norm layers. Defaults to ``dict(type='LN2d', eps=1e-6)``. act_cfg (dict): The config dict for activation between pointwise convolution. Defaults to ``dict(type='GELU')``. linear_pw_conv (bool): Whether to use linear layer to do pointwise convolution. Defaults to True. drop_path_rate (float): Stochastic depth rate. Defaults to 0. layer_scale_init_value (float): Init value for Layer Scale. Defaults to 1e-6. out_indices (Sequence | int): Output from which stages. Defaults to -1, means the last stage. frozen_stages (int): Stages to be frozen (all param fixed). Defaults to 0, which means not freezing any parameters. gap_before_final_norm (bool): Whether to globally average the feature map before the final norm layer. In the official repo, it's only used in classification task. Defaults to True. with_cp (bool): Use checkpoint or not. Using checkpoint will save some memory while slowing down the training speed. Defaults to False. init_cfg (dict, optional): Initialization config dict """ # noqa: E501 arch_settings = { 'tiny': { 'depths': [3, 3, 9, 3], 'channels': [96, 192, 384, 768] }, 'small': { 'depths': [3, 3, 27, 3], 'channels': [96, 192, 384, 768] }, 'base': { 'depths': [3, 3, 27, 3], 'channels': [128, 256, 512, 1024] }, 'large': { 'depths': [3, 3, 27, 3], 'channels': [192, 384, 768, 1536] }, 'xlarge': { 'depths': [3, 3, 27, 3], 'channels': [256, 512, 1024, 2048] }, } def __init__(self, arch='tiny', in_channels=3, stem_patch_size=4, norm_cfg=dict(type='LN2d', eps=1e-6), act_cfg=dict(type='GELU'), linear_pw_conv=True, drop_path_rate=0., layer_scale_init_value=1e-6, out_indices=-1, frozen_stages=0, gap_before_final_norm=True, with_cp=False, init_cfg=None): super().__init__(init_cfg=init_cfg) if isinstance(arch, str): assert arch in self.arch_settings, \ f'Unavailable arch, please choose from ' \ f'({set(self.arch_settings)}) or pass a dict.' arch = self.arch_settings[arch] elif isinstance(arch, dict): assert 'depths' in arch and 'channels' in arch, \ f'The arch dict must have "depths" and "channels", ' \ f'but got {list(arch.keys())}.' self.depths = arch['depths'] self.channels = arch['channels'] assert (isinstance(self.depths, Sequence) and isinstance(self.channels, Sequence) and len(self.depths) == len(self.channels)), \ f'The "depths" ({self.depths}) and "channels" ({self.channels}) ' \ 'should be both sequence with the same length.' self.num_stages = len(self.depths) if isinstance(out_indices, int): out_indices = [out_indices] assert isinstance(out_indices, Sequence), \ f'"out_indices" must by a sequence or int, ' \ f'get {type(out_indices)} instead.' for i, index in enumerate(out_indices): if index < 0: out_indices[i] = 4 + index assert out_indices[i] >= 0, f'Invalid out_indices {index}' self.out_indices = out_indices self.frozen_stages = frozen_stages self.gap_before_final_norm = gap_before_final_norm # stochastic depth decay rule dpr = [ x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths)) ] block_idx = 0 # 4 downsample layers between stages, including the stem layer. self.downsample_layers = ModuleList() stem = nn.Sequential( nn.Conv2d( in_channels, self.channels[0], kernel_size=stem_patch_size, stride=stem_patch_size), build_norm_layer(norm_cfg, self.channels[0])[1], ) self.downsample_layers.append(stem) # 4 feature resolution stages, each consisting of multiple residual # blocks self.stages = nn.ModuleList() for i in range(self.num_stages): depth = self.depths[i] channels = self.channels[i] if i >= 1: downsample_layer = nn.Sequential( build_norm_layer(norm_cfg, self.channels[i - 1])[1], nn.Conv2d( self.channels[i - 1], channels, kernel_size=2, stride=2), ) self.downsample_layers.append(downsample_layer) stage = Sequential(*[ ConvNeXtBlock( in_channels=channels, drop_path_rate=dpr[block_idx + j], norm_cfg=norm_cfg, act_cfg=act_cfg, linear_pw_conv=linear_pw_conv, layer_scale_init_value=layer_scale_init_value, with_cp=with_cp) for j in range(depth) ]) block_idx += depth self.stages.append(stage) if i in self.out_indices: norm_layer = build_norm_layer(norm_cfg, channels)[1] self.add_module(f'norm{i}', norm_layer) self._freeze_stages() def forward(self, x): outs = [] for i, stage in enumerate(self.stages): x = self.downsample_layers[i](x) x = stage(x) if i in self.out_indices: norm_layer = getattr(self, f'norm{i}') if self.gap_before_final_norm: gap = x.mean([-2, -1], keepdim=True) outs.append(norm_layer(gap).flatten(1)) else: outs.append(norm_layer(x)) return tuple(outs) def _freeze_stages(self): for i in range(self.frozen_stages): downsample_layer = self.downsample_layers[i] stage = self.stages[i] downsample_layer.eval() stage.eval() for param in chain(downsample_layer.parameters(), stage.parameters()): param.requires_grad = False def train(self, mode=True): super(ConvNeXt, self).train(mode) self._freeze_stages()