import torch import argparse import sys sys.path.append('../../../') from model_utils import Block, _cfg, PatchEmbed, get_sinusoid_encoding_table parser = argparse.ArgumentParser() parser.add_argument('--input', type=str, help='The path to the checkpoint.') parser.add_argument('--output', type=str, default=None, help='the output path of the checkpoint') args = parser.parse_args() breakpoint() state_dict = torch.load(args.input, map_location='cpu') new_state_dict = {} for k, v in state_dict.items(): if 'pos_embed' in k: breakpoint() else: pass k = 'backbone.net.' + k new_state_dict[k] = v output_path = args.input.replace('.pth', '-encoder.pth') if args.output is None else args.output torch.save(new_state_dict, output_path) print('Save model to', output_path)