# Copyright (c) OpenMMLab. All rights reserved. import torch import torch.nn as nn from mmengine.model.weight_init import trunc_normal_ from mmcls.registry import MODELS from .vision_transformer import VisionTransformer @MODELS.register_module() class DistilledVisionTransformer(VisionTransformer): """Distilled Vision Transformer. A PyTorch implement of : `Training data-efficient image transformers & distillation through attention `_ Args: arch (str | dict): Vision Transformer architecture. If use string, choose from 'small', 'base', 'large', 'deit-tiny', 'deit-small' and 'deit-base'. If use dict, it should have below keys: - **embed_dims** (int): The dimensions of embedding. - **num_layers** (int): The number of transformer encoder layers. - **num_heads** (int): The number of heads in attention modules. - **feedforward_channels** (int): The hidden dimensions in feedforward modules. Defaults to 'deit-base'. img_size (int | tuple): The expected input image shape. Because we support dynamic input shape, just set the argument to the most common input image shape. Defaults to 224. patch_size (int | tuple): The patch size in patch embedding. Defaults to 16. in_channels (int): The num of input channels. Defaults to 3. out_indices (Sequence | int): Output from which stages. Defaults to -1, means the last stage. drop_rate (float): Probability of an element to be zeroed. Defaults to 0. drop_path_rate (float): stochastic depth rate. Defaults to 0. qkv_bias (bool): Whether to add bias for qkv in attention modules. Defaults to True. norm_cfg (dict): Config dict for normalization layer. Defaults to ``dict(type='LN')``. final_norm (bool): Whether to add a additional layer to normalize final feature map. Defaults to True. with_cls_token (bool): Whether concatenating class token into image tokens as transformer input. Defaults to True. output_cls_token (bool): Whether output the cls_token. If set True, ``with_cls_token`` must be True. Defaults to True. interpolate_mode (str): Select the interpolate mode for position embeding vector resize. Defaults to "bicubic". patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict. layer_cfgs (Sequence | dict): Configs of each transformer layer in encoder. Defaults to an empty dict. init_cfg (dict, optional): Initialization config dict. Defaults to None. """ num_extra_tokens = 2 # cls_token, dist_token def __init__(self, arch='deit-base', *args, **kwargs): super(DistilledVisionTransformer, self).__init__( arch=arch, *args, **kwargs) self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims)) def forward(self, x): B = x.shape[0] x, patch_resolution = self.patch_embed(x) # stole cls_tokens impl from Phil Wang, thanks cls_tokens = self.cls_token.expand(B, -1, -1) dist_token = self.dist_token.expand(B, -1, -1) x = torch.cat((cls_tokens, dist_token, x), dim=1) x = x + self.resize_pos_embed( self.pos_embed, self.patch_resolution, patch_resolution, mode=self.interpolate_mode, num_extra_tokens=self.num_extra_tokens) x = self.drop_after_pos(x) if not self.with_cls_token: # Remove class token for transformer encoder input x = x[:, 2:] outs = [] for i, layer in enumerate(self.layers): x = layer(x) if i == len(self.layers) - 1 and self.final_norm: x = self.norm1(x) if i in self.out_indices: B, _, C = x.shape if self.with_cls_token: patch_token = x[:, 2:].reshape(B, *patch_resolution, C) patch_token = patch_token.permute(0, 3, 1, 2) cls_token = x[:, 0] dist_token = x[:, 1] else: patch_token = x.reshape(B, *patch_resolution, C) patch_token = patch_token.permute(0, 3, 1, 2) cls_token = None dist_token = None if self.output_cls_token: out = [patch_token, cls_token, dist_token] else: out = patch_token outs.append(out) return tuple(outs) def init_weights(self): super(DistilledVisionTransformer, self).init_weights() if not (isinstance(self.init_cfg, dict) and self.init_cfg['type'] == 'Pretrained'): trunc_normal_(self.dist_token, std=0.02)