File size: 380 Bytes
4c4f051 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 |
import torch
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--input', '-I', type=str, help='Input file to prune', required = True)
args = parser.parse_args()
file = args.input
checkpoint = torch.load(file)
new_sd = dict()
for k in checkpoint.keys():
if k != 'optimizer_states':
new_sd[k] = checkpoint[k]
torch.save(new_sd, f'pruned-{file}') |