# Copyright (c) OpenMMLab. All rights reserved. from typing import List, Optional, Union import torch from mmcv.cnn import build_norm_layer from mmcv.cnn.bricks.drop import DropPath from mmcv.cnn.bricks.transformer import PatchEmbed, PatchMerging from mmengine.model import BaseModule from torch import nn from torch.utils.checkpoint import checkpoint from mmcls.models.backbones.base_backbone import BaseBackbone from mmcls.models.backbones.vision_transformer import TransformerEncoderLayer from mmcls.models.utils.attention import WindowMSA from mmcls.models.utils.helpers import to_2tuple from mmcls.registry import MODELS class MixMIMWindowAttention(WindowMSA): """MixMIM Window Attention. Compared with WindowMSA, we add some modifications in ``forward`` to meet the requirement of MixMIM during pretraining. Implements one windown attention in MixMIM. Args: embed_dims (int): The feature dimension. window_size (list): The height and width of the window. num_heads (int): The number of head in attention. qkv_bias (bool): Whether to add bias for qkv in attention modules. Defaults to True. qk_scale (float, optional): Override default qk scale of ``head_dim ** -0.5`` if set. Defaults to None. attn_drop_rate (float): attention drop rate. Defaults to 0. proj_drop_rate (float): Probability of an element to be zeroed. Defaults to 0. init_cfg (dict, optional): Initialization config dict. Defaults to None. """ def __init__(self, embed_dims, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop_rate=0., proj_drop_rate=0., init_cfg=None): super().__init__( embed_dims=embed_dims, window_size=window_size, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop_rate, proj_drop=proj_drop_rate, init_cfg=init_cfg) def forward(self, x, mask=None): B_, N, C = x.shape qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[ 2] # make torchscript happy (cannot use tensor as tuple) q = q * self.scale attn = (q @ k.transpose(-2, -1)) relative_position_bias = self.relative_position_bias_table[ self.relative_position_index.view(-1)].view( self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH relative_position_bias = relative_position_bias.permute( 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww attn = attn + relative_position_bias.unsqueeze(0) if mask is not None: mask = mask.reshape(B_, 1, 1, N) mask_new = mask * mask.transpose( 2, 3) + (1 - mask) * (1 - mask).transpose(2, 3) mask_new = 1 - mask_new if mask_new.dtype == torch.float16: attn = attn - 65500 * mask_new else: attn = attn - 1e30 * mask_new attn = self.softmax(attn) else: attn = self.softmax(attn) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B_, N, C) x = self.proj(x) x = self.proj_drop(x) return x class MixMIMBlock(TransformerEncoderLayer): """MixMIM Block. Implements one block in MixMIM. Args: embed_dims (int): The feature dimension. input_resolution (tuple): Input resolution of this layer. num_heads (int): The number of head in attention, window_size (list): The height and width of the window. mlp_ratio (int): The MLP ration in FFN. num_fcs (int): The number of linear layers in a block. qkv_bias (bool): Whether to add bias for qkv in attention modules. Defaults to True. proj_drop_rate (float): Probability of an element to be zeroed. Defaults to 0. attn_drop_rate (float): attention drop rate. Defaults to 0. drop_path_rate (float): stochastic depth rate. Defaults to 0. norm_cfg (dict): Config dict for normalization layer. Defaults to ``dict(type='LN')``. init_cfg (dict, optional): Initialization config dict. Defaults to None. """ def __init__(self, embed_dims, input_resolution, num_heads, window_size=7, mlp_ratio=4., num_fcs=2, qkv_bias=True, proj_drop_rate=0., attn_drop_rate=0., drop_path_rate=0., act_cfg=dict(type='GELU'), norm_cfg=dict(type='LN'), init_cfg: Optional[Union[List[dict], dict]] = None) -> None: super().__init__( embed_dims=embed_dims, num_heads=num_heads, feedforward_channels=int(mlp_ratio * embed_dims), drop_rate=proj_drop_rate, attn_drop_rate=attn_drop_rate, drop_path_rate=drop_path_rate, num_fcs=num_fcs, qkv_bias=qkv_bias, act_cfg=act_cfg, norm_cfg=norm_cfg, init_cfg=init_cfg) self.embed_dims = embed_dims self.input_resolution = input_resolution self.num_heads = num_heads self.window_size = window_size self.mlp_ratio = mlp_ratio if min(self.input_resolution) <= self.window_size: self.window_size = min(self.input_resolution) self.attn = MixMIMWindowAttention( embed_dims=embed_dims, window_size=to_2tuple(self.window_size), num_heads=num_heads, qkv_bias=qkv_bias, attn_drop_rate=attn_drop_rate, proj_drop_rate=proj_drop_rate) self.drop_path = DropPath( drop_path_rate) if drop_path_rate > 0. else nn.Identity() @staticmethod def window_reverse(windows, H, W, window_size): B = int(windows.shape[0] / (H * W / window_size / window_size)) x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) return x @staticmethod def window_partition(x, window_size): B, H, W, C = x.shape x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) windows = x.permute(0, 1, 3, 2, 4, 5).contiguous() windows = windows.view(-1, window_size, window_size, C) return windows def forward(self, x, attn_mask=None): H, W = self.input_resolution B, L, C = x.shape shortcut = x x = self.norm1(x) x = x.view(B, H, W, C) # partition windows x_windows = self.window_partition( x, self.window_size) # nW*B, window_size, window_size, C x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C if attn_mask is not None: attn_mask = attn_mask.repeat(B, 1, 1) # B, N, 1 attn_mask = attn_mask.view(B, H, W, 1) attn_mask = self.window_partition(attn_mask, self.window_size) attn_mask = attn_mask.view(-1, self.window_size * self.window_size, 1) # W-MSA/SW-MSA attn_windows = self.attn( x_windows, mask=attn_mask) # nW*B, window_size*window_size, C # merge windows attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) x = self.window_reverse(attn_windows, H, W, self.window_size) # B H' W' C x = x.view(B, H * W, C) x = shortcut + self.drop_path(x) x = self.ffn(self.norm2(x), identity=x) # ffn contains DropPath return x class MixMIMLayer(BaseModule): """Implements one MixMIM layer, which may contains several MixMIM blocks. Args: embed_dims (int): The feature dimension. input_resolution (tuple): Input resolution of this layer. depth (int): The number of blocks in this layer. num_heads (int): The number of head in attention, window_size (list): The height and width of the window. mlp_ratio (int): The MLP ration in FFN. qkv_bias (bool): Whether to add bias for qkv in attention modules. Defaults to True. proj_drop_rate (float): Probability of an element to be zeroed. Defaults to 0. attn_drop_rate (float): attention drop rate. Defaults to 0. drop_path_rate (float): stochastic depth rate. Defaults to 0. norm_cfg (dict): Config dict for normalization layer. Defaults to ``dict(type='LN')``. downsample (class, optional): Downsample the output of blocks b y patch merging.Defaults to None. use_checkpoint (bool): Whether use the checkpoint to reduce GPU memory cost. init_cfg (dict, optional): Initialization config dict. Defaults to None. """ def __init__(self, embed_dims: int, input_resolution: int, depth: int, num_heads: int, window_size: int, mlp_ratio=4., qkv_bias=True, proj_drop_rate=0., attn_drop_rate=0., drop_path_rate=[0.], norm_cfg=dict(type='LN'), downsample=None, use_checkpoint=False, init_cfg: Optional[Union[List[dict], dict]] = None) -> None: super().__init__(init_cfg=init_cfg) self.embed_dims = embed_dims self.input_resolution = input_resolution self.depth = depth self.use_checkpoint = use_checkpoint # build blocks self.blocks = nn.ModuleList() for i in range(depth): self.blocks.append( MixMIMBlock( embed_dims=embed_dims, input_resolution=input_resolution, num_heads=num_heads, window_size=window_size, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, proj_drop_rate=proj_drop_rate, attn_drop_rate=attn_drop_rate, drop_path_rate=drop_path_rate[i], norm_cfg=norm_cfg)) # patch merging layer if downsample is not None: self.downsample = downsample( in_channels=embed_dims, out_channels=2 * embed_dims, norm_cfg=norm_cfg) else: self.downsample = None def forward(self, x, attn_mask=None): for blk in self.blocks: if self.use_checkpoint: x = checkpoint(blk, x, attn_mask) else: x = blk(x, attn_mask=attn_mask) if self.downsample is not None: x, _ = self.downsample(x, self.input_resolution) return x def extra_repr(self) -> str: return f'dim={self.embed_dims}, \ input_resolution={self.input_resolution}, depth={self.depth}' @MODELS.register_module() class MixMIMTransformer(BaseBackbone): """MixMIM backbone. A PyTorch implement of : ` MixMIM: Mixed and Masked Image Modeling for Efficient Visual Representation Learning `_ Args: arch (str | dict): MixMIM architecture. If use string, choose from 'base','large' and 'huge'. If use dict, it should have below keys: - **embed_dims** (int): The dimensions of embedding. - **depths** (int): The number of transformer encoder layers. - **num_heads** (int): The number of heads in attention modules. Defaults to 'base'. mlp_ratio (int): The mlp ratio in FFN. Defaults to 4. img_size (int | tuple): The expected input image shape. Because we support dynamic input shape, just set the argument to mlp_ratio the most common input image shape. Defaults to 224. patch_size (int | tuple): The patch size in patch embedding. Defaults to 16. in_channels (int): The num of input channels. Defaults to 3. window_size (list): The height and width of the window. qkv_bias (bool): Whether to add bias for qkv in attention modules. Defaults to True. patch_cfg (dict): Extra config dict for patch embedding. Defaults to an empty dict. norm_cfg (dict): Config dict for normalization layer. Defaults to ``dict(type='LN')``. drop_rate (float): Probability of an element to be zeroed. Defaults to 0. drop_path_rate (float): stochastic depth rate. Defaults to 0. attn_drop_rate (float): attention drop rate. Defaults to 0. use_checkpoint (bool): Whether use the checkpoint to reduce GPU memory cost. init_cfg (dict, optional): Initialization config dict. Defaults to None. """ arch_zoo = { **dict.fromkeys( ['b', 'base'], { 'embed_dims': 128, 'depths': [2, 2, 18, 2], 'num_heads': [4, 8, 16, 32] }), **dict.fromkeys( ['l', 'large'], { 'embed_dims': 192, 'depths': [2, 2, 18, 2], 'num_heads': [6, 12, 24, 48] }), **dict.fromkeys( ['h', 'huge'], { 'embed_dims': 352, 'depths': [2, 2, 18, 2], 'num_heads': [11, 22, 44, 88] }), } def __init__( self, arch='base', mlp_ratio=4, img_size=224, patch_size=4, in_channels=3, window_size=[14, 14, 14, 7], qkv_bias=True, patch_cfg=dict(), norm_cfg=dict(type='LN'), drop_rate=0.0, drop_path_rate=0.0, attn_drop_rate=0.0, use_checkpoint=False, init_cfg: Optional[dict] = None, ) -> None: super(MixMIMTransformer, self).__init__(init_cfg=init_cfg) if isinstance(arch, str): arch = arch.lower() assert arch in set(self.arch_zoo), \ f'Arch {arch} is not in default archs {set(self.arch_zoo)}' self.arch_settings = self.arch_zoo[arch] else: essential_keys = { 'embed_dims', 'num_layers', 'num_heads', 'feedforward_channels' } assert isinstance(arch, dict) and essential_keys <= set(arch), \ f'Custom arch needs a dict with keys {essential_keys}' self.arch_settings = arch self.embed_dims = self.arch_settings['embed_dims'] self.depths = self.arch_settings['depths'] self.num_heads = self.arch_settings['num_heads'] self.encoder_stride = 32 self.num_layers = len(self.depths) self.qkv_bias = qkv_bias self.drop_rate = drop_rate self.attn_drop_rate = attn_drop_rate self.use_checkpoint = use_checkpoint self.mlp_ratio = mlp_ratio self.window_size = window_size _patch_cfg = dict( in_channels=in_channels, input_size=img_size, embed_dims=self.embed_dims, conv_type='Conv2d', kernel_size=patch_size, stride=patch_size, norm_cfg=dict(type='LN'), ) _patch_cfg.update(patch_cfg) self.patch_embed = PatchEmbed(**_patch_cfg) self.patch_resolution = self.patch_embed.init_out_size self.dpr = [ x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths)) ] self.layers = nn.ModuleList() for i_layer in range(self.num_layers): self.layers.append( MixMIMLayer( embed_dims=int(self.embed_dims * 2**i_layer), input_resolution=(self.patch_resolution[0] // (2**i_layer), self.patch_resolution[1] // (2**i_layer)), depth=self.depths[i_layer], num_heads=self.num_heads[i_layer], window_size=self.window_size[i_layer], mlp_ratio=self.mlp_ratio, qkv_bias=self.qkv_bias, proj_drop_rate=self.drop_rate, attn_drop_rate=self.attn_drop_rate, drop_path_rate=self.dpr[sum(self.depths[:i_layer] ):sum(self.depths[:i_layer + 1])], norm_cfg=norm_cfg, downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, use_checkpoint=self.use_checkpoint)) self.num_features = int(self.embed_dims * 2**(self.num_layers - 1)) self.drop_after_pos = nn.Dropout(p=self.drop_rate) self.avgpool = nn.AdaptiveAvgPool1d(1) self.num_patches = self.patch_resolution[0] * self.patch_resolution[1] self.absolute_pos_embed = nn.Parameter( torch.zeros(1, self.num_patches, self.embed_dims), requires_grad=False) _, self.norm = build_norm_layer(norm_cfg, self.num_features) def forward(self, x: torch.Tensor): x, _ = self.patch_embed(x) x = x + self.absolute_pos_embed x = self.drop_after_pos(x) for layer in self.layers: x = layer(x, attn_mask=None) x = self.norm(x) x = self.avgpool(x.transpose(1, 2)) # B C 1 x = torch.flatten(x, 1) return (x, )