junnyu commited on
Commit
b28258a
•
1 Parent(s): 7800100

Upload convert_bert4keras_tf_to_pytorch.py

Browse files
Files changed (1) hide show
  1. convert_bert4keras_tf_to_pytorch.py +28 -0
convert_bert4keras_tf_to_pytorch.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tensorflow as tf
3
+ import torch
4
+ from collections import OrderedDict
5
+
6
+ tf_checkpoint_path = "chinese_GAU-alpha-char_L-24_H-768-tf/bert_model.ckpt"
7
+ tf_path = os.path.abspath(tf_checkpoint_path)
8
+ init_vars = tf.train.list_variables(tf_path)
9
+ arrays = []
10
+
11
+ pytorch_state_dict = OrderedDict()
12
+ for name, shape in init_vars:
13
+ array = tf.train.load_variable(tf_path, name)
14
+ new_name = (
15
+ name.replace("GAU_alpha", "gau_alpha")
16
+ .replace("bert", "gau_alpha")
17
+ .replace("/", ".")
18
+ .replace("layer_", "layer.")
19
+ .replace("kernel", "weight")
20
+ .replace("gamma", "weight")
21
+ )
22
+ if "embeddings" in new_name:
23
+ new_name = new_name + ".weight"
24
+ if "_dense" in new_name:
25
+ array = array.T
26
+ pytorch_state_dict[new_name] = torch.from_numpy(array)
27
+
28
+ torch.save(pytorch_state_dict, "pytorch_model.bin")