File size: 803 Bytes
6dfcb0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch
import argparse
import sys
sys.path.append('../../../')
from model_utils import Block, _cfg, PatchEmbed, get_sinusoid_encoding_table

parser = argparse.ArgumentParser()
parser.add_argument('--input', type=str, help='The path to the checkpoint.')
parser.add_argument('--output', type=str, default=None, help='the output path of the checkpoint')
args = parser.parse_args()
breakpoint()
state_dict = torch.load(args.input, map_location='cpu')

new_state_dict = {}

for k, v in state_dict.items():
    if 'pos_embed' in k:
        breakpoint()
    else:
        pass
    k = 'backbone.net.' + k
    new_state_dict[k] = v

output_path = args.input.replace('.pth', '-encoder.pth') if args.output is None else args.output
torch.save(new_state_dict, output_path)
print('Save model to', output_path)