Spaces:
Runtime error
Runtime error
File size: 17,051 Bytes
f549064 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 |
# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import Callable, Optional, Sequence
import torch
import torch.nn.functional as F
from mmcv.cnn import ConvModule, build_norm_layer
from mmengine.registry import MODELS
from torch import nn
from .base_backbone import BaseBackbone
from .mobilenet_v2 import InvertedResidual
from .vision_transformer import TransformerEncoderLayer
class MobileVitBlock(nn.Module):
"""MobileViT block.
According to the paper, the MobileViT block has a local representation.
a transformer-as-convolution layer which consists of a global
representation with unfolding and folding, and a final fusion layer.
Args:
in_channels (int): Number of input image channels.
transformer_dim (int): Number of transformer channels.
ffn_dim (int): Number of ffn channels in transformer block.
out_channels (int): Number of channels in output.
conv_ksize (int): Conv kernel size in local representation
and fusion. Defaults to 3.
conv_cfg (dict, optional): Config dict for convolution layer.
Defaults to None, which means using conv2d.
norm_cfg (dict, optional): Config dict for normalization layer.
Defaults to dict(type='BN').
act_cfg (dict, optional): Config dict for activation layer.
Defaults to dict(type='Swish').
num_transformer_blocks (int): Number of transformer blocks in
a MobileViT block. Defaults to 2.
patch_size (int): Patch size for unfolding and folding.
Defaults to 2.
num_heads (int): Number of heads in global representation.
Defaults to 4.
drop_rate (float): Probability of an element to be zeroed
after the feed forward layer. Defaults to 0.
attn_drop_rate (float): The drop out rate for attention output weights.
Defaults to 0.
drop_path_rate (float): Stochastic depth rate. Defaults to 0.
no_fusion (bool): Whether to remove the fusion layer.
Defaults to False.
transformer_norm_cfg (dict, optional): Config dict for normalization
layer in transformer. Defaults to dict(type='LN').
"""
def __init__(
self,
in_channels: int,
transformer_dim: int,
ffn_dim: int,
out_channels: int,
conv_ksize: int = 3,
conv_cfg: Optional[dict] = None,
norm_cfg: Optional[dict] = dict(type='BN'),
act_cfg: Optional[dict] = dict(type='Swish'),
num_transformer_blocks: int = 2,
patch_size: int = 2,
num_heads: int = 4,
drop_rate: float = 0.,
attn_drop_rate: float = 0.,
drop_path_rate: float = 0.,
no_fusion: bool = False,
transformer_norm_cfg: Callable = dict(type='LN'),
):
super(MobileVitBlock, self).__init__()
self.local_rep = nn.Sequential(
ConvModule(
in_channels=in_channels,
out_channels=in_channels,
kernel_size=conv_ksize,
padding=int((conv_ksize - 1) / 2),
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg),
ConvModule(
in_channels=in_channels,
out_channels=transformer_dim,
kernel_size=1,
bias=False,
conv_cfg=conv_cfg,
norm_cfg=None,
act_cfg=None),
)
global_rep = [
TransformerEncoderLayer(
embed_dims=transformer_dim,
num_heads=num_heads,
feedforward_channels=ffn_dim,
drop_rate=drop_rate,
attn_drop_rate=attn_drop_rate,
drop_path_rate=drop_path_rate,
qkv_bias=True,
act_cfg=dict(type='Swish'),
norm_cfg=transformer_norm_cfg)
for _ in range(num_transformer_blocks)
]
global_rep.append(
build_norm_layer(transformer_norm_cfg, transformer_dim)[1])
self.global_rep = nn.Sequential(*global_rep)
self.conv_proj = ConvModule(
in_channels=transformer_dim,
out_channels=out_channels,
kernel_size=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
if no_fusion:
self.conv_fusion = None
else:
self.conv_fusion = ConvModule(
in_channels=in_channels + out_channels,
out_channels=out_channels,
kernel_size=conv_ksize,
padding=int((conv_ksize - 1) / 2),
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
self.patch_size = (patch_size, patch_size)
self.patch_area = self.patch_size[0] * self.patch_size[1]
def forward(self, x: torch.Tensor) -> torch.Tensor:
shortcut = x
# Local representation
x = self.local_rep(x)
# Unfold (feature map -> patches)
patch_h, patch_w = self.patch_size
B, C, H, W = x.shape
new_h, new_w = math.ceil(H / patch_h) * patch_h, math.ceil(
W / patch_w) * patch_w
num_patch_h, num_patch_w = new_h // patch_h, new_w // patch_w # n_h, n_w # noqa
num_patches = num_patch_h * num_patch_w # N
interpolate = False
if new_h != H or new_w != W:
# Note: Padding can be done, but then it needs to be handled in attention function. # noqa
x = F.interpolate(
x, size=(new_h, new_w), mode='bilinear', align_corners=False)
interpolate = True
# [B, C, H, W] --> [B * C * n_h, n_w, p_h, p_w]
x = x.reshape(B * C * num_patch_h, patch_h, num_patch_w,
patch_w).transpose(1, 2)
# [B * C * n_h, n_w, p_h, p_w] --> [BP, N, C] where P = p_h * p_w and N = n_h * n_w # noqa
x = x.reshape(B, C, num_patches,
self.patch_area).transpose(1, 3).reshape(
B * self.patch_area, num_patches, -1)
# Global representations
x = self.global_rep(x)
# Fold (patch -> feature map)
# [B, P, N, C] --> [B*C*n_h, n_w, p_h, p_w]
x = x.contiguous().view(B, self.patch_area, num_patches, -1)
x = x.transpose(1, 3).reshape(B * C * num_patch_h, num_patch_w,
patch_h, patch_w)
# [B*C*n_h, n_w, p_h, p_w] --> [B*C*n_h, p_h, n_w, p_w] --> [B, C, H, W] # noqa
x = x.transpose(1, 2).reshape(B, C, num_patch_h * patch_h,
num_patch_w * patch_w)
if interpolate:
x = F.interpolate(
x, size=(H, W), mode='bilinear', align_corners=False)
x = self.conv_proj(x)
if self.conv_fusion is not None:
x = self.conv_fusion(torch.cat((shortcut, x), dim=1))
return x
@MODELS.register_module()
class MobileViT(BaseBackbone):
"""MobileViT backbone.
A PyTorch implementation of : `MobileViT: Light-weight, General-purpose,
and Mobile-friendly Vision Transformer <https://arxiv.org/pdf/2110.02178.pdf>`_
Modified from the `official repo
<https://github.com/apple/ml-cvnets/blob/main/cvnets/models/classification/mobilevit.py>`_
and `timm
<https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/mobilevit.py>`_.
Args:
arch (str | List[list]): Architecture of MobileViT.
- If a string, choose from "small", "x_small" and "xx_small".
- If a list, every item should be also a list, and the first item
of the sub-list can be chosen from "moblienetv2" and "mobilevit",
which indicates the type of this layer sequence. If "mobilenetv2",
the other items are the arguments of :attr:`~MobileViT.make_mobilenetv2_layer`
(except ``in_channels``) and if "mobilevit", the other items are
the arguments of :attr:`~MobileViT.make_mobilevit_layer`
(except ``in_channels``).
Defaults to "small".
in_channels (int): Number of input image channels. Defaults to 3.
stem_channels (int): Channels of stem layer. Defaults to 16.
last_exp_factor (int): Channels expand factor of last layer.
Defaults to 4.
out_indices (Sequence[int]): Output from which stages.
Defaults to (4, ).
frozen_stages (int): Stages to be frozen (all param fixed).
Defaults to -1, which means not freezing any parameters.
conv_cfg (dict, optional): Config dict for convolution layer.
Defaults to None, which means using conv2d.
norm_cfg (dict, optional): Config dict for normalization layer.
Defaults to dict(type='BN').
act_cfg (dict, optional): Config dict for activation layer.
Defaults to dict(type='Swish').
init_cfg (dict, optional): Initialization config dict.
""" # noqa
# Parameters to build layers. The first param is the type of layer.
# For `mobilenetv2` layer, the rest params from left to right are:
# out channels, stride, num of blocks, expand_ratio.
# For `mobilevit` layer, the rest params from left to right are:
# out channels, stride, transformer_channels, ffn channels,
# num of transformer blocks, expand_ratio.
arch_settings = {
'small': [
['mobilenetv2', 32, 1, 1, 4],
['mobilenetv2', 64, 2, 3, 4],
['mobilevit', 96, 2, 144, 288, 2, 4],
['mobilevit', 128, 2, 192, 384, 4, 4],
['mobilevit', 160, 2, 240, 480, 3, 4],
],
'x_small': [
['mobilenetv2', 32, 1, 1, 4],
['mobilenetv2', 48, 2, 3, 4],
['mobilevit', 64, 2, 96, 192, 2, 4],
['mobilevit', 80, 2, 120, 240, 4, 4],
['mobilevit', 96, 2, 144, 288, 3, 4],
],
'xx_small': [
['mobilenetv2', 16, 1, 1, 2],
['mobilenetv2', 24, 2, 3, 2],
['mobilevit', 48, 2, 64, 128, 2, 2],
['mobilevit', 64, 2, 80, 160, 4, 2],
['mobilevit', 80, 2, 96, 192, 3, 2],
]
}
def __init__(self,
arch='small',
in_channels=3,
stem_channels=16,
last_exp_factor=4,
out_indices=(4, ),
frozen_stages=-1,
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='Swish'),
init_cfg=[
dict(type='Kaiming', layer=['Conv2d']),
dict(
type='Constant',
val=1,
layer=['_BatchNorm', 'GroupNorm'])
]):
super(MobileViT, self).__init__(init_cfg)
if isinstance(arch, str):
arch = arch.lower()
assert arch in self.arch_settings, \
f'Unavailable arch, please choose from ' \
f'({set(self.arch_settings)}) or pass a list.'
arch = self.arch_settings[arch]
self.arch = arch
self.num_stages = len(arch)
# check out indices and frozen stages
if isinstance(out_indices, int):
out_indices = [out_indices]
assert isinstance(out_indices, Sequence), \
f'"out_indices" must by a sequence or int, ' \
f'get {type(out_indices)} instead.'
for i, index in enumerate(out_indices):
if index < 0:
out_indices[i] = self.num_stages + index
assert out_indices[i] >= 0, f'Invalid out_indices {index}'
self.out_indices = out_indices
if frozen_stages not in range(-1, self.num_stages):
raise ValueError('frozen_stages must be in range(-1, '
f'{self.num_stages}). '
f'But received {frozen_stages}')
self.frozen_stages = frozen_stages
_make_layer_func = {
'mobilenetv2': self.make_mobilenetv2_layer,
'mobilevit': self.make_mobilevit_layer,
}
self.stem = ConvModule(
in_channels=in_channels,
out_channels=stem_channels,
kernel_size=3,
stride=2,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
in_channels = stem_channels
layers = []
for i, layer_settings in enumerate(arch):
layer_type, settings = layer_settings[0], layer_settings[1:]
layer, out_channels = _make_layer_func[layer_type](in_channels,
*settings)
layers.append(layer)
in_channels = out_channels
self.layers = nn.Sequential(*layers)
self.conv_1x1_exp = ConvModule(
in_channels=in_channels,
out_channels=last_exp_factor * in_channels,
kernel_size=1,
stride=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
@staticmethod
def make_mobilevit_layer(in_channels,
out_channels,
stride,
transformer_dim,
ffn_dim,
num_transformer_blocks,
expand_ratio=4):
"""Build mobilevit layer, which consists of one InvertedResidual and
one MobileVitBlock.
Args:
in_channels (int): The input channels.
out_channels (int): The output channels.
stride (int): The stride of the first 3x3 convolution in the
``InvertedResidual`` layers.
transformer_dim (int): The channels of the transformer layers.
ffn_dim (int): The mid-channels of the feedforward network in
transformer layers.
num_transformer_blocks (int): The number of transformer blocks.
expand_ratio (int): adjusts number of channels of the hidden layer
in ``InvertedResidual`` by this amount. Defaults to 4.
"""
layer = []
layer.append(
InvertedResidual(
in_channels=in_channels,
out_channels=out_channels,
stride=stride,
expand_ratio=expand_ratio,
act_cfg=dict(type='Swish'),
))
layer.append(
MobileVitBlock(
in_channels=out_channels,
transformer_dim=transformer_dim,
ffn_dim=ffn_dim,
out_channels=out_channels,
num_transformer_blocks=num_transformer_blocks,
))
return nn.Sequential(*layer), out_channels
@staticmethod
def make_mobilenetv2_layer(in_channels,
out_channels,
stride,
num_blocks,
expand_ratio=4):
"""Build mobilenetv2 layer, which consists of several InvertedResidual
layers.
Args:
in_channels (int): The input channels.
out_channels (int): The output channels.
stride (int): The stride of the first 3x3 convolution in the
``InvertedResidual`` layers.
num_blocks (int): The number of ``InvertedResidual`` blocks.
expand_ratio (int): adjusts number of channels of the hidden layer
in ``InvertedResidual`` by this amount. Defaults to 4.
"""
layer = []
for i in range(num_blocks):
stride = stride if i == 0 else 1
layer.append(
InvertedResidual(
in_channels=in_channels,
out_channels=out_channels,
stride=stride,
expand_ratio=expand_ratio,
act_cfg=dict(type='Swish'),
))
in_channels = out_channels
return nn.Sequential(*layer), out_channels
def _freeze_stages(self):
for i in range(0, self.frozen_stages):
layer = self.layers[i]
layer.eval()
for param in layer.parameters():
param.requires_grad = False
def train(self, mode=True):
super(MobileViT, self).train(mode)
self._freeze_stages()
def forward(self, x):
x = self.stem(x)
outs = []
for i, layer in enumerate(self.layers):
x = layer(x)
if i == len(self.layers) - 1:
x = self.conv_1x1_exp(x)
if i in self.out_indices:
outs.append(x)
return tuple(outs)
|