Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from typing import Sequence, Tuple | |
import torch | |
import torch.nn as nn | |
import torch.utils.checkpoint as checkpoint | |
from mmcv.cnn.bricks import DropPath, build_activation_layer, build_norm_layer | |
from mmengine.model import BaseModule, ModuleList, Sequential | |
from mmengine.registry import MODELS | |
from torch.nn import functional as F | |
from ..utils import LeAttention | |
from .base_backbone import BaseBackbone | |
class ConvBN2d(Sequential): | |
"""An implementation of Conv2d + BatchNorm2d with support of fusion. | |
Modified from | |
https://github.com/microsoft/Cream/blob/main/TinyViT/models/tiny_vit.py | |
Args: | |
in_channels (int): The number of input channels. | |
out_channels (int): The number of output channels. | |
kernel_size (int): The size of the convolution kernel. | |
Default: 1. | |
stride (int): The stride of the convolution. | |
Default: 1. | |
padding (int): The padding of the convolution. | |
Default: 0. | |
dilation (int): The dilation of the convolution. | |
Default: 1. | |
groups (int): The number of groups in the convolution. | |
Default: 1. | |
bn_weight_init (float): The initial value of the weight of | |
the nn.BatchNorm2d layer. Default: 1.0. | |
init_cfg (dict): The initialization config of the module. | |
Default: None. | |
""" | |
def __init__(self, | |
in_channels, | |
out_channels, | |
kernel_size=1, | |
stride=1, | |
padding=0, | |
dilation=1, | |
groups=1, | |
bn_weight_init=1.0, | |
init_cfg=None): | |
super().__init__(init_cfg=init_cfg) | |
self.add_module( | |
'conv2d', | |
nn.Conv2d( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_size=kernel_size, | |
stride=stride, | |
padding=padding, | |
dilation=dilation, | |
groups=groups, | |
bias=False)) | |
bn2d = nn.BatchNorm2d(num_features=out_channels) | |
# bn initialization | |
torch.nn.init.constant_(bn2d.weight, bn_weight_init) | |
torch.nn.init.constant_(bn2d.bias, 0) | |
self.add_module('bn2d', bn2d) | |
def fuse(self): | |
conv2d, bn2d = self._modules.values() | |
w = bn2d.weight / (bn2d.running_var + bn2d.eps)**0.5 | |
w = conv2d.weight * w[:, None, None, None] | |
b = bn2d.bias - bn2d.running_mean * bn2d.weight / \ | |
(bn2d.running_var + bn2d.eps)**0.5 | |
m = nn.Conv2d( | |
in_channels=w.size(1) * self.c.groups, | |
out_channels=w.size(0), | |
kernel_size=w.shape[2:], | |
stride=self.conv2d.stride, | |
padding=self.conv2d.padding, | |
dilation=self.conv2d.dilation, | |
groups=self.conv2d.groups) | |
m.weight.data.copy_(w) | |
m.bias.data.copy_(b) | |
return m | |
class PatchEmbed(BaseModule): | |
"""Patch Embedding for Vision Transformer. | |
Adapted from | |
https://github.com/microsoft/Cream/blob/main/TinyViT/models/tiny_vit.py | |
Different from `mmcv.cnn.bricks.transformer.PatchEmbed`, this module use | |
Conv2d and BatchNorm2d to implement PatchEmbedding, and output shape is | |
(N, C, H, W). | |
Args: | |
in_channels (int): The number of input channels. | |
embed_dim (int): The embedding dimension. | |
resolution (Tuple[int, int]): The resolution of the input feature. | |
act_cfg (dict): The activation config of the module. | |
Default: dict(type='GELU'). | |
""" | |
def __init__(self, | |
in_channels, | |
embed_dim, | |
resolution, | |
act_cfg=dict(type='GELU')): | |
super().__init__() | |
img_size: Tuple[int, int] = resolution | |
self.patches_resolution = (img_size[0] // 4, img_size[1] // 4) | |
self.num_patches = self.patches_resolution[0] * \ | |
self.patches_resolution[1] | |
self.in_channels = in_channels | |
self.embed_dim = embed_dim | |
self.seq = nn.Sequential( | |
ConvBN2d( | |
in_channels, | |
embed_dim // 2, | |
kernel_size=3, | |
stride=2, | |
padding=1), | |
build_activation_layer(act_cfg), | |
ConvBN2d( | |
embed_dim // 2, embed_dim, kernel_size=3, stride=2, padding=1), | |
) | |
def forward(self, x): | |
return self.seq(x) | |
class PatchMerging(nn.Module): | |
"""Patch Merging for TinyViT. | |
Adapted from | |
https://github.com/microsoft/Cream/blob/main/TinyViT/models/tiny_vit.py | |
Different from `mmcls.models.utils.PatchMerging`, this module use Conv2d | |
and BatchNorm2d to implement PatchMerging. | |
Args: | |
in_channels (int): The number of input channels. | |
resolution (Tuple[int, int]): The resolution of the input feature. | |
out_channels (int): The number of output channels. | |
act_cfg (dict): The activation config of the module. | |
Default: dict(type='GELU'). | |
""" | |
def __init__(self, | |
resolution, | |
in_channels, | |
out_channels, | |
act_cfg=dict(type='GELU')): | |
super().__init__() | |
self.img_size = resolution | |
self.act = build_activation_layer(act_cfg) | |
self.conv1 = ConvBN2d(in_channels, out_channels, kernel_size=1) | |
self.conv2 = ConvBN2d( | |
out_channels, | |
out_channels, | |
kernel_size=3, | |
stride=2, | |
padding=1, | |
groups=out_channels) | |
self.conv3 = ConvBN2d(out_channels, out_channels, kernel_size=1) | |
self.out_resolution = (resolution[0] // 2, resolution[1] // 2) | |
def forward(self, x): | |
if len(x.shape) == 3: | |
H, W = self.img_size | |
B = x.shape[0] | |
x = x.view(B, H, W, -1).permute(0, 3, 1, 2) | |
x = self.conv1(x) | |
x = self.act(x) | |
x = self.conv2(x) | |
x = self.act(x) | |
x = self.conv3(x) | |
x = x.flatten(2).transpose(1, 2) | |
return x | |
class MBConvBlock(nn.Module): | |
"""Mobile Inverted Residual Bottleneck Block for TinyViT. Adapted from | |
https://github.com/microsoft/Cream/blob/main/TinyViT/models/tiny_vit.py. | |
Args: | |
in_channels (int): The number of input channels. | |
out_channels (int): The number of output channels. | |
expand_ratio (int): The expand ratio of the hidden channels. | |
drop_rate (float): The drop rate of the block. | |
act_cfg (dict): The activation config of the module. | |
Default: dict(type='GELU'). | |
""" | |
def __init__(self, | |
in_channels, | |
out_channels, | |
expand_ratio, | |
drop_path, | |
act_cfg=dict(type='GELU')): | |
super().__init__() | |
self.in_channels = in_channels | |
hidden_channels = int(in_channels * expand_ratio) | |
# linear | |
self.conv1 = ConvBN2d(in_channels, hidden_channels, kernel_size=1) | |
self.act = build_activation_layer(act_cfg) | |
# depthwise conv | |
self.conv2 = ConvBN2d( | |
in_channels=hidden_channels, | |
out_channels=hidden_channels, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
groups=hidden_channels) | |
# linear | |
self.conv3 = ConvBN2d( | |
hidden_channels, out_channels, kernel_size=1, bn_weight_init=0.0) | |
self.drop_path = DropPath( | |
drop_path) if drop_path > 0. else nn.Identity() | |
def forward(self, x): | |
shortcut = x | |
x = self.conv1(x) | |
x = self.act(x) | |
x = self.conv2(x) | |
x = self.act(x) | |
x = self.conv3(x) | |
x = self.drop_path(x) | |
x += shortcut | |
x = self.act(x) | |
return x | |
class ConvStage(BaseModule): | |
"""Convolution Stage for TinyViT. | |
Adapted from | |
https://github.com/microsoft/Cream/blob/main/TinyViT/models/tiny_vit.py | |
Args: | |
in_channels (int): The number of input channels. | |
resolution (Tuple[int, int]): The resolution of the input feature. | |
depth (int): The number of blocks in the stage. | |
act_cfg (dict): The activation config of the module. | |
drop_path (float): The drop path of the block. | |
downsample (None | nn.Module): The downsample operation. | |
Default: None. | |
use_checkpoint (bool): Whether to use checkpointing to save memory. | |
out_channels (int): The number of output channels. | |
conv_expand_ratio (int): The expand ratio of the hidden channels. | |
Default: 4. | |
init_cfg (dict | list[dict], optional): Initialization config dict. | |
Default: None. | |
""" | |
def __init__(self, | |
in_channels, | |
resolution, | |
depth, | |
act_cfg, | |
drop_path=0., | |
downsample=None, | |
use_checkpoint=False, | |
out_channels=None, | |
conv_expand_ratio=4., | |
init_cfg=None): | |
super().__init__(init_cfg=init_cfg) | |
self.use_checkpoint = use_checkpoint | |
# build blocks | |
self.blocks = ModuleList([ | |
MBConvBlock( | |
in_channels=in_channels, | |
out_channels=in_channels, | |
expand_ratio=conv_expand_ratio, | |
drop_path=drop_path[i] | |
if isinstance(drop_path, list) else drop_path) | |
for i in range(depth) | |
]) | |
# patch merging layer | |
if downsample is not None: | |
self.downsample = downsample( | |
resolution=resolution, | |
in_channels=in_channels, | |
out_channels=out_channels, | |
act_cfg=act_cfg) | |
self.resolution = self.downsample.out_resolution | |
else: | |
self.downsample = None | |
self.resolution = resolution | |
def forward(self, x): | |
for block in self.blocks: | |
if self.use_checkpoint: | |
x = checkpoint.checkpoint(block, x) | |
else: | |
x = block(x) | |
if self.downsample is not None: | |
x = self.downsample(x) | |
return x | |
class MLP(BaseModule): | |
"""MLP module for TinyViT. | |
Args: | |
in_channels (int): The number of input channels. | |
hidden_channels (int, optional): The number of hidden channels. | |
Default: None. | |
out_channels (int, optional): The number of output channels. | |
Default: None. | |
act_cfg (dict): The activation config of the module. | |
Default: dict(type='GELU'). | |
drop (float): Probability of an element to be zeroed. | |
Default: 0. | |
init_cfg (dict | list[dict], optional): Initialization config dict. | |
Default: None. | |
""" | |
def __init__(self, | |
in_channels, | |
hidden_channels=None, | |
out_channels=None, | |
act_cfg=dict(type='GELU'), | |
drop=0., | |
init_cfg=None): | |
super().__init__(init_cfg=init_cfg) | |
out_channels = out_channels or in_channels | |
hidden_channels = hidden_channels or in_channels | |
self.norm = nn.LayerNorm(in_channels) | |
self.fc1 = nn.Linear(in_channels, hidden_channels) | |
self.fc2 = nn.Linear(hidden_channels, out_channels) | |
self.act = build_activation_layer(act_cfg) | |
self.drop = nn.Dropout(drop) | |
def forward(self, x): | |
x = self.norm(x) | |
x = self.fc1(x) | |
x = self.act(x) | |
x = self.drop(x) | |
x = self.fc2(x) | |
x = self.drop(x) | |
return x | |
class TinyViTBlock(BaseModule): | |
"""TinViT Block. | |
Args: | |
in_channels (int): The number of input channels. | |
resolution (Tuple[int, int]): The resolution of the input feature. | |
num_heads (int): The number of heads in the multi-head attention. | |
window_size (int): The size of the window. | |
Default: 7. | |
mlp_ratio (float): The ratio of mlp hidden dim to embedding dim. | |
Default: 4. | |
drop (float): Probability of an element to be zeroed. | |
Default: 0. | |
drop_path (float): The drop path of the block. | |
Default: 0. | |
local_conv_size (int): The size of the local convolution. | |
Default: 3. | |
act_cfg (dict): The activation config of the module. | |
Default: dict(type='GELU'). | |
""" | |
def __init__(self, | |
in_channels, | |
resolution, | |
num_heads, | |
window_size=7, | |
mlp_ratio=4., | |
drop=0., | |
drop_path=0., | |
local_conv_size=3, | |
act_cfg=dict(type='GELU')): | |
super().__init__() | |
self.in_channels = in_channels | |
self.img_size = resolution | |
self.num_heads = num_heads | |
assert window_size > 0, 'window_size must be greater than 0' | |
self.window_size = window_size | |
self.mlp_ratio = mlp_ratio | |
self.drop_path = DropPath( | |
drop_path) if drop_path > 0. else nn.Identity() | |
assert in_channels % num_heads == 0, \ | |
'dim must be divisible by num_heads' | |
head_dim = in_channels // num_heads | |
window_resolution = (window_size, window_size) | |
self.attn = LeAttention( | |
in_channels, | |
head_dim, | |
num_heads, | |
attn_ratio=1, | |
resolution=window_resolution) | |
mlp_hidden_dim = int(in_channels * mlp_ratio) | |
self.mlp = MLP( | |
in_channels=in_channels, | |
hidden_channels=mlp_hidden_dim, | |
act_cfg=act_cfg, | |
drop=drop) | |
self.local_conv = ConvBN2d( | |
in_channels=in_channels, | |
out_channels=in_channels, | |
kernel_size=local_conv_size, | |
stride=1, | |
padding=local_conv_size // 2, | |
groups=in_channels) | |
def forward(self, x): | |
H, W = self.img_size | |
B, L, C = x.shape | |
assert L == H * W, 'input feature has wrong size' | |
res_x = x | |
if H == self.window_size and W == self.window_size: | |
x = self.attn(x) | |
else: | |
x = x.view(B, H, W, C) | |
pad_b = (self.window_size - | |
H % self.window_size) % self.window_size | |
pad_r = (self.window_size - | |
W % self.window_size) % self.window_size | |
padding = pad_b > 0 or pad_r > 0 | |
if padding: | |
x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b)) | |
pH, pW = H + pad_b, W + pad_r | |
nH = pH // self.window_size | |
nW = pW // self.window_size | |
# window partition | |
x = x.view(B, nH, self.window_size, nW, self.window_size, | |
C).transpose(2, 3).reshape( | |
B * nH * nW, self.window_size * self.window_size, C) | |
x = self.attn(x) | |
# window reverse | |
x = x.view(B, nH, nW, self.window_size, self.window_size, | |
C).transpose(2, 3).reshape(B, pH, pW, C) | |
if padding: | |
x = x[:, :H, :W].contiguous() | |
x = x.view(B, L, C) | |
x = res_x + self.drop_path(x) | |
x = x.transpose(1, 2).reshape(B, C, H, W) | |
x = self.local_conv(x) | |
x = x.view(B, C, L).transpose(1, 2) | |
x = x + self.drop_path(self.mlp(x)) | |
return x | |
class BasicStage(BaseModule): | |
"""Basic Stage for TinyViT. | |
Args: | |
in_channels (int): The number of input channels. | |
resolution (Tuple[int, int]): The resolution of the input feature. | |
depth (int): The number of blocks in the stage. | |
num_heads (int): The number of heads in the multi-head attention. | |
window_size (int): The size of the window. | |
mlp_ratio (float): The ratio of mlp hidden dim to embedding dim. | |
Default: 4. | |
drop (float): Probability of an element to be zeroed. | |
Default: 0. | |
drop_path (float): The drop path of the block. | |
Default: 0. | |
downsample (None | nn.Module): The downsample operation. | |
Default: None. | |
use_checkpoint (bool): Whether to use checkpointing to save memory. | |
Default: False. | |
act_cfg (dict): The activation config of the module. | |
Default: dict(type='GELU'). | |
init_cfg (dict | list[dict], optional): Initialization config dict. | |
Default: None. | |
""" | |
def __init__(self, | |
in_channels, | |
resolution, | |
depth, | |
num_heads, | |
window_size, | |
mlp_ratio=4., | |
drop=0., | |
drop_path=0., | |
downsample=None, | |
use_checkpoint=False, | |
local_conv_size=3, | |
out_channels=None, | |
act_cfg=dict(type='GELU'), | |
init_cfg=None): | |
super().__init__(init_cfg=init_cfg) | |
self.use_checkpoint = use_checkpoint | |
# build blocks | |
self.blocks = ModuleList([ | |
TinyViTBlock( | |
in_channels=in_channels, | |
resolution=resolution, | |
num_heads=num_heads, | |
window_size=window_size, | |
mlp_ratio=mlp_ratio, | |
drop=drop, | |
local_conv_size=local_conv_size, | |
act_cfg=act_cfg, | |
drop_path=drop_path[i] | |
if isinstance(drop_path, list) else drop_path) | |
for i in range(depth) | |
]) | |
# build patch merging layer | |
if downsample is not None: | |
self.downsample = downsample( | |
resolution=resolution, | |
in_channels=in_channels, | |
out_channels=out_channels, | |
act_cfg=act_cfg) | |
self.resolution = self.downsample.out_resolution | |
else: | |
self.downsample = None | |
self.resolution = resolution | |
def forward(self, x): | |
for block in self.blocks: | |
if self.use_checkpoint: | |
x = checkpoint.checkpoint(block, x) | |
else: | |
x = block(x) | |
if self.downsample is not None: | |
x = self.downsample(x) | |
return x | |
class TinyViT(BaseBackbone): | |
"""TinyViT. | |
A PyTorch implementation of : `TinyViT: Fast Pretraining Distillation | |
for Small Vision Transformers<https://arxiv.org/abs/2201.03545v1>`_ | |
Inspiration from | |
https://github.com/microsoft/Cream/blob/main/TinyViT | |
Args: | |
arch (str | dict): The architecture of TinyViT. | |
Default: '5m'. | |
img_size (tuple | int): The resolution of the input image. | |
Default: (224, 224) | |
window_size (list): The size of the window. | |
Default: [7, 7, 14, 7] | |
in_channels (int): The number of input channels. | |
Default: 3. | |
depths (list[int]): The depth of each stage. | |
Default: [2, 2, 6, 2]. | |
mlp_ratio (list[int]): The ratio of mlp hidden dim to embedding dim. | |
Default: 4. | |
drop_rate (float): Probability of an element to be zeroed. | |
Default: 0. | |
drop_path_rate (float): The drop path of the block. | |
Default: 0.1. | |
use_checkpoint (bool): Whether to use checkpointing to save memory. | |
Default: False. | |
mbconv_expand_ratio (int): The expand ratio of the mbconv. | |
Default: 4.0 | |
local_conv_size (int): The size of the local conv. | |
Default: 3. | |
layer_lr_decay (float): The layer lr decay. | |
Default: 1.0 | |
out_indices (int | list[int]): Output from which stages. | |
Default: -1 | |
frozen_stages (int | list[int]): Stages to be frozen (all param fixed). | |
Default: -0 | |
gap_before_final_nrom (bool): Whether to add a gap before the final | |
norm. Default: True. | |
act_cfg (dict): The activation config of the module. | |
Default: dict(type='GELU'). | |
norm_cfg (dict): Config dict for normalization layer. | |
Default: dict(type='LN'). | |
init_cfg (dict | list[dict], optional): Initialization config dict. | |
Default: None. | |
""" | |
arch_settings = { | |
'5m': { | |
'channels': [64, 128, 160, 320], | |
'num_heads': [2, 4, 5, 10], | |
'depths': [2, 2, 6, 2], | |
}, | |
'11m': { | |
'channels': [64, 128, 256, 448], | |
'num_heads': [2, 4, 8, 14], | |
'depths': [2, 2, 6, 2], | |
}, | |
'21m': { | |
'channels': [96, 192, 384, 576], | |
'num_heads': [3, 6, 12, 18], | |
'depths': [2, 2, 6, 2], | |
}, | |
} | |
def __init__(self, | |
arch='5m', | |
img_size=(224, 224), | |
window_size=[7, 7, 14, 7], | |
in_channels=3, | |
mlp_ratio=4., | |
drop_rate=0., | |
drop_path_rate=0.1, | |
use_checkpoint=False, | |
mbconv_expand_ratio=4.0, | |
local_conv_size=3, | |
layer_lr_decay=1.0, | |
out_indices=-1, | |
frozen_stages=0, | |
gap_before_final_norm=True, | |
act_cfg=dict(type='GELU'), | |
norm_cfg=dict(type='LN'), | |
init_cfg=None): | |
super().__init__(init_cfg=init_cfg) | |
if isinstance(arch, str): | |
assert arch in self.arch_settings, \ | |
f'Unavaiable arch, please choose from ' \ | |
f'({set(self.arch_settings)} or pass a dict.' | |
arch = self.arch_settings[arch] | |
elif isinstance(arch, dict): | |
assert 'channels' in arch and 'num_heads' in arch and \ | |
'depths' in arch, 'The arch dict must have' \ | |
f'"channels", "num_heads", "window_sizes" ' \ | |
f'keys, but got {arch.keys()}' | |
self.channels = arch['channels'] | |
self.num_heads = arch['num_heads'] | |
self.widow_sizes = window_size | |
self.img_size = img_size | |
self.depths = arch['depths'] | |
self.num_stages = len(self.channels) | |
if isinstance(out_indices, int): | |
out_indices = [out_indices] | |
assert isinstance(out_indices, Sequence), \ | |
f'"out_indices" must by a sequence or int, ' \ | |
f'get {type(out_indices)} instead.' | |
for i, index in enumerate(out_indices): | |
if index < 0: | |
out_indices[i] = 4 + index | |
assert out_indices[i] >= 0, f'Invalid out_indices {index}' | |
self.out_indices = out_indices | |
self.frozen_stages = frozen_stages | |
self.gap_before_final_norm = gap_before_final_norm | |
self.layer_lr_decay = layer_lr_decay | |
self.patch_embed = PatchEmbed( | |
in_channels=in_channels, | |
embed_dim=self.channels[0], | |
resolution=self.img_size, | |
act_cfg=dict(type='GELU')) | |
patches_resolution = self.patch_embed.patches_resolution | |
# stochastic depth decay rule | |
dpr = [ | |
x.item() | |
for x in torch.linspace(0, drop_path_rate, sum(self.depths)) | |
] | |
# build stages | |
self.stages = ModuleList() | |
for i in range(self.num_stages): | |
depth = self.depths[i] | |
channel = self.channels[i] | |
curr_resolution = (patches_resolution[0] // (2**i), | |
patches_resolution[1] // (2**i)) | |
drop_path = dpr[sum(self.depths[:i]):sum(self.depths[:i + 1])] | |
downsample = PatchMerging if (i < self.num_stages - 1) else None | |
out_channels = self.channels[min(i + 1, self.num_stages - 1)] | |
if i >= 1: | |
stage = BasicStage( | |
in_channels=channel, | |
resolution=curr_resolution, | |
depth=depth, | |
num_heads=self.num_heads[i], | |
window_size=self.widow_sizes[i], | |
mlp_ratio=mlp_ratio, | |
drop=drop_rate, | |
drop_path=drop_path, | |
downsample=downsample, | |
use_checkpoint=use_checkpoint, | |
local_conv_size=local_conv_size, | |
out_channels=out_channels, | |
act_cfg=act_cfg) | |
else: | |
stage = ConvStage( | |
in_channels=channel, | |
resolution=curr_resolution, | |
depth=depth, | |
act_cfg=act_cfg, | |
drop_path=drop_path, | |
downsample=downsample, | |
use_checkpoint=use_checkpoint, | |
out_channels=out_channels, | |
conv_expand_ratio=mbconv_expand_ratio) | |
self.stages.append(stage) | |
# add output norm | |
if i in self.out_indices: | |
norm_layer = build_norm_layer(norm_cfg, out_channels)[1] | |
self.add_module(f'norm{i}', norm_layer) | |
def set_layer_lr_decay(self, layer_lr_decay): | |
# TODO: add layer_lr_decay | |
pass | |
def forward(self, x): | |
outs = [] | |
x = self.patch_embed(x) | |
for i, stage in enumerate(self.stages): | |
x = stage(x) | |
if i in self.out_indices: | |
norm_layer = getattr(self, f'norm{i}') | |
if self.gap_before_final_norm: | |
gap = x.mean(1) | |
outs.append(norm_layer(gap)) | |
else: | |
out = norm_layer(x) | |
# convert the (B,L,C) format into (B,C,H,W) format | |
# which would be better for the downstream tasks. | |
B, L, C = out.shape | |
out = out.view(B, *stage.resolution, C) | |
outs.append(out.permute(0, 3, 1, 2)) | |
return tuple(outs) | |
def _freeze_stages(self): | |
for i in range(self.frozen_stages): | |
stage = self.stages[i] | |
stage.eval() | |
for param in stage.parameters(): | |
param.requires_grad = False | |
def train(self, mode=True): | |
super(TinyViT, self).train(mode) | |
self._freeze_stages() | |