File size: 1,954 Bytes
6dfcb0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import torch
import argparse
import sys
sys.path.append('../../../')
from model_utils import Block, _cfg, PatchEmbed, get_sinusoid_encoding_table

parser = argparse.ArgumentParser()
parser.add_argument('--input', type=str, help='The path to the checkpoint.')
parser.add_argument('--output', type=str, default=None, help='the output path of the checkpoint')
args = parser.parse_args()

state_dict = torch.load(args.input, map_location='cpu')['model']
mae = True
# C = state_dict['encoder.patch_embed.proj.weight'].shape[0]
C = 768
pos_embed = get_sinusoid_encoding_table(14*14, C)
cls_token = torch.zeros(1, 1, C)
pos_embed = torch.cat([cls_token, pos_embed], dim=1)


new_state_dict = {'backbone.net.pos_embed': pos_embed}
for k, v in state_dict.items():

    if mae or ('encoder' in k and not 'decoder' in k or 'patch_embed' in k):

        if 'patch_embed.proj.weight' in k:

            if len(v.shape) == 5:
                if v.shape[2] == 1:
                    v = v.squeeze(2) # (768, 3, 1, 16, 16) -> (768, 3, 16, 16)
                else:
                    v = v[:, :, 0]

        old_k = k
        k = k.replace('encoder.', 'backbone.net.') if not mae else 'backbone.net.'+k

        if 'attn' in k and '_bias' in k:
            old_attn = '.'.join(old_k.split('.')[:-1])
            attn = '.'.join(k.split('.')[:-1])
            k = attn + '.qkv.bias'
            if k in new_state_dict:
                continue

            v = torch.cat([
                state_dict[old_attn + '.q_bias'],
                state_dict[old_attn + '.k_bias'] if (old_attn + '.k_bias') in state_dict else torch.zeros_like(state_dict[old_attn + '.q_bias']),
                state_dict[old_attn + '.v_bias'],
            ], dim=0)
        print(k, v.shape)
        new_state_dict[k] = v

output_path = args.input.replace('.pth', '-encoder.pth') if args.output is None else args.output
torch.save(new_state_dict, output_path)
print('Save model to', output_path)