import pickle | |
import jax | |
dic = pickle.load( | |
open("./wavegru_vocoder_tpu_gta_preemphasis_pruning_0800000.ckpt", "rb") | |
) | |
dic = jax.device_get(dic) | |
del dic["optim_state_dict"] | |
pickle.dump( | |
dic, open("./wavegru_vocoder_tpu_gta_preemphasis_pruning_0800000.ckpt", "wb") | |
) | |