Spaces:
Runtime error
Runtime error
File size: 16,511 Bytes
f549064 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 |
# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy
from typing import Sequence
import numpy as np
import torch
import torch.nn as nn
from mmcv.cnn import build_norm_layer
from mmcv.cnn.bricks.transformer import FFN
from mmengine.model import BaseModule, ModuleList
from mmengine.model.weight_init import trunc_normal_
from mmcls.registry import MODELS
from ..utils import MultiheadAttention, resize_pos_embed, to_2tuple
from .base_backbone import BaseBackbone
class T2TTransformerLayer(BaseModule):
"""Transformer Layer for T2T_ViT.
Comparing with :obj:`TransformerEncoderLayer` in ViT, it supports
different ``input_dims`` and ``embed_dims``.
Args:
embed_dims (int): The feature dimension.
num_heads (int): Parallel attention heads.
feedforward_channels (int): The hidden dimension for FFNs
input_dims (int, optional): The input token dimension.
Defaults to None.
drop_rate (float): Probability of an element to be zeroed
after the feed forward layer. Defaults to 0.
attn_drop_rate (float): The drop out rate for attention output weights.
Defaults to 0.
drop_path_rate (float): Stochastic depth rate. Defaults to 0.
num_fcs (int): The number of fully-connected layers for FFNs.
Defaults to 2.
qkv_bias (bool): enable bias for qkv if True. Defaults to True.
qk_scale (float, optional): Override default qk scale of
``(input_dims // num_heads) ** -0.5`` if set. Defaults to None.
act_cfg (dict): The activation config for FFNs.
Defaluts to ``dict(type='GELU')``.
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='LN')``.
init_cfg (dict, optional): Initialization config dict.
Defaults to None.
Notes:
In general, ``qk_scale`` should be ``head_dims ** -0.5``, i.e.
``(embed_dims // num_heads) ** -0.5``. However, in the official
code, it uses ``(input_dims // num_heads) ** -0.5``, so here we
keep the same with the official implementation.
"""
def __init__(self,
embed_dims,
num_heads,
feedforward_channels,
input_dims=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
num_fcs=2,
qkv_bias=False,
qk_scale=None,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='LN'),
init_cfg=None):
super(T2TTransformerLayer, self).__init__(init_cfg=init_cfg)
self.v_shortcut = True if input_dims is not None else False
input_dims = input_dims or embed_dims
self.norm1_name, norm1 = build_norm_layer(
norm_cfg, input_dims, postfix=1)
self.add_module(self.norm1_name, norm1)
self.attn = MultiheadAttention(
input_dims=input_dims,
embed_dims=embed_dims,
num_heads=num_heads,
attn_drop=attn_drop_rate,
proj_drop=drop_rate,
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
qkv_bias=qkv_bias,
qk_scale=qk_scale or (input_dims // num_heads)**-0.5,
v_shortcut=self.v_shortcut)
self.norm2_name, norm2 = build_norm_layer(
norm_cfg, embed_dims, postfix=2)
self.add_module(self.norm2_name, norm2)
self.ffn = FFN(
embed_dims=embed_dims,
feedforward_channels=feedforward_channels,
num_fcs=num_fcs,
ffn_drop=drop_rate,
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
act_cfg=act_cfg)
@property
def norm1(self):
return getattr(self, self.norm1_name)
@property
def norm2(self):
return getattr(self, self.norm2_name)
def forward(self, x):
if self.v_shortcut:
x = self.attn(self.norm1(x))
else:
x = x + self.attn(self.norm1(x))
x = self.ffn(self.norm2(x), identity=x)
return x
class T2TModule(BaseModule):
"""Tokens-to-Token module.
"Tokens-to-Token module" (T2T Module) can model the local structure
information of images and reduce the length of tokens progressively.
Args:
img_size (int): Input image size
in_channels (int): Number of input channels
embed_dims (int): Embedding dimension
token_dims (int): Tokens dimension in T2TModuleAttention.
use_performer (bool): If True, use Performer version self-attention to
adopt regular self-attention. Defaults to False.
init_cfg (dict, optional): The extra config for initialization.
Default: None.
Notes:
Usually, ``token_dim`` is set as a small value (32 or 64) to reduce
MACs
"""
def __init__(
self,
img_size=224,
in_channels=3,
embed_dims=384,
token_dims=64,
use_performer=False,
init_cfg=None,
):
super(T2TModule, self).__init__(init_cfg)
self.embed_dims = embed_dims
self.soft_split0 = nn.Unfold(
kernel_size=(7, 7), stride=(4, 4), padding=(2, 2))
self.soft_split1 = nn.Unfold(
kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
self.soft_split2 = nn.Unfold(
kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
if not use_performer:
self.attention1 = T2TTransformerLayer(
input_dims=in_channels * 7 * 7,
embed_dims=token_dims,
num_heads=1,
feedforward_channels=token_dims)
self.attention2 = T2TTransformerLayer(
input_dims=token_dims * 3 * 3,
embed_dims=token_dims,
num_heads=1,
feedforward_channels=token_dims)
self.project = nn.Linear(token_dims * 3 * 3, embed_dims)
else:
raise NotImplementedError("Performer hasn't been implemented.")
# there are 3 soft split, stride are 4,2,2 separately
out_side = img_size // (4 * 2 * 2)
self.init_out_size = [out_side, out_side]
self.num_patches = out_side**2
@staticmethod
def _get_unfold_size(unfold: nn.Unfold, input_size):
h, w = input_size
kernel_size = to_2tuple(unfold.kernel_size)
stride = to_2tuple(unfold.stride)
padding = to_2tuple(unfold.padding)
dilation = to_2tuple(unfold.dilation)
h_out = (h + 2 * padding[0] - dilation[0] *
(kernel_size[0] - 1) - 1) // stride[0] + 1
w_out = (w + 2 * padding[1] - dilation[1] *
(kernel_size[1] - 1) - 1) // stride[1] + 1
return (h_out, w_out)
def forward(self, x):
# step0: soft split
hw_shape = self._get_unfold_size(self.soft_split0, x.shape[2:])
x = self.soft_split0(x).transpose(1, 2)
for step in [1, 2]:
# re-structurization/reconstruction
attn = getattr(self, f'attention{step}')
x = attn(x).transpose(1, 2)
B, C, _ = x.shape
x = x.reshape(B, C, hw_shape[0], hw_shape[1])
# soft split
soft_split = getattr(self, f'soft_split{step}')
hw_shape = self._get_unfold_size(soft_split, hw_shape)
x = soft_split(x).transpose(1, 2)
# final tokens
x = self.project(x)
return x, hw_shape
def get_sinusoid_encoding(n_position, embed_dims):
"""Generate sinusoid encoding table.
Sinusoid encoding is a kind of relative position encoding method came from
`Attention Is All You Need<https://arxiv.org/abs/1706.03762>`_.
Args:
n_position (int): The length of the input token.
embed_dims (int): The position embedding dimension.
Returns:
:obj:`torch.FloatTensor`: The sinusoid encoding table.
"""
def get_position_angle_vec(position):
return [
position / np.power(10000, 2 * (i // 2) / embed_dims)
for i in range(embed_dims)
]
sinusoid_table = np.array(
[get_position_angle_vec(pos) for pos in range(n_position)])
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
return torch.FloatTensor(sinusoid_table).unsqueeze(0)
@MODELS.register_module()
class T2T_ViT(BaseBackbone):
"""Tokens-to-Token Vision Transformer (T2T-ViT)
A PyTorch implementation of `Tokens-to-Token ViT: Training Vision
Transformers from Scratch on ImageNet <https://arxiv.org/abs/2101.11986>`_
Args:
img_size (int | tuple): The expected input image shape. Because we
support dynamic input shape, just set the argument to the most
common input image shape. Defaults to 224.
in_channels (int): Number of input channels.
embed_dims (int): Embedding dimension.
num_layers (int): Num of transformer layers in encoder.
Defaults to 14.
out_indices (Sequence | int): Output from which stages.
Defaults to -1, means the last stage.
drop_rate (float): Dropout rate after position embedding.
Defaults to 0.
drop_path_rate (float): stochastic depth rate. Defaults to 0.
norm_cfg (dict): Config dict for normalization layer. Defaults to
``dict(type='LN')``.
final_norm (bool): Whether to add a additional layer to normalize
final feature map. Defaults to True.
with_cls_token (bool): Whether concatenating class token into image
tokens as transformer input. Defaults to True.
output_cls_token (bool): Whether output the cls_token. If set True,
``with_cls_token`` must be True. Defaults to True.
interpolate_mode (str): Select the interpolate mode for position
embeding vector resize. Defaults to "bicubic".
t2t_cfg (dict): Extra config of Tokens-to-Token module.
Defaults to an empty dict.
layer_cfgs (Sequence | dict): Configs of each transformer layer in
encoder. Defaults to an empty dict.
init_cfg (dict, optional): The Config for initialization.
Defaults to None.
"""
num_extra_tokens = 1 # cls_token
def __init__(self,
img_size=224,
in_channels=3,
embed_dims=384,
num_layers=14,
out_indices=-1,
drop_rate=0.,
drop_path_rate=0.,
norm_cfg=dict(type='LN'),
final_norm=True,
with_cls_token=True,
output_cls_token=True,
interpolate_mode='bicubic',
t2t_cfg=dict(),
layer_cfgs=dict(),
init_cfg=None):
super(T2T_ViT, self).__init__(init_cfg)
# Token-to-Token Module
self.tokens_to_token = T2TModule(
img_size=img_size,
in_channels=in_channels,
embed_dims=embed_dims,
**t2t_cfg)
self.patch_resolution = self.tokens_to_token.init_out_size
num_patches = self.patch_resolution[0] * self.patch_resolution[1]
# Set cls token
if output_cls_token:
assert with_cls_token is True, f'with_cls_token must be True if' \
f'set output_cls_token to True, but got {with_cls_token}'
self.with_cls_token = with_cls_token
self.output_cls_token = output_cls_token
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims))
# Set position embedding
self.interpolate_mode = interpolate_mode
sinusoid_table = get_sinusoid_encoding(
num_patches + self.num_extra_tokens, embed_dims)
self.register_buffer('pos_embed', sinusoid_table)
self._register_load_state_dict_pre_hook(self._prepare_pos_embed)
self.drop_after_pos = nn.Dropout(p=drop_rate)
if isinstance(out_indices, int):
out_indices = [out_indices]
assert isinstance(out_indices, Sequence), \
f'"out_indices" must be a sequence or int, ' \
f'get {type(out_indices)} instead.'
for i, index in enumerate(out_indices):
if index < 0:
out_indices[i] = num_layers + index
assert 0 <= out_indices[i] <= num_layers, \
f'Invalid out_indices {index}'
self.out_indices = out_indices
# stochastic depth decay rule
dpr = [x for x in np.linspace(0, drop_path_rate, num_layers)]
self.encoder = ModuleList()
for i in range(num_layers):
if isinstance(layer_cfgs, Sequence):
layer_cfg = layer_cfgs[i]
else:
layer_cfg = deepcopy(layer_cfgs)
layer_cfg = {
'embed_dims': embed_dims,
'num_heads': 6,
'feedforward_channels': 3 * embed_dims,
'drop_path_rate': dpr[i],
'qkv_bias': False,
'norm_cfg': norm_cfg,
**layer_cfg
}
layer = T2TTransformerLayer(**layer_cfg)
self.encoder.append(layer)
self.final_norm = final_norm
if final_norm:
self.norm = build_norm_layer(norm_cfg, embed_dims)[1]
else:
self.norm = nn.Identity()
def init_weights(self):
super().init_weights()
if (isinstance(self.init_cfg, dict)
and self.init_cfg['type'] == 'Pretrained'):
# Suppress custom init if use pretrained model.
return
trunc_normal_(self.cls_token, std=.02)
def _prepare_pos_embed(self, state_dict, prefix, *args, **kwargs):
name = prefix + 'pos_embed'
if name not in state_dict.keys():
return
ckpt_pos_embed_shape = state_dict[name].shape
if self.pos_embed.shape != ckpt_pos_embed_shape:
from mmengine.logging import MMLogger
logger = MMLogger.get_current_instance()
logger.info(
f'Resize the pos_embed shape from {ckpt_pos_embed_shape} '
f'to {self.pos_embed.shape}.')
ckpt_pos_embed_shape = to_2tuple(
int(np.sqrt(ckpt_pos_embed_shape[1] - self.num_extra_tokens)))
pos_embed_shape = self.tokens_to_token.init_out_size
state_dict[name] = resize_pos_embed(state_dict[name],
ckpt_pos_embed_shape,
pos_embed_shape,
self.interpolate_mode,
self.num_extra_tokens)
def forward(self, x):
B = x.shape[0]
x, patch_resolution = self.tokens_to_token(x)
# stole cls_tokens impl from Phil Wang, thanks
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
x = x + resize_pos_embed(
self.pos_embed,
self.patch_resolution,
patch_resolution,
mode=self.interpolate_mode,
num_extra_tokens=self.num_extra_tokens)
x = self.drop_after_pos(x)
if not self.with_cls_token:
# Remove class token for transformer encoder input
x = x[:, 1:]
outs = []
for i, layer in enumerate(self.encoder):
x = layer(x)
if i == len(self.encoder) - 1 and self.final_norm:
x = self.norm(x)
if i in self.out_indices:
B, _, C = x.shape
if self.with_cls_token:
patch_token = x[:, 1:].reshape(B, *patch_resolution, C)
patch_token = patch_token.permute(0, 3, 1, 2)
cls_token = x[:, 0]
else:
patch_token = x.reshape(B, *patch_resolution, C)
patch_token = patch_token.permute(0, 3, 1, 2)
cls_token = None
if self.output_cls_token:
out = [patch_token, cls_token]
else:
out = patch_token
outs.append(out)
return tuple(outs)
|