import os | |
import tensorflow as tf | |
import torch | |
from collections import OrderedDict | |
tf_checkpoint_path = "chinese_GAU-alpha-char_L-24_H-768-tf/bert_model.ckpt" | |
tf_path = os.path.abspath(tf_checkpoint_path) | |
init_vars = tf.train.list_variables(tf_path) | |
arrays = [] | |
pytorch_state_dict = OrderedDict() | |
for name, shape in init_vars: | |
array = tf.train.load_variable(tf_path, name) | |
new_name = ( | |
name.replace("GAU_alpha", "gau_alpha") | |
.replace("bert", "gau_alpha") | |
.replace("/", ".") | |
.replace("layer_", "layer.") | |
.replace("kernel", "weight") | |
.replace("gamma", "weight") | |
) | |
if "embeddings" in new_name: | |
new_name = new_name + ".weight" | |
if "_dense" in new_name: | |
array = array.T | |
pytorch_state_dict[new_name] = torch.from_numpy(array) | |
torch.save(pytorch_state_dict, "pytorch_model.bin") | |