KyanChen's picture
init
f549064
raw
history blame contribute delete
No virus
26.5 kB
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Sequence, Tuple
import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint
from mmcv.cnn.bricks import DropPath, build_activation_layer, build_norm_layer
from mmengine.model import BaseModule, ModuleList, Sequential
from mmengine.registry import MODELS
from torch.nn import functional as F
from ..utils import LeAttention
from .base_backbone import BaseBackbone
class ConvBN2d(Sequential):
"""An implementation of Conv2d + BatchNorm2d with support of fusion.
Modified from
https://github.com/microsoft/Cream/blob/main/TinyViT/models/tiny_vit.py
Args:
in_channels (int): The number of input channels.
out_channels (int): The number of output channels.
kernel_size (int): The size of the convolution kernel.
Default: 1.
stride (int): The stride of the convolution.
Default: 1.
padding (int): The padding of the convolution.
Default: 0.
dilation (int): The dilation of the convolution.
Default: 1.
groups (int): The number of groups in the convolution.
Default: 1.
bn_weight_init (float): The initial value of the weight of
the nn.BatchNorm2d layer. Default: 1.0.
init_cfg (dict): The initialization config of the module.
Default: None.
"""
def __init__(self,
in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0,
dilation=1,
groups=1,
bn_weight_init=1.0,
init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.add_module(
'conv2d',
nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=False))
bn2d = nn.BatchNorm2d(num_features=out_channels)
# bn initialization
torch.nn.init.constant_(bn2d.weight, bn_weight_init)
torch.nn.init.constant_(bn2d.bias, 0)
self.add_module('bn2d', bn2d)
@torch.no_grad()
def fuse(self):
conv2d, bn2d = self._modules.values()
w = bn2d.weight / (bn2d.running_var + bn2d.eps)**0.5
w = conv2d.weight * w[:, None, None, None]
b = bn2d.bias - bn2d.running_mean * bn2d.weight / \
(bn2d.running_var + bn2d.eps)**0.5
m = nn.Conv2d(
in_channels=w.size(1) * self.c.groups,
out_channels=w.size(0),
kernel_size=w.shape[2:],
stride=self.conv2d.stride,
padding=self.conv2d.padding,
dilation=self.conv2d.dilation,
groups=self.conv2d.groups)
m.weight.data.copy_(w)
m.bias.data.copy_(b)
return m
class PatchEmbed(BaseModule):
"""Patch Embedding for Vision Transformer.
Adapted from
https://github.com/microsoft/Cream/blob/main/TinyViT/models/tiny_vit.py
Different from `mmcv.cnn.bricks.transformer.PatchEmbed`, this module use
Conv2d and BatchNorm2d to implement PatchEmbedding, and output shape is
(N, C, H, W).
Args:
in_channels (int): The number of input channels.
embed_dim (int): The embedding dimension.
resolution (Tuple[int, int]): The resolution of the input feature.
act_cfg (dict): The activation config of the module.
Default: dict(type='GELU').
"""
def __init__(self,
in_channels,
embed_dim,
resolution,
act_cfg=dict(type='GELU')):
super().__init__()
img_size: Tuple[int, int] = resolution
self.patches_resolution = (img_size[0] // 4, img_size[1] // 4)
self.num_patches = self.patches_resolution[0] * \
self.patches_resolution[1]
self.in_channels = in_channels
self.embed_dim = embed_dim
self.seq = nn.Sequential(
ConvBN2d(
in_channels,
embed_dim // 2,
kernel_size=3,
stride=2,
padding=1),
build_activation_layer(act_cfg),
ConvBN2d(
embed_dim // 2, embed_dim, kernel_size=3, stride=2, padding=1),
)
def forward(self, x):
return self.seq(x)
class PatchMerging(nn.Module):
"""Patch Merging for TinyViT.
Adapted from
https://github.com/microsoft/Cream/blob/main/TinyViT/models/tiny_vit.py
Different from `mmcls.models.utils.PatchMerging`, this module use Conv2d
and BatchNorm2d to implement PatchMerging.
Args:
in_channels (int): The number of input channels.
resolution (Tuple[int, int]): The resolution of the input feature.
out_channels (int): The number of output channels.
act_cfg (dict): The activation config of the module.
Default: dict(type='GELU').
"""
def __init__(self,
resolution,
in_channels,
out_channels,
act_cfg=dict(type='GELU')):
super().__init__()
self.img_size = resolution
self.act = build_activation_layer(act_cfg)
self.conv1 = ConvBN2d(in_channels, out_channels, kernel_size=1)
self.conv2 = ConvBN2d(
out_channels,
out_channels,
kernel_size=3,
stride=2,
padding=1,
groups=out_channels)
self.conv3 = ConvBN2d(out_channels, out_channels, kernel_size=1)
self.out_resolution = (resolution[0] // 2, resolution[1] // 2)
def forward(self, x):
if len(x.shape) == 3:
H, W = self.img_size
B = x.shape[0]
x = x.view(B, H, W, -1).permute(0, 3, 1, 2)
x = self.conv1(x)
x = self.act(x)
x = self.conv2(x)
x = self.act(x)
x = self.conv3(x)
x = x.flatten(2).transpose(1, 2)
return x
class MBConvBlock(nn.Module):
"""Mobile Inverted Residual Bottleneck Block for TinyViT. Adapted from
https://github.com/microsoft/Cream/blob/main/TinyViT/models/tiny_vit.py.
Args:
in_channels (int): The number of input channels.
out_channels (int): The number of output channels.
expand_ratio (int): The expand ratio of the hidden channels.
drop_rate (float): The drop rate of the block.
act_cfg (dict): The activation config of the module.
Default: dict(type='GELU').
"""
def __init__(self,
in_channels,
out_channels,
expand_ratio,
drop_path,
act_cfg=dict(type='GELU')):
super().__init__()
self.in_channels = in_channels
hidden_channels = int(in_channels * expand_ratio)
# linear
self.conv1 = ConvBN2d(in_channels, hidden_channels, kernel_size=1)
self.act = build_activation_layer(act_cfg)
# depthwise conv
self.conv2 = ConvBN2d(
in_channels=hidden_channels,
out_channels=hidden_channels,
kernel_size=3,
stride=1,
padding=1,
groups=hidden_channels)
# linear
self.conv3 = ConvBN2d(
hidden_channels, out_channels, kernel_size=1, bn_weight_init=0.0)
self.drop_path = DropPath(
drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x):
shortcut = x
x = self.conv1(x)
x = self.act(x)
x = self.conv2(x)
x = self.act(x)
x = self.conv3(x)
x = self.drop_path(x)
x += shortcut
x = self.act(x)
return x
class ConvStage(BaseModule):
"""Convolution Stage for TinyViT.
Adapted from
https://github.com/microsoft/Cream/blob/main/TinyViT/models/tiny_vit.py
Args:
in_channels (int): The number of input channels.
resolution (Tuple[int, int]): The resolution of the input feature.
depth (int): The number of blocks in the stage.
act_cfg (dict): The activation config of the module.
drop_path (float): The drop path of the block.
downsample (None | nn.Module): The downsample operation.
Default: None.
use_checkpoint (bool): Whether to use checkpointing to save memory.
out_channels (int): The number of output channels.
conv_expand_ratio (int): The expand ratio of the hidden channels.
Default: 4.
init_cfg (dict | list[dict], optional): Initialization config dict.
Default: None.
"""
def __init__(self,
in_channels,
resolution,
depth,
act_cfg,
drop_path=0.,
downsample=None,
use_checkpoint=False,
out_channels=None,
conv_expand_ratio=4.,
init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.use_checkpoint = use_checkpoint
# build blocks
self.blocks = ModuleList([
MBConvBlock(
in_channels=in_channels,
out_channels=in_channels,
expand_ratio=conv_expand_ratio,
drop_path=drop_path[i]
if isinstance(drop_path, list) else drop_path)
for i in range(depth)
])
# patch merging layer
if downsample is not None:
self.downsample = downsample(
resolution=resolution,
in_channels=in_channels,
out_channels=out_channels,
act_cfg=act_cfg)
self.resolution = self.downsample.out_resolution
else:
self.downsample = None
self.resolution = resolution
def forward(self, x):
for block in self.blocks:
if self.use_checkpoint:
x = checkpoint.checkpoint(block, x)
else:
x = block(x)
if self.downsample is not None:
x = self.downsample(x)
return x
class MLP(BaseModule):
"""MLP module for TinyViT.
Args:
in_channels (int): The number of input channels.
hidden_channels (int, optional): The number of hidden channels.
Default: None.
out_channels (int, optional): The number of output channels.
Default: None.
act_cfg (dict): The activation config of the module.
Default: dict(type='GELU').
drop (float): Probability of an element to be zeroed.
Default: 0.
init_cfg (dict | list[dict], optional): Initialization config dict.
Default: None.
"""
def __init__(self,
in_channels,
hidden_channels=None,
out_channels=None,
act_cfg=dict(type='GELU'),
drop=0.,
init_cfg=None):
super().__init__(init_cfg=init_cfg)
out_channels = out_channels or in_channels
hidden_channels = hidden_channels or in_channels
self.norm = nn.LayerNorm(in_channels)
self.fc1 = nn.Linear(in_channels, hidden_channels)
self.fc2 = nn.Linear(hidden_channels, out_channels)
self.act = build_activation_layer(act_cfg)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.norm(x)
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class TinyViTBlock(BaseModule):
"""TinViT Block.
Args:
in_channels (int): The number of input channels.
resolution (Tuple[int, int]): The resolution of the input feature.
num_heads (int): The number of heads in the multi-head attention.
window_size (int): The size of the window.
Default: 7.
mlp_ratio (float): The ratio of mlp hidden dim to embedding dim.
Default: 4.
drop (float): Probability of an element to be zeroed.
Default: 0.
drop_path (float): The drop path of the block.
Default: 0.
local_conv_size (int): The size of the local convolution.
Default: 3.
act_cfg (dict): The activation config of the module.
Default: dict(type='GELU').
"""
def __init__(self,
in_channels,
resolution,
num_heads,
window_size=7,
mlp_ratio=4.,
drop=0.,
drop_path=0.,
local_conv_size=3,
act_cfg=dict(type='GELU')):
super().__init__()
self.in_channels = in_channels
self.img_size = resolution
self.num_heads = num_heads
assert window_size > 0, 'window_size must be greater than 0'
self.window_size = window_size
self.mlp_ratio = mlp_ratio
self.drop_path = DropPath(
drop_path) if drop_path > 0. else nn.Identity()
assert in_channels % num_heads == 0, \
'dim must be divisible by num_heads'
head_dim = in_channels // num_heads
window_resolution = (window_size, window_size)
self.attn = LeAttention(
in_channels,
head_dim,
num_heads,
attn_ratio=1,
resolution=window_resolution)
mlp_hidden_dim = int(in_channels * mlp_ratio)
self.mlp = MLP(
in_channels=in_channels,
hidden_channels=mlp_hidden_dim,
act_cfg=act_cfg,
drop=drop)
self.local_conv = ConvBN2d(
in_channels=in_channels,
out_channels=in_channels,
kernel_size=local_conv_size,
stride=1,
padding=local_conv_size // 2,
groups=in_channels)
def forward(self, x):
H, W = self.img_size
B, L, C = x.shape
assert L == H * W, 'input feature has wrong size'
res_x = x
if H == self.window_size and W == self.window_size:
x = self.attn(x)
else:
x = x.view(B, H, W, C)
pad_b = (self.window_size -
H % self.window_size) % self.window_size
pad_r = (self.window_size -
W % self.window_size) % self.window_size
padding = pad_b > 0 or pad_r > 0
if padding:
x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b))
pH, pW = H + pad_b, W + pad_r
nH = pH // self.window_size
nW = pW // self.window_size
# window partition
x = x.view(B, nH, self.window_size, nW, self.window_size,
C).transpose(2, 3).reshape(
B * nH * nW, self.window_size * self.window_size, C)
x = self.attn(x)
# window reverse
x = x.view(B, nH, nW, self.window_size, self.window_size,
C).transpose(2, 3).reshape(B, pH, pW, C)
if padding:
x = x[:, :H, :W].contiguous()
x = x.view(B, L, C)
x = res_x + self.drop_path(x)
x = x.transpose(1, 2).reshape(B, C, H, W)
x = self.local_conv(x)
x = x.view(B, C, L).transpose(1, 2)
x = x + self.drop_path(self.mlp(x))
return x
class BasicStage(BaseModule):
"""Basic Stage for TinyViT.
Args:
in_channels (int): The number of input channels.
resolution (Tuple[int, int]): The resolution of the input feature.
depth (int): The number of blocks in the stage.
num_heads (int): The number of heads in the multi-head attention.
window_size (int): The size of the window.
mlp_ratio (float): The ratio of mlp hidden dim to embedding dim.
Default: 4.
drop (float): Probability of an element to be zeroed.
Default: 0.
drop_path (float): The drop path of the block.
Default: 0.
downsample (None | nn.Module): The downsample operation.
Default: None.
use_checkpoint (bool): Whether to use checkpointing to save memory.
Default: False.
act_cfg (dict): The activation config of the module.
Default: dict(type='GELU').
init_cfg (dict | list[dict], optional): Initialization config dict.
Default: None.
"""
def __init__(self,
in_channels,
resolution,
depth,
num_heads,
window_size,
mlp_ratio=4.,
drop=0.,
drop_path=0.,
downsample=None,
use_checkpoint=False,
local_conv_size=3,
out_channels=None,
act_cfg=dict(type='GELU'),
init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.use_checkpoint = use_checkpoint
# build blocks
self.blocks = ModuleList([
TinyViTBlock(
in_channels=in_channels,
resolution=resolution,
num_heads=num_heads,
window_size=window_size,
mlp_ratio=mlp_ratio,
drop=drop,
local_conv_size=local_conv_size,
act_cfg=act_cfg,
drop_path=drop_path[i]
if isinstance(drop_path, list) else drop_path)
for i in range(depth)
])
# build patch merging layer
if downsample is not None:
self.downsample = downsample(
resolution=resolution,
in_channels=in_channels,
out_channels=out_channels,
act_cfg=act_cfg)
self.resolution = self.downsample.out_resolution
else:
self.downsample = None
self.resolution = resolution
def forward(self, x):
for block in self.blocks:
if self.use_checkpoint:
x = checkpoint.checkpoint(block, x)
else:
x = block(x)
if self.downsample is not None:
x = self.downsample(x)
return x
@MODELS.register_module()
class TinyViT(BaseBackbone):
"""TinyViT.
A PyTorch implementation of : `TinyViT: Fast Pretraining Distillation
for Small Vision Transformers<https://arxiv.org/abs/2201.03545v1>`_
Inspiration from
https://github.com/microsoft/Cream/blob/main/TinyViT
Args:
arch (str | dict): The architecture of TinyViT.
Default: '5m'.
img_size (tuple | int): The resolution of the input image.
Default: (224, 224)
window_size (list): The size of the window.
Default: [7, 7, 14, 7]
in_channels (int): The number of input channels.
Default: 3.
depths (list[int]): The depth of each stage.
Default: [2, 2, 6, 2].
mlp_ratio (list[int]): The ratio of mlp hidden dim to embedding dim.
Default: 4.
drop_rate (float): Probability of an element to be zeroed.
Default: 0.
drop_path_rate (float): The drop path of the block.
Default: 0.1.
use_checkpoint (bool): Whether to use checkpointing to save memory.
Default: False.
mbconv_expand_ratio (int): The expand ratio of the mbconv.
Default: 4.0
local_conv_size (int): The size of the local conv.
Default: 3.
layer_lr_decay (float): The layer lr decay.
Default: 1.0
out_indices (int | list[int]): Output from which stages.
Default: -1
frozen_stages (int | list[int]): Stages to be frozen (all param fixed).
Default: -0
gap_before_final_nrom (bool): Whether to add a gap before the final
norm. Default: True.
act_cfg (dict): The activation config of the module.
Default: dict(type='GELU').
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='LN').
init_cfg (dict | list[dict], optional): Initialization config dict.
Default: None.
"""
arch_settings = {
'5m': {
'channels': [64, 128, 160, 320],
'num_heads': [2, 4, 5, 10],
'depths': [2, 2, 6, 2],
},
'11m': {
'channels': [64, 128, 256, 448],
'num_heads': [2, 4, 8, 14],
'depths': [2, 2, 6, 2],
},
'21m': {
'channels': [96, 192, 384, 576],
'num_heads': [3, 6, 12, 18],
'depths': [2, 2, 6, 2],
},
}
def __init__(self,
arch='5m',
img_size=(224, 224),
window_size=[7, 7, 14, 7],
in_channels=3,
mlp_ratio=4.,
drop_rate=0.,
drop_path_rate=0.1,
use_checkpoint=False,
mbconv_expand_ratio=4.0,
local_conv_size=3,
layer_lr_decay=1.0,
out_indices=-1,
frozen_stages=0,
gap_before_final_norm=True,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='LN'),
init_cfg=None):
super().__init__(init_cfg=init_cfg)
if isinstance(arch, str):
assert arch in self.arch_settings, \
f'Unavaiable arch, please choose from ' \
f'({set(self.arch_settings)} or pass a dict.'
arch = self.arch_settings[arch]
elif isinstance(arch, dict):
assert 'channels' in arch and 'num_heads' in arch and \
'depths' in arch, 'The arch dict must have' \
f'"channels", "num_heads", "window_sizes" ' \
f'keys, but got {arch.keys()}'
self.channels = arch['channels']
self.num_heads = arch['num_heads']
self.widow_sizes = window_size
self.img_size = img_size
self.depths = arch['depths']
self.num_stages = len(self.channels)
if isinstance(out_indices, int):
out_indices = [out_indices]
assert isinstance(out_indices, Sequence), \
f'"out_indices" must by a sequence or int, ' \
f'get {type(out_indices)} instead.'
for i, index in enumerate(out_indices):
if index < 0:
out_indices[i] = 4 + index
assert out_indices[i] >= 0, f'Invalid out_indices {index}'
self.out_indices = out_indices
self.frozen_stages = frozen_stages
self.gap_before_final_norm = gap_before_final_norm
self.layer_lr_decay = layer_lr_decay
self.patch_embed = PatchEmbed(
in_channels=in_channels,
embed_dim=self.channels[0],
resolution=self.img_size,
act_cfg=dict(type='GELU'))
patches_resolution = self.patch_embed.patches_resolution
# stochastic depth decay rule
dpr = [
x.item()
for x in torch.linspace(0, drop_path_rate, sum(self.depths))
]
# build stages
self.stages = ModuleList()
for i in range(self.num_stages):
depth = self.depths[i]
channel = self.channels[i]
curr_resolution = (patches_resolution[0] // (2**i),
patches_resolution[1] // (2**i))
drop_path = dpr[sum(self.depths[:i]):sum(self.depths[:i + 1])]
downsample = PatchMerging if (i < self.num_stages - 1) else None
out_channels = self.channels[min(i + 1, self.num_stages - 1)]
if i >= 1:
stage = BasicStage(
in_channels=channel,
resolution=curr_resolution,
depth=depth,
num_heads=self.num_heads[i],
window_size=self.widow_sizes[i],
mlp_ratio=mlp_ratio,
drop=drop_rate,
drop_path=drop_path,
downsample=downsample,
use_checkpoint=use_checkpoint,
local_conv_size=local_conv_size,
out_channels=out_channels,
act_cfg=act_cfg)
else:
stage = ConvStage(
in_channels=channel,
resolution=curr_resolution,
depth=depth,
act_cfg=act_cfg,
drop_path=drop_path,
downsample=downsample,
use_checkpoint=use_checkpoint,
out_channels=out_channels,
conv_expand_ratio=mbconv_expand_ratio)
self.stages.append(stage)
# add output norm
if i in self.out_indices:
norm_layer = build_norm_layer(norm_cfg, out_channels)[1]
self.add_module(f'norm{i}', norm_layer)
def set_layer_lr_decay(self, layer_lr_decay):
# TODO: add layer_lr_decay
pass
def forward(self, x):
outs = []
x = self.patch_embed(x)
for i, stage in enumerate(self.stages):
x = stage(x)
if i in self.out_indices:
norm_layer = getattr(self, f'norm{i}')
if self.gap_before_final_norm:
gap = x.mean(1)
outs.append(norm_layer(gap))
else:
out = norm_layer(x)
# convert the (B,L,C) format into (B,C,H,W) format
# which would be better for the downstream tasks.
B, L, C = out.shape
out = out.view(B, *stage.resolution, C)
outs.append(out.permute(0, 3, 1, 2))
return tuple(outs)
def _freeze_stages(self):
for i in range(self.frozen_stages):
stage = self.stages[i]
stage.eval()
for param in stage.parameters():
param.requires_grad = False
def train(self, mode=True):
super(TinyViT, self).train(mode)
self._freeze_stages()