Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
counterfactual-world-models
/
cwm
/eval
/Segmentation
/archive
/common
/convert_cwm_checkpoint_detectron_format.py
import torch | |
import argparse | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--input', type=str, help='The path to the checkpoint.') | |
parser.add_argument('--output', type=str, default=None, help='the output path of the checkpoint') | |
args = parser.parse_args() | |
state_dict = torch.load(args.input, map_location='cpu')['model'] | |
new_state_dict = {} | |
for k, v in state_dict.items(): | |
if 'encoder' in k and not 'decoder' in k: | |
new_k = 'backbone.net.model.' + k | |
new_state_dict[new_k] = v | |
output_path = args.input.replace('.pth', '-encoder.pth') if args.output is None else args.output | |
torch.save(new_state_dict, output_path) | |
print('Save model to', output_path) |