English
anas-awadalla's picture
removed lm weights from checkpoint
28a1d28
raw
history blame
377 Bytes
import torch
# Load the checkpoint
checkpoint = torch.load('checkpoint_cleaned.pt', map_location=torch.device('cpu'))
print(checkpoint.keys())
# remove keys of fform lang_encoder.gpt_neox.layers.x.decoder_layer
for key in list(checkpoint.keys()):
if 'decoder_layer' in key:
del checkpoint[key]
# save the checkpoint
torch.save(checkpoint, 'checkpoint_cleaned.pt')