import torch # Load the checkpoint checkpoint = torch.load('checkpoint_cleaned.pt', map_location=torch.device('cpu')) print(checkpoint.keys()) # remove keys of fform lang_encoder.gpt_neox.layers.x.decoder_layer for key in list(checkpoint.keys()): if 'decoder_layer' in key: del checkpoint[key] # save the checkpoint torch.save(checkpoint, 'checkpoint_cleaned.pt')