counterfactual-world-models / cwm /eval /Segmentation /archive /common /convert_cwm_checkpoint_detectron_format_v2.py
rahulvenkk
app.py updated
6dfcb0f
raw
history blame
1.95 kB
import torch
import argparse
import sys
sys.path.append('../../../')
from model_utils import Block, _cfg, PatchEmbed, get_sinusoid_encoding_table
parser = argparse.ArgumentParser()
parser.add_argument('--input', type=str, help='The path to the checkpoint.')
parser.add_argument('--output', type=str, default=None, help='the output path of the checkpoint')
args = parser.parse_args()
state_dict = torch.load(args.input, map_location='cpu')['model']
mae = True
# C = state_dict['encoder.patch_embed.proj.weight'].shape[0]
C = 768
pos_embed = get_sinusoid_encoding_table(14*14, C)
cls_token = torch.zeros(1, 1, C)
pos_embed = torch.cat([cls_token, pos_embed], dim=1)
new_state_dict = {'backbone.net.pos_embed': pos_embed}
for k, v in state_dict.items():
if mae or ('encoder' in k and not 'decoder' in k or 'patch_embed' in k):
if 'patch_embed.proj.weight' in k:
if len(v.shape) == 5:
if v.shape[2] == 1:
v = v.squeeze(2) # (768, 3, 1, 16, 16) -> (768, 3, 16, 16)
else:
v = v[:, :, 0]
old_k = k
k = k.replace('encoder.', 'backbone.net.') if not mae else 'backbone.net.'+k
if 'attn' in k and '_bias' in k:
old_attn = '.'.join(old_k.split('.')[:-1])
attn = '.'.join(k.split('.')[:-1])
k = attn + '.qkv.bias'
if k in new_state_dict:
continue
v = torch.cat([
state_dict[old_attn + '.q_bias'],
state_dict[old_attn + '.k_bias'] if (old_attn + '.k_bias') in state_dict else torch.zeros_like(state_dict[old_attn + '.q_bias']),
state_dict[old_attn + '.v_bias'],
], dim=0)
print(k, v.shape)
new_state_dict[k] = v
output_path = args.input.replace('.pth', '-encoder.pth') if args.output is None else args.output
torch.save(new_state_dict, output_path)
print('Save model to', output_path)