Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import math | |
import warnings | |
from collections import OrderedDict | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from mmcv.cnn import Conv2d, build_activation_layer, build_norm_layer | |
from mmcv.cnn.bricks.drop import build_dropout | |
from mmcv.cnn.bricks.transformer import MultiheadAttention | |
from mmengine.logging import MMLogger | |
from mmengine.model import (BaseModule, ModuleList, Sequential, constant_init, | |
normal_init, trunc_normal_init) | |
from mmengine.model.weight_init import trunc_normal_ | |
from mmengine.runner.checkpoint import CheckpointLoader, load_state_dict | |
from torch.nn.modules.utils import _pair as to_2tuple | |
from mmdet.registry import MODELS | |
from ..layers import PatchEmbed, nchw_to_nlc, nlc_to_nchw | |
class MixFFN(BaseModule): | |
"""An implementation of MixFFN of PVT. | |
The differences between MixFFN & FFN: | |
1. Use 1X1 Conv to replace Linear layer. | |
2. Introduce 3X3 Depth-wise Conv to encode positional information. | |
Args: | |
embed_dims (int): The feature dimension. Same as | |
`MultiheadAttention`. | |
feedforward_channels (int): The hidden dimension of FFNs. | |
act_cfg (dict, optional): The activation config for FFNs. | |
Default: dict(type='GELU'). | |
ffn_drop (float, optional): Probability of an element to be | |
zeroed in FFN. Default 0.0. | |
dropout_layer (obj:`ConfigDict`): The dropout_layer used | |
when adding the shortcut. | |
Default: None. | |
use_conv (bool): If True, add 3x3 DWConv between two Linear layers. | |
Defaults: False. | |
init_cfg (obj:`mmengine.ConfigDict`): The Config for initialization. | |
Default: None. | |
""" | |
def __init__(self, | |
embed_dims, | |
feedforward_channels, | |
act_cfg=dict(type='GELU'), | |
ffn_drop=0., | |
dropout_layer=None, | |
use_conv=False, | |
init_cfg=None): | |
super(MixFFN, self).__init__(init_cfg=init_cfg) | |
self.embed_dims = embed_dims | |
self.feedforward_channels = feedforward_channels | |
self.act_cfg = act_cfg | |
activate = build_activation_layer(act_cfg) | |
in_channels = embed_dims | |
fc1 = Conv2d( | |
in_channels=in_channels, | |
out_channels=feedforward_channels, | |
kernel_size=1, | |
stride=1, | |
bias=True) | |
if use_conv: | |
# 3x3 depth wise conv to provide positional encode information | |
dw_conv = Conv2d( | |
in_channels=feedforward_channels, | |
out_channels=feedforward_channels, | |
kernel_size=3, | |
stride=1, | |
padding=(3 - 1) // 2, | |
bias=True, | |
groups=feedforward_channels) | |
fc2 = Conv2d( | |
in_channels=feedforward_channels, | |
out_channels=in_channels, | |
kernel_size=1, | |
stride=1, | |
bias=True) | |
drop = nn.Dropout(ffn_drop) | |
layers = [fc1, activate, drop, fc2, drop] | |
if use_conv: | |
layers.insert(1, dw_conv) | |
self.layers = Sequential(*layers) | |
self.dropout_layer = build_dropout( | |
dropout_layer) if dropout_layer else torch.nn.Identity() | |
def forward(self, x, hw_shape, identity=None): | |
out = nlc_to_nchw(x, hw_shape) | |
out = self.layers(out) | |
out = nchw_to_nlc(out) | |
if identity is None: | |
identity = x | |
return identity + self.dropout_layer(out) | |
class SpatialReductionAttention(MultiheadAttention): | |
"""An implementation of Spatial Reduction Attention of PVT. | |
This module is modified from MultiheadAttention which is a module from | |
mmcv.cnn.bricks.transformer. | |
Args: | |
embed_dims (int): The embedding dimension. | |
num_heads (int): Parallel attention heads. | |
attn_drop (float): A Dropout layer on attn_output_weights. | |
Default: 0.0. | |
proj_drop (float): A Dropout layer after `nn.MultiheadAttention`. | |
Default: 0.0. | |
dropout_layer (obj:`ConfigDict`): The dropout_layer used | |
when adding the shortcut. Default: None. | |
batch_first (bool): Key, Query and Value are shape of | |
(batch, n, embed_dim) | |
or (n, batch, embed_dim). Default: False. | |
qkv_bias (bool): enable bias for qkv if True. Default: True. | |
norm_cfg (dict): Config dict for normalization layer. | |
Default: dict(type='LN'). | |
sr_ratio (int): The ratio of spatial reduction of Spatial Reduction | |
Attention of PVT. Default: 1. | |
init_cfg (obj:`mmengine.ConfigDict`): The Config for initialization. | |
Default: None. | |
""" | |
def __init__(self, | |
embed_dims, | |
num_heads, | |
attn_drop=0., | |
proj_drop=0., | |
dropout_layer=None, | |
batch_first=True, | |
qkv_bias=True, | |
norm_cfg=dict(type='LN'), | |
sr_ratio=1, | |
init_cfg=None): | |
super().__init__( | |
embed_dims, | |
num_heads, | |
attn_drop, | |
proj_drop, | |
batch_first=batch_first, | |
dropout_layer=dropout_layer, | |
bias=qkv_bias, | |
init_cfg=init_cfg) | |
self.sr_ratio = sr_ratio | |
if sr_ratio > 1: | |
self.sr = Conv2d( | |
in_channels=embed_dims, | |
out_channels=embed_dims, | |
kernel_size=sr_ratio, | |
stride=sr_ratio) | |
# The ret[0] of build_norm_layer is norm name. | |
self.norm = build_norm_layer(norm_cfg, embed_dims)[1] | |
# handle the BC-breaking from https://github.com/open-mmlab/mmcv/pull/1418 # noqa | |
from mmdet import digit_version, mmcv_version | |
if mmcv_version < digit_version('1.3.17'): | |
warnings.warn('The legacy version of forward function in' | |
'SpatialReductionAttention is deprecated in' | |
'mmcv>=1.3.17 and will no longer support in the' | |
'future. Please upgrade your mmcv.') | |
self.forward = self.legacy_forward | |
def forward(self, x, hw_shape, identity=None): | |
x_q = x | |
if self.sr_ratio > 1: | |
x_kv = nlc_to_nchw(x, hw_shape) | |
x_kv = self.sr(x_kv) | |
x_kv = nchw_to_nlc(x_kv) | |
x_kv = self.norm(x_kv) | |
else: | |
x_kv = x | |
if identity is None: | |
identity = x_q | |
# Because the dataflow('key', 'query', 'value') of | |
# ``torch.nn.MultiheadAttention`` is (num_queries, batch, | |
# embed_dims), We should adjust the shape of dataflow from | |
# batch_first (batch, num_queries, embed_dims) to num_queries_first | |
# (num_queries ,batch, embed_dims), and recover ``attn_output`` | |
# from num_queries_first to batch_first. | |
if self.batch_first: | |
x_q = x_q.transpose(0, 1) | |
x_kv = x_kv.transpose(0, 1) | |
out = self.attn(query=x_q, key=x_kv, value=x_kv)[0] | |
if self.batch_first: | |
out = out.transpose(0, 1) | |
return identity + self.dropout_layer(self.proj_drop(out)) | |
def legacy_forward(self, x, hw_shape, identity=None): | |
"""multi head attention forward in mmcv version < 1.3.17.""" | |
x_q = x | |
if self.sr_ratio > 1: | |
x_kv = nlc_to_nchw(x, hw_shape) | |
x_kv = self.sr(x_kv) | |
x_kv = nchw_to_nlc(x_kv) | |
x_kv = self.norm(x_kv) | |
else: | |
x_kv = x | |
if identity is None: | |
identity = x_q | |
out = self.attn(query=x_q, key=x_kv, value=x_kv)[0] | |
return identity + self.dropout_layer(self.proj_drop(out)) | |
class PVTEncoderLayer(BaseModule): | |
"""Implements one encoder layer in PVT. | |
Args: | |
embed_dims (int): The feature dimension. | |
num_heads (int): Parallel attention heads. | |
feedforward_channels (int): The hidden dimension for FFNs. | |
drop_rate (float): Probability of an element to be zeroed. | |
after the feed forward layer. Default: 0.0. | |
attn_drop_rate (float): The drop out rate for attention layer. | |
Default: 0.0. | |
drop_path_rate (float): stochastic depth rate. Default: 0.0. | |
qkv_bias (bool): enable bias for qkv if True. | |
Default: True. | |
act_cfg (dict): The activation config for FFNs. | |
Default: dict(type='GELU'). | |
norm_cfg (dict): Config dict for normalization layer. | |
Default: dict(type='LN'). | |
sr_ratio (int): The ratio of spatial reduction of Spatial Reduction | |
Attention of PVT. Default: 1. | |
use_conv_ffn (bool): If True, use Convolutional FFN to replace FFN. | |
Default: False. | |
init_cfg (dict, optional): Initialization config dict. | |
Default: None. | |
""" | |
def __init__(self, | |
embed_dims, | |
num_heads, | |
feedforward_channels, | |
drop_rate=0., | |
attn_drop_rate=0., | |
drop_path_rate=0., | |
qkv_bias=True, | |
act_cfg=dict(type='GELU'), | |
norm_cfg=dict(type='LN'), | |
sr_ratio=1, | |
use_conv_ffn=False, | |
init_cfg=None): | |
super(PVTEncoderLayer, self).__init__(init_cfg=init_cfg) | |
# The ret[0] of build_norm_layer is norm name. | |
self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1] | |
self.attn = SpatialReductionAttention( | |
embed_dims=embed_dims, | |
num_heads=num_heads, | |
attn_drop=attn_drop_rate, | |
proj_drop=drop_rate, | |
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), | |
qkv_bias=qkv_bias, | |
norm_cfg=norm_cfg, | |
sr_ratio=sr_ratio) | |
# The ret[0] of build_norm_layer is norm name. | |
self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1] | |
self.ffn = MixFFN( | |
embed_dims=embed_dims, | |
feedforward_channels=feedforward_channels, | |
ffn_drop=drop_rate, | |
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), | |
use_conv=use_conv_ffn, | |
act_cfg=act_cfg) | |
def forward(self, x, hw_shape): | |
x = self.attn(self.norm1(x), hw_shape, identity=x) | |
x = self.ffn(self.norm2(x), hw_shape, identity=x) | |
return x | |
class AbsolutePositionEmbedding(BaseModule): | |
"""An implementation of the absolute position embedding in PVT. | |
Args: | |
pos_shape (int): The shape of the absolute position embedding. | |
pos_dim (int): The dimension of the absolute position embedding. | |
drop_rate (float): Probability of an element to be zeroed. | |
Default: 0.0. | |
""" | |
def __init__(self, pos_shape, pos_dim, drop_rate=0., init_cfg=None): | |
super().__init__(init_cfg=init_cfg) | |
if isinstance(pos_shape, int): | |
pos_shape = to_2tuple(pos_shape) | |
elif isinstance(pos_shape, tuple): | |
if len(pos_shape) == 1: | |
pos_shape = to_2tuple(pos_shape[0]) | |
assert len(pos_shape) == 2, \ | |
f'The size of image should have length 1 or 2, ' \ | |
f'but got {len(pos_shape)}' | |
self.pos_shape = pos_shape | |
self.pos_dim = pos_dim | |
self.pos_embed = nn.Parameter( | |
torch.zeros(1, pos_shape[0] * pos_shape[1], pos_dim)) | |
self.drop = nn.Dropout(p=drop_rate) | |
def init_weights(self): | |
trunc_normal_(self.pos_embed, std=0.02) | |
def resize_pos_embed(self, pos_embed, input_shape, mode='bilinear'): | |
"""Resize pos_embed weights. | |
Resize pos_embed using bilinear interpolate method. | |
Args: | |
pos_embed (torch.Tensor): Position embedding weights. | |
input_shape (tuple): Tuple for (downsampled input image height, | |
downsampled input image width). | |
mode (str): Algorithm used for upsampling: | |
``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` | | |
``'trilinear'``. Default: ``'bilinear'``. | |
Return: | |
torch.Tensor: The resized pos_embed of shape [B, L_new, C]. | |
""" | |
assert pos_embed.ndim == 3, 'shape of pos_embed must be [B, L, C]' | |
pos_h, pos_w = self.pos_shape | |
pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w):] | |
pos_embed_weight = pos_embed_weight.reshape( | |
1, pos_h, pos_w, self.pos_dim).permute(0, 3, 1, 2).contiguous() | |
pos_embed_weight = F.interpolate( | |
pos_embed_weight, size=input_shape, mode=mode) | |
pos_embed_weight = torch.flatten(pos_embed_weight, | |
2).transpose(1, 2).contiguous() | |
pos_embed = pos_embed_weight | |
return pos_embed | |
def forward(self, x, hw_shape, mode='bilinear'): | |
pos_embed = self.resize_pos_embed(self.pos_embed, hw_shape, mode) | |
return self.drop(x + pos_embed) | |
class PyramidVisionTransformer(BaseModule): | |
"""Pyramid Vision Transformer (PVT) | |
Implementation of `Pyramid Vision Transformer: A Versatile Backbone for | |
Dense Prediction without Convolutions | |
<https://arxiv.org/pdf/2102.12122.pdf>`_. | |
Args: | |
pretrain_img_size (int | tuple[int]): The size of input image when | |
pretrain. Defaults: 224. | |
in_channels (int): Number of input channels. Default: 3. | |
embed_dims (int): Embedding dimension. Default: 64. | |
num_stags (int): The num of stages. Default: 4. | |
num_layers (Sequence[int]): The layer number of each transformer encode | |
layer. Default: [3, 4, 6, 3]. | |
num_heads (Sequence[int]): The attention heads of each transformer | |
encode layer. Default: [1, 2, 5, 8]. | |
patch_sizes (Sequence[int]): The patch_size of each patch embedding. | |
Default: [4, 2, 2, 2]. | |
strides (Sequence[int]): The stride of each patch embedding. | |
Default: [4, 2, 2, 2]. | |
paddings (Sequence[int]): The padding of each patch embedding. | |
Default: [0, 0, 0, 0]. | |
sr_ratios (Sequence[int]): The spatial reduction rate of each | |
transformer encode layer. Default: [8, 4, 2, 1]. | |
out_indices (Sequence[int] | int): Output from which stages. | |
Default: (0, 1, 2, 3). | |
mlp_ratios (Sequence[int]): The ratio of the mlp hidden dim to the | |
embedding dim of each transformer encode layer. | |
Default: [8, 8, 4, 4]. | |
qkv_bias (bool): Enable bias for qkv if True. Default: True. | |
drop_rate (float): Probability of an element to be zeroed. | |
Default 0.0. | |
attn_drop_rate (float): The drop out rate for attention layer. | |
Default 0.0. | |
drop_path_rate (float): stochastic depth rate. Default 0.1. | |
use_abs_pos_embed (bool): If True, add absolute position embedding to | |
the patch embedding. Defaults: True. | |
use_conv_ffn (bool): If True, use Convolutional FFN to replace FFN. | |
Default: False. | |
act_cfg (dict): The activation config for FFNs. | |
Default: dict(type='GELU'). | |
norm_cfg (dict): Config dict for normalization layer. | |
Default: dict(type='LN'). | |
pretrained (str, optional): model pretrained path. Default: None. | |
convert_weights (bool): The flag indicates whether the | |
pre-trained model is from the original repo. We may need | |
to convert some keys to make it compatible. | |
Default: True. | |
init_cfg (dict or list[dict], optional): Initialization config dict. | |
Default: None. | |
""" | |
def __init__(self, | |
pretrain_img_size=224, | |
in_channels=3, | |
embed_dims=64, | |
num_stages=4, | |
num_layers=[3, 4, 6, 3], | |
num_heads=[1, 2, 5, 8], | |
patch_sizes=[4, 2, 2, 2], | |
strides=[4, 2, 2, 2], | |
paddings=[0, 0, 0, 0], | |
sr_ratios=[8, 4, 2, 1], | |
out_indices=(0, 1, 2, 3), | |
mlp_ratios=[8, 8, 4, 4], | |
qkv_bias=True, | |
drop_rate=0., | |
attn_drop_rate=0., | |
drop_path_rate=0.1, | |
use_abs_pos_embed=True, | |
norm_after_stage=False, | |
use_conv_ffn=False, | |
act_cfg=dict(type='GELU'), | |
norm_cfg=dict(type='LN', eps=1e-6), | |
pretrained=None, | |
convert_weights=True, | |
init_cfg=None): | |
super().__init__(init_cfg=init_cfg) | |
self.convert_weights = convert_weights | |
if isinstance(pretrain_img_size, int): | |
pretrain_img_size = to_2tuple(pretrain_img_size) | |
elif isinstance(pretrain_img_size, tuple): | |
if len(pretrain_img_size) == 1: | |
pretrain_img_size = to_2tuple(pretrain_img_size[0]) | |
assert len(pretrain_img_size) == 2, \ | |
f'The size of image should have length 1 or 2, ' \ | |
f'but got {len(pretrain_img_size)}' | |
assert not (init_cfg and pretrained), \ | |
'init_cfg and pretrained cannot be setting at the same time' | |
if isinstance(pretrained, str): | |
warnings.warn('DeprecationWarning: pretrained is deprecated, ' | |
'please use "init_cfg" instead') | |
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) | |
elif pretrained is None: | |
self.init_cfg = init_cfg | |
else: | |
raise TypeError('pretrained must be a str or None') | |
self.embed_dims = embed_dims | |
self.num_stages = num_stages | |
self.num_layers = num_layers | |
self.num_heads = num_heads | |
self.patch_sizes = patch_sizes | |
self.strides = strides | |
self.sr_ratios = sr_ratios | |
assert num_stages == len(num_layers) == len(num_heads) \ | |
== len(patch_sizes) == len(strides) == len(sr_ratios) | |
self.out_indices = out_indices | |
assert max(out_indices) < self.num_stages | |
self.pretrained = pretrained | |
# transformer encoder | |
dpr = [ | |
x.item() | |
for x in torch.linspace(0, drop_path_rate, sum(num_layers)) | |
] # stochastic num_layer decay rule | |
cur = 0 | |
self.layers = ModuleList() | |
for i, num_layer in enumerate(num_layers): | |
embed_dims_i = embed_dims * num_heads[i] | |
patch_embed = PatchEmbed( | |
in_channels=in_channels, | |
embed_dims=embed_dims_i, | |
kernel_size=patch_sizes[i], | |
stride=strides[i], | |
padding=paddings[i], | |
bias=True, | |
norm_cfg=norm_cfg) | |
layers = ModuleList() | |
if use_abs_pos_embed: | |
pos_shape = pretrain_img_size // np.prod(patch_sizes[:i + 1]) | |
pos_embed = AbsolutePositionEmbedding( | |
pos_shape=pos_shape, | |
pos_dim=embed_dims_i, | |
drop_rate=drop_rate) | |
layers.append(pos_embed) | |
layers.extend([ | |
PVTEncoderLayer( | |
embed_dims=embed_dims_i, | |
num_heads=num_heads[i], | |
feedforward_channels=mlp_ratios[i] * embed_dims_i, | |
drop_rate=drop_rate, | |
attn_drop_rate=attn_drop_rate, | |
drop_path_rate=dpr[cur + idx], | |
qkv_bias=qkv_bias, | |
act_cfg=act_cfg, | |
norm_cfg=norm_cfg, | |
sr_ratio=sr_ratios[i], | |
use_conv_ffn=use_conv_ffn) for idx in range(num_layer) | |
]) | |
in_channels = embed_dims_i | |
# The ret[0] of build_norm_layer is norm name. | |
if norm_after_stage: | |
norm = build_norm_layer(norm_cfg, embed_dims_i)[1] | |
else: | |
norm = nn.Identity() | |
self.layers.append(ModuleList([patch_embed, layers, norm])) | |
cur += num_layer | |
def init_weights(self): | |
logger = MMLogger.get_current_instance() | |
if self.init_cfg is None: | |
logger.warn(f'No pre-trained weights for ' | |
f'{self.__class__.__name__}, ' | |
f'training start from scratch') | |
for m in self.modules(): | |
if isinstance(m, nn.Linear): | |
trunc_normal_init(m, std=.02, bias=0.) | |
elif isinstance(m, nn.LayerNorm): | |
constant_init(m, 1.0) | |
elif isinstance(m, nn.Conv2d): | |
fan_out = m.kernel_size[0] * m.kernel_size[ | |
1] * m.out_channels | |
fan_out //= m.groups | |
normal_init(m, 0, math.sqrt(2.0 / fan_out)) | |
elif isinstance(m, AbsolutePositionEmbedding): | |
m.init_weights() | |
else: | |
assert 'checkpoint' in self.init_cfg, f'Only support ' \ | |
f'specify `Pretrained` in ' \ | |
f'`init_cfg` in ' \ | |
f'{self.__class__.__name__} ' | |
checkpoint = CheckpointLoader.load_checkpoint( | |
self.init_cfg.checkpoint, logger=logger, map_location='cpu') | |
logger.warn(f'Load pre-trained model for ' | |
f'{self.__class__.__name__} from original repo') | |
if 'state_dict' in checkpoint: | |
state_dict = checkpoint['state_dict'] | |
elif 'model' in checkpoint: | |
state_dict = checkpoint['model'] | |
else: | |
state_dict = checkpoint | |
if self.convert_weights: | |
# Because pvt backbones are not supported by mmcls, | |
# so we need to convert pre-trained weights to match this | |
# implementation. | |
state_dict = pvt_convert(state_dict) | |
load_state_dict(self, state_dict, strict=False, logger=logger) | |
def forward(self, x): | |
outs = [] | |
for i, layer in enumerate(self.layers): | |
x, hw_shape = layer[0](x) | |
for block in layer[1]: | |
x = block(x, hw_shape) | |
x = layer[2](x) | |
x = nlc_to_nchw(x, hw_shape) | |
if i in self.out_indices: | |
outs.append(x) | |
return outs | |
class PyramidVisionTransformerV2(PyramidVisionTransformer): | |
"""Implementation of `PVTv2: Improved Baselines with Pyramid Vision | |
Transformer <https://arxiv.org/pdf/2106.13797.pdf>`_.""" | |
def __init__(self, **kwargs): | |
super(PyramidVisionTransformerV2, self).__init__( | |
patch_sizes=[7, 3, 3, 3], | |
paddings=[3, 1, 1, 1], | |
use_abs_pos_embed=False, | |
norm_after_stage=True, | |
use_conv_ffn=True, | |
**kwargs) | |
def pvt_convert(ckpt): | |
new_ckpt = OrderedDict() | |
# Process the concat between q linear weights and kv linear weights | |
use_abs_pos_embed = False | |
use_conv_ffn = False | |
for k in ckpt.keys(): | |
if k.startswith('pos_embed'): | |
use_abs_pos_embed = True | |
if k.find('dwconv') >= 0: | |
use_conv_ffn = True | |
for k, v in ckpt.items(): | |
if k.startswith('head'): | |
continue | |
if k.startswith('norm.'): | |
continue | |
if k.startswith('cls_token'): | |
continue | |
if k.startswith('pos_embed'): | |
stage_i = int(k.replace('pos_embed', '')) | |
new_k = k.replace(f'pos_embed{stage_i}', | |
f'layers.{stage_i - 1}.1.0.pos_embed') | |
if stage_i == 4 and v.size(1) == 50: # 1 (cls token) + 7 * 7 | |
new_v = v[:, 1:, :] # remove cls token | |
else: | |
new_v = v | |
elif k.startswith('patch_embed'): | |
stage_i = int(k.split('.')[0].replace('patch_embed', '')) | |
new_k = k.replace(f'patch_embed{stage_i}', | |
f'layers.{stage_i - 1}.0') | |
new_v = v | |
if 'proj.' in new_k: | |
new_k = new_k.replace('proj.', 'projection.') | |
elif k.startswith('block'): | |
stage_i = int(k.split('.')[0].replace('block', '')) | |
layer_i = int(k.split('.')[1]) | |
new_layer_i = layer_i + use_abs_pos_embed | |
new_k = k.replace(f'block{stage_i}.{layer_i}', | |
f'layers.{stage_i - 1}.1.{new_layer_i}') | |
new_v = v | |
if 'attn.q.' in new_k: | |
sub_item_k = k.replace('q.', 'kv.') | |
new_k = new_k.replace('q.', 'attn.in_proj_') | |
new_v = torch.cat([v, ckpt[sub_item_k]], dim=0) | |
elif 'attn.kv.' in new_k: | |
continue | |
elif 'attn.proj.' in new_k: | |
new_k = new_k.replace('proj.', 'attn.out_proj.') | |
elif 'attn.sr.' in new_k: | |
new_k = new_k.replace('sr.', 'sr.') | |
elif 'mlp.' in new_k: | |
string = f'{new_k}-' | |
new_k = new_k.replace('mlp.', 'ffn.layers.') | |
if 'fc1.weight' in new_k or 'fc2.weight' in new_k: | |
new_v = v.reshape((*v.shape, 1, 1)) | |
new_k = new_k.replace('fc1.', '0.') | |
new_k = new_k.replace('dwconv.dwconv.', '1.') | |
if use_conv_ffn: | |
new_k = new_k.replace('fc2.', '4.') | |
else: | |
new_k = new_k.replace('fc2.', '3.') | |
string += f'{new_k} {v.shape}-{new_v.shape}' | |
elif k.startswith('norm'): | |
stage_i = int(k[4]) | |
new_k = k.replace(f'norm{stage_i}', f'layers.{stage_i - 1}.2') | |
new_v = v | |
else: | |
new_k = k | |
new_v = v | |
new_ckpt[new_k] = new_v | |
return new_ckpt | |