chinese_GAU-alpha-char_L-24_H-768 / convert_bert4keras_tf_to_pytorch.py
junnyu's picture
Upload convert_bert4keras_tf_to_pytorch.py
b28258a
raw
history blame contribute delete
866 Bytes
import os
import tensorflow as tf
import torch
from collections import OrderedDict
tf_checkpoint_path = "chinese_GAU-alpha-char_L-24_H-768-tf/bert_model.ckpt"
tf_path = os.path.abspath(tf_checkpoint_path)
init_vars = tf.train.list_variables(tf_path)
arrays = []
pytorch_state_dict = OrderedDict()
for name, shape in init_vars:
array = tf.train.load_variable(tf_path, name)
new_name = (
name.replace("GAU_alpha", "gau_alpha")
.replace("bert", "gau_alpha")
.replace("/", ".")
.replace("layer_", "layer.")
.replace("kernel", "weight")
.replace("gamma", "weight")
)
if "embeddings" in new_name:
new_name = new_name + ".weight"
if "_dense" in new_name:
array = array.T
pytorch_state_dict[new_name] = torch.from_numpy(array)
torch.save(pytorch_state_dict, "pytorch_model.bin")