import torch import argparse import sys sys.path.append('../../../') from model_utils import Block, _cfg, PatchEmbed, get_sinusoid_encoding_table parser = argparse.ArgumentParser() parser.add_argument('--input', type=str, help='The path to the checkpoint.') parser.add_argument('--output', type=str, default=None, help='the output path of the checkpoint') args = parser.parse_args() state_dict = torch.load(args.input, map_location='cpu')['model'] mae = True # C = state_dict['encoder.patch_embed.proj.weight'].shape[0] C = 768 pos_embed = get_sinusoid_encoding_table(14*14, C) cls_token = torch.zeros(1, 1, C) pos_embed = torch.cat([cls_token, pos_embed], dim=1) new_state_dict = {'backbone.net.pos_embed': pos_embed} for k, v in state_dict.items(): if mae or ('encoder' in k and not 'decoder' in k or 'patch_embed' in k): if 'patch_embed.proj.weight' in k: if len(v.shape) == 5: if v.shape[2] == 1: v = v.squeeze(2) # (768, 3, 1, 16, 16) -> (768, 3, 16, 16) else: v = v[:, :, 0] old_k = k k = k.replace('encoder.', 'backbone.net.') if not mae else 'backbone.net.'+k if 'attn' in k and '_bias' in k: old_attn = '.'.join(old_k.split('.')[:-1]) attn = '.'.join(k.split('.')[:-1]) k = attn + '.qkv.bias' if k in new_state_dict: continue v = torch.cat([ state_dict[old_attn + '.q_bias'], state_dict[old_attn + '.k_bias'] if (old_attn + '.k_bias') in state_dict else torch.zeros_like(state_dict[old_attn + '.q_bias']), state_dict[old_attn + '.v_bias'], ], dim=0) print(k, v.shape) new_state_dict[k] = v output_path = args.input.replace('.pth', '-encoder.pth') if args.output is None else args.output torch.save(new_state_dict, output_path) print('Save model to', output_path)