# Copyright (c) OpenMMLab. All rights reserved. from typing import Sequence, Tuple import torch import torch.nn as nn import torch.utils.checkpoint as checkpoint from mmcv.cnn.bricks import DropPath, build_activation_layer, build_norm_layer from mmengine.model import BaseModule, ModuleList, Sequential from mmengine.registry import MODELS from torch.nn import functional as F from ..utils import LeAttention from .base_backbone import BaseBackbone class ConvBN2d(Sequential): """An implementation of Conv2d + BatchNorm2d with support of fusion. Modified from https://github.com/microsoft/Cream/blob/main/TinyViT/models/tiny_vit.py Args: in_channels (int): The number of input channels. out_channels (int): The number of output channels. kernel_size (int): The size of the convolution kernel. Default: 1. stride (int): The stride of the convolution. Default: 1. padding (int): The padding of the convolution. Default: 0. dilation (int): The dilation of the convolution. Default: 1. groups (int): The number of groups in the convolution. Default: 1. bn_weight_init (float): The initial value of the weight of the nn.BatchNorm2d layer. Default: 1.0. init_cfg (dict): The initialization config of the module. Default: None. """ def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, groups=1, bn_weight_init=1.0, init_cfg=None): super().__init__(init_cfg=init_cfg) self.add_module( 'conv2d', nn.Conv2d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=False)) bn2d = nn.BatchNorm2d(num_features=out_channels) # bn initialization torch.nn.init.constant_(bn2d.weight, bn_weight_init) torch.nn.init.constant_(bn2d.bias, 0) self.add_module('bn2d', bn2d) @torch.no_grad() def fuse(self): conv2d, bn2d = self._modules.values() w = bn2d.weight / (bn2d.running_var + bn2d.eps)**0.5 w = conv2d.weight * w[:, None, None, None] b = bn2d.bias - bn2d.running_mean * bn2d.weight / \ (bn2d.running_var + bn2d.eps)**0.5 m = nn.Conv2d( in_channels=w.size(1) * self.c.groups, out_channels=w.size(0), kernel_size=w.shape[2:], stride=self.conv2d.stride, padding=self.conv2d.padding, dilation=self.conv2d.dilation, groups=self.conv2d.groups) m.weight.data.copy_(w) m.bias.data.copy_(b) return m class PatchEmbed(BaseModule): """Patch Embedding for Vision Transformer. Adapted from https://github.com/microsoft/Cream/blob/main/TinyViT/models/tiny_vit.py Different from `mmcv.cnn.bricks.transformer.PatchEmbed`, this module use Conv2d and BatchNorm2d to implement PatchEmbedding, and output shape is (N, C, H, W). Args: in_channels (int): The number of input channels. embed_dim (int): The embedding dimension. resolution (Tuple[int, int]): The resolution of the input feature. act_cfg (dict): The activation config of the module. Default: dict(type='GELU'). """ def __init__(self, in_channels, embed_dim, resolution, act_cfg=dict(type='GELU')): super().__init__() img_size: Tuple[int, int] = resolution self.patches_resolution = (img_size[0] // 4, img_size[1] // 4) self.num_patches = self.patches_resolution[0] * \ self.patches_resolution[1] self.in_channels = in_channels self.embed_dim = embed_dim self.seq = nn.Sequential( ConvBN2d( in_channels, embed_dim // 2, kernel_size=3, stride=2, padding=1), build_activation_layer(act_cfg), ConvBN2d( embed_dim // 2, embed_dim, kernel_size=3, stride=2, padding=1), ) def forward(self, x): return self.seq(x) class PatchMerging(nn.Module): """Patch Merging for TinyViT. Adapted from https://github.com/microsoft/Cream/blob/main/TinyViT/models/tiny_vit.py Different from `mmcls.models.utils.PatchMerging`, this module use Conv2d and BatchNorm2d to implement PatchMerging. Args: in_channels (int): The number of input channels. resolution (Tuple[int, int]): The resolution of the input feature. out_channels (int): The number of output channels. act_cfg (dict): The activation config of the module. Default: dict(type='GELU'). """ def __init__(self, resolution, in_channels, out_channels, act_cfg=dict(type='GELU')): super().__init__() self.img_size = resolution self.act = build_activation_layer(act_cfg) self.conv1 = ConvBN2d(in_channels, out_channels, kernel_size=1) self.conv2 = ConvBN2d( out_channels, out_channels, kernel_size=3, stride=2, padding=1, groups=out_channels) self.conv3 = ConvBN2d(out_channels, out_channels, kernel_size=1) self.out_resolution = (resolution[0] // 2, resolution[1] // 2) def forward(self, x): if len(x.shape) == 3: H, W = self.img_size B = x.shape[0] x = x.view(B, H, W, -1).permute(0, 3, 1, 2) x = self.conv1(x) x = self.act(x) x = self.conv2(x) x = self.act(x) x = self.conv3(x) x = x.flatten(2).transpose(1, 2) return x class MBConvBlock(nn.Module): """Mobile Inverted Residual Bottleneck Block for TinyViT. Adapted from https://github.com/microsoft/Cream/blob/main/TinyViT/models/tiny_vit.py. Args: in_channels (int): The number of input channels. out_channels (int): The number of output channels. expand_ratio (int): The expand ratio of the hidden channels. drop_rate (float): The drop rate of the block. act_cfg (dict): The activation config of the module. Default: dict(type='GELU'). """ def __init__(self, in_channels, out_channels, expand_ratio, drop_path, act_cfg=dict(type='GELU')): super().__init__() self.in_channels = in_channels hidden_channels = int(in_channels * expand_ratio) # linear self.conv1 = ConvBN2d(in_channels, hidden_channels, kernel_size=1) self.act = build_activation_layer(act_cfg) # depthwise conv self.conv2 = ConvBN2d( in_channels=hidden_channels, out_channels=hidden_channels, kernel_size=3, stride=1, padding=1, groups=hidden_channels) # linear self.conv3 = ConvBN2d( hidden_channels, out_channels, kernel_size=1, bn_weight_init=0.0) self.drop_path = DropPath( drop_path) if drop_path > 0. else nn.Identity() def forward(self, x): shortcut = x x = self.conv1(x) x = self.act(x) x = self.conv2(x) x = self.act(x) x = self.conv3(x) x = self.drop_path(x) x += shortcut x = self.act(x) return x class ConvStage(BaseModule): """Convolution Stage for TinyViT. Adapted from https://github.com/microsoft/Cream/blob/main/TinyViT/models/tiny_vit.py Args: in_channels (int): The number of input channels. resolution (Tuple[int, int]): The resolution of the input feature. depth (int): The number of blocks in the stage. act_cfg (dict): The activation config of the module. drop_path (float): The drop path of the block. downsample (None | nn.Module): The downsample operation. Default: None. use_checkpoint (bool): Whether to use checkpointing to save memory. out_channels (int): The number of output channels. conv_expand_ratio (int): The expand ratio of the hidden channels. Default: 4. init_cfg (dict | list[dict], optional): Initialization config dict. Default: None. """ def __init__(self, in_channels, resolution, depth, act_cfg, drop_path=0., downsample=None, use_checkpoint=False, out_channels=None, conv_expand_ratio=4., init_cfg=None): super().__init__(init_cfg=init_cfg) self.use_checkpoint = use_checkpoint # build blocks self.blocks = ModuleList([ MBConvBlock( in_channels=in_channels, out_channels=in_channels, expand_ratio=conv_expand_ratio, drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path) for i in range(depth) ]) # patch merging layer if downsample is not None: self.downsample = downsample( resolution=resolution, in_channels=in_channels, out_channels=out_channels, act_cfg=act_cfg) self.resolution = self.downsample.out_resolution else: self.downsample = None self.resolution = resolution def forward(self, x): for block in self.blocks: if self.use_checkpoint: x = checkpoint.checkpoint(block, x) else: x = block(x) if self.downsample is not None: x = self.downsample(x) return x class MLP(BaseModule): """MLP module for TinyViT. Args: in_channels (int): The number of input channels. hidden_channels (int, optional): The number of hidden channels. Default: None. out_channels (int, optional): The number of output channels. Default: None. act_cfg (dict): The activation config of the module. Default: dict(type='GELU'). drop (float): Probability of an element to be zeroed. Default: 0. init_cfg (dict | list[dict], optional): Initialization config dict. Default: None. """ def __init__(self, in_channels, hidden_channels=None, out_channels=None, act_cfg=dict(type='GELU'), drop=0., init_cfg=None): super().__init__(init_cfg=init_cfg) out_channels = out_channels or in_channels hidden_channels = hidden_channels or in_channels self.norm = nn.LayerNorm(in_channels) self.fc1 = nn.Linear(in_channels, hidden_channels) self.fc2 = nn.Linear(hidden_channels, out_channels) self.act = build_activation_layer(act_cfg) self.drop = nn.Dropout(drop) def forward(self, x): x = self.norm(x) x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x class TinyViTBlock(BaseModule): """TinViT Block. Args: in_channels (int): The number of input channels. resolution (Tuple[int, int]): The resolution of the input feature. num_heads (int): The number of heads in the multi-head attention. window_size (int): The size of the window. Default: 7. mlp_ratio (float): The ratio of mlp hidden dim to embedding dim. Default: 4. drop (float): Probability of an element to be zeroed. Default: 0. drop_path (float): The drop path of the block. Default: 0. local_conv_size (int): The size of the local convolution. Default: 3. act_cfg (dict): The activation config of the module. Default: dict(type='GELU'). """ def __init__(self, in_channels, resolution, num_heads, window_size=7, mlp_ratio=4., drop=0., drop_path=0., local_conv_size=3, act_cfg=dict(type='GELU')): super().__init__() self.in_channels = in_channels self.img_size = resolution self.num_heads = num_heads assert window_size > 0, 'window_size must be greater than 0' self.window_size = window_size self.mlp_ratio = mlp_ratio self.drop_path = DropPath( drop_path) if drop_path > 0. else nn.Identity() assert in_channels % num_heads == 0, \ 'dim must be divisible by num_heads' head_dim = in_channels // num_heads window_resolution = (window_size, window_size) self.attn = LeAttention( in_channels, head_dim, num_heads, attn_ratio=1, resolution=window_resolution) mlp_hidden_dim = int(in_channels * mlp_ratio) self.mlp = MLP( in_channels=in_channels, hidden_channels=mlp_hidden_dim, act_cfg=act_cfg, drop=drop) self.local_conv = ConvBN2d( in_channels=in_channels, out_channels=in_channels, kernel_size=local_conv_size, stride=1, padding=local_conv_size // 2, groups=in_channels) def forward(self, x): H, W = self.img_size B, L, C = x.shape assert L == H * W, 'input feature has wrong size' res_x = x if H == self.window_size and W == self.window_size: x = self.attn(x) else: x = x.view(B, H, W, C) pad_b = (self.window_size - H % self.window_size) % self.window_size pad_r = (self.window_size - W % self.window_size) % self.window_size padding = pad_b > 0 or pad_r > 0 if padding: x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b)) pH, pW = H + pad_b, W + pad_r nH = pH // self.window_size nW = pW // self.window_size # window partition x = x.view(B, nH, self.window_size, nW, self.window_size, C).transpose(2, 3).reshape( B * nH * nW, self.window_size * self.window_size, C) x = self.attn(x) # window reverse x = x.view(B, nH, nW, self.window_size, self.window_size, C).transpose(2, 3).reshape(B, pH, pW, C) if padding: x = x[:, :H, :W].contiguous() x = x.view(B, L, C) x = res_x + self.drop_path(x) x = x.transpose(1, 2).reshape(B, C, H, W) x = self.local_conv(x) x = x.view(B, C, L).transpose(1, 2) x = x + self.drop_path(self.mlp(x)) return x class BasicStage(BaseModule): """Basic Stage for TinyViT. Args: in_channels (int): The number of input channels. resolution (Tuple[int, int]): The resolution of the input feature. depth (int): The number of blocks in the stage. num_heads (int): The number of heads in the multi-head attention. window_size (int): The size of the window. mlp_ratio (float): The ratio of mlp hidden dim to embedding dim. Default: 4. drop (float): Probability of an element to be zeroed. Default: 0. drop_path (float): The drop path of the block. Default: 0. downsample (None | nn.Module): The downsample operation. Default: None. use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. act_cfg (dict): The activation config of the module. Default: dict(type='GELU'). init_cfg (dict | list[dict], optional): Initialization config dict. Default: None. """ def __init__(self, in_channels, resolution, depth, num_heads, window_size, mlp_ratio=4., drop=0., drop_path=0., downsample=None, use_checkpoint=False, local_conv_size=3, out_channels=None, act_cfg=dict(type='GELU'), init_cfg=None): super().__init__(init_cfg=init_cfg) self.use_checkpoint = use_checkpoint # build blocks self.blocks = ModuleList([ TinyViTBlock( in_channels=in_channels, resolution=resolution, num_heads=num_heads, window_size=window_size, mlp_ratio=mlp_ratio, drop=drop, local_conv_size=local_conv_size, act_cfg=act_cfg, drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path) for i in range(depth) ]) # build patch merging layer if downsample is not None: self.downsample = downsample( resolution=resolution, in_channels=in_channels, out_channels=out_channels, act_cfg=act_cfg) self.resolution = self.downsample.out_resolution else: self.downsample = None self.resolution = resolution def forward(self, x): for block in self.blocks: if self.use_checkpoint: x = checkpoint.checkpoint(block, x) else: x = block(x) if self.downsample is not None: x = self.downsample(x) return x @MODELS.register_module() class TinyViT(BaseBackbone): """TinyViT. A PyTorch implementation of : `TinyViT: Fast Pretraining Distillation for Small Vision Transformers`_ Inspiration from https://github.com/microsoft/Cream/blob/main/TinyViT Args: arch (str | dict): The architecture of TinyViT. Default: '5m'. img_size (tuple | int): The resolution of the input image. Default: (224, 224) window_size (list): The size of the window. Default: [7, 7, 14, 7] in_channels (int): The number of input channels. Default: 3. depths (list[int]): The depth of each stage. Default: [2, 2, 6, 2]. mlp_ratio (list[int]): The ratio of mlp hidden dim to embedding dim. Default: 4. drop_rate (float): Probability of an element to be zeroed. Default: 0. drop_path_rate (float): The drop path of the block. Default: 0.1. use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. mbconv_expand_ratio (int): The expand ratio of the mbconv. Default: 4.0 local_conv_size (int): The size of the local conv. Default: 3. layer_lr_decay (float): The layer lr decay. Default: 1.0 out_indices (int | list[int]): Output from which stages. Default: -1 frozen_stages (int | list[int]): Stages to be frozen (all param fixed). Default: -0 gap_before_final_nrom (bool): Whether to add a gap before the final norm. Default: True. act_cfg (dict): The activation config of the module. Default: dict(type='GELU'). norm_cfg (dict): Config dict for normalization layer. Default: dict(type='LN'). init_cfg (dict | list[dict], optional): Initialization config dict. Default: None. """ arch_settings = { '5m': { 'channels': [64, 128, 160, 320], 'num_heads': [2, 4, 5, 10], 'depths': [2, 2, 6, 2], }, '11m': { 'channels': [64, 128, 256, 448], 'num_heads': [2, 4, 8, 14], 'depths': [2, 2, 6, 2], }, '21m': { 'channels': [96, 192, 384, 576], 'num_heads': [3, 6, 12, 18], 'depths': [2, 2, 6, 2], }, } def __init__(self, arch='5m', img_size=(224, 224), window_size=[7, 7, 14, 7], in_channels=3, mlp_ratio=4., drop_rate=0., drop_path_rate=0.1, use_checkpoint=False, mbconv_expand_ratio=4.0, local_conv_size=3, layer_lr_decay=1.0, out_indices=-1, frozen_stages=0, gap_before_final_norm=True, act_cfg=dict(type='GELU'), norm_cfg=dict(type='LN'), init_cfg=None): super().__init__(init_cfg=init_cfg) if isinstance(arch, str): assert arch in self.arch_settings, \ f'Unavaiable arch, please choose from ' \ f'({set(self.arch_settings)} or pass a dict.' arch = self.arch_settings[arch] elif isinstance(arch, dict): assert 'channels' in arch and 'num_heads' in arch and \ 'depths' in arch, 'The arch dict must have' \ f'"channels", "num_heads", "window_sizes" ' \ f'keys, but got {arch.keys()}' self.channels = arch['channels'] self.num_heads = arch['num_heads'] self.widow_sizes = window_size self.img_size = img_size self.depths = arch['depths'] self.num_stages = len(self.channels) if isinstance(out_indices, int): out_indices = [out_indices] assert isinstance(out_indices, Sequence), \ f'"out_indices" must by a sequence or int, ' \ f'get {type(out_indices)} instead.' for i, index in enumerate(out_indices): if index < 0: out_indices[i] = 4 + index assert out_indices[i] >= 0, f'Invalid out_indices {index}' self.out_indices = out_indices self.frozen_stages = frozen_stages self.gap_before_final_norm = gap_before_final_norm self.layer_lr_decay = layer_lr_decay self.patch_embed = PatchEmbed( in_channels=in_channels, embed_dim=self.channels[0], resolution=self.img_size, act_cfg=dict(type='GELU')) patches_resolution = self.patch_embed.patches_resolution # stochastic depth decay rule dpr = [ x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths)) ] # build stages self.stages = ModuleList() for i in range(self.num_stages): depth = self.depths[i] channel = self.channels[i] curr_resolution = (patches_resolution[0] // (2**i), patches_resolution[1] // (2**i)) drop_path = dpr[sum(self.depths[:i]):sum(self.depths[:i + 1])] downsample = PatchMerging if (i < self.num_stages - 1) else None out_channels = self.channels[min(i + 1, self.num_stages - 1)] if i >= 1: stage = BasicStage( in_channels=channel, resolution=curr_resolution, depth=depth, num_heads=self.num_heads[i], window_size=self.widow_sizes[i], mlp_ratio=mlp_ratio, drop=drop_rate, drop_path=drop_path, downsample=downsample, use_checkpoint=use_checkpoint, local_conv_size=local_conv_size, out_channels=out_channels, act_cfg=act_cfg) else: stage = ConvStage( in_channels=channel, resolution=curr_resolution, depth=depth, act_cfg=act_cfg, drop_path=drop_path, downsample=downsample, use_checkpoint=use_checkpoint, out_channels=out_channels, conv_expand_ratio=mbconv_expand_ratio) self.stages.append(stage) # add output norm if i in self.out_indices: norm_layer = build_norm_layer(norm_cfg, out_channels)[1] self.add_module(f'norm{i}', norm_layer) def set_layer_lr_decay(self, layer_lr_decay): # TODO: add layer_lr_decay pass def forward(self, x): outs = [] x = self.patch_embed(x) for i, stage in enumerate(self.stages): x = stage(x) if i in self.out_indices: norm_layer = getattr(self, f'norm{i}') if self.gap_before_final_norm: gap = x.mean(1) outs.append(norm_layer(gap)) else: out = norm_layer(x) # convert the (B,L,C) format into (B,C,H,W) format # which would be better for the downstream tasks. B, L, C = out.shape out = out.view(B, *stage.resolution, C) outs.append(out.permute(0, 3, 1, 2)) return tuple(outs) def _freeze_stages(self): for i in range(self.frozen_stages): stage = self.stages[i] stage.eval() for param in stage.parameters(): param.requires_grad = False def train(self, mode=True): super(TinyViT, self).train(mode) self._freeze_stages()