Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import torch | |
import torch.nn as nn | |
from mmengine.model.weight_init import trunc_normal_ | |
from mmcls.registry import MODELS | |
from .vision_transformer import VisionTransformer | |
class DistilledVisionTransformer(VisionTransformer): | |
"""Distilled Vision Transformer. | |
A PyTorch implement of : `Training data-efficient image transformers & | |
distillation through attention <https://arxiv.org/abs/2012.12877>`_ | |
Args: | |
arch (str | dict): Vision Transformer architecture. If use string, | |
choose from 'small', 'base', 'large', 'deit-tiny', 'deit-small' | |
and 'deit-base'. If use dict, it should have below keys: | |
- **embed_dims** (int): The dimensions of embedding. | |
- **num_layers** (int): The number of transformer encoder layers. | |
- **num_heads** (int): The number of heads in attention modules. | |
- **feedforward_channels** (int): The hidden dimensions in | |
feedforward modules. | |
Defaults to 'deit-base'. | |
img_size (int | tuple): The expected input image shape. Because we | |
support dynamic input shape, just set the argument to the most | |
common input image shape. Defaults to 224. | |
patch_size (int | tuple): The patch size in patch embedding. | |
Defaults to 16. | |
in_channels (int): The num of input channels. Defaults to 3. | |
out_indices (Sequence | int): Output from which stages. | |
Defaults to -1, means the last stage. | |
drop_rate (float): Probability of an element to be zeroed. | |
Defaults to 0. | |
drop_path_rate (float): stochastic depth rate. Defaults to 0. | |
qkv_bias (bool): Whether to add bias for qkv in attention modules. | |
Defaults to True. | |
norm_cfg (dict): Config dict for normalization layer. | |
Defaults to ``dict(type='LN')``. | |
final_norm (bool): Whether to add a additional layer to normalize | |
final feature map. Defaults to True. | |
with_cls_token (bool): Whether concatenating class token into image | |
tokens as transformer input. Defaults to True. | |
output_cls_token (bool): Whether output the cls_token. If set True, | |
``with_cls_token`` must be True. Defaults to True. | |
interpolate_mode (str): Select the interpolate mode for position | |
embeding vector resize. Defaults to "bicubic". | |
patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict. | |
layer_cfgs (Sequence | dict): Configs of each transformer layer in | |
encoder. Defaults to an empty dict. | |
init_cfg (dict, optional): Initialization config dict. | |
Defaults to None. | |
""" | |
num_extra_tokens = 2 # cls_token, dist_token | |
def __init__(self, arch='deit-base', *args, **kwargs): | |
super(DistilledVisionTransformer, self).__init__( | |
arch=arch, *args, **kwargs) | |
self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims)) | |
def forward(self, x): | |
B = x.shape[0] | |
x, patch_resolution = self.patch_embed(x) | |
# stole cls_tokens impl from Phil Wang, thanks | |
cls_tokens = self.cls_token.expand(B, -1, -1) | |
dist_token = self.dist_token.expand(B, -1, -1) | |
x = torch.cat((cls_tokens, dist_token, x), dim=1) | |
x = x + self.resize_pos_embed( | |
self.pos_embed, | |
self.patch_resolution, | |
patch_resolution, | |
mode=self.interpolate_mode, | |
num_extra_tokens=self.num_extra_tokens) | |
x = self.drop_after_pos(x) | |
if not self.with_cls_token: | |
# Remove class token for transformer encoder input | |
x = x[:, 2:] | |
outs = [] | |
for i, layer in enumerate(self.layers): | |
x = layer(x) | |
if i == len(self.layers) - 1 and self.final_norm: | |
x = self.norm1(x) | |
if i in self.out_indices: | |
B, _, C = x.shape | |
if self.with_cls_token: | |
patch_token = x[:, 2:].reshape(B, *patch_resolution, C) | |
patch_token = patch_token.permute(0, 3, 1, 2) | |
cls_token = x[:, 0] | |
dist_token = x[:, 1] | |
else: | |
patch_token = x.reshape(B, *patch_resolution, C) | |
patch_token = patch_token.permute(0, 3, 1, 2) | |
cls_token = None | |
dist_token = None | |
if self.output_cls_token: | |
out = [patch_token, cls_token, dist_token] | |
else: | |
out = patch_token | |
outs.append(out) | |
return tuple(outs) | |
def init_weights(self): | |
super(DistilledVisionTransformer, self).init_weights() | |
if not (isinstance(self.init_cfg, dict) | |
and self.init_cfg['type'] == 'Pretrained'): | |
trunc_normal_(self.dist_token, std=0.02) | |