KyanChen's picture
init
f549064
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Union
import torch
from mmcv.cnn import build_norm_layer
from mmcv.cnn.bricks.drop import DropPath
from mmcv.cnn.bricks.transformer import PatchEmbed, PatchMerging
from mmengine.model import BaseModule
from torch import nn
from torch.utils.checkpoint import checkpoint
from mmcls.models.backbones.base_backbone import BaseBackbone
from mmcls.models.backbones.vision_transformer import TransformerEncoderLayer
from mmcls.models.utils.attention import WindowMSA
from mmcls.models.utils.helpers import to_2tuple
from mmcls.registry import MODELS
class MixMIMWindowAttention(WindowMSA):
"""MixMIM Window Attention.
Compared with WindowMSA, we add some modifications
in ``forward`` to meet the requirement of MixMIM during
pretraining.
Implements one windown attention in MixMIM.
Args:
embed_dims (int): The feature dimension.
window_size (list): The height and width of the window.
num_heads (int): The number of head in attention.
qkv_bias (bool): Whether to add bias for qkv in attention modules.
Defaults to True.
qk_scale (float, optional): Override default qk scale of
``head_dim ** -0.5`` if set. Defaults to None.
attn_drop_rate (float): attention drop rate.
Defaults to 0.
proj_drop_rate (float): Probability of an element to be zeroed.
Defaults to 0.
init_cfg (dict, optional): Initialization config dict.
Defaults to None.
"""
def __init__(self,
embed_dims,
window_size,
num_heads,
qkv_bias=True,
qk_scale=None,
attn_drop_rate=0.,
proj_drop_rate=0.,
init_cfg=None):
super().__init__(
embed_dims=embed_dims,
window_size=window_size,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop_rate,
proj_drop=proj_drop_rate,
init_cfg=init_cfg)
def forward(self, x, mask=None):
B_, N, C = x.shape
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads,
C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[
2] # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
relative_position_bias = self.relative_position_bias_table[
self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1],
self.window_size[0] * self.window_size[1],
-1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(
2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
mask = mask.reshape(B_, 1, 1, N)
mask_new = mask * mask.transpose(
2, 3) + (1 - mask) * (1 - mask).transpose(2, 3)
mask_new = 1 - mask_new
if mask_new.dtype == torch.float16:
attn = attn - 65500 * mask_new
else:
attn = attn - 1e30 * mask_new
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class MixMIMBlock(TransformerEncoderLayer):
"""MixMIM Block. Implements one block in MixMIM.
Args:
embed_dims (int): The feature dimension.
input_resolution (tuple): Input resolution of this layer.
num_heads (int): The number of head in attention,
window_size (list): The height and width of the window.
mlp_ratio (int): The MLP ration in FFN.
num_fcs (int): The number of linear layers in a block.
qkv_bias (bool): Whether to add bias for qkv in attention modules.
Defaults to True.
proj_drop_rate (float): Probability of an element to be zeroed.
Defaults to 0.
attn_drop_rate (float): attention drop rate.
Defaults to 0.
drop_path_rate (float): stochastic depth rate.
Defaults to 0.
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='LN')``.
init_cfg (dict, optional): Initialization config dict.
Defaults to None.
"""
def __init__(self,
embed_dims,
input_resolution,
num_heads,
window_size=7,
mlp_ratio=4.,
num_fcs=2,
qkv_bias=True,
proj_drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='LN'),
init_cfg: Optional[Union[List[dict], dict]] = None) -> None:
super().__init__(
embed_dims=embed_dims,
num_heads=num_heads,
feedforward_channels=int(mlp_ratio * embed_dims),
drop_rate=proj_drop_rate,
attn_drop_rate=attn_drop_rate,
drop_path_rate=drop_path_rate,
num_fcs=num_fcs,
qkv_bias=qkv_bias,
act_cfg=act_cfg,
norm_cfg=norm_cfg,
init_cfg=init_cfg)
self.embed_dims = embed_dims
self.input_resolution = input_resolution
self.num_heads = num_heads
self.window_size = window_size
self.mlp_ratio = mlp_ratio
if min(self.input_resolution) <= self.window_size:
self.window_size = min(self.input_resolution)
self.attn = MixMIMWindowAttention(
embed_dims=embed_dims,
window_size=to_2tuple(self.window_size),
num_heads=num_heads,
qkv_bias=qkv_bias,
attn_drop_rate=attn_drop_rate,
proj_drop_rate=proj_drop_rate)
self.drop_path = DropPath(
drop_path_rate) if drop_path_rate > 0. else nn.Identity()
@staticmethod
def window_reverse(windows, H, W, window_size):
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size,
window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
@staticmethod
def window_partition(x, window_size):
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size,
window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous()
windows = windows.view(-1, window_size, window_size, C)
return windows
def forward(self, x, attn_mask=None):
H, W = self.input_resolution
B, L, C = x.shape
shortcut = x
x = self.norm1(x)
x = x.view(B, H, W, C)
# partition windows
x_windows = self.window_partition(
x, self.window_size) # nW*B, window_size, window_size, C
x_windows = x_windows.view(-1, self.window_size * self.window_size,
C) # nW*B, window_size*window_size, C
if attn_mask is not None:
attn_mask = attn_mask.repeat(B, 1, 1) # B, N, 1
attn_mask = attn_mask.view(B, H, W, 1)
attn_mask = self.window_partition(attn_mask, self.window_size)
attn_mask = attn_mask.view(-1, self.window_size * self.window_size,
1)
# W-MSA/SW-MSA
attn_windows = self.attn(
x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
# merge windows
attn_windows = attn_windows.view(-1, self.window_size,
self.window_size, C)
x = self.window_reverse(attn_windows, H, W,
self.window_size) # B H' W' C
x = x.view(B, H * W, C)
x = shortcut + self.drop_path(x)
x = self.ffn(self.norm2(x), identity=x) # ffn contains DropPath
return x
class MixMIMLayer(BaseModule):
"""Implements one MixMIM layer, which may contains several MixMIM blocks.
Args:
embed_dims (int): The feature dimension.
input_resolution (tuple): Input resolution of this layer.
depth (int): The number of blocks in this layer.
num_heads (int): The number of head in attention,
window_size (list): The height and width of the window.
mlp_ratio (int): The MLP ration in FFN.
qkv_bias (bool): Whether to add bias for qkv in attention modules.
Defaults to True.
proj_drop_rate (float): Probability of an element to be zeroed.
Defaults to 0.
attn_drop_rate (float): attention drop rate.
Defaults to 0.
drop_path_rate (float): stochastic depth rate.
Defaults to 0.
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='LN')``.
downsample (class, optional): Downsample the output of blocks b
y patch merging.Defaults to None.
use_checkpoint (bool): Whether use the checkpoint to
reduce GPU memory cost.
init_cfg (dict, optional): Initialization config dict.
Defaults to None.
"""
def __init__(self,
embed_dims: int,
input_resolution: int,
depth: int,
num_heads: int,
window_size: int,
mlp_ratio=4.,
qkv_bias=True,
proj_drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=[0.],
norm_cfg=dict(type='LN'),
downsample=None,
use_checkpoint=False,
init_cfg: Optional[Union[List[dict], dict]] = None) -> None:
super().__init__(init_cfg=init_cfg)
self.embed_dims = embed_dims
self.input_resolution = input_resolution
self.depth = depth
self.use_checkpoint = use_checkpoint
# build blocks
self.blocks = nn.ModuleList()
for i in range(depth):
self.blocks.append(
MixMIMBlock(
embed_dims=embed_dims,
input_resolution=input_resolution,
num_heads=num_heads,
window_size=window_size,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
proj_drop_rate=proj_drop_rate,
attn_drop_rate=attn_drop_rate,
drop_path_rate=drop_path_rate[i],
norm_cfg=norm_cfg))
# patch merging layer
if downsample is not None:
self.downsample = downsample(
in_channels=embed_dims,
out_channels=2 * embed_dims,
norm_cfg=norm_cfg)
else:
self.downsample = None
def forward(self, x, attn_mask=None):
for blk in self.blocks:
if self.use_checkpoint:
x = checkpoint(blk, x, attn_mask)
else:
x = blk(x, attn_mask=attn_mask)
if self.downsample is not None:
x, _ = self.downsample(x, self.input_resolution)
return x
def extra_repr(self) -> str:
return f'dim={self.embed_dims}, \
input_resolution={self.input_resolution}, depth={self.depth}'
@MODELS.register_module()
class MixMIMTransformer(BaseBackbone):
"""MixMIM backbone.
A PyTorch implement of : ` MixMIM: Mixed and Masked Image
Modeling for Efficient Visual Representation Learning
<https://arxiv.org/abs/2205.13137>`_
Args:
arch (str | dict): MixMIM architecture. If use string,
choose from 'base','large' and 'huge'.
If use dict, it should have below keys:
- **embed_dims** (int): The dimensions of embedding.
- **depths** (int): The number of transformer encoder layers.
- **num_heads** (int): The number of heads in attention modules.
Defaults to 'base'.
mlp_ratio (int): The mlp ratio in FFN. Defaults to 4.
img_size (int | tuple): The expected input image shape. Because we
support dynamic input shape, just set the argument to mlp_ratio
the most common input image shape. Defaults to 224.
patch_size (int | tuple): The patch size in patch embedding.
Defaults to 16.
in_channels (int): The num of input channels. Defaults to 3.
window_size (list): The height and width of the window.
qkv_bias (bool): Whether to add bias for qkv in attention modules.
Defaults to True.
patch_cfg (dict): Extra config dict for patch embedding.
Defaults to an empty dict.
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='LN')``.
drop_rate (float): Probability of an element to be zeroed.
Defaults to 0.
drop_path_rate (float): stochastic depth rate. Defaults to 0.
attn_drop_rate (float): attention drop rate. Defaults to 0.
use_checkpoint (bool): Whether use the checkpoint to
reduce GPU memory cost.
init_cfg (dict, optional): Initialization config dict.
Defaults to None.
"""
arch_zoo = {
**dict.fromkeys(
['b', 'base'], {
'embed_dims': 128,
'depths': [2, 2, 18, 2],
'num_heads': [4, 8, 16, 32]
}),
**dict.fromkeys(
['l', 'large'], {
'embed_dims': 192,
'depths': [2, 2, 18, 2],
'num_heads': [6, 12, 24, 48]
}),
**dict.fromkeys(
['h', 'huge'], {
'embed_dims': 352,
'depths': [2, 2, 18, 2],
'num_heads': [11, 22, 44, 88]
}),
}
def __init__(
self,
arch='base',
mlp_ratio=4,
img_size=224,
patch_size=4,
in_channels=3,
window_size=[14, 14, 14, 7],
qkv_bias=True,
patch_cfg=dict(),
norm_cfg=dict(type='LN'),
drop_rate=0.0,
drop_path_rate=0.0,
attn_drop_rate=0.0,
use_checkpoint=False,
init_cfg: Optional[dict] = None,
) -> None:
super(MixMIMTransformer, self).__init__(init_cfg=init_cfg)
if isinstance(arch, str):
arch = arch.lower()
assert arch in set(self.arch_zoo), \
f'Arch {arch} is not in default archs {set(self.arch_zoo)}'
self.arch_settings = self.arch_zoo[arch]
else:
essential_keys = {
'embed_dims', 'num_layers', 'num_heads', 'feedforward_channels'
}
assert isinstance(arch, dict) and essential_keys <= set(arch), \
f'Custom arch needs a dict with keys {essential_keys}'
self.arch_settings = arch
self.embed_dims = self.arch_settings['embed_dims']
self.depths = self.arch_settings['depths']
self.num_heads = self.arch_settings['num_heads']
self.encoder_stride = 32
self.num_layers = len(self.depths)
self.qkv_bias = qkv_bias
self.drop_rate = drop_rate
self.attn_drop_rate = attn_drop_rate
self.use_checkpoint = use_checkpoint
self.mlp_ratio = mlp_ratio
self.window_size = window_size
_patch_cfg = dict(
in_channels=in_channels,
input_size=img_size,
embed_dims=self.embed_dims,
conv_type='Conv2d',
kernel_size=patch_size,
stride=patch_size,
norm_cfg=dict(type='LN'),
)
_patch_cfg.update(patch_cfg)
self.patch_embed = PatchEmbed(**_patch_cfg)
self.patch_resolution = self.patch_embed.init_out_size
self.dpr = [
x.item()
for x in torch.linspace(0, drop_path_rate, sum(self.depths))
]
self.layers = nn.ModuleList()
for i_layer in range(self.num_layers):
self.layers.append(
MixMIMLayer(
embed_dims=int(self.embed_dims * 2**i_layer),
input_resolution=(self.patch_resolution[0] // (2**i_layer),
self.patch_resolution[1] //
(2**i_layer)),
depth=self.depths[i_layer],
num_heads=self.num_heads[i_layer],
window_size=self.window_size[i_layer],
mlp_ratio=self.mlp_ratio,
qkv_bias=self.qkv_bias,
proj_drop_rate=self.drop_rate,
attn_drop_rate=self.attn_drop_rate,
drop_path_rate=self.dpr[sum(self.depths[:i_layer]
):sum(self.depths[:i_layer +
1])],
norm_cfg=norm_cfg,
downsample=PatchMerging if
(i_layer < self.num_layers - 1) else None,
use_checkpoint=self.use_checkpoint))
self.num_features = int(self.embed_dims * 2**(self.num_layers - 1))
self.drop_after_pos = nn.Dropout(p=self.drop_rate)
self.avgpool = nn.AdaptiveAvgPool1d(1)
self.num_patches = self.patch_resolution[0] * self.patch_resolution[1]
self.absolute_pos_embed = nn.Parameter(
torch.zeros(1, self.num_patches, self.embed_dims),
requires_grad=False)
_, self.norm = build_norm_layer(norm_cfg, self.num_features)
def forward(self, x: torch.Tensor):
x, _ = self.patch_embed(x)
x = x + self.absolute_pos_embed
x = self.drop_after_pos(x)
for layer in self.layers:
x = layer(x, attn_mask=None)
x = self.norm(x)
x = self.avgpool(x.transpose(1, 2)) # B C 1
x = torch.flatten(x, 1)
return (x, )