KyanChen's picture
init
f549064
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Sequence
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import build_activation_layer, build_norm_layer
from mmcv.cnn.bricks import DropPath
from mmcv.cnn.bricks.transformer import PatchEmbed
from mmengine.model import BaseModule, ModuleList
from mmengine.model.weight_init import trunc_normal_
from mmengine.utils import to_2tuple
from ..builder import BACKBONES
from ..utils import resize_pos_embed
from .base_backbone import BaseBackbone
def resize_decomposed_rel_pos(rel_pos, q_size, k_size):
"""Get relative positional embeddings according to the relative positions
of query and key sizes.
Args:
q_size (int): size of query q.
k_size (int): size of key k.
rel_pos (Tensor): relative position embeddings (L, C).
Returns:
Extracted positional embeddings according to relative positions.
"""
max_rel_dist = int(2 * max(q_size, k_size) - 1)
# Interpolate rel pos if needed.
if rel_pos.shape[0] != max_rel_dist:
# Interpolate rel pos.
resized = F.interpolate(
# (L, C) -> (1, C, L)
rel_pos.transpose(0, 1).unsqueeze(0),
size=max_rel_dist,
mode='linear',
)
# (1, C, L) -> (L, C)
resized = resized.squeeze(0).transpose(0, 1)
else:
resized = rel_pos
# Scale the coords with short length if shapes for q and k are different.
q_h_ratio = max(k_size / q_size, 1.0)
k_h_ratio = max(q_size / k_size, 1.0)
q_coords = torch.arange(q_size)[:, None] * q_h_ratio
k_coords = torch.arange(k_size)[None, :] * k_h_ratio
relative_coords = (q_coords - k_coords) + (k_size - 1) * k_h_ratio
return resized[relative_coords.long()]
def add_decomposed_rel_pos(attn,
q,
q_shape,
k_shape,
rel_pos_h,
rel_pos_w,
has_cls_token=False):
"""Spatial Relative Positional Embeddings."""
sp_idx = 1 if has_cls_token else 0
B, num_heads, _, C = q.shape
q_h, q_w = q_shape
k_h, k_w = k_shape
Rh = resize_decomposed_rel_pos(rel_pos_h, q_h, k_h)
Rw = resize_decomposed_rel_pos(rel_pos_w, q_w, k_w)
r_q = q[:, :, sp_idx:].reshape(B, num_heads, q_h, q_w, C)
rel_h = torch.einsum('byhwc,hkc->byhwk', r_q, Rh)
rel_w = torch.einsum('byhwc,wkc->byhwk', r_q, Rw)
rel_pos_embed = rel_h[:, :, :, :, :, None] + rel_w[:, :, :, :, None, :]
attn_map = attn[:, :, sp_idx:, sp_idx:].view(B, -1, q_h, q_w, k_h, k_w)
attn_map += rel_pos_embed
attn[:, :, sp_idx:, sp_idx:] = attn_map.view(B, -1, q_h * q_w, k_h * k_w)
return attn
class MLP(BaseModule):
"""Two-layer multilayer perceptron.
Comparing with :class:`mmcv.cnn.bricks.transformer.FFN`, this class allows
different input and output channel numbers.
Args:
in_channels (int): The number of input channels.
hidden_channels (int, optional): The number of hidden layer channels.
If None, same as the ``in_channels``. Defaults to None.
out_channels (int, optional): The number of output channels. If None,
same as the ``in_channels``. Defaults to None.
act_cfg (dict): The config of activation function.
Defaults to ``dict(type='GELU')``.
init_cfg (dict, optional): The config of weight initialization.
Defaults to None.
"""
def __init__(self,
in_channels,
hidden_channels=None,
out_channels=None,
act_cfg=dict(type='GELU'),
init_cfg=None):
super().__init__(init_cfg=init_cfg)
out_channels = out_channels or in_channels
hidden_channels = hidden_channels or in_channels
self.fc1 = nn.Linear(in_channels, hidden_channels)
self.act = build_activation_layer(act_cfg)
self.fc2 = nn.Linear(hidden_channels, out_channels)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.fc2(x)
return x
def attention_pool(x: torch.Tensor,
pool: nn.Module,
in_size: tuple,
norm: Optional[nn.Module] = None):
"""Pooling the feature tokens.
Args:
x (torch.Tensor): The input tensor, should be with shape
``(B, num_heads, L, C)`` or ``(B, L, C)``.
pool (nn.Module): The pooling module.
in_size (Tuple[int]): The shape of the input feature map.
norm (nn.Module, optional): The normalization module.
Defaults to None.
"""
ndim = x.ndim
if ndim == 4:
B, num_heads, L, C = x.shape
elif ndim == 3:
num_heads = 1
B, L, C = x.shape
else:
raise RuntimeError(f'Unsupported input dimension {x.shape}')
H, W = in_size
assert L == H * W
# (B, num_heads, H*W, C) -> (B*num_heads, C, H, W)
x = x.reshape(B * num_heads, H, W, C).permute(0, 3, 1, 2).contiguous()
x = pool(x)
out_size = x.shape[-2:]
# (B*num_heads, C, H', W') -> (B, num_heads, H'*W', C)
x = x.reshape(B, num_heads, C, -1).transpose(2, 3)
if norm is not None:
x = norm(x)
if ndim == 3:
x = x.squeeze(1)
return x, out_size
class MultiScaleAttention(BaseModule):
"""Multiscale Multi-head Attention block.
Args:
in_dims (int): Number of input channels.
out_dims (int): Number of output channels.
num_heads (int): Number of attention heads.
qkv_bias (bool): If True, add a learnable bias to query, key and
value. Defaults to True.
norm_cfg (dict): The config of normalization layers.
Defaults to ``dict(type='LN')``.
pool_kernel (tuple): kernel size for qkv pooling layers.
Defaults to (3, 3).
stride_q (int): stride size for q pooling layer. Defaults to 1.
stride_kv (int): stride size for kv pooling layer. Defaults to 1.
rel_pos_spatial (bool): Whether to enable the spatial relative
position embedding. Defaults to True.
residual_pooling (bool): Whether to enable the residual connection
after attention pooling. Defaults to True.
input_size (Tuple[int], optional): The input resolution, necessary
if enable the ``rel_pos_spatial``. Defaults to None.
rel_pos_zero_init (bool): If True, zero initialize relative
positional parameters. Defaults to False.
init_cfg (dict, optional): The config of weight initialization.
Defaults to None.
"""
def __init__(self,
in_dims,
out_dims,
num_heads,
qkv_bias=True,
norm_cfg=dict(type='LN'),
pool_kernel=(3, 3),
stride_q=1,
stride_kv=1,
rel_pos_spatial=False,
residual_pooling=True,
input_size=None,
rel_pos_zero_init=False,
init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.num_heads = num_heads
self.in_dims = in_dims
self.out_dims = out_dims
head_dim = out_dims // num_heads
self.scale = head_dim**-0.5
self.qkv = nn.Linear(in_dims, out_dims * 3, bias=qkv_bias)
self.proj = nn.Linear(out_dims, out_dims)
# qkv pooling
pool_padding = [k // 2 for k in pool_kernel]
pool_dims = out_dims // num_heads
def build_pooling(stride):
pool = nn.Conv2d(
pool_dims,
pool_dims,
pool_kernel,
stride=stride,
padding=pool_padding,
groups=pool_dims,
bias=False,
)
norm = build_norm_layer(norm_cfg, pool_dims)[1]
return pool, norm
self.pool_q, self.norm_q = build_pooling(stride_q)
self.pool_k, self.norm_k = build_pooling(stride_kv)
self.pool_v, self.norm_v = build_pooling(stride_kv)
self.residual_pooling = residual_pooling
self.rel_pos_spatial = rel_pos_spatial
self.rel_pos_zero_init = rel_pos_zero_init
if self.rel_pos_spatial:
# initialize relative positional embeddings
assert input_size[0] == input_size[1]
size = input_size[0]
rel_dim = 2 * max(size // stride_q, size // stride_kv) - 1
self.rel_pos_h = nn.Parameter(torch.zeros(rel_dim, head_dim))
self.rel_pos_w = nn.Parameter(torch.zeros(rel_dim, head_dim))
def init_weights(self):
"""Weight initialization."""
super().init_weights()
if (isinstance(self.init_cfg, dict)
and self.init_cfg['type'] == 'Pretrained'):
# Suppress rel_pos_zero_init if use pretrained model.
return
if not self.rel_pos_zero_init:
trunc_normal_(self.rel_pos_h, std=0.02)
trunc_normal_(self.rel_pos_w, std=0.02)
def forward(self, x, in_size):
"""Forward the MultiScaleAttention."""
B, N, _ = x.shape # (B, H*W, C)
# qkv: (B, H*W, 3, num_heads, C)
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, -1)
# q, k, v: (B, num_heads, H*W, C)
q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(0)
q, q_shape = attention_pool(q, self.pool_q, in_size, norm=self.norm_q)
k, k_shape = attention_pool(k, self.pool_k, in_size, norm=self.norm_k)
v, v_shape = attention_pool(v, self.pool_v, in_size, norm=self.norm_v)
attn = (q * self.scale) @ k.transpose(-2, -1)
if self.rel_pos_spatial:
attn = add_decomposed_rel_pos(attn, q, q_shape, k_shape,
self.rel_pos_h, self.rel_pos_w)
attn = attn.softmax(dim=-1)
x = attn @ v
if self.residual_pooling:
x = x + q
# (B, num_heads, H'*W', C'//num_heads) -> (B, H'*W', C')
x = x.transpose(1, 2).reshape(B, -1, self.out_dims)
x = self.proj(x)
return x, q_shape
class MultiScaleBlock(BaseModule):
"""Multiscale Transformer blocks.
Args:
in_dims (int): Number of input channels.
out_dims (int): Number of output channels.
num_heads (int): Number of attention heads.
mlp_ratio (float): Ratio of hidden dimensions in MLP layers.
Defaults to 4.0.
qkv_bias (bool): If True, add a learnable bias to query, key and
value. Defaults to True.
drop_path (float): Stochastic depth rate. Defaults to 0.
norm_cfg (dict): The config of normalization layers.
Defaults to ``dict(type='LN')``.
act_cfg (dict): The config of activation function.
Defaults to ``dict(type='GELU')``.
qkv_pool_kernel (tuple): kernel size for qkv pooling layers.
Defaults to (3, 3).
stride_q (int): stride size for q pooling layer. Defaults to 1.
stride_kv (int): stride size for kv pooling layer. Defaults to 1.
rel_pos_spatial (bool): Whether to enable the spatial relative
position embedding. Defaults to True.
residual_pooling (bool): Whether to enable the residual connection
after attention pooling. Defaults to True.
dim_mul_in_attention (bool): Whether to multiply the ``embed_dims`` in
attention layers. If False, multiply it in MLP layers.
Defaults to True.
input_size (Tuple[int], optional): The input resolution, necessary
if enable the ``rel_pos_spatial``. Defaults to None.
rel_pos_zero_init (bool): If True, zero initialize relative
positional parameters. Defaults to False.
init_cfg (dict, optional): The config of weight initialization.
Defaults to None.
"""
def __init__(
self,
in_dims,
out_dims,
num_heads,
mlp_ratio=4.0,
qkv_bias=True,
drop_path=0.0,
norm_cfg=dict(type='LN'),
act_cfg=dict(type='GELU'),
qkv_pool_kernel=(3, 3),
stride_q=1,
stride_kv=1,
rel_pos_spatial=True,
residual_pooling=True,
dim_mul_in_attention=True,
input_size=None,
rel_pos_zero_init=False,
init_cfg=None,
):
super().__init__(init_cfg=init_cfg)
self.in_dims = in_dims
self.out_dims = out_dims
self.norm1 = build_norm_layer(norm_cfg, in_dims)[1]
self.dim_mul_in_attention = dim_mul_in_attention
attn_dims = out_dims if dim_mul_in_attention else in_dims
self.attn = MultiScaleAttention(
in_dims,
attn_dims,
num_heads=num_heads,
qkv_bias=qkv_bias,
norm_cfg=norm_cfg,
pool_kernel=qkv_pool_kernel,
stride_q=stride_q,
stride_kv=stride_kv,
rel_pos_spatial=rel_pos_spatial,
residual_pooling=residual_pooling,
input_size=input_size,
rel_pos_zero_init=rel_pos_zero_init)
self.drop_path = DropPath(
drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = build_norm_layer(norm_cfg, attn_dims)[1]
self.mlp = MLP(
in_channels=attn_dims,
hidden_channels=int(attn_dims * mlp_ratio),
out_channels=out_dims,
act_cfg=act_cfg)
if in_dims != out_dims:
self.proj = nn.Linear(in_dims, out_dims)
else:
self.proj = None
if stride_q > 1:
kernel_skip = stride_q + 1
padding_skip = int(kernel_skip // 2)
self.pool_skip = nn.MaxPool2d(
kernel_skip, stride_q, padding_skip, ceil_mode=False)
if input_size is not None:
input_size = to_2tuple(input_size)
out_size = [size // stride_q for size in input_size]
self.init_out_size = out_size
else:
self.init_out_size = None
else:
self.pool_skip = None
self.init_out_size = input_size
def forward(self, x, in_size):
x_norm = self.norm1(x)
x_attn, out_size = self.attn(x_norm, in_size)
if self.dim_mul_in_attention and self.proj is not None:
skip = self.proj(x_norm)
else:
skip = x
if self.pool_skip is not None:
skip, _ = attention_pool(skip, self.pool_skip, in_size)
x = skip + self.drop_path(x_attn)
x_norm = self.norm2(x)
x_mlp = self.mlp(x_norm)
if not self.dim_mul_in_attention and self.proj is not None:
skip = self.proj(x_norm)
else:
skip = x
x = skip + self.drop_path(x_mlp)
return x, out_size
@BACKBONES.register_module()
class MViT(BaseBackbone):
"""Multi-scale ViT v2.
A PyTorch implement of : `MViTv2: Improved Multiscale Vision Transformers
for Classification and Detection <https://arxiv.org/abs/2112.01526>`_
Inspiration from `the official implementation
<https://github.com/facebookresearch/mvit>`_ and `the detectron2
implementation <https://github.com/facebookresearch/detectron2>`_
Args:
arch (str | dict): MViT architecture. If use string, choose
from 'tiny', 'small', 'base' and 'large'. If use dict, it should
have below keys:
- **embed_dims** (int): The dimensions of embedding.
- **num_layers** (int): The number of layers.
- **num_heads** (int): The number of heads in attention
modules of the initial layer.
- **downscale_indices** (List[int]): The layer indices to downscale
the feature map.
Defaults to 'base'.
img_size (int): The expected input image shape. Defaults to 224.
in_channels (int): The num of input channels. Defaults to 3.
out_scales (int | Sequence[int]): The output scale indices.
They should not exceed the length of ``downscale_indices``.
Defaults to -1, which means the last scale.
drop_path_rate (float): Stochastic depth rate. Defaults to 0.1.
use_abs_pos_embed (bool): If True, add absolute position embedding to
the patch embedding. Defaults to False.
interpolate_mode (str): Select the interpolate mode for absolute
position embedding vector resize. Defaults to "bicubic".
pool_kernel (tuple): kernel size for qkv pooling layers.
Defaults to (3, 3).
dim_mul (int): The magnification for ``embed_dims`` in the downscale
layers. Defaults to 2.
head_mul (int): The magnification for ``num_heads`` in the downscale
layers. Defaults to 2.
adaptive_kv_stride (int): The stride size for kv pooling in the initial
layer. Defaults to 4.
rel_pos_spatial (bool): Whether to enable the spatial relative position
embedding. Defaults to True.
residual_pooling (bool): Whether to enable the residual connection
after attention pooling. Defaults to True.
dim_mul_in_attention (bool): Whether to multiply the ``embed_dims`` in
attention layers. If False, multiply it in MLP layers.
Defaults to True.
rel_pos_zero_init (bool): If True, zero initialize relative
positional parameters. Defaults to False.
mlp_ratio (float): Ratio of hidden dimensions in MLP layers.
Defaults to 4.0.
qkv_bias (bool): enable bias for qkv if True. Defaults to True.
norm_cfg (dict): Config dict for normalization layer for all output
features. Defaults to ``dict(type='LN', eps=1e-6)``.
patch_cfg (dict): Config dict for the patch embedding layer.
Defaults to ``dict(kernel_size=7, stride=4, padding=3)``.
init_cfg (dict, optional): The Config for initialization.
Defaults to None.
Examples:
>>> import torch
>>> from mmcls.models import build_backbone
>>>
>>> cfg = dict(type='MViT', arch='tiny', out_scales=[0, 1, 2, 3])
>>> model = build_backbone(cfg)
>>> inputs = torch.rand(1, 3, 224, 224)
>>> outputs = model(inputs)
>>> for i, output in enumerate(outputs):
>>> print(f'scale{i}: {output.shape}')
scale0: torch.Size([1, 96, 56, 56])
scale1: torch.Size([1, 192, 28, 28])
scale2: torch.Size([1, 384, 14, 14])
scale3: torch.Size([1, 768, 7, 7])
"""
arch_zoo = {
'tiny': {
'embed_dims': 96,
'num_layers': 10,
'num_heads': 1,
'downscale_indices': [1, 3, 8]
},
'small': {
'embed_dims': 96,
'num_layers': 16,
'num_heads': 1,
'downscale_indices': [1, 3, 14]
},
'base': {
'embed_dims': 96,
'num_layers': 24,
'num_heads': 1,
'downscale_indices': [2, 5, 21]
},
'large': {
'embed_dims': 144,
'num_layers': 48,
'num_heads': 2,
'downscale_indices': [2, 8, 44]
},
}
num_extra_tokens = 0
def __init__(self,
arch='base',
img_size=224,
in_channels=3,
out_scales=-1,
drop_path_rate=0.,
use_abs_pos_embed=False,
interpolate_mode='bicubic',
pool_kernel=(3, 3),
dim_mul=2,
head_mul=2,
adaptive_kv_stride=4,
rel_pos_spatial=True,
residual_pooling=True,
dim_mul_in_attention=True,
rel_pos_zero_init=False,
mlp_ratio=4.,
qkv_bias=True,
norm_cfg=dict(type='LN', eps=1e-6),
patch_cfg=dict(kernel_size=7, stride=4, padding=3),
init_cfg=None):
super().__init__(init_cfg)
if isinstance(arch, str):
arch = arch.lower()
assert arch in set(self.arch_zoo), \
f'Arch {arch} is not in default archs {set(self.arch_zoo)}'
self.arch_settings = self.arch_zoo[arch]
else:
essential_keys = {
'embed_dims', 'num_layers', 'num_heads', 'downscale_indices'
}
assert isinstance(arch, dict) and essential_keys <= set(arch), \
f'Custom arch needs a dict with keys {essential_keys}'
self.arch_settings = arch
self.embed_dims = self.arch_settings['embed_dims']
self.num_layers = self.arch_settings['num_layers']
self.num_heads = self.arch_settings['num_heads']
self.downscale_indices = self.arch_settings['downscale_indices']
self.num_scales = len(self.downscale_indices) + 1
self.stage_indices = {
index - 1: i
for i, index in enumerate(self.downscale_indices)
}
self.stage_indices[self.num_layers - 1] = self.num_scales - 1
self.use_abs_pos_embed = use_abs_pos_embed
self.interpolate_mode = interpolate_mode
if isinstance(out_scales, int):
out_scales = [out_scales]
assert isinstance(out_scales, Sequence), \
f'"out_scales" must by a sequence or int, ' \
f'get {type(out_scales)} instead.'
for i, index in enumerate(out_scales):
if index < 0:
out_scales[i] = self.num_scales + index
assert 0 <= out_scales[i] <= self.num_scales, \
f'Invalid out_scales {index}'
self.out_scales = sorted(list(out_scales))
# Set patch embedding
_patch_cfg = dict(
in_channels=in_channels,
input_size=img_size,
embed_dims=self.embed_dims,
conv_type='Conv2d',
)
_patch_cfg.update(patch_cfg)
self.patch_embed = PatchEmbed(**_patch_cfg)
self.patch_resolution = self.patch_embed.init_out_size
# Set absolute position embedding
if self.use_abs_pos_embed:
num_patches = self.patch_resolution[0] * self.patch_resolution[1]
self.pos_embed = nn.Parameter(
torch.zeros(1, num_patches, self.embed_dims))
# stochastic depth decay rule
dpr = np.linspace(0, drop_path_rate, self.num_layers)
self.blocks = ModuleList()
out_dims_list = [self.embed_dims]
num_heads = self.num_heads
stride_kv = adaptive_kv_stride
input_size = self.patch_resolution
for i in range(self.num_layers):
if i in self.downscale_indices:
num_heads *= head_mul
stride_q = 2
stride_kv = max(stride_kv // 2, 1)
else:
stride_q = 1
# Set output embed_dims
if dim_mul_in_attention and i in self.downscale_indices:
# multiply embed_dims in downscale layers.
out_dims = out_dims_list[-1] * dim_mul
elif not dim_mul_in_attention and i + 1 in self.downscale_indices:
# multiply embed_dims before downscale layers.
out_dims = out_dims_list[-1] * dim_mul
else:
out_dims = out_dims_list[-1]
attention_block = MultiScaleBlock(
in_dims=out_dims_list[-1],
out_dims=out_dims,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop_path=dpr[i],
norm_cfg=norm_cfg,
qkv_pool_kernel=pool_kernel,
stride_q=stride_q,
stride_kv=stride_kv,
rel_pos_spatial=rel_pos_spatial,
residual_pooling=residual_pooling,
dim_mul_in_attention=dim_mul_in_attention,
input_size=input_size,
rel_pos_zero_init=rel_pos_zero_init)
self.blocks.append(attention_block)
input_size = attention_block.init_out_size
out_dims_list.append(out_dims)
if i in self.stage_indices:
stage_index = self.stage_indices[i]
if stage_index in self.out_scales:
norm_layer = build_norm_layer(norm_cfg, out_dims)[1]
self.add_module(f'norm{stage_index}', norm_layer)
def init_weights(self):
super().init_weights()
if (isinstance(self.init_cfg, dict)
and self.init_cfg['type'] == 'Pretrained'):
# Suppress default init if use pretrained model.
return
if self.use_abs_pos_embed:
trunc_normal_(self.pos_embed, std=0.02)
def forward(self, x):
"""Forward the MViT."""
B = x.shape[0]
x, patch_resolution = self.patch_embed(x)
if self.use_abs_pos_embed:
x = x + resize_pos_embed(
self.pos_embed,
self.patch_resolution,
patch_resolution,
mode=self.interpolate_mode,
num_extra_tokens=self.num_extra_tokens)
outs = []
for i, block in enumerate(self.blocks):
x, patch_resolution = block(x, patch_resolution)
if i in self.stage_indices:
stage_index = self.stage_indices[i]
if stage_index in self.out_scales:
B, _, C = x.shape
x = getattr(self, f'norm{stage_index}')(x)
out = x.transpose(1, 2).reshape(B, C, *patch_resolution)
outs.append(out.contiguous())
return tuple(outs)