Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
counterfactual-world-models
/
cwm
/eval
/Segmentation
/archive
/common
/convert_cwm_checkpoint_detectron_format_v2.py
import torch | |
import argparse | |
import sys | |
sys.path.append('../../../') | |
from model_utils import Block, _cfg, PatchEmbed, get_sinusoid_encoding_table | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--input', type=str, help='The path to the checkpoint.') | |
parser.add_argument('--output', type=str, default=None, help='the output path of the checkpoint') | |
args = parser.parse_args() | |
state_dict = torch.load(args.input, map_location='cpu')['model'] | |
mae = True | |
# C = state_dict['encoder.patch_embed.proj.weight'].shape[0] | |
C = 768 | |
pos_embed = get_sinusoid_encoding_table(14*14, C) | |
cls_token = torch.zeros(1, 1, C) | |
pos_embed = torch.cat([cls_token, pos_embed], dim=1) | |
new_state_dict = {'backbone.net.pos_embed': pos_embed} | |
for k, v in state_dict.items(): | |
if mae or ('encoder' in k and not 'decoder' in k or 'patch_embed' in k): | |
if 'patch_embed.proj.weight' in k: | |
if len(v.shape) == 5: | |
if v.shape[2] == 1: | |
v = v.squeeze(2) # (768, 3, 1, 16, 16) -> (768, 3, 16, 16) | |
else: | |
v = v[:, :, 0] | |
old_k = k | |
k = k.replace('encoder.', 'backbone.net.') if not mae else 'backbone.net.'+k | |
if 'attn' in k and '_bias' in k: | |
old_attn = '.'.join(old_k.split('.')[:-1]) | |
attn = '.'.join(k.split('.')[:-1]) | |
k = attn + '.qkv.bias' | |
if k in new_state_dict: | |
continue | |
v = torch.cat([ | |
state_dict[old_attn + '.q_bias'], | |
state_dict[old_attn + '.k_bias'] if (old_attn + '.k_bias') in state_dict else torch.zeros_like(state_dict[old_attn + '.q_bias']), | |
state_dict[old_attn + '.v_bias'], | |
], dim=0) | |
print(k, v.shape) | |
new_state_dict[k] = v | |
output_path = args.input.replace('.pth', '-encoder.pth') if args.output is None else args.output | |
torch.save(new_state_dict, output_path) | |
print('Save model to', output_path) |