# Copyright (c) OpenMMLab. All rights reserved. import math import torch import torch.nn as nn from mmcv.cnn import build_norm_layer from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention from mmengine.model import BaseModule, ModuleList from mmengine.model.weight_init import trunc_normal_ from mmcls.registry import MODELS from ..utils import to_2tuple from .base_backbone import BaseBackbone class TransformerBlock(BaseModule): """Implement a transformer block in TnTLayer. Args: embed_dims (int): The feature dimension num_heads (int): Parallel attention heads ffn_ratio (int): A ratio to calculate the hidden_dims in ffn layer. Default: 4 drop_rate (float): Probability of an element to be zeroed after the feed forward layer. Default 0. attn_drop_rate (float): The drop out rate for attention layer. Default 0. drop_path_rate (float): stochastic depth rate. Default 0. num_fcs (int): The number of fully-connected layers for FFNs. Default 2 qkv_bias (bool): Enable bias for qkv if True. Default False act_cfg (dict): The activation config for FFNs. Defaults to GELU. norm_cfg (dict): Config dict for normalization layer. Default layer normalization batch_first (bool): Key, Query and Value are shape of (batch, n, embed_dim) or (n, batch, embed_dim). (batch, n, embed_dim) is common case in CV. Defaults to False init_cfg (dict, optional): Initialization config dict. Defaults to None """ def __init__(self, embed_dims, num_heads, ffn_ratio=4, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., num_fcs=2, qkv_bias=False, act_cfg=dict(type='GELU'), norm_cfg=dict(type='LN'), batch_first=True, init_cfg=None): super(TransformerBlock, self).__init__(init_cfg=init_cfg) self.norm_attn = build_norm_layer(norm_cfg, embed_dims)[1] self.attn = MultiheadAttention( embed_dims=embed_dims, num_heads=num_heads, attn_drop=attn_drop_rate, proj_drop=drop_rate, dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), batch_first=batch_first) self.norm_ffn = build_norm_layer(norm_cfg, embed_dims)[1] self.ffn = FFN( embed_dims=embed_dims, feedforward_channels=embed_dims * ffn_ratio, num_fcs=num_fcs, ffn_drop=drop_rate, dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), act_cfg=act_cfg) if not qkv_bias: self.attn.attn.in_proj_bias = None def forward(self, x): x = self.attn(self.norm_attn(x), identity=x) x = self.ffn(self.norm_ffn(x), identity=x) return x class TnTLayer(BaseModule): """Implement one encoder layer in Transformer in Transformer. Args: num_pixel (int): The pixel number in target patch transformed with a linear projection in inner transformer embed_dims_inner (int): Feature dimension in inner transformer block embed_dims_outer (int): Feature dimension in outer transformer block num_heads_inner (int): Parallel attention heads in inner transformer. num_heads_outer (int): Parallel attention heads in outer transformer. inner_block_cfg (dict): Extra config of inner transformer block. Defaults to empty dict. outer_block_cfg (dict): Extra config of outer transformer block. Defaults to empty dict. norm_cfg (dict): Config dict for normalization layer. Default layer normalization init_cfg (dict, optional): Initialization config dict. Defaults to None """ def __init__(self, num_pixel, embed_dims_inner, embed_dims_outer, num_heads_inner, num_heads_outer, inner_block_cfg=dict(), outer_block_cfg=dict(), norm_cfg=dict(type='LN'), init_cfg=None): super(TnTLayer, self).__init__(init_cfg=init_cfg) self.inner_block = TransformerBlock( embed_dims=embed_dims_inner, num_heads=num_heads_inner, **inner_block_cfg) self.norm_proj = build_norm_layer(norm_cfg, embed_dims_inner)[1] self.projection = nn.Linear( embed_dims_inner * num_pixel, embed_dims_outer, bias=True) self.outer_block = TransformerBlock( embed_dims=embed_dims_outer, num_heads=num_heads_outer, **outer_block_cfg) def forward(self, pixel_embed, patch_embed): pixel_embed = self.inner_block(pixel_embed) B, N, C = patch_embed.size() patch_embed[:, 1:] = patch_embed[:, 1:] + self.projection( self.norm_proj(pixel_embed).reshape(B, N - 1, -1)) patch_embed = self.outer_block(patch_embed) return pixel_embed, patch_embed class PixelEmbed(BaseModule): """Image to Pixel Embedding. Args: img_size (int | tuple): The size of input image patch_size (int): The size of one patch in_channels (int): The num of input channels embed_dims_inner (int): The num of channels of the target patch transformed with a linear projection in inner transformer stride (int): The stride of the conv2d layer. We use a conv2d layer and a unfold layer to implement image to pixel embedding. init_cfg (dict, optional): Initialization config dict """ def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dims_inner=48, stride=4, init_cfg=None): super(PixelEmbed, self).__init__(init_cfg=init_cfg) img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) # patches_resolution property necessary for resizing # positional embedding patches_resolution = [ img_size[0] // patch_size[0], img_size[1] // patch_size[1] ] num_patches = patches_resolution[0] * patches_resolution[1] self.img_size = img_size self.num_patches = num_patches self.embed_dims_inner = embed_dims_inner new_patch_size = [math.ceil(ps / stride) for ps in patch_size] self.new_patch_size = new_patch_size self.proj = nn.Conv2d( in_channels, self.embed_dims_inner, kernel_size=7, padding=3, stride=stride) self.unfold = nn.Unfold( kernel_size=new_patch_size, stride=new_patch_size) def forward(self, x, pixel_pos): B, C, H, W = x.shape assert H == self.img_size[0] and W == self.img_size[1], \ f"Input image size ({H}*{W}) doesn't match model " \ f'({self.img_size[0]}*{self.img_size[1]}).' x = self.proj(x) x = self.unfold(x) x = x.transpose(1, 2).reshape(B * self.num_patches, self.embed_dims_inner, self.new_patch_size[0], self.new_patch_size[1]) x = x + pixel_pos x = x.reshape(B * self.num_patches, self.embed_dims_inner, -1).transpose(1, 2) return x @MODELS.register_module() class TNT(BaseBackbone): """Transformer in Transformer. A PyTorch implement of: `Transformer in Transformer `_ Inspiration from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/tnt.py Args: arch (str | dict): Vision Transformer architecture Default: 'b' img_size (int | tuple): Input image size. Defaults to 224 patch_size (int | tuple): The patch size. Deault to 16 in_channels (int): Number of input channels. Defaults to 3 ffn_ratio (int): A ratio to calculate the hidden_dims in ffn layer. Default: 4 qkv_bias (bool): Enable bias for qkv if True. Default False drop_rate (float): Probability of an element to be zeroed after the feed forward layer. Default 0. attn_drop_rate (float): The drop out rate for attention layer. Default 0. drop_path_rate (float): stochastic depth rate. Default 0. act_cfg (dict): The activation config for FFNs. Defaults to GELU. norm_cfg (dict): Config dict for normalization layer. Default layer normalization first_stride (int): The stride of the conv2d layer. We use a conv2d layer and a unfold layer to implement image to pixel embedding. num_fcs (int): The number of fully-connected layers for FFNs. Default 2 init_cfg (dict, optional): Initialization config dict """ arch_zoo = { **dict.fromkeys( ['s', 'small'], { 'embed_dims_outer': 384, 'embed_dims_inner': 24, 'num_layers': 12, 'num_heads_outer': 6, 'num_heads_inner': 4 }), **dict.fromkeys( ['b', 'base'], { 'embed_dims_outer': 640, 'embed_dims_inner': 40, 'num_layers': 12, 'num_heads_outer': 10, 'num_heads_inner': 4 }) } def __init__(self, arch='b', img_size=224, patch_size=16, in_channels=3, ffn_ratio=4, qkv_bias=False, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., act_cfg=dict(type='GELU'), norm_cfg=dict(type='LN'), first_stride=4, num_fcs=2, init_cfg=[ dict(type='TruncNormal', layer='Linear', std=.02), dict(type='Constant', layer='LayerNorm', val=1., bias=0.) ]): super(TNT, self).__init__(init_cfg=init_cfg) if isinstance(arch, str): arch = arch.lower() assert arch in set(self.arch_zoo), \ f'Arch {arch} is not in default archs {set(self.arch_zoo)}' self.arch_settings = self.arch_zoo[arch] else: essential_keys = { 'embed_dims_outer', 'embed_dims_inner', 'num_layers', 'num_heads_inner', 'num_heads_outer' } assert isinstance(arch, dict) and set(arch) == essential_keys, \ f'Custom arch needs a dict with keys {essential_keys}' self.arch_settings = arch self.embed_dims_inner = self.arch_settings['embed_dims_inner'] self.embed_dims_outer = self.arch_settings['embed_dims_outer'] # embed_dims for consistency with other models self.embed_dims = self.embed_dims_outer self.num_layers = self.arch_settings['num_layers'] self.num_heads_inner = self.arch_settings['num_heads_inner'] self.num_heads_outer = self.arch_settings['num_heads_outer'] self.pixel_embed = PixelEmbed( img_size=img_size, patch_size=patch_size, in_channels=in_channels, embed_dims_inner=self.embed_dims_inner, stride=first_stride) num_patches = self.pixel_embed.num_patches self.num_patches = num_patches new_patch_size = self.pixel_embed.new_patch_size num_pixel = new_patch_size[0] * new_patch_size[1] self.norm1_proj = build_norm_layer(norm_cfg, num_pixel * self.embed_dims_inner)[1] self.projection = nn.Linear(num_pixel * self.embed_dims_inner, self.embed_dims_outer) self.norm2_proj = build_norm_layer(norm_cfg, self.embed_dims_outer)[1] self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims_outer)) self.patch_pos = nn.Parameter( torch.zeros(1, num_patches + 1, self.embed_dims_outer)) self.pixel_pos = nn.Parameter( torch.zeros(1, self.embed_dims_inner, new_patch_size[0], new_patch_size[1])) self.drop_after_pos = nn.Dropout(p=drop_rate) dpr = [ x.item() for x in torch.linspace(0, drop_path_rate, self.num_layers) ] # stochastic depth decay rule self.layers = ModuleList() for i in range(self.num_layers): block_cfg = dict( ffn_ratio=ffn_ratio, drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, drop_path_rate=dpr[i], num_fcs=num_fcs, qkv_bias=qkv_bias, norm_cfg=norm_cfg, batch_first=True) self.layers.append( TnTLayer( num_pixel=num_pixel, embed_dims_inner=self.embed_dims_inner, embed_dims_outer=self.embed_dims_outer, num_heads_inner=self.num_heads_inner, num_heads_outer=self.num_heads_outer, inner_block_cfg=block_cfg, outer_block_cfg=block_cfg, norm_cfg=norm_cfg)) self.norm = build_norm_layer(norm_cfg, self.embed_dims_outer)[1] trunc_normal_(self.cls_token, std=.02) trunc_normal_(self.patch_pos, std=.02) trunc_normal_(self.pixel_pos, std=.02) def forward(self, x): B = x.shape[0] pixel_embed = self.pixel_embed(x, self.pixel_pos) patch_embed = self.norm2_proj( self.projection( self.norm1_proj(pixel_embed.reshape(B, self.num_patches, -1)))) patch_embed = torch.cat( (self.cls_token.expand(B, -1, -1), patch_embed), dim=1) patch_embed = patch_embed + self.patch_pos patch_embed = self.drop_after_pos(patch_embed) for layer in self.layers: pixel_embed, patch_embed = layer(pixel_embed, patch_embed) patch_embed = self.norm(patch_embed) return (patch_embed[:, 0], )