import pickle | |
import jax | |
dic = pickle.load(open("./mono_tts_cbhg_small_0700000.ckpt", "rb")) | |
del dic["optim_state_dict"] | |
dic = jax.device_get(dic) | |
pickle.dump(dic, open("./mono_tts_cbhg_small_0700000.ckpt", "wb")) | |
import pickle | |
import jax | |
dic = pickle.load(open("./mono_tts_cbhg_small_0700000.ckpt", "rb")) | |
del dic["optim_state_dict"] | |
dic = jax.device_get(dic) | |
pickle.dump(dic, open("./mono_tts_cbhg_small_0700000.ckpt", "wb")) | |