import torch | |
# Load the checkpoint | |
checkpoint = torch.load('checkpoint_cleaned.pt', map_location=torch.device('cpu')) | |
print(checkpoint.keys()) | |
# remove keys of fform lang_encoder.gpt_neox.layers.x.decoder_layer | |
for key in list(checkpoint.keys()): | |
if 'decoder_layer' in key: | |
del checkpoint[key] | |
# save the checkpoint | |
torch.save(checkpoint, 'checkpoint_cleaned.pt') |