aapot commited on
Commit
74efc84
1 Parent(s): 61cd9b4

Add configs

Browse files
.gitattributes CHANGED
@@ -26,3 +26,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
26
  *.zip filter=lfs diff=lfs merge=lfs -text
27
  *.zstandard filter=lfs diff=lfs merge=lfs -text
28
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
26
  *.zip filter=lfs diff=lfs merge=lfs -text
27
  *.zstandard filter=lfs diff=lfs merge=lfs -text
28
  *tfevents* filter=lfs diff=lfs merge=lfs -text
29
+ checkpoint*/** filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__/
base_nl36.gin ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # T5.1.1 Efficient base nl36 model.
2
+
3
+ import seqio
4
+ include 't5x/examples/t5/t5_1_1/base.gin' # imports vocab, optimizer and model.
5
+
6
+ # ------------------- Network specification overrides --------------------------
7
+ network.Transformer.config = @network.T5Config()
8
+ network.T5Config:
9
+ emb_dim = 768
10
+ num_heads = 12
11
+ num_encoder_layers = 36
12
+ num_decoder_layers = 36
13
+ head_dim = 64
14
+ mlp_dim = 3072
15
+
16
+ # ------------------- Model specification overrides --------------------------
17
+ VOCABULARY = @seqio.SentencePieceVocabulary()
18
+ seqio.SentencePieceVocabulary.sentencepiece_model_file = "spiece.model"
19
+
20
+ MODEL = @models.EncoderDecoderModel()
21
+ models.EncoderDecoderModel:
22
+ input_vocabulary = %VOCABULARY
23
+ output_vocabulary = %VOCABULARY
base_nl36_pretrain.gin ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Register necessary SeqIO Tasks/Mixtures.
2
+ from __gin__ import dynamic_registration
3
+ from t5x import utils
4
+ import tasks
5
+ import __main__ as train_script
6
+
7
+ include 'base_nl36.gin'
8
+ include 't5x/configs/runs/pretrain.gin'
9
+
10
+
11
+ # ------------------- Training specification overrides --------------------------
12
+ train_script.train:
13
+ eval_period = 10000
14
+
15
+ utils.SaveCheckpointConfig:
16
+ period = 10000
17
+ keep = 10
18
+
19
+ MIXTURE_OR_TASK_NAME = "pretrain_finnish"
20
+ USE_CACHED_TASKS = False
21
+ TASK_FEATURE_LENGTHS = {"inputs": 512, "targets": 512}
22
+ TRAIN_STEPS = 500000
23
+ DROPOUT_RATE = 0.0
24
+ BATCH_SIZE = 256
config.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "./",
3
+ "architectures": [
4
+ "T5ForConditionalGeneration"
5
+ ],
6
+ "d_ff": 3072,
7
+ "d_kv": 64,
8
+ "d_model": 768,
9
+ "decoder_start_token_id": 0,
10
+ "dropout_rate": 0.1,
11
+ "eos_token_id": 1,
12
+ "feed_forward_proj": "gated-gelu",
13
+ "initializer_factor": 1.0,
14
+ "is_encoder_decoder": true,
15
+ "layer_norm_epsilon": 1e-06,
16
+ "model_type": "t5",
17
+ "n_positions": 512,
18
+ "num_decoder_layers": 36,
19
+ "num_heads": 12,
20
+ "num_layers": 36,
21
+ "output_past": true,
22
+ "pad_token_id": 0,
23
+ "relative_attention_max_distance": 128,
24
+ "relative_attention_num_buckets": 32,
25
+ "tie_word_embeddings": false,
26
+ "transformers_version": "4.17.0",
27
+ "use_cache": true,
28
+ "vocab_size": 32128
29
+ }
convert_t5x_checkpoint_to_flax.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://gist.github.com/stefan-it/30e4998ef159f33696e377a46f699d9f
2
+
3
+ import argparse
4
+
5
+ from t5x import checkpoints
6
+ from transformers import T5Config, FlaxT5ForConditionalGeneration
7
+
8
+
9
+ def convert_t5x_checkpoint_to_flax(t5x_checkpoint_path, config_name, flax_dump_folder_path):
10
+ config = T5Config.from_pretrained(config_name)
11
+ flax_model = FlaxT5ForConditionalGeneration(config=config)
12
+ t5x_model = checkpoints.load_t5x_checkpoint(t5x_checkpoint_path)
13
+
14
+ split_mlp_wi = "wi_0" in t5x_model["target"]["encoder"]["layers_0"]["mlp"]
15
+
16
+ # Encoder
17
+ for layer_index in range(config.num_layers):
18
+ layer_name = f"layers_{str(layer_index)}"
19
+
20
+ # Self-Attention
21
+ t5x_attention_key = t5x_model["target"]["encoder"][layer_name]["attention"]["key"]["kernel"]
22
+ t5x_attention_out = t5x_model["target"]["encoder"][layer_name]["attention"]["out"]["kernel"]
23
+ t5x_attention_query = t5x_model["target"]["encoder"][layer_name]["attention"]["query"]["kernel"]
24
+ t5x_attention_value = t5x_model["target"]["encoder"][layer_name]["attention"]["value"]["kernel"]
25
+
26
+ ## Layer Normalization
27
+ t5x_attention_layer_norm = t5x_model["target"]["encoder"][layer_name]["pre_attention_layer_norm"]["scale"]
28
+
29
+ if split_mlp_wi:
30
+ t5x_mlp_wi_0 = t5x_model["target"]["encoder"][layer_name]["mlp"]["wi_0"]["kernel"]
31
+ t5x_mlp_wi_1 = t5x_model["target"]["encoder"][layer_name]["mlp"]["wi_1"]["kernel"]
32
+ else:
33
+ t5x_mlp_wi = t5x_model["target"]["encoder"][layer_name]["mlp"]["wi"]["kernel"]
34
+
35
+ t5x_mlp_wo = t5x_model["target"]["encoder"][layer_name]["mlp"]["wo"]["kernel"]
36
+
37
+ ## Layer Normalization
38
+ t5x_mlp_layer_norm = t5x_model["target"]["encoder"][layer_name]["pre_mlp_layer_norm"]["scale"]
39
+
40
+ # Assigning
41
+ flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["k"]["kernel"] = t5x_attention_key
42
+ flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["o"]["kernel"] = t5x_attention_out
43
+ flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["q"]["kernel"] = t5x_attention_query
44
+ flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["v"]["kernel"] = t5x_attention_value
45
+
46
+ flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["layer_norm"]["weight"] = t5x_attention_layer_norm
47
+
48
+ if split_mlp_wi:
49
+ flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["DenseReluDense"]["wi_0"]["kernel"] = t5x_mlp_wi_0
50
+ flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["DenseReluDense"]["wi_1"]["kernel"] = t5x_mlp_wi_1
51
+ else:
52
+ flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["DenseReluDense"]["wi"]["kernel"] = t5x_mlp_wi
53
+
54
+ flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["DenseReluDense"]["wo"]["kernel"] = t5x_mlp_wo
55
+ flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["layer_norm"]["weight"] = t5x_mlp_layer_norm
56
+
57
+ # Only for layer 0:
58
+ t5x_encoder_rel_embedding = t5x_model["target"]["encoder"]["relpos_bias"]["rel_embedding"].T
59
+ flax_model.params["encoder"]["block"]["0"]["layer"]["0"]["SelfAttention"]["relative_attention_bias"]["embedding"] = t5x_encoder_rel_embedding
60
+
61
+ # Assigning
62
+ t5x_encoder_norm = t5x_model["target"]["encoder"]["encoder_norm"]["scale"]
63
+ flax_model.params["encoder"]["final_layer_norm"]["weight"] = t5x_encoder_norm
64
+
65
+ # Decoder
66
+ for layer_index in range(config.num_layers):
67
+ layer_name = f"layers_{str(layer_index)}"
68
+
69
+ # Self-Attention
70
+ t5x_attention_key = t5x_model["target"]["decoder"][layer_name]["self_attention"]["key"]["kernel"]
71
+ t5x_attention_out = t5x_model["target"]["decoder"][layer_name]["self_attention"]["out"]["kernel"]
72
+ t5x_attention_query = t5x_model["target"]["decoder"][layer_name]["self_attention"]["query"]["kernel"]
73
+ t5x_attention_value = t5x_model["target"]["decoder"][layer_name]["self_attention"]["value"]["kernel"]
74
+
75
+ ## Layer Normalization
76
+ t5x_pre_attention_layer_norm = t5x_model["target"]["decoder"][layer_name]["pre_self_attention_layer_norm"]["scale"]
77
+
78
+ # Encoder-Decoder-Attention
79
+ t5x_enc_dec_attention_key = t5x_model["target"]["decoder"][layer_name]["encoder_decoder_attention"]["key"]["kernel"]
80
+ t5x_enc_dec_attention_out = t5x_model["target"]["decoder"][layer_name]["encoder_decoder_attention"]["out"]["kernel"]
81
+ t5x_enc_dec_attention_query = t5x_model["target"]["decoder"][layer_name]["encoder_decoder_attention"]["query"]["kernel"]
82
+ t5x_enc_dec_attention_value = t5x_model["target"]["decoder"][layer_name]["encoder_decoder_attention"]["value"]["kernel"]
83
+
84
+ ## Layer Normalization
85
+ t5x_cross_layer_norm = t5x_model["target"]["decoder"][layer_name]["pre_cross_attention_layer_norm"]["scale"]
86
+
87
+ # MLP
88
+ if split_mlp_wi:
89
+ t5x_mlp_wi_0 = t5x_model["target"]["decoder"][layer_name]["mlp"]["wi_0"]["kernel"]
90
+ t5x_mlp_wi_1 = t5x_model["target"]["decoder"][layer_name]["mlp"]["wi_1"]["kernel"]
91
+ else:
92
+ t5x_mlp_wi = t5x_model["target"]["decoder"][layer_name]["mlp"]["wi"]["kernel"]
93
+
94
+ t5x_mlp_wo = t5x_model["target"]["decoder"][layer_name]["mlp"]["wo"]["kernel"]
95
+
96
+ ## Layer Normalization
97
+ tx5_mlp_layer_norm = t5x_model["target"]["decoder"][layer_name]["pre_mlp_layer_norm"]["scale"]
98
+
99
+ # Assigning
100
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["k"]["kernel"] = t5x_attention_key
101
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["o"]["kernel"] = t5x_attention_out
102
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["q"]["kernel"] = t5x_attention_query
103
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["v"]["kernel"] = t5x_attention_value
104
+
105
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["layer_norm"]["weight"] = t5x_pre_attention_layer_norm
106
+
107
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["EncDecAttention"]["k"]["kernel"] = t5x_enc_dec_attention_key
108
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["EncDecAttention"]["o"]["kernel"] = t5x_enc_dec_attention_out
109
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["EncDecAttention"]["q"]["kernel"] = t5x_enc_dec_attention_query
110
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["EncDecAttention"]["v"]["kernel"] = t5x_enc_dec_attention_value
111
+
112
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["layer_norm"]["weight"] = t5x_cross_layer_norm
113
+
114
+ if split_mlp_wi:
115
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"]["DenseReluDense"]["wi_0"]["kernel"] = t5x_mlp_wi_0
116
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"]["DenseReluDense"]["wi_1"]["kernel"] = t5x_mlp_wi_1
117
+ else:
118
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"]["DenseReluDense"]["wi"]["kernel"] = t5x_mlp_wi
119
+
120
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"]["DenseReluDense"]["wo"]["kernel"] = t5x_mlp_wo
121
+
122
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"]["layer_norm"]["weight"] = tx5_mlp_layer_norm
123
+
124
+ # Decoder Normalization
125
+ tx5_decoder_norm = t5x_model["target"]["decoder"]["decoder_norm"]["scale"]
126
+ flax_model.params["decoder"]["final_layer_norm"]["weight"] = tx5_decoder_norm
127
+
128
+ # Only for layer 0:
129
+ t5x_decoder_rel_embedding = t5x_model["target"]["decoder"]["relpos_bias"]["rel_embedding"].T
130
+ flax_model.params["decoder"]["block"]["0"]["layer"]["0"]["SelfAttention"]["relative_attention_bias"]["embedding"] = t5x_decoder_rel_embedding
131
+
132
+ # Token Embeddings
133
+ tx5_token_embeddings = t5x_model["target"]["token_embedder"]["embedding"]
134
+ flax_model.params["shared"]["embedding"] = tx5_token_embeddings
135
+
136
+ # LM Head
137
+ flax_model.params["lm_head"]["kernel"] = t5x_model["target"]["decoder"]["logits_dense"]["kernel"]
138
+
139
+ flax_model.save_pretrained(flax_dump_folder_path)
140
+ print("T5X Model was sucessfully converted!")
141
+
142
+
143
+ if __name__ == "__main__":
144
+ parser = argparse.ArgumentParser()
145
+ # Required parameters
146
+ parser.add_argument(
147
+ "--t5x_checkpoint_path", default=None, type=str, required=True, help="Path the TX5 checkpoint."
148
+ )
149
+ parser.add_argument(
150
+ "--config_name", default=None, type=str, required=True, help="Config name of T5 model."
151
+ )
152
+ parser.add_argument(
153
+ "--flax_dump_folder_path", default=None, type=str, required=True, help="Path to the output FLAX model."
154
+ )
155
+ args = parser.parse_args()
156
+ convert_t5x_checkpoint_to_flax(args.t5x_checkpoint_path, args.config_name, args.flax_dump_folder_path)
157
+
flax_model_to_pytorch.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForSeq2SeqLM, FlaxAutoModelForSeq2SeqLM, AutoTokenizer
2
+ import torch
3
+ import numpy as np
4
+ import jax
5
+ import jax.numpy as jnp
6
+
7
+ def to_f32(t):
8
+ return jax.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t)
9
+
10
+ jax.config.update('jax_platform_name', 'cpu')
11
+ MODEL_PATH = "./"
12
+ model = FlaxAutoModelForSeq2SeqLM.from_pretrained(MODEL_PATH)
13
+ model.params = to_f32(model.params)
14
+ model.save_pretrained(MODEL_PATH)
15
+
16
+ pt_model = AutoModelForSeq2SeqLM.from_pretrained(
17
+ MODEL_PATH, from_flax=True).to('cpu')
18
+
19
+ input_ids = np.asarray(2 * [128 * [0]], dtype=np.int32)
20
+ input_ids_pt = torch.tensor(input_ids)
21
+
22
+ logits_pt = pt_model(input_ids=input_ids_pt, decoder_input_ids=input_ids_pt).logits
23
+ print(logits_pt)
24
+ logits_fx = model(input_ids=input_ids, decoder_input_ids=input_ids).logits
25
+ print(logits_fx)
26
+
27
+ pt_model.save_pretrained(MODEL_PATH)
spiece.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:55a3645122435e9773fac81fa3f94c1e14149e80311636dfa9245fba3e57a826
3
+ size 824186
spiece.vocab ADDED
The diff for this file is too large to render. See raw diff
 
start_train.sh ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # set train hyperparams
2
+ unset LD_PRELOAD
3
+
4
+ PROJECT_DIR="/researchdisk/t5x-base-nl36-finnish"
5
+ T5X_DIR=${HOME}"/t5x" # directory where the t5x is cloned.
6
+ MODEL_DIR="/researchdisk/t5x-base-nl36-finnish"
7
+ export PYTHONPATH=${PROJECT_DIR}
8
+
9
+ python3 ${T5X_DIR}/t5x/train.py \
10
+ --gin_search_paths=${PROJECT_DIR} \
11
+ --gin_file="base_nl36_pretrain.gin" \
12
+ --gin.MODEL_DIR=\"${MODEL_DIR}\"
tasks.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # adapted from https://huggingface.co/pere/pk-nb-t5x/blob/main/tasks.py
2
+
3
+ import functools
4
+
5
+ import seqio
6
+ import tensorflow as tf
7
+ import t5.data
8
+ from datasets import load_dataset, load_from_disk
9
+ from t5.data import postprocessors
10
+ from t5.data import preprocessors
11
+ from t5.evaluation import metrics
12
+ from seqio import FunctionDataSource, utils
13
+
14
+ TaskRegistry = seqio.TaskRegistry
15
+
16
+ vocabulary = seqio.SentencePieceVocabulary('spiece.model', extra_ids=0)
17
+
18
+ DEFAULT_OUTPUT_FEATURES = {
19
+ "inputs": seqio.Feature(
20
+ vocabulary=vocabulary, add_eos=True,
21
+ required=False),
22
+ "targets": seqio.Feature(
23
+ vocabulary=vocabulary, add_eos=True)
24
+ }
25
+
26
+
27
+ def gen_dataset(split, shuffle=False, seed=None, column="text", dataset=None):
28
+ if shuffle:
29
+ if seed:
30
+ dataset = dataset.shuffle(seed=seed)
31
+ else:
32
+ dataset = dataset.shuffle()
33
+ while True:
34
+ for item in dataset[str(split)]:
35
+ yield item[column]
36
+
37
+
38
+ def dataset_fn(split, shuffle_files, seed=None, dataset=None):
39
+ return tf.data.Dataset.from_generator(
40
+ functools.partial(gen_dataset, split, shuffle_files, seed, dataset=dataset),
41
+ output_signature=tf.TensorSpec(shape=(), dtype=tf.string, name=dataset_name)
42
+ )
43
+
44
+
45
+ @utils.map_over_dataset
46
+ def target_to_key(x, key_map, target_key):
47
+ """Assign the value from the dataset to target_key in key_map"""
48
+ return {**key_map, target_key: x}
49
+
50
+
51
+ # Final pretraining task used in Raffel et al., 2019 adaptated to NCC
52
+ dataset_name = "/researchdisk/lm_training_dataset_full"
53
+ dataset_params = {"from_disk_path": dataset_name}
54
+
55
+ if "from_disk_path" in dataset_params:
56
+ dataset = load_from_disk(dataset_params.get("from_disk_path"))
57
+ else:
58
+ dataset = load_dataset(**dataset_params)
59
+
60
+ dataset_shapes = {"train": dataset["train"].num_rows, "validation": dataset["validation"].num_rows}
61
+ TaskRegistry.add(
62
+ "pretrain_finnish",
63
+ source=seqio.FunctionDataSource(
64
+ dataset_fn=functools.partial(dataset_fn, dataset=dataset),
65
+ splits=("train", "validation"),
66
+ caching_permitted=False,
67
+ num_input_examples=dataset_shapes,
68
+ ),
69
+ preprocessors=[
70
+ functools.partial(
71
+ target_to_key, key_map={
72
+ "inputs": None,
73
+ "targets": None,
74
+ }, target_key="targets"),
75
+ seqio.preprocessors.tokenize,
76
+ # seqio.CacheDatasetPlaceholder(),
77
+ preprocessors.span_corruption,
78
+ seqio.preprocessors.append_eos_after_trim,
79
+ ],
80
+ output_features={"targets": DEFAULT_OUTPUT_FEATURES["targets"]},
81
+ metric_fns=[metrics.accuracy]
82
+ )