aliabd
commited on
Commit
·
c6e7238
1
Parent(s):
c528e7b
full working demo
Browse files- .idea/gpt-neo.iml +8 -0
- .idea/inspectionProfiles/profiles_settings.xml +6 -0
- .idea/modules.xml +8 -0
- CODEOWNERS +1 -0
- Dockerfile +15 -0
- GPTNeo_example_notebook.ipynb +0 -0
- LICENSE +21 -0
- app.py +12 -0
- configs.py +47 -0
- configs/dataset_configs/example.json +8 -0
- configs/dataset_configs/openwebtext2_new_inputs.json +9 -0
- configs/dataset_configs/pile.json +9 -0
- configs/gpt2_small.json +36 -0
- configs/gpt3_13B_256.json +40 -0
- configs/gpt3_13B_256_Pile.json +38 -0
- configs/gpt3_2-7B_256.json +38 -0
- configs/gpt3_6-7B_256.json +36 -0
- configs/gpt3_PAR_small_256.json +36 -0
- configs/gpt3_XL_256_Pile.json +37 -0
- configs/gpt3_large_256.json +39 -0
- configs/gpt3_medium_256.json +36 -0
- configs/gpt3_small_256.json +36 -0
- data/create_tfrecords.py +263 -0
- data/encoders.py +28 -0
- data/train_tokenizer.py +73 -0
- docker-compose.yml +67 -0
- encoders.py +28 -0
- export.py +14 -0
- gradio/demo.py +12 -0
- inputs.py +384 -0
- main.py +257 -0
- model_fns.py +305 -0
- models/activations.py +95 -0
- models/gpt2/gpt2.py +217 -0
- models/layers.py +357 -0
- models/utils.py +124 -0
- optimizers.py +176 -0
- requirements.txt +18 -0
- run_experiment.py +265 -0
- sample.py +218 -0
- tasks.py +116 -0
- test_models.py +180 -0
- utils.py +291 -0
.idea/gpt-neo.iml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<module type="PYTHON_MODULE" version="4">
|
3 |
+
<component name="NewModuleRootManager">
|
4 |
+
<content url="file://$MODULE_DIR$" />
|
5 |
+
<orderEntry type="inheritedJdk" />
|
6 |
+
<orderEntry type="sourceFolder" forTests="false" />
|
7 |
+
</component>
|
8 |
+
</module>
|
.idea/inspectionProfiles/profiles_settings.xml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<component name="InspectionProjectProfileManager">
|
2 |
+
<settings>
|
3 |
+
<option name="USE_PROJECT_PROFILE" value="false" />
|
4 |
+
<version value="1.0" />
|
5 |
+
</settings>
|
6 |
+
</component>
|
.idea/modules.xml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<project version="4">
|
3 |
+
<component name="ProjectModuleManager">
|
4 |
+
<modules>
|
5 |
+
<module fileurl="file://$PROJECT_DIR$/.idea/gpt-neo.iml" filepath="$PROJECT_DIR$/.idea/gpt-neo.iml" />
|
6 |
+
</modules>
|
7 |
+
</component>
|
8 |
+
</project>
|
CODEOWNERS
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
* EleutherAI/pm-gptneo
|
Dockerfile
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM gcr.io/deeplearning-platform-release/tf-cpu.1-15
|
2 |
+
|
3 |
+
WORKDIR /neogpt
|
4 |
+
|
5 |
+
# Make RUN commands use `bash --login`:
|
6 |
+
SHELL ["/bin/bash", "--login", "-c"]
|
7 |
+
ENV DEBIAN_FRONTEND=noninteractive
|
8 |
+
RUN apt-get update -y && apt-get install tmux -y
|
9 |
+
RUN conda install gcc_linux-64 gxx_linux-64 -y
|
10 |
+
ADD requirements.txt .
|
11 |
+
RUN pip install -r requirements.txt
|
12 |
+
RUN apt-get install screen htop -y
|
13 |
+
RUN python -m pip install tensorboard==1.15 cloud_tpu_profiler==1.15
|
14 |
+
|
15 |
+
CMD tmux
|
GPTNeo_example_notebook.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2020 EleutherAI
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
app.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
|
3 |
+
title = "GPT-Neo Demo"
|
4 |
+
description = "demo for GPT-Neo by EleutherAI for text generation. To use it, simply add your text, or click one of the examples to load them. Read more at the links below."
|
5 |
+
article = "<p style='text-align: center'><a href='http://github.com/eleutherai/gpt-neo'>GPT-Neo: Large Scale Autoregressive Language Modeling with Mesh-Tensorflow</a></p>"
|
6 |
+
examples = [
|
7 |
+
['The tower is 324 metres (1,063 ft) tall,'],
|
8 |
+
["The Moon's orbit around Earth has"],
|
9 |
+
["The smooth Borealis basin in the Northern Hemisphere covers 40%"]
|
10 |
+
]
|
11 |
+
|
12 |
+
gr.Interface.load("huggingface/EleutherAI/gpt-neo-2.7B", inputs=gr.inputs.Textbox(lines=5, label="Input Text"),title=title,description=description,article=article, examples=examples).launch()
|
configs.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from pathlib import Path
|
3 |
+
from collections import defaultdict
|
4 |
+
|
5 |
+
DATASETS = {}
|
6 |
+
|
7 |
+
for path in Path("configs/dataset_configs").glob("*.json"):
|
8 |
+
dataset_id = path.stem
|
9 |
+
DATASETS[dataset_id] = json.loads(path.read_text())
|
10 |
+
|
11 |
+
|
12 |
+
def fetch_model_params(model):
|
13 |
+
model_path = model if model.endswith(".json") else f"configs/{model}.json"
|
14 |
+
with open(model_path) as f:
|
15 |
+
params = json.load(f)
|
16 |
+
|
17 |
+
dataset_ids = []
|
18 |
+
for d in params.get("datasets"):
|
19 |
+
if isinstance(d, list):
|
20 |
+
dataset_ids.append(d[0])
|
21 |
+
else:
|
22 |
+
dataset_ids.append(d)
|
23 |
+
no_datasets = params.get("no_dataset", False)
|
24 |
+
assert no_datasets or len(dataset_ids) > 0, "You must specify at least one dataset id in the model config"
|
25 |
+
|
26 |
+
datasets = {}
|
27 |
+
last_dataset = None
|
28 |
+
for dataset_id in dataset_ids:
|
29 |
+
assert dataset_id in DATASETS, f"Dataset '{dataset_id}' was not found under dataset_configs/ folder. Please follow the example.json in that folder."
|
30 |
+
dataset = DATASETS[dataset_id]
|
31 |
+
assert params["n_vocab"] >= dataset["n_vocab"], f"The embedding table size '{params['n_vocab']}' must be greater or equal to the vocab size used to encode the dataset '{dataset_id}' ({dataset['n_vocab']})"
|
32 |
+
datasets[dataset_id] = dataset
|
33 |
+
last_dataset = dataset
|
34 |
+
|
35 |
+
if last_dataset is not None:
|
36 |
+
params["padding_id"] = last_dataset.get("padding_id", 0)
|
37 |
+
params["eos_id"] = last_dataset.get("eos_id", 1)
|
38 |
+
|
39 |
+
params["dataset_configs"] = datasets
|
40 |
+
|
41 |
+
# Set some other parameter defaults
|
42 |
+
params["mlm_training"] = params.get("mlm_training") == True
|
43 |
+
params["causal"] = not params["mlm_training"]
|
44 |
+
|
45 |
+
# Set all other parameter values to default to None
|
46 |
+
params = defaultdict(lambda: None, params)
|
47 |
+
return params
|
configs/dataset_configs/example.json
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"n_vocab": 32768,
|
3 |
+
"path": "./tfrecords/openwebtext_*.tfrecords",
|
4 |
+
"eval_path": "",
|
5 |
+
"tokenizer_path": "./datasets/openwebtext/byte-level-bpe.tokenizer.json",
|
6 |
+
"eos_id": 1,
|
7 |
+
"padding_id": 0
|
8 |
+
}
|
configs/dataset_configs/openwebtext2_new_inputs.json
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"n_vocab": 50257,
|
3 |
+
"path": "gs://neo-datasets/openwebtext2_new_inputs/train/*.tfrecords",
|
4 |
+
"eval_path": "gs://neo-datasets/openwebtext2_new_inputs/eval/*.tfrecords",
|
5 |
+
"tokenizer_is_pretrained": true,
|
6 |
+
"tokenizer_path": "gpt2",
|
7 |
+
"eos_id": 50256,
|
8 |
+
"padding_id": 50257
|
9 |
+
}
|
configs/dataset_configs/pile.json
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"n_vocab": 50257,
|
3 |
+
"path": "gs://neo-datasets/pile/pile_*.tfrecords",
|
4 |
+
"eval_path": "gs://neo-datasets/pile_val.tfrecords",
|
5 |
+
"tokenizer_is_pretrained": true,
|
6 |
+
"tokenizer_path": "gpt2",
|
7 |
+
"eos_id": 50256,
|
8 |
+
"padding_id": 50257
|
9 |
+
}
|
configs/gpt2_small.json
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"n_head": 6,
|
3 |
+
"n_vocab": 50257,
|
4 |
+
"embed_dropout": 0.1,
|
5 |
+
"lr": 0.0006,
|
6 |
+
"lr_decay": "cosine",
|
7 |
+
"warmup_steps": 3000,
|
8 |
+
"beta1": 0.9,
|
9 |
+
"beta2": 0.95,
|
10 |
+
"epsilon": 1e-8,
|
11 |
+
"opt_name": "adam",
|
12 |
+
"weight_decay": 0,
|
13 |
+
"train_batch_size": 512,
|
14 |
+
"attn_dropout": 0.1,
|
15 |
+
"train_steps": 1000000,
|
16 |
+
"lr_decay_end": 300000,
|
17 |
+
"eval_steps": 30,
|
18 |
+
"predict_steps": 0,
|
19 |
+
"res_dropout": 0.1,
|
20 |
+
"eval_batch_size": 128,
|
21 |
+
"predict_batch_size": 8,
|
22 |
+
"iterations": 2500,
|
23 |
+
"n_embd": 768,
|
24 |
+
"datasets": ["openwebtext2_new_inputs"],
|
25 |
+
"model_path": "gs://neo-models/GPT2_SMALL",
|
26 |
+
"n_ctx": 1024,
|
27 |
+
"n_layer": 12,
|
28 |
+
"scale_by_depth": true,
|
29 |
+
"scale_by_in": false,
|
30 |
+
"attention_types" : [[["global"],12]],
|
31 |
+
"activation_function": "gelu",
|
32 |
+
"mesh_shape": "all:64",
|
33 |
+
"layout": "batch:all",
|
34 |
+
"recompute_grad": false,
|
35 |
+
"gradient_clipping": 1.0
|
36 |
+
}
|
configs/gpt3_13B_256.json
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"n_head": 40,
|
3 |
+
"n_vocab": 50257,
|
4 |
+
"embed_dropout": 0,
|
5 |
+
"lr": 0.0001,
|
6 |
+
"lr_decay": "cosine",
|
7 |
+
"warmup_steps": 3000,
|
8 |
+
"beta1": 0.9,
|
9 |
+
"beta2": 0.95,
|
10 |
+
"epsilon": 1e-8,
|
11 |
+
"ada_epsilon1": 1e-30,
|
12 |
+
"ada_epsilon2": 1e-3,
|
13 |
+
"opt_name": "adam",
|
14 |
+
"weight_decay": 0.10,
|
15 |
+
"train_batch_size": 1024,
|
16 |
+
"attn_dropout": 0,
|
17 |
+
"train_steps": 143075,
|
18 |
+
"eval_steps": 0,
|
19 |
+
"predict_steps": 1,
|
20 |
+
"res_dropout": 0,
|
21 |
+
"eval_batch_size": 128,
|
22 |
+
"predict_batch_size": 1,
|
23 |
+
"iterations": 500,
|
24 |
+
"n_embd": 5120,
|
25 |
+
"datasets": [["openwebtext-documents", 25, "documents_random", 1.0]],
|
26 |
+
"model_path": "gs://neo-models/GPT3_13B",
|
27 |
+
"n_ctx": 2048,
|
28 |
+
"n_layer": 40,
|
29 |
+
"scale_by_depth": true,
|
30 |
+
"scale_by_in": false,
|
31 |
+
"attention_types" : [[["global", "local"],20]],
|
32 |
+
"mesh_shape": "x:16,y:16",
|
33 |
+
"layout": "batch:x,embd:y,memory_length:y",
|
34 |
+
"activation_function": "gelu",
|
35 |
+
"recompute_grad": true,
|
36 |
+
"gradient_clipping": 1.0,
|
37 |
+
"tokens_per_mb_per_replica": 2048,
|
38 |
+
"precision": "bfloat16"
|
39 |
+
}
|
40 |
+
|
configs/gpt3_13B_256_Pile.json
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
{
|
3 |
+
"n_head": 40,
|
4 |
+
"n_vocab": 50257,
|
5 |
+
"embed_dropout": 0,
|
6 |
+
"lr": 0.0001,
|
7 |
+
"lr_decay": "cosine",
|
8 |
+
"warmup_steps": 3000,
|
9 |
+
"beta1": 0.9,
|
10 |
+
"beta2": 0.95,
|
11 |
+
"epsilon": 1e-8,
|
12 |
+
"opt_name": "adam",
|
13 |
+
"weight_decay": 0.1,
|
14 |
+
"train_batch_size": 1024,
|
15 |
+
"attn_dropout": 0,
|
16 |
+
"train_steps": 286150,
|
17 |
+
"eval_steps": 10,
|
18 |
+
"predict_steps": 1,
|
19 |
+
"res_dropout": 0,
|
20 |
+
"eval_batch_size": 512,
|
21 |
+
"predict_batch_size": 1,
|
22 |
+
"iterations": 500,
|
23 |
+
"n_embd": 5120,
|
24 |
+
"datasets": [["pile", 25, "documents_random", 1.0]],
|
25 |
+
"model_path": "gs://neo-models/GPT3_13B_Pile",
|
26 |
+
"n_ctx": 2048,
|
27 |
+
"n_layer": 40,
|
28 |
+
"scale_by_depth": true,
|
29 |
+
"scale_by_in": false,
|
30 |
+
"attention_types" : [[["global"],40]],
|
31 |
+
"mesh_shape": "x:16,y:16",
|
32 |
+
"layout": "batch:x,memory_length:y,embd:y",
|
33 |
+
"activation_function": "gelu",
|
34 |
+
"recompute_grad": true,
|
35 |
+
"gradient_clipping": 1.0,
|
36 |
+
"tokens_per_mb_per_replica": 2048,
|
37 |
+
"precision": "bfloat16"
|
38 |
+
}
|
configs/gpt3_2-7B_256.json
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"n_head": 32,
|
3 |
+
"n_vocab": 50257,
|
4 |
+
"embed_dropout": 0,
|
5 |
+
"lr": 0.00016,
|
6 |
+
"lr_decay": "cosine",
|
7 |
+
"warmup_steps": 3000,
|
8 |
+
"beta1": 0.9,
|
9 |
+
"beta2": 0.95,
|
10 |
+
"epsilon": 1e-8,
|
11 |
+
"ada_epsilon1": 1e-30,
|
12 |
+
"ada_epsilon2": 1e-3,
|
13 |
+
"opt_name": "adam",
|
14 |
+
"weight_decay": 0.10,
|
15 |
+
"train_batch_size": 512,
|
16 |
+
"attn_dropout": 0,
|
17 |
+
"train_steps": 286150,
|
18 |
+
"eval_steps": 0,
|
19 |
+
"predict_steps": 1,
|
20 |
+
"res_dropout": 0,
|
21 |
+
"eval_batch_size": 128,
|
22 |
+
"predict_batch_size": 1,
|
23 |
+
"iterations": 500,
|
24 |
+
"n_embd": 2560,
|
25 |
+
"datasets": [["openwebtext-documents", 25, "documents_random", 1.0]],
|
26 |
+
"model_path": "gs://neo-models/GPT3_2-7B",
|
27 |
+
"n_ctx": 2048,
|
28 |
+
"n_layer": 32,
|
29 |
+
"scale_by_depth": true,
|
30 |
+
"scale_by_in": false,
|
31 |
+
"attention_types" : [[["global"],32]],
|
32 |
+
"mesh_shape": "x:128,y:2",
|
33 |
+
"layout": "embd:y,batch:x",
|
34 |
+
"activation_function": "gelu",
|
35 |
+
"recompute_grad": true,
|
36 |
+
"gradient_clipping": 1.0
|
37 |
+
}
|
38 |
+
|
configs/gpt3_6-7B_256.json
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"n_head": 32,
|
3 |
+
"n_vocab": 50257,
|
4 |
+
"embed_dropout": 0,
|
5 |
+
"lr": 0.00012,
|
6 |
+
"lr_decay": "cosine",
|
7 |
+
"warmup_steps": 3000,
|
8 |
+
"beta1": 0.9,
|
9 |
+
"beta2": 0.95,
|
10 |
+
"epsilon": 1e-8,
|
11 |
+
"opt_name": "adam",
|
12 |
+
"weight_decay": 0.10,
|
13 |
+
"train_batch_size": 1024,
|
14 |
+
"attn_dropout": 0,
|
15 |
+
"train_steps": 143075,
|
16 |
+
"eval_steps": 0,
|
17 |
+
"predict_steps": 1,
|
18 |
+
"res_dropout": 0,
|
19 |
+
"eval_batch_size": 128,
|
20 |
+
"predict_batch_size": 1,
|
21 |
+
"iterations": 500,
|
22 |
+
"n_embd": 4096,
|
23 |
+
"datasets": [["openwebtext-documents", 25, "documents_random", 1.0]],
|
24 |
+
"model_path": "gs://neo-models/GPT3_6-7B",
|
25 |
+
"n_ctx": 2048,
|
26 |
+
"n_layer": 32,
|
27 |
+
"scale_by_depth": true,
|
28 |
+
"scale_by_in": false,
|
29 |
+
"attention_types" : [[["global"],32]],
|
30 |
+
"mesh_shape": "x:128,y:2",
|
31 |
+
"layout": "embd:y,batch:x",
|
32 |
+
"activation_function": "gelu",
|
33 |
+
"recompute_grad": true,
|
34 |
+
"gradient_clipping": 1.0
|
35 |
+
}
|
36 |
+
|
configs/gpt3_PAR_small_256.json
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"n_head": 12,
|
3 |
+
"n_vocab": 50304,
|
4 |
+
"embed_dropout": 0,
|
5 |
+
"lr": 0.0006,
|
6 |
+
"lr_decay": "cosine",
|
7 |
+
"warmup_steps": 3000,
|
8 |
+
"beta1": 0.9,
|
9 |
+
"beta2": 0.95,
|
10 |
+
"epsilon": 1e-8,
|
11 |
+
"opt_name": "adam",
|
12 |
+
"weight_decay": 0.10,
|
13 |
+
"train_batch_size": 256,
|
14 |
+
"attn_dropout": 0,
|
15 |
+
"train_steps": 572300,
|
16 |
+
"eval_steps": 0,
|
17 |
+
"predict_steps": 1,
|
18 |
+
"res_dropout": 0,
|
19 |
+
"eval_batch_size": 64,
|
20 |
+
"predict_batch_size": 1,
|
21 |
+
"iterations": 1000,
|
22 |
+
"n_embd": 768,
|
23 |
+
"datasets": [["openwebtext-documents", 25, "documents_random", 1.0]],
|
24 |
+
"model_path": "gs://neo-models/GPT3_PAR_SMALL",
|
25 |
+
"n_ctx": 2048,
|
26 |
+
"n_layer": 19,
|
27 |
+
"scale_by_depth": true,
|
28 |
+
"scale_by_in": false,
|
29 |
+
"attention_types": [[["global", "none", "none"],5], [["none"], 4]],
|
30 |
+
"mesh_shape": "x:64,y:4",
|
31 |
+
"layout": "batch:x,heads:y,vocab:y,intermediate_expanded:y",
|
32 |
+
"activation_function": "gelu",
|
33 |
+
"recompute_grad": false,
|
34 |
+
"gradient_clipping": 1.0
|
35 |
+
}
|
36 |
+
|
configs/gpt3_XL_256_Pile.json
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"n_head": 32,
|
3 |
+
"n_vocab": 50257,
|
4 |
+
"embed_dropout": 0,
|
5 |
+
"lr": 0.0002,
|
6 |
+
"lr_decay": "cosine",
|
7 |
+
"warmup_steps": 3000,
|
8 |
+
"beta1": 0.9,
|
9 |
+
"beta2": 0.95,
|
10 |
+
"epsilon": 1e-8,
|
11 |
+
"opt_name": "adam",
|
12 |
+
"weight_decay": 0.1,
|
13 |
+
"train_batch_size": 512,
|
14 |
+
"attn_dropout": 0,
|
15 |
+
"train_steps": 286150,
|
16 |
+
"eval_steps": 10,
|
17 |
+
"predict_steps": 1,
|
18 |
+
"res_dropout": 0,
|
19 |
+
"eval_batch_size": 512,
|
20 |
+
"predict_batch_size": 1,
|
21 |
+
"iterations": 500,
|
22 |
+
"n_embd": 2048,
|
23 |
+
"datasets": [["pile", 25, "documents_random", 1.0]],
|
24 |
+
"model_path": "gs://neo-models/GPT3_XL_Pile",
|
25 |
+
"n_ctx": 2048,
|
26 |
+
"n_layer": 24,
|
27 |
+
"scale_by_depth": true,
|
28 |
+
"scale_by_in": false,
|
29 |
+
"attention_types" : [[["global"],24]],
|
30 |
+
"mesh_shape": "x:128,y:2",
|
31 |
+
"layout": "batch:x,memory_length:y,embd:y",
|
32 |
+
"activation_function": "gelu",
|
33 |
+
"recompute_grad": true,
|
34 |
+
"gradient_clipping": 1.0,
|
35 |
+
"tokens_per_mb_per_replica": 2048,
|
36 |
+
"precision": "bfloat16"
|
37 |
+
}
|
configs/gpt3_large_256.json
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"n_head": 16,
|
3 |
+
"n_vocab": 50304,
|
4 |
+
"embed_dropout": 0,
|
5 |
+
"lr": 0.00025,
|
6 |
+
"lr_decay": "cosine",
|
7 |
+
"warmup_steps": 3000,
|
8 |
+
"beta1": 0.9,
|
9 |
+
"beta2": 0.95,
|
10 |
+
"epsilon": 1e-8,
|
11 |
+
"ada_epsilon1": 1e-30,
|
12 |
+
"ada_epsilon2": 1e-3,
|
13 |
+
"opt_name": "adam",
|
14 |
+
"weight_decay": 0.10,
|
15 |
+
"train_batch_size": 256,
|
16 |
+
"attn_dropout": 0,
|
17 |
+
"train_steps": 572300,
|
18 |
+
"eval_steps": 0,
|
19 |
+
"predict_steps": 1,
|
20 |
+
"res_dropout": 0,
|
21 |
+
"eval_batch_size": 64,
|
22 |
+
"predict_batch_size": 1,
|
23 |
+
"iterations": 2500,
|
24 |
+
"n_embd": 1536,
|
25 |
+
"datasets": [["openwebtext-documents", 25, "documents_random", 1.0]],
|
26 |
+
"model_path": "gs://neo-models/GPT3_LARGE",
|
27 |
+
"n_ctx": 2048,
|
28 |
+
"n_layer": 24,
|
29 |
+
"scale_by_depth": true,
|
30 |
+
"scale_by_in": false,
|
31 |
+
"attention_types" : [[["global"],24]],
|
32 |
+
"mesh_shape": "x:64,y:4",
|
33 |
+
"layout": "batch:x,vocab:y,heads:y",
|
34 |
+
"activation_function": "gelu",
|
35 |
+
"recompute_grad": true,
|
36 |
+
"gradient_clipping": 1.0,
|
37 |
+
"tokens_per_mb_per_replica": 2048
|
38 |
+
}
|
39 |
+
|
configs/gpt3_medium_256.json
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"n_head": 16,
|
3 |
+
"n_vocab": 50304,
|
4 |
+
"embed_dropout": 0,
|
5 |
+
"lr": 0.0003,
|
6 |
+
"lr_decay": "cosine",
|
7 |
+
"warmup_steps": 3000,
|
8 |
+
"beta1": 0.9,
|
9 |
+
"beta2": 0.95,
|
10 |
+
"epsilon": 1e-8,
|
11 |
+
"opt_name": "adam",
|
12 |
+
"weight_decay": 0.10,
|
13 |
+
"train_batch_size": 256,
|
14 |
+
"attn_dropout": 0,
|
15 |
+
"train_steps": 572300,
|
16 |
+
"eval_steps": 0,
|
17 |
+
"predict_steps": 1,
|
18 |
+
"res_dropout": 0,
|
19 |
+
"eval_batch_size": 64,
|
20 |
+
"predict_batch_size": 1,
|
21 |
+
"iterations": 2500,
|
22 |
+
"n_embd": 1024,
|
23 |
+
"datasets": [["openwebtext-documents", 25, "documents_random", 1.0]],
|
24 |
+
"model_path": "gs://neo-models/GPT3_MEDIUM",
|
25 |
+
"n_ctx": 2048,
|
26 |
+
"n_layer": 24,
|
27 |
+
"scale_by_depth": true,
|
28 |
+
"scale_by_in": false,
|
29 |
+
"attention_types" : [[["global"],24]],
|
30 |
+
"mesh_shape": "x:64,y:4",
|
31 |
+
"layout": "batch:x,heads:y,vocab:y",
|
32 |
+
"activation_function": "gelu",
|
33 |
+
"recompute_grad": false,
|
34 |
+
"gradient_clipping": 1.0
|
35 |
+
}
|
36 |
+
|
configs/gpt3_small_256.json
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"n_head": 12,
|
3 |
+
"n_vocab": 50304,
|
4 |
+
"embed_dropout": 0,
|
5 |
+
"lr": 0.0006,
|
6 |
+
"lr_decay": "cosine",
|
7 |
+
"warmup_steps": 3000,
|
8 |
+
"beta1": 0.9,
|
9 |
+
"beta2": 0.95,
|
10 |
+
"epsilon": 1e-8,
|
11 |
+
"opt_name": "adam",
|
12 |
+
"weight_decay": 0.10,
|
13 |
+
"train_batch_size": 256,
|
14 |
+
"attn_dropout": 0,
|
15 |
+
"train_steps": 572300,
|
16 |
+
"eval_steps": 0,
|
17 |
+
"predict_steps": 1,
|
18 |
+
"res_dropout": 0,
|
19 |
+
"eval_batch_size": 64,
|
20 |
+
"predict_batch_size": 1,
|
21 |
+
"iterations": 2500,
|
22 |
+
"n_embd": 768,
|
23 |
+
"datasets": [["openwebtext-documents", 25, "documents_random", 1.0]],
|
24 |
+
"model_path": "gs://neo-models/GPT3_SMALL",
|
25 |
+
"n_ctx": 2048,
|
26 |
+
"n_layer": 12,
|
27 |
+
"scale_by_depth": true,
|
28 |
+
"scale_by_in": false,
|
29 |
+
"attention_types": [[["global"],12]],
|
30 |
+
"mesh_shape": "x:64,y:4",
|
31 |
+
"layout": "batch:x,heads:y,vocab:y,intermediate_expanded:y",
|
32 |
+
"activation_function": "gelu",
|
33 |
+
"recompute_grad": false,
|
34 |
+
"gradient_clipping": 1.0
|
35 |
+
}
|
36 |
+
|
data/create_tfrecords.py
ADDED
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
import ftfy
|
6 |
+
import tensorflow as tf
|
7 |
+
from lm_dataformat import Reader
|
8 |
+
from tokenizers import Tokenizer
|
9 |
+
from transformers import GPT2TokenizerFast
|
10 |
+
from tqdm import tqdm
|
11 |
+
import logging
|
12 |
+
from multiprocessing import Pool, cpu_count
|
13 |
+
from itertools import repeat
|
14 |
+
import re
|
15 |
+
|
16 |
+
logging.getLogger("transformers").setLevel(logging.ERROR)
|
17 |
+
|
18 |
+
parser = argparse.ArgumentParser()
|
19 |
+
parser.add_argument("--input_dir", type=str, help="Path to where your files are located. Files ending in .zst are "
|
20 |
+
"treated as archives, all others as raw text.")
|
21 |
+
parser.add_argument("--files_per", type=int, default=100000, help="Text files per tfrecord")
|
22 |
+
parser.add_argument("--name", type=str, default="openwebtext",
|
23 |
+
help="Name of output files will be name_i.tfrecords where i is the number of the file")
|
24 |
+
parser.add_argument("--output_dir", type=str, default="./tfrecords", help="Where to put tfrecords")
|
25 |
+
parser.add_argument("--encoder_path", type=str,
|
26 |
+
help="Path to encoder files, or leave unspecified to use GPT2 tokenizer")
|
27 |
+
parser.add_argument("--minimum_size", type=int, default=100, help="Minimum size a document has to be to be included")
|
28 |
+
parser.add_argument("--ftfy", action="store_false", help="normalize with ftfy")
|
29 |
+
parser.add_argument("--wikitext-detokenize", action="store_false", help="use wikitext detokenizer")
|
30 |
+
parser.add_argument("--separator", nargs="+", type=int, default=[50256],
|
31 |
+
help="separator to place between files in chunk mode")
|
32 |
+
parser.add_argument("--chunk_size", type=int, default=2048, help="How big a chunk should be in chunk mode. "
|
33 |
+
"Should equal your model's context size")
|
34 |
+
parser.add_argument("--write_dataset_config", action="store_true", help="Write the dataset config file on completion")
|
35 |
+
parser.add_argument("--processes", type=int, default=0, help="Number of processes to use. Defaults to cpu count.")
|
36 |
+
|
37 |
+
args = parser.parse_args()
|
38 |
+
if not args.output_dir.endswith("/"):
|
39 |
+
args.output_dir = args.output_dir + "/"
|
40 |
+
if not args.input_dir.endswith("/"):
|
41 |
+
args.input_dir = args.input_dir + "/"
|
42 |
+
assert len(args.separator) == 1
|
43 |
+
|
44 |
+
|
45 |
+
def wikitext_detokenizer(string):
|
46 |
+
# contractions
|
47 |
+
string = string.replace("s '", "s'")
|
48 |
+
string = re.sub(r"/' [0-9]/", r"/'[0-9]/", string)
|
49 |
+
# number separators
|
50 |
+
string = string.replace(" @-@ ", "-")
|
51 |
+
string = string.replace(" @,@ ", ",")
|
52 |
+
string = string.replace(" @.@ ", ".")
|
53 |
+
# punctuation
|
54 |
+
string = string.replace(" : ", ": ")
|
55 |
+
string = string.replace(" ; ", "; ")
|
56 |
+
string = string.replace(" . ", ". ")
|
57 |
+
string = string.replace(" ! ", "! ")
|
58 |
+
string = string.replace(" ? ", "? ")
|
59 |
+
string = string.replace(" , ", ", ")
|
60 |
+
# double brackets
|
61 |
+
string = re.sub(r"\(\s*([^\)]*?)\s*\)", r"(\1)", string)
|
62 |
+
string = re.sub(r"\[\s*([^\]]*?)\s*\]", r"[\1]", string)
|
63 |
+
string = re.sub(r"{\s*([^}]*?)\s*}", r"{\1}", string)
|
64 |
+
string = re.sub(r"\"\s*([^\"]*?)\s*\"", r'"\1"', string)
|
65 |
+
string = re.sub(r"'\s*([^']*?)\s*'", r"'\1'", string)
|
66 |
+
# miscellaneous
|
67 |
+
string = string.replace("= = = =", "====")
|
68 |
+
string = string.replace("= = =", "===")
|
69 |
+
string = string.replace("= =", "==")
|
70 |
+
string = string.replace(" " + chr(176) + " ", chr(176))
|
71 |
+
string = string.replace(" \n", "\n")
|
72 |
+
string = string.replace("\n ", "\n")
|
73 |
+
string = string.replace(" N ", " 1 ")
|
74 |
+
string = string.replace(" 's", "'s")
|
75 |
+
|
76 |
+
return string
|
77 |
+
|
78 |
+
|
79 |
+
def _int64_feature(value):
|
80 |
+
"""
|
81 |
+
Returns an int64_list from a bool / enum / int / uint.
|
82 |
+
"""
|
83 |
+
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
|
84 |
+
|
85 |
+
|
86 |
+
def write_to_file(writer, data):
|
87 |
+
"""
|
88 |
+
writes data to tfrecord file
|
89 |
+
"""
|
90 |
+
feature = {
|
91 |
+
"text": _int64_feature(data)
|
92 |
+
}
|
93 |
+
tf_example = tf.train.Example(features=tf.train.Features(feature=feature))
|
94 |
+
writer.write(tf_example.SerializeToString())
|
95 |
+
|
96 |
+
|
97 |
+
def get_tokenizer(args):
|
98 |
+
if args.encoder_path is None:
|
99 |
+
return GPT2TokenizerFast.from_pretrained('gpt2')
|
100 |
+
else:
|
101 |
+
return Tokenizer.from_file(args.encoder_path)
|
102 |
+
|
103 |
+
|
104 |
+
def split_list(l, n):
|
105 |
+
# splits list/string into n size chunks
|
106 |
+
return [l[i:i + n] for i in range(0, len(l), n)]
|
107 |
+
|
108 |
+
|
109 |
+
def archive_to_tokens(f, encoder, args):
|
110 |
+
# Generator that yields the contents of the files in an archive
|
111 |
+
# if data_to_prepend is not None, prepend data_to_prepend + a EOS separator to the encoded data
|
112 |
+
reader = Reader(f)
|
113 |
+
for doc in reader.stream_data(threaded=False):
|
114 |
+
if args.ftfy: # fix text with ftfy if specified
|
115 |
+
doc = ftfy.fix_text(doc, normalization='NFKC')
|
116 |
+
if args.wikitext_detokenize:
|
117 |
+
doc = wikitext_detokenizer(doc)
|
118 |
+
doc = encoder.encode(doc) + args.separator # read document from lmd and append separator token
|
119 |
+
yield split_list(doc, args.chunk_size) # split into n_ctx + 1 size chunks
|
120 |
+
|
121 |
+
|
122 |
+
def write_files(files, files_per, output_dir, out_name, start_no, write_remainder=False, process_no=None):
|
123 |
+
# writes a list of files to .tfrecords
|
124 |
+
if files == None:
|
125 |
+
return
|
126 |
+
chunks = split_list(files, files_per)
|
127 |
+
|
128 |
+
if len(chunks[-1]) != files_per and not write_remainder: # pop the last file if it's length != files per
|
129 |
+
remainder = chunks.pop(-1)
|
130 |
+
else:
|
131 |
+
remainder = None # assuming files = remainder from an old chunk here
|
132 |
+
files_per = len(chunks[-1])
|
133 |
+
|
134 |
+
for files in chunks:
|
135 |
+
fp = f"{output_dir}/{out_name}_{start_no}"
|
136 |
+
if process_no is not None:
|
137 |
+
fp += f"_{process_no}"
|
138 |
+
fp += f"_{files_per}" # add number of files in tfrecord to end of fp
|
139 |
+
fp += ".tfrecords"
|
140 |
+
with tf.io.TFRecordWriter(fp) as writer:
|
141 |
+
for f in files:
|
142 |
+
write_to_file(writer, f)
|
143 |
+
start_no += 1
|
144 |
+
return start_no, remainder
|
145 |
+
|
146 |
+
|
147 |
+
def get_files(input_dir, filetypes=None):
|
148 |
+
# gets all files of <filetypes> in input_dir
|
149 |
+
if filetypes == None:
|
150 |
+
filetypes = ["jsonl.zst", ".txt", ".xz", ".tar.gz"]
|
151 |
+
files = [list(Path(input_dir).glob(f"*{ft}")) for ft in filetypes]
|
152 |
+
return [str(item) for sublist in files for item in sublist] # flatten list of list -> list and stringify Paths
|
153 |
+
|
154 |
+
|
155 |
+
def read_checkpoint(checkpoint_path, resume_from_checkpoint=True):
|
156 |
+
# init checkpointing
|
157 |
+
if resume_from_checkpoint and os.path.isfile(checkpoint_path):
|
158 |
+
try:
|
159 |
+
resume_files_processed, tfrecord_count = [int(i) for i in open(checkpoint_path, "r").read().split(", ")]
|
160 |
+
print(f"\nResuming from tfrecord no. {tfrecord_count} / file no. {resume_files_processed}")
|
161 |
+
return resume_files_processed, tfrecord_count
|
162 |
+
except:
|
163 |
+
pass
|
164 |
+
return 0, 0
|
165 |
+
|
166 |
+
|
167 |
+
def create_tfrecords(params, write_remainder=True, write_every_n_files=1, save_checkpoints=False,
|
168 |
+
resume_from_checkpoint=False, display_pbar=False):
|
169 |
+
# iterates through files in input_dir, splitting into <args.chunk_size> chunks and saving a tfrecords file every <args.files_per> chunks.
|
170 |
+
files, args, process_no = params
|
171 |
+
enc = get_tokenizer(args) # get tokenizer
|
172 |
+
|
173 |
+
# init metadata
|
174 |
+
discarded_files = 0
|
175 |
+
files_processed = 0
|
176 |
+
pbar = tqdm(desc=f"Writing TFRecord Files to {args.output_dir}. Parsed 0 input files. files_written ",
|
177 |
+
disable=not display_pbar)
|
178 |
+
checkpoint_path = f"{args.output_dir}/checkpoint.txt"
|
179 |
+
resume_files_processed, tfrecord_count = read_checkpoint(checkpoint_path, resume_from_checkpoint)
|
180 |
+
|
181 |
+
data_to_prepend = []
|
182 |
+
tokenized_files_array = []
|
183 |
+
|
184 |
+
for f in files:
|
185 |
+
for tokenized_files in archive_to_tokens(f, enc, args):
|
186 |
+
files_processed += 1
|
187 |
+
if files_processed < resume_files_processed:
|
188 |
+
continue # resume from checkpoint
|
189 |
+
|
190 |
+
# if the last chunk < chunk size, but > minimum_size, take it and append it to the beginning of the next file
|
191 |
+
n_tokens = len(tokenized_files[-1])
|
192 |
+
if n_tokens < args.chunk_size:
|
193 |
+
data = tokenized_files.pop(-1)
|
194 |
+
if n_tokens >= args.minimum_size:
|
195 |
+
data_to_prepend.extend(data)
|
196 |
+
else:
|
197 |
+
discarded_files += 1
|
198 |
+
|
199 |
+
if len(data_to_prepend) >= args.chunk_size:
|
200 |
+
# if length of data_to_prepend becomes greater than chunk size, add concatted files to tokenized files
|
201 |
+
tokenized_files_array.append(data_to_prepend[:args.chunk_size])
|
202 |
+
data_to_prepend = data_to_prepend[args.chunk_size:]
|
203 |
+
# add tokenized files > chunk size to main array
|
204 |
+
tokenized_files_array.extend(tokenized_files)
|
205 |
+
|
206 |
+
if len(tokenized_files_array) >= args.files_per * write_every_n_files: # write every n files
|
207 |
+
_tfrecord_count, remainder = write_files(tokenized_files_array, files_per=args.files_per,
|
208 |
+
output_dir=args.output_dir, out_name=args.name,
|
209 |
+
start_no=tfrecord_count, process_no=process_no)
|
210 |
+
pbar.update(_tfrecord_count - tfrecord_count) # update progress bar
|
211 |
+
pbar.set_description(
|
212 |
+
f"Writing TFRecord Files to {args.output_dir}. Parsed {files_processed} input files. files_written ")
|
213 |
+
tfrecord_count = _tfrecord_count
|
214 |
+
tokenized_files_array = remainder if remainder is not None else [] # add remaining files to next chunk
|
215 |
+
with open(checkpoint_path, "w") as checkpoint_file:
|
216 |
+
checkpoint_file.write(f"{files_processed}, {tfrecord_count}")
|
217 |
+
|
218 |
+
if len(tokenized_files_array) >= args.files_per: # also write at end
|
219 |
+
_tfrecord_count, remainder = write_files(tokenized_files_array, files_per=args.files_per,
|
220 |
+
output_dir=args.output_dir, out_name=args.name,
|
221 |
+
start_no=tfrecord_count, process_no=process_no)
|
222 |
+
pbar.update(_tfrecord_count - tfrecord_count)
|
223 |
+
pbar.set_description(
|
224 |
+
f"Writing TFRecord Files to {args.output_dir}. Parsed {files_processed} input files. files_written ")
|
225 |
+
tfrecord_count = _tfrecord_count
|
226 |
+
with open(checkpoint_path, "w") as checkpoint_file:
|
227 |
+
checkpoint_file.write(f"{files_processed}, {tfrecord_count}")
|
228 |
+
else:
|
229 |
+
remainder = tokenized_files_array # add remaining to remainder
|
230 |
+
|
231 |
+
if write_remainder:
|
232 |
+
# write out the remaining files even if there's less than files_per
|
233 |
+
write_files(remainder, files_per=args.files_per, output_dir=args.output_dir, out_name=args.name,
|
234 |
+
start_no=tfrecord_count, write_remainder=True)
|
235 |
+
|
236 |
+
successful_files = files_processed - discarded_files
|
237 |
+
return {"discarded": discarded_files, "processed": files_processed, "successful": successful_files}
|
238 |
+
|
239 |
+
|
240 |
+
def create_tfrecords_mp(files, args):
|
241 |
+
files = split_list(files, len(files) // args.processes)
|
242 |
+
with Pool(processes=args.processes) as pool:
|
243 |
+
pbar = tqdm(pool.imap(create_tfrecords, zip(files, repeat(args), range(len(files)))))
|
244 |
+
meta = {"discarded": 0, "processed": 0, "successful": 0}
|
245 |
+
for results in pbar:
|
246 |
+
pbar.update()
|
247 |
+
for k, v in results.items():
|
248 |
+
meta[k] += v # update metadata
|
249 |
+
return meta
|
250 |
+
|
251 |
+
|
252 |
+
if __name__ == "__main__":
|
253 |
+
os.makedirs(args.output_dir, exist_ok=True) # make output dir if it doesn't exist
|
254 |
+
files = get_files(args.input_dir)
|
255 |
+
args.chunk_size += 1 # we shift the data by 1 to the right for targets, so increment the chunk size here
|
256 |
+
|
257 |
+
if args.processes == 0:
|
258 |
+
args.processes = cpu_count()
|
259 |
+
if args.processes > 1:
|
260 |
+
results = create_tfrecords_mp(files, args)
|
261 |
+
else:
|
262 |
+
results = create_tfrecords((files, args, 0), display_pbar=True)
|
263 |
+
print(results)
|
data/encoders.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from tokenizers import Tokenizer
|
2 |
+
from transformers import GPT2Tokenizer, GPT2TokenizerFast
|
3 |
+
|
4 |
+
def fetch_encoder(params):
|
5 |
+
no_dataset = params.get('no_dataset', False)
|
6 |
+
if no_dataset:
|
7 |
+
return None
|
8 |
+
|
9 |
+
dataset = next(iter(params['dataset_configs'].values())) # Get the first value from the dict
|
10 |
+
path = dataset["tokenizer_path"]
|
11 |
+
is_pretrained = dataset.get("tokenizer_is_pretrained", False)
|
12 |
+
|
13 |
+
if is_pretrained:
|
14 |
+
tok = GPT2TokenizerFast.from_pretrained(path)
|
15 |
+
|
16 |
+
# Will add a padding token id of 50257 at run-time
|
17 |
+
tok.add_special_tokens({'pad_token': '<|padding|>'})
|
18 |
+
return tok
|
19 |
+
|
20 |
+
return Tokenizer.from_file(path)
|
21 |
+
|
22 |
+
|
23 |
+
# GPT2Tokenizer and Tokenizer have different ways of fetching token ids
|
24 |
+
def encode(encoder, text):
|
25 |
+
result = encoder.encode(text)
|
26 |
+
if isinstance(result, list):
|
27 |
+
return result
|
28 |
+
return result.ids
|
data/train_tokenizer.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
import argparse
|
4 |
+
import shutil
|
5 |
+
from glob import glob
|
6 |
+
from pathlib import Path
|
7 |
+
|
8 |
+
from lm_dataformat import Reader
|
9 |
+
from tokenizers import (Tokenizer, decoders, models, pre_tokenizers,
|
10 |
+
processors, trainers)
|
11 |
+
from tokenizers.normalizers import NFKC
|
12 |
+
from tqdm import tqdm
|
13 |
+
|
14 |
+
# parser
|
15 |
+
|
16 |
+
parser = argparse.ArgumentParser()
|
17 |
+
parser.add_argument("--base_dir", type=str, help="Path to where your files are located. Files ending in .zst are treated as \
|
18 |
+
archives, all others as raw text.")
|
19 |
+
parser.add_argument("--output_dir", type=str, default="tokenizers", help="Where to put the tokenizer")
|
20 |
+
parser.add_argument("--file_type", type=str, choices=["xz", "txt"], default="xz", help="Extension of file to parse")
|
21 |
+
parser.add_argument("--vocab_size", type=int, help="Size of vocabulary", required = True)
|
22 |
+
args = parser.parse_args()
|
23 |
+
|
24 |
+
# main script
|
25 |
+
|
26 |
+
data_path = Path(args.base_dir)
|
27 |
+
archives = glob(str(data_path / f"*.{args.file_type}"))
|
28 |
+
|
29 |
+
out_path = Path(args.output_dir)
|
30 |
+
|
31 |
+
if os.path.exists(out_path):
|
32 |
+
shutil.rmtree(out_path)
|
33 |
+
|
34 |
+
if not out_path.is_dir():
|
35 |
+
out_path.mkdir()
|
36 |
+
|
37 |
+
for arch in tqdm(archives):
|
38 |
+
name = os.path.basename(arch).split(".")[0] + ".txt"
|
39 |
+
fp = out_path / name
|
40 |
+
|
41 |
+
if args.file_type == 'xz':
|
42 |
+
g = Reader(arch).stream_data()
|
43 |
+
|
44 |
+
with open(fp, "w") as f:
|
45 |
+
for s in g:
|
46 |
+
f.write(s)
|
47 |
+
f.write("\n\n")
|
48 |
+
elif args.file_type == 'txt':
|
49 |
+
shutil.copyfile(str(arch), str(fp))
|
50 |
+
|
51 |
+
data_files = glob(str(out_path / "*.txt"))
|
52 |
+
data_files = random.sample(data_files, int(0.2 * len(data_files)))
|
53 |
+
|
54 |
+
assert len(data_files) > 0, 'No data files found'
|
55 |
+
|
56 |
+
# Initialize a tokenizer
|
57 |
+
tokenizer = Tokenizer(models.BPE())
|
58 |
+
|
59 |
+
# Customize pre-tokenization and decoding
|
60 |
+
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=True)
|
61 |
+
tokenizer.decoder = decoders.ByteLevel()
|
62 |
+
tokenizer.post_processor = processors.ByteLevel(trim_offsets=True)
|
63 |
+
tokenizer.normalizer = NFKC()
|
64 |
+
|
65 |
+
# And then train
|
66 |
+
trainer = trainers.BpeTrainer(vocab_size=args.vocab_size, min_frequency=2, special_tokens=["<|endoftext|>", "<|padding|>"])
|
67 |
+
tokenizer.train(trainer, data_files)
|
68 |
+
|
69 |
+
# And Save it
|
70 |
+
tokenizer_path = out_path / "byte-level-bpe.tokenizer.json"
|
71 |
+
tokenizer.save(str(tokenizer_path), pretty=True)
|
72 |
+
|
73 |
+
print(f'tokenizer saved at {str(tokenizer_path)}')
|
docker-compose.yml
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
version: '3'
|
2 |
+
services:
|
3 |
+
|
4 |
+
mongo:
|
5 |
+
image: mongo
|
6 |
+
ports:
|
7 |
+
- 127.0.0.1:27017:27017
|
8 |
+
environment:
|
9 |
+
MONGO_INITDB_ROOT_USERNAME: user
|
10 |
+
MONGO_INITDB_ROOT_PASSWORD: password
|
11 |
+
MONGO_INITDB_DATABASE: db
|
12 |
+
expose:
|
13 |
+
- 27017
|
14 |
+
networks:
|
15 |
+
- omniboard
|
16 |
+
volumes:
|
17 |
+
- ./data:/data/db
|
18 |
+
|
19 |
+
mongoClientTemp:
|
20 |
+
image: mongo:latest
|
21 |
+
container_name: mongoClientTemp
|
22 |
+
links:
|
23 |
+
- mongo:mongo
|
24 |
+
command: mongo --host mongo -u user -p password --eval "db.getSiblingDB('db').createUser({user:'readonly', pwd:'password', roles:[{role:'read',db:'db'}]});"
|
25 |
+
depends_on:
|
26 |
+
- mongo
|
27 |
+
networks:
|
28 |
+
- omniboard
|
29 |
+
|
30 |
+
omniboard_readonly:
|
31 |
+
#image: vivekratnavel/omniboard:latest
|
32 |
+
build: https://github.com/lucidrains/omniboard.git
|
33 |
+
command: ["--mu", "mongodb://readonly:password@mongo:27017/db"]
|
34 |
+
ports:
|
35 |
+
- 0.0.0.0:8081:9000
|
36 |
+
networks:
|
37 |
+
- omniboard
|
38 |
+
depends_on:
|
39 |
+
- mongo
|
40 |
+
|
41 |
+
omniboard:
|
42 |
+
#image: vivekratnavel/omniboard:latest
|
43 |
+
build: https://github.com/lucidrains/omniboard.git
|
44 |
+
command: ["--mu", "mongodb://user:password@mongo:27017/db?authSource=admin"]
|
45 |
+
expose:
|
46 |
+
- 9000
|
47 |
+
networks:
|
48 |
+
- omniboard
|
49 |
+
depends_on:
|
50 |
+
- mongo
|
51 |
+
|
52 |
+
nginx:
|
53 |
+
image: dhswt/nginx-basic-auth:1.3
|
54 |
+
environment:
|
55 |
+
- HTPASSWD=isaac: #put passwd here
|
56 |
+
- FORWARD_HOST=omniboard
|
57 |
+
- FORWARD_PORT=9000
|
58 |
+
networks:
|
59 |
+
- omniboard
|
60 |
+
depends_on:
|
61 |
+
- omniboard
|
62 |
+
ports:
|
63 |
+
- 0.0.0.0:8080:80
|
64 |
+
expose:
|
65 |
+
- 8080
|
66 |
+
networks:
|
67 |
+
omniboard:
|
encoders.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from tokenizers import Tokenizer
|
2 |
+
from transformers import GPT2Tokenizer, GPT2TokenizerFast
|
3 |
+
|
4 |
+
def fetch_encoder(params):
|
5 |
+
no_dataset = params.get('no_dataset', False)
|
6 |
+
if no_dataset:
|
7 |
+
return None
|
8 |
+
|
9 |
+
dataset = next(iter(params['dataset_configs'].values())) # Get the first value from the dict
|
10 |
+
path = dataset["tokenizer_path"]
|
11 |
+
is_pretrained = dataset.get("tokenizer_is_pretrained", False)
|
12 |
+
|
13 |
+
if is_pretrained:
|
14 |
+
tok = GPT2TokenizerFast.from_pretrained(path)
|
15 |
+
|
16 |
+
# Will add a padding token id of 50257 at run-time
|
17 |
+
tok.add_special_tokens({'pad_token': '<|padding|>'})
|
18 |
+
return tok
|
19 |
+
|
20 |
+
return Tokenizer.from_file(path)
|
21 |
+
|
22 |
+
|
23 |
+
# GPT2Tokenizer and Tokenizer have different ways of fetching token ids
|
24 |
+
def encode(encoder, text, gpt=True):
|
25 |
+
result = encoder.encode(text)
|
26 |
+
if isinstance(result, list):
|
27 |
+
return result
|
28 |
+
return result.ids
|
export.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow.compat.v1 as tf
|
2 |
+
|
3 |
+
def export_model(estimator, export_dir, params,
|
4 |
+
checkpoint_path=None):
|
5 |
+
|
6 |
+
|
7 |
+
def serving_input_receiver_fn():
|
8 |
+
t = tf.placeholder(dtype=tf.int64,
|
9 |
+
shape=[1, params["n_ctx"]],
|
10 |
+
name='input_example_tensor')
|
11 |
+
return tf.estimator.export.ServingInputReceiver(t, t)
|
12 |
+
|
13 |
+
return estimator.export_saved_model(
|
14 |
+
export_dir, serving_input_receiver_fn, checkpoint_path=checkpoint_path)
|
gradio/demo.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
|
3 |
+
title = "GPT-Neo Demo"
|
4 |
+
description = "demo for GPT-Neo by EleutherAI for text generation. To use it, simply add your text, or click one of the examples to load them. Read more at the links below."
|
5 |
+
article = "<p style='text-align: center'><a href='http://github.com/eleutherai/gpt-neo'>GPT-Neo: Large Scale Autoregressive Language Modeling with Mesh-Tensorflow</a></p>"
|
6 |
+
examples = [
|
7 |
+
['The tower is 324 metres (1,063 ft) tall,'],
|
8 |
+
["The Moon's orbit around Earth has"],
|
9 |
+
["The smooth Borealis basin in the Northern Hemisphere covers 40%"]
|
10 |
+
]
|
11 |
+
|
12 |
+
gr.Interface.load("huggingface/EleutherAI/gpt-neo-2.7B", inputs=gr.inputs.Textbox(lines=5, label="Input Text"),title=title,description=description,article=article, examples=examples).launch()
|
inputs.py
ADDED
@@ -0,0 +1,384 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import tensorflow.compat.v1 as tf
|
3 |
+
from functools import partial
|
4 |
+
from data.encoders import encode
|
5 |
+
import random
|
6 |
+
import re
|
7 |
+
import logging
|
8 |
+
from itertools import cycle
|
9 |
+
from utils import natural_sort
|
10 |
+
|
11 |
+
|
12 |
+
### IN USE ###
|
13 |
+
|
14 |
+
def _get_number_of_documents(filename):
|
15 |
+
# extracts number of files from a filename formatted "<name>_<num_documents>.tfrecords."
|
16 |
+
# if no pattern is matched, returns None
|
17 |
+
match = re.search("_(\d{1,}).tfrecords$", filename)
|
18 |
+
return int(match.group(1)) if match is not None else match
|
19 |
+
|
20 |
+
|
21 |
+
def _get_number_of_documents_by_iteration(filename):
|
22 |
+
# extracts number of files from a tfrecord document in the event it doesn't have metadata in the filename
|
23 |
+
# this could be very slow.
|
24 |
+
logging.warning(
|
25 |
+
"inputs/sequential_input() found no metadata found in filename - iterating through first tfrecord to find global length")
|
26 |
+
count = 0
|
27 |
+
for item in tf.io.tf_record_iterator(filename):
|
28 |
+
count += 1
|
29 |
+
return count
|
30 |
+
|
31 |
+
|
32 |
+
def _get_skip_index(all_files, n_batches):
|
33 |
+
prev_cumsum = 0
|
34 |
+
cumsum = 0
|
35 |
+
global_n_documents = None
|
36 |
+
for count, f in cycle(enumerate(all_files)):
|
37 |
+
prev_cumsum = cumsum
|
38 |
+
if _get_number_of_documents(f) is not None:
|
39 |
+
cumsum += _get_number_of_documents(f)
|
40 |
+
elif global_n_documents is None:
|
41 |
+
global_n_documents = _get_number_of_documents_by_iteration(f)
|
42 |
+
cumsum += global_n_documents
|
43 |
+
else:
|
44 |
+
cumsum += global_n_documents
|
45 |
+
if cumsum == n_batches:
|
46 |
+
remainder = 0
|
47 |
+
skip_idx = count + 1
|
48 |
+
elif cumsum > n_batches:
|
49 |
+
remainder = n_batches - prev_cumsum
|
50 |
+
skip_idx = count
|
51 |
+
break
|
52 |
+
return skip_idx, remainder
|
53 |
+
|
54 |
+
|
55 |
+
def _parse_function(example_proto):
|
56 |
+
features = {
|
57 |
+
"text": tf.VarLenFeature(tf.int64)
|
58 |
+
}
|
59 |
+
parsed_features = tf.parse_single_example(example_proto, features)
|
60 |
+
return tf.sparse.to_dense(parsed_features["text"], parsed_features["text"].dense_shape[0])
|
61 |
+
|
62 |
+
|
63 |
+
def autoregressive_sample_text(params, x):
|
64 |
+
vals1 = x[:params["n_ctx"]]
|
65 |
+
vals2 = x[1:params["n_ctx"] + 1]
|
66 |
+
|
67 |
+
vals1 = tf.reshape(vals1, [params["n_ctx"]])
|
68 |
+
vals2 = tf.reshape(vals2, [params["n_ctx"]])
|
69 |
+
vals1 = tf.cast(vals1, dtype=tf.int32)
|
70 |
+
vals2 = tf.cast(vals2, dtype=tf.int32)
|
71 |
+
return vals1, vals2
|
72 |
+
|
73 |
+
|
74 |
+
def sequential_input(params, global_step=None, eval=False):
|
75 |
+
"""
|
76 |
+
Input fn that reads tfrecords encoded with a fixed chunk size (== n_ctx + 1), and that either:
|
77 |
+
|
78 |
+
- has the number of documents for each tfrecord file encoded in the title in the format
|
79 |
+
<name>_<n_documents>.tfrecords.
|
80 |
+
|
81 |
+
OR
|
82 |
+
|
83 |
+
- has a fixed number of documents per tfrecord file.
|
84 |
+
|
85 |
+
If the glob pattern above isn't matched, we assume that each document has the same number of samples as the first tfrecord read.
|
86 |
+
If this isn't the case, it may result in errors, or some samples being missed.
|
87 |
+
|
88 |
+
This means we can calculate the number of samples we've seen so far using the global step,
|
89 |
+
and can use dataset.skip() to iterate through the list of filenames, as opposed to the whole dataset, which is incredibly inefficient.
|
90 |
+
|
91 |
+
If training is starting and stopping often, as with TPU pre-emption, reading the whole dataset sequentially appears to improve model
|
92 |
+
performance, as it results in less repeated data.
|
93 |
+
"""
|
94 |
+
if not eval:
|
95 |
+
assert global_step is not None
|
96 |
+
logging.warning(
|
97 |
+
"Changing batch size with sequential_input() will result in some data being skipped or repeated. Please ensure your batch size stays constant throughout training.")
|
98 |
+
batch_size = params['eval_batch_size' if eval else 'train_batch_size']
|
99 |
+
|
100 |
+
filenames = []
|
101 |
+
for dataset_config in params['dataset_configs'].values(): # iterate through each dataset and read params
|
102 |
+
path_key = 'path' if not eval else 'eval_path'
|
103 |
+
path = dataset_config[path_key]
|
104 |
+
filenames.extend(
|
105 |
+
tf.io.gfile.glob(path)) # then glob all files that fit the pattern specified in dataset_configs
|
106 |
+
|
107 |
+
filenames = natural_sort(filenames)
|
108 |
+
shuffle_filenames = params.get("shuffle_input_filenames", True)
|
109 |
+
if shuffle_filenames:
|
110 |
+
seed = params.get('seed', 1) # shuffle deterministically
|
111 |
+
random.seed(seed)
|
112 |
+
random.shuffle(filenames)
|
113 |
+
|
114 |
+
dataset = tf.data.Dataset.from_tensor_slices(filenames).repeat() # repeat filenames to infinity
|
115 |
+
|
116 |
+
if not eval:
|
117 |
+
# skip forward first in the filenames list, then skip the remaining amount in the parsed tfrecords files
|
118 |
+
skip_idx, remainder = _get_skip_index(filenames, n_batches=global_step * params[
|
119 |
+
"train_batch_size"]) # TODO: fix for > 1 epoch
|
120 |
+
dataset = dataset.skip(skip_idx) # skip to skip idx
|
121 |
+
|
122 |
+
# read tfrecord examples and skip remainder
|
123 |
+
dataset = dataset.apply(tf.data.TFRecordDataset)
|
124 |
+
dataset = dataset.skip(remainder)
|
125 |
+
else:
|
126 |
+
# shuffle filenames if in eval mode
|
127 |
+
dataset = dataset.shuffle(len(filenames))
|
128 |
+
dataset = dataset.apply(tf.data.TFRecordDataset)
|
129 |
+
|
130 |
+
# parse the tokenized data from the tfrecord files and shuffle
|
131 |
+
dataset = dataset.map(_parse_function, num_parallel_calls=1)
|
132 |
+
dataset = dataset.map(partial(autoregressive_sample_text, params), num_parallel_calls=1)
|
133 |
+
|
134 |
+
# batch data and repeat to infinity
|
135 |
+
dataset = dataset.batch(batch_size, drop_remainder=True).prefetch(params["iterations"] * 2)
|
136 |
+
return dataset.repeat()
|
137 |
+
|
138 |
+
|
139 |
+
def pred_input(params, logger, enc=None,
|
140 |
+
path_to_prompt=""):
|
141 |
+
unicorns = "In a shocking finding, scientists discovered a herd of unicorns living in a remote, " \
|
142 |
+
"previously unexplored valley, in the Andes Mountains. Even more surprising to the " \
|
143 |
+
"researchers was the fact that the unicorns spoke perfect English."
|
144 |
+
|
145 |
+
text = unicorns if path_to_prompt == "" else open(path_to_prompt, "r").read()
|
146 |
+
tokens = encode(enc, text)
|
147 |
+
|
148 |
+
if len(tokens) > params["n_ctx"]:
|
149 |
+
logger.info("The length of your input prompt is longer than the model's context length - truncating input.")
|
150 |
+
tokens = tokens[len(tokens) - params["n_ctx"]:]
|
151 |
+
if len(tokens) < params["n_ctx"]:
|
152 |
+
tokens = tf.pad(tokens, [[0, params["n_ctx"] - len(tokens)]], constant_values=params["padding_id"])
|
153 |
+
t = tf.broadcast_to(tokens, [params["batch_size"], params["n_ctx"]])
|
154 |
+
dataset = tf.data.Dataset.from_tensors(t)
|
155 |
+
|
156 |
+
def _dummy_labels(x):
|
157 |
+
return x, x
|
158 |
+
|
159 |
+
dataset = dataset.map(_dummy_labels)
|
160 |
+
return dataset
|
161 |
+
|
162 |
+
|
163 |
+
def handle_pred_output(predictions, logger, enc, params, out_name="test"):
|
164 |
+
with tf.gfile.Open(f"{out_name}.txt", "w") as f:
|
165 |
+
for i, p in enumerate(predictions):
|
166 |
+
p = p["outputs"]
|
167 |
+
|
168 |
+
# remove eos + padding ids from output
|
169 |
+
idx = np.argmax(p == params['eos_id'])
|
170 |
+
if idx > 0:
|
171 |
+
p = p[:idx]
|
172 |
+
idx = np.argmax(p == params['padding_id'])
|
173 |
+
if idx > 0:
|
174 |
+
p = p[:idx]
|
175 |
+
|
176 |
+
text = enc.decode(p)
|
177 |
+
f.write("=" * 40 + " SAMPLE " + str(i) + " " + "=" * 40 + "\n")
|
178 |
+
f.write(text)
|
179 |
+
f.write("\n" + "=" * 80 + "\n")
|
180 |
+
|
181 |
+
logger.info("=" * 40 + " SAMPLE " + str(i) + " " + "=" * 40 + "\n")
|
182 |
+
logger.info(text)
|
183 |
+
logger.info("\n" + "=" * 80 + "\n")
|
184 |
+
|
185 |
+
|
186 |
+
### DEPRECATED ###
|
187 |
+
|
188 |
+
def generic_text(params, eval=False, sample_text_fn=None, **kwargs):
|
189 |
+
logging.warning("DEPRECATION WARNING: generic_text will be phased out in future versions.")
|
190 |
+
i = 0 if not eval else 1
|
191 |
+
|
192 |
+
weights = []
|
193 |
+
datasets = []
|
194 |
+
|
195 |
+
for dataset in params["datasets"]:
|
196 |
+
dataset_id, stitch, datatype, weight = dataset
|
197 |
+
|
198 |
+
assert dataset_id in params[
|
199 |
+
'dataset_configs'], f'Unknown dataset id {dataset_id} given. Please make sure your dataset ids contain that configuration'
|
200 |
+
dataset_config = params['dataset_configs'][dataset_id]
|
201 |
+
|
202 |
+
path_key = 'path' if not eval else 'eval_path'
|
203 |
+
path = dataset_config[path_key]
|
204 |
+
|
205 |
+
datasets.append(text_dataset(
|
206 |
+
tf.io.gfile.glob(path),
|
207 |
+
params,
|
208 |
+
stitch=stitch,
|
209 |
+
datatype=datatype,
|
210 |
+
batch=False,
|
211 |
+
sample_text_fn=sample_text_fn
|
212 |
+
))
|
213 |
+
|
214 |
+
weights.append(weight)
|
215 |
+
|
216 |
+
batch_size = params['eval_batch_size' if eval else 'train_batch_size']
|
217 |
+
|
218 |
+
seed = params.get('seed', None)
|
219 |
+
dataset = tf.data.experimental.sample_from_datasets(datasets, weights=weights, seed=seed)
|
220 |
+
dataset = dataset.batch(batch_size, drop_remainder=True).prefetch(params["iterations"] * 2)
|
221 |
+
return dataset
|
222 |
+
|
223 |
+
|
224 |
+
def text_dataset(files, params, stitch, datatype, batch=True, sample_text_fn=None):
|
225 |
+
seed = params.get('seed', None)
|
226 |
+
deterministic = seed is not None
|
227 |
+
num_parallel_calls = 1 if deterministic else tf.data.experimental.AUTOTUNE
|
228 |
+
|
229 |
+
dataset = tf.data.Dataset.from_tensor_slices(files)
|
230 |
+
|
231 |
+
if deterministic:
|
232 |
+
dataset = dataset.interleave(tf.data.TFRecordDataset, cycle_length=4)
|
233 |
+
else:
|
234 |
+
dataset = dataset.apply(
|
235 |
+
tf.data.experimental.parallel_interleave(tf.data.TFRecordDataset, cycle_length=4, sloppy=False))
|
236 |
+
|
237 |
+
if "documents" in datatype:
|
238 |
+
def _parse_function(example_proto):
|
239 |
+
features = {
|
240 |
+
# "hash": tf.VarLenFeature(tf.string),
|
241 |
+
"text": tf.VarLenFeature(tf.int64)
|
242 |
+
}
|
243 |
+
parsed_features = tf.parse_single_example(example_proto, features)
|
244 |
+
return parsed_features["text"], parsed_features["text"].dense_shape[0]
|
245 |
+
else:
|
246 |
+
def _parse_function(example_proto):
|
247 |
+
features = {
|
248 |
+
"text": tf.VarLenFeature(tf.int64)
|
249 |
+
}
|
250 |
+
parsed_features = tf.parse_single_example(example_proto, features)
|
251 |
+
return parsed_features["text"] # Assuming the text is not sparse
|
252 |
+
|
253 |
+
dataset = dataset.map(_parse_function, num_parallel_calls=1)
|
254 |
+
|
255 |
+
# Subsample method
|
256 |
+
if "documents" in datatype:
|
257 |
+
# Since samples can be less than the correct length, and TPUs don't like variable lengths, this function stitches together enough samples
|
258 |
+
# to have a text at least 1024 tokens long. For this to work the stitch parameter must be correctly tuned so that
|
259 |
+
# stitch * min(characters_in_text) >= amount
|
260 |
+
def _stitch_text(x, y):
|
261 |
+
x = tf.sparse.to_dense(x)
|
262 |
+
|
263 |
+
def _get_x(i):
|
264 |
+
return tf.gather(x[i], tf.range(y[i]))
|
265 |
+
|
266 |
+
out = _get_x(0)
|
267 |
+
eos_id = params['eos_id']
|
268 |
+
|
269 |
+
for i in range(1, stitch):
|
270 |
+
out = tf.concat([out, [eos_id], _get_x(i)], axis=0) # text1<|endoftext|>text2
|
271 |
+
|
272 |
+
return out
|
273 |
+
|
274 |
+
# Hack-y way to stitch together multiple texts
|
275 |
+
|
276 |
+
dataset = dataset.shuffle(1000 * stitch, seed=seed).batch(stitch, drop_remainder=True).map(_stitch_text,
|
277 |
+
num_parallel_calls=num_parallel_calls)
|
278 |
+
|
279 |
+
# Sample 1024(+1) tokens from the stitched together text
|
280 |
+
is_random_documents = datatype == "documents_random"
|
281 |
+
if sample_text_fn is not None:
|
282 |
+
_sample_text = partial(sample_text_fn, random_documents=is_random_documents)
|
283 |
+
else:
|
284 |
+
_sample_text = autoregressive_sample_text_random_documents if is_random_documents else autoregressive_sample_text
|
285 |
+
_sample_text = partial(_sample_text, params)
|
286 |
+
|
287 |
+
dataset = dataset.map(_sample_text, num_parallel_calls=num_parallel_calls)
|
288 |
+
|
289 |
+
if batch:
|
290 |
+
dataset = dataset.batch(params["train_batch_size"], drop_remainder=True).prefetch(params["iterations"] * 2)
|
291 |
+
|
292 |
+
dataset = dataset.repeat()
|
293 |
+
|
294 |
+
return dataset
|
295 |
+
|
296 |
+
|
297 |
+
def autoregressive_sample_text_random_documents(params, x):
|
298 |
+
seed = params.get('seed', None)
|
299 |
+
s = tf.size(x)
|
300 |
+
r = tf.random.uniform([], maxval=s - (params["n_ctx"] + 1), dtype=tf.dtypes.int32, seed=seed)
|
301 |
+
r1 = tf.range(r, r + params["n_ctx"])
|
302 |
+
r2 = tf.range(r + 1, (r + 1) + params["n_ctx"])
|
303 |
+
r1 = tf.reshape(r1, [params["n_ctx"]]) # Somehow, this makes the compiler happy
|
304 |
+
r2 = tf.reshape(r2, [params[
|
305 |
+
"n_ctx"]]) # TPUs want constant sized input, and these reshapes makes it recognize the shape of the input
|
306 |
+
vals1 = tf.gather(x, r1)
|
307 |
+
vals2 = tf.gather(x, r2)
|
308 |
+
|
309 |
+
vals1 = tf.reshape(vals1, [params["n_ctx"]])
|
310 |
+
vals2 = tf.reshape(vals2, [params["n_ctx"]])
|
311 |
+
vals1 = tf.cast(vals1, dtype=tf.int32)
|
312 |
+
vals2 = tf.cast(vals2, dtype=tf.int32)
|
313 |
+
return vals1, vals2
|
314 |
+
|
315 |
+
|
316 |
+
def mlm_sample_text(params, x, random_documents=False):
|
317 |
+
seed = params.get('seed', None)
|
318 |
+
ctx_len = params["n_ctx"]
|
319 |
+
assert 'mlm_mask_id' in params, 'the key `mlm_mask_id` must be set on your config to do masked language model training, specifying the id of the reserved mask token'
|
320 |
+
|
321 |
+
mask_id = params['mlm_mask_id']
|
322 |
+
cls_token_id = params.get('mlm_cls_token_id', None)
|
323 |
+
num_tokens = params.get('n_vocab', None)
|
324 |
+
|
325 |
+
mask_ignore_ids = set(params.get('mlm_mask_ignore_ids', []))
|
326 |
+
mask_ignore_ids.add(cls_token_id)
|
327 |
+
|
328 |
+
mask_prob = params.get('mlm_mask_prob', 0.15)
|
329 |
+
same_token_prob = params.get('mlm_same_token_prob', 0.10)
|
330 |
+
random_token_prob = params.get('mlm_random_token_prob', 0.)
|
331 |
+
|
332 |
+
seq_len = ctx_len if cls_token_id is None else (ctx_len - 1)
|
333 |
+
|
334 |
+
if random_documents:
|
335 |
+
s = tf.size(x)
|
336 |
+
r = tf.random.uniform([], maxval=(s - seq_len), dtype=tf.dtypes.int32, seed=seed)
|
337 |
+
r1 = tf.range(r, r + seq_len)
|
338 |
+
r1 = tf.reshape(r1, [seq_len])
|
339 |
+
features = tf.gather(x, r1)
|
340 |
+
else:
|
341 |
+
features = x[:seq_len]
|
342 |
+
|
343 |
+
# add cls token id if specified by `mlm_cls_token_id`
|
344 |
+
if cls_token_id is not None:
|
345 |
+
features = tf.pad(features, [[1, 0]], constant_values=cls_token_id)
|
346 |
+
|
347 |
+
features = tf.cast(features, dtype=tf.int32)
|
348 |
+
shape = features.shape
|
349 |
+
|
350 |
+
# determine which tokens are mask-able
|
351 |
+
can_mask = tf.not_equal(features, 0)
|
352 |
+
for ignore_id in mask_ignore_ids:
|
353 |
+
can_mask &= tf.not_equal(features, ignore_id)
|
354 |
+
|
355 |
+
# generate boolean mask for masking ids
|
356 |
+
mask_mask = tf.less(tf.random.uniform(shape, minval=0., maxval=1., dtype=tf.float32, seed=seed), mask_prob)
|
357 |
+
mask_mask &= can_mask
|
358 |
+
|
359 |
+
# generate mask for actually replacing the tokens, for allowing a small number of tokens to stay the same
|
360 |
+
replace_mask = tf.less(tf.random.uniform(shape, minval=0., maxval=1., dtype=tf.float32, seed=seed),
|
361 |
+
1 - same_token_prob)
|
362 |
+
|
363 |
+
# randomly replace some tokens with random tokens before masking
|
364 |
+
if random_token_prob > 0:
|
365 |
+
random_token_mask = tf.less(tf.random.uniform(shape, minval=0., maxval=1., dtype=tf.float32, seed=seed),
|
366 |
+
random_token_prob)
|
367 |
+
random_tokens = tf.random.uniform(shape, minval=1, maxval=num_tokens, dtype=tf.dtypes.int32, seed=seed)
|
368 |
+
|
369 |
+
# make sure random tokens do not include illegal token ids specified by `mlm_mask_ignore_ids`
|
370 |
+
random_can_mask = tf.not_equal(random_tokens, 0)
|
371 |
+
for ignore_id in mask_ignore_ids:
|
372 |
+
random_can_mask &= tf.not_equal(random_tokens, ignore_id)
|
373 |
+
|
374 |
+
features = tf.where(random_token_mask & random_can_mask, random_tokens, features)
|
375 |
+
|
376 |
+
# mask the tokens
|
377 |
+
mask_tokens = tf.ones(shape, dtype=tf.int32) * mask_id
|
378 |
+
masked_features = tf.where(mask_mask & replace_mask, mask_tokens, features)
|
379 |
+
|
380 |
+
# labels will be set to 0 for all non-masked tokens
|
381 |
+
labels = tf.where(mask_mask, tf.zeros(shape, dtype=tf.int32), features)
|
382 |
+
|
383 |
+
masked_features, labels = map(lambda t: tf.reshape(t, [ctx_len]), (masked_features, labels))
|
384 |
+
return masked_features, labels
|
main.py
ADDED
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""GPT-like model in Mesh-Tensorflow"""
|
2 |
+
|
3 |
+
from functools import partial
|
4 |
+
import mesh_tensorflow as mtf
|
5 |
+
import tensorflow.compat.v1 as tf
|
6 |
+
from tensorflow.python.tpu import tpu_config, tpu_estimator
|
7 |
+
from tensorflow_estimator.python.estimator import estimator as estimator_lib
|
8 |
+
from utils import save_config, expand_attention_types_params, yes_or_no, remove_gs_or_filepath, setup_logging, \
|
9 |
+
check_dataset
|
10 |
+
from inputs import sequential_input, pred_input, handle_pred_output, mlm_sample_text, generic_text
|
11 |
+
from export import export_model
|
12 |
+
from model_fns import model_fn
|
13 |
+
from data.encoders import fetch_encoder
|
14 |
+
from configs import fetch_model_params
|
15 |
+
from tasks import task_descriptors
|
16 |
+
import argparse
|
17 |
+
import json
|
18 |
+
import numpy
|
19 |
+
|
20 |
+
|
21 |
+
def parse_args():
|
22 |
+
# Parse command line arguments
|
23 |
+
parser = argparse.ArgumentParser()
|
24 |
+
parser.add_argument("--tpu", type=str, help="Name of TPU to train on, if any.")
|
25 |
+
parser.add_argument("--gpu_ids", nargs="+", type=str, default=["device:GPU:0"],
|
26 |
+
help="If training on GPU, can specify your GPU names in a list - i.e 'device:GPU:0 device:GPU:1'")
|
27 |
+
parser.add_argument("--model", type=str, default=None, help="JSON file that contains model parameters.")
|
28 |
+
parser.add_argument("--steps_per_checkpoint", type=int, default=5000, help="Save a model checkpoint every X steps.")
|
29 |
+
parser.add_argument("--auto_layout", action="store_true", help="If set, generates and prints the most memory "
|
30 |
+
"efficient layout according to MTF auto layout.")
|
31 |
+
parser.add_argument("--auto_layout_and_mesh_shape", action="store_true",
|
32 |
+
help="If set, generates and prints the most memory efficient layout and mesh shape according to"
|
33 |
+
" MTF auto layout.")
|
34 |
+
parser.add_argument("--new", action="store_true", help="If set, deletes previous checkpoint, if it exists, and "
|
35 |
+
"starts a new training run")
|
36 |
+
parser.add_argument("--predict", action="store_true", help="If set, uses the model to predict rather than train.")
|
37 |
+
parser.add_argument("--eval", action="store_true", help="If set, run model in evaluation mode.")
|
38 |
+
parser.add_argument("--prompt", type=str, help="path to .txt file containing a prompt for prediction. If empty, "
|
39 |
+
"defaults to unicorns.",
|
40 |
+
default="")
|
41 |
+
parser.add_argument("--check_dataset", action="store_true",
|
42 |
+
help="If set, outputs sample from the dataset and quits.")
|
43 |
+
parser.add_argument("--sacred_id", type=str, default="nosacred", help="Sacred run id.")
|
44 |
+
parser.add_argument("--entmax_sampling", action="store_true", help="(experimental) use entmax sampling")
|
45 |
+
parser.add_argument("--export", action="store_true", help="If set, will export the model.")
|
46 |
+
args = parser.parse_args()
|
47 |
+
assert args.model is not None, "Model must be set"
|
48 |
+
return args
|
49 |
+
|
50 |
+
|
51 |
+
def main(args):
|
52 |
+
# Setup logging
|
53 |
+
logger = setup_logging(args)
|
54 |
+
|
55 |
+
# Read params of model
|
56 |
+
params = fetch_model_params(args.model)
|
57 |
+
|
58 |
+
# Fetch appropriate input functions
|
59 |
+
input_fn = params.get("input_fn", "sequential_input")
|
60 |
+
if input_fn == "sequential_input":
|
61 |
+
input_fn = sequential_input
|
62 |
+
elif input_fn == "generic_text":
|
63 |
+
input_fn = generic_text
|
64 |
+
pred_input_fn = pred_input
|
65 |
+
handle_pred_output_fn = handle_pred_output
|
66 |
+
|
67 |
+
# get current step
|
68 |
+
current_step = int(estimator_lib._load_global_step_from_checkpoint_dir(params["model_path"]))
|
69 |
+
logger.info(f"Current step {current_step}")
|
70 |
+
|
71 |
+
if params["mlm_training"]:
|
72 |
+
mlm_sample_text_fn = partial(mlm_sample_text, params)
|
73 |
+
input_fn = partial(generic_text, sample_text_fn=mlm_sample_text_fn)
|
74 |
+
if args.check_dataset:
|
75 |
+
check_dataset(input_fn, params)
|
76 |
+
|
77 |
+
|
78 |
+
# Fetch encoder per params
|
79 |
+
encoder = fetch_encoder(params)
|
80 |
+
|
81 |
+
pred_input_fn = partial(pred_input_fn, path_to_prompt=args.prompt, logger=logger, enc=encoder)
|
82 |
+
|
83 |
+
# Sample from Dataset if check dataset flag is on
|
84 |
+
if args.check_dataset:
|
85 |
+
check_dataset(input_fn, params, global_step=current_step)
|
86 |
+
|
87 |
+
# Confirm deletion of checkpoint files if --new flag is set
|
88 |
+
if args.new:
|
89 |
+
if yes_or_no(f"Are you sure you want to remove '{params['model_path']}' to start afresh?"):
|
90 |
+
remove_gs_or_filepath(params["model_path"])
|
91 |
+
else:
|
92 |
+
exit()
|
93 |
+
|
94 |
+
# Save config to logdir for experiment management
|
95 |
+
save_config(params, params["model_path"])
|
96 |
+
|
97 |
+
# Add to params: auto_layout, auto_layout_and_mesh_shape, use_tpu, num_cores
|
98 |
+
mesh_shape = mtf.convert_to_shape(params["mesh_shape"])
|
99 |
+
params["num_cores"] = mesh_shape.size
|
100 |
+
params["auto_layout"] = args.auto_layout
|
101 |
+
params["auto_layout_and_mesh_shape"] = args.auto_layout_and_mesh_shape
|
102 |
+
params["use_tpu"] = True if not args.tpu is None else False
|
103 |
+
params["gpu_ids"] = args.gpu_ids
|
104 |
+
params["steps_per_checkpoint"] = args.steps_per_checkpoint
|
105 |
+
# Expand attention types param
|
106 |
+
params["attention_types"] = expand_attention_types_params(params["attention_types"])
|
107 |
+
assert len(params["attention_types"]) == params["n_layer"] # Assert that the length of expanded list = num layers
|
108 |
+
params["predict_batch_size"] = params.get("predict_batch_size", 1) # Default to 1
|
109 |
+
params["predict"] = args.predict
|
110 |
+
params['model'] = params.get("model", "GPT") # Default model selection to GPT since it's the only option for now
|
111 |
+
params["export"] = args.export
|
112 |
+
# Set sampling parameters
|
113 |
+
params["sampling_use_entmax"] = args.entmax_sampling
|
114 |
+
|
115 |
+
# Sample quality of MoE models suffers when using the faster sampling method, so default to slow_sampling if
|
116 |
+
# moe layers are present
|
117 |
+
params["slow_sampling"] = True if params["moe_layers"] is not None else False
|
118 |
+
|
119 |
+
logger.info(f"params = {params}")
|
120 |
+
|
121 |
+
# Get eval tasks from params
|
122 |
+
eval_tasks = params.get("eval_tasks", [])
|
123 |
+
has_predict_or_eval_steps_or_eval_tasks = params["predict_steps"] > 0 or params["eval_steps"] > 0 or len(
|
124 |
+
eval_tasks) > 0
|
125 |
+
|
126 |
+
for t in eval_tasks:
|
127 |
+
assert t in task_descriptors, f"Eval task '{t}' is not known"
|
128 |
+
task_descriptors[t]["init_fn"](params)
|
129 |
+
|
130 |
+
# Set up TPUs and Estimator
|
131 |
+
if args.tpu == "colab":
|
132 |
+
tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver() if params["use_tpu"] else None
|
133 |
+
else:
|
134 |
+
tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(args.tpu) if params["use_tpu"] else None
|
135 |
+
|
136 |
+
config = tpu_config.RunConfig(
|
137 |
+
cluster=tpu_cluster_resolver,
|
138 |
+
model_dir=params["model_path"],
|
139 |
+
save_checkpoints_steps=None, # Disable the default saver
|
140 |
+
save_checkpoints_secs=None, # Disable the default saver
|
141 |
+
log_step_count_steps=params["iterations"],
|
142 |
+
save_summary_steps=params["iterations"],
|
143 |
+
tpu_config=tpu_config.TPUConfig(
|
144 |
+
num_shards=mesh_shape.size,
|
145 |
+
iterations_per_loop=params["iterations"],
|
146 |
+
num_cores_per_replica=1,
|
147 |
+
per_host_input_for_training=tpu_config.InputPipelineConfig.BROADCAST))
|
148 |
+
|
149 |
+
estimator = tpu_estimator.TPUEstimator(
|
150 |
+
use_tpu=params["use_tpu"],
|
151 |
+
model_fn=model_fn,
|
152 |
+
config=config,
|
153 |
+
train_batch_size=params["train_batch_size"],
|
154 |
+
eval_batch_size=params["train_batch_size"],
|
155 |
+
predict_batch_size=params["predict_batch_size"],
|
156 |
+
params=params)
|
157 |
+
|
158 |
+
def _make_task_estimator(task):
|
159 |
+
task_params = params.copy()
|
160 |
+
task_params["eval_task"] = task
|
161 |
+
return tpu_estimator.TPUEstimator(
|
162 |
+
use_tpu=params["use_tpu"],
|
163 |
+
model_fn=model_fn,
|
164 |
+
config=config,
|
165 |
+
train_batch_size=params["train_batch_size"],
|
166 |
+
eval_batch_size=params["eval_batch_size"],
|
167 |
+
predict_batch_size=params["predict_batch_size"],
|
168 |
+
params=task_params)
|
169 |
+
|
170 |
+
eval_task_estimators = {
|
171 |
+
task: _make_task_estimator(task)
|
172 |
+
for task in eval_tasks
|
173 |
+
}
|
174 |
+
|
175 |
+
if args.export:
|
176 |
+
export_model(estimator, "export", params)
|
177 |
+
return
|
178 |
+
|
179 |
+
if args.predict:
|
180 |
+
# Predict
|
181 |
+
predictions = estimator.predict(input_fn=pred_input_fn)
|
182 |
+
logger.info("Predictions generated")
|
183 |
+
enc = fetch_encoder(params)
|
184 |
+
handle_pred_output_fn(predictions, logger, enc, params, out_name=f"predictions_{args.sacred_id}_{current_step}")
|
185 |
+
return
|
186 |
+
|
187 |
+
def save_eval_results(task, eval_results):
|
188 |
+
def as_python(x):
|
189 |
+
if isinstance(x, numpy.generic):
|
190 |
+
return x.item()
|
191 |
+
return x
|
192 |
+
eval_results = {k: as_python(v) for k, v in eval_results.items()}
|
193 |
+
with open(f'eval_{args.sacred_id}.jsonl', 'a') as fh:
|
194 |
+
json.dump({'task': task, 'current_step': current_step, **eval_results}, fh)
|
195 |
+
fh.write('\n')
|
196 |
+
|
197 |
+
def run_eval():
|
198 |
+
logger.info("Running evaluation...")
|
199 |
+
eval_results = estimator.evaluate(
|
200 |
+
input_fn=partial(input_fn, eval=True),
|
201 |
+
steps=params["eval_steps"])
|
202 |
+
logger.info(f"Eval results: {eval_results}")
|
203 |
+
save_eval_results('validation', eval_results)
|
204 |
+
|
205 |
+
def run_eval_tasks():
|
206 |
+
for task in eval_tasks:
|
207 |
+
logger.info(f"Starting evaluation task '{task}'")
|
208 |
+
task_info = task_descriptors[task]["get_task_info_fn"](params)
|
209 |
+
task_estimator = eval_task_estimators[task]
|
210 |
+
task_input_fn = task_descriptors[task]["input_fn"]
|
211 |
+
eval_results = task_estimator.evaluate(
|
212 |
+
input_fn=task_input_fn,
|
213 |
+
steps=task_info["n_steps"],
|
214 |
+
name=task)
|
215 |
+
logger.info(f"Eval task '{task}' results: {eval_results}")
|
216 |
+
save_eval_results(task, eval_results)
|
217 |
+
|
218 |
+
if args.eval:
|
219 |
+
run_eval_tasks()
|
220 |
+
if params["eval_steps"] > 0:
|
221 |
+
run_eval()
|
222 |
+
return
|
223 |
+
|
224 |
+
|
225 |
+
elif has_predict_or_eval_steps_or_eval_tasks:
|
226 |
+
# Eval and train - stop and predict and/or eval every checkpoint
|
227 |
+
while current_step < params["train_steps"]:
|
228 |
+
next_checkpoint = min(current_step + args.steps_per_checkpoint,
|
229 |
+
params["train_steps"])
|
230 |
+
|
231 |
+
estimator.train(input_fn=partial(input_fn, global_step=current_step, eval=False), max_steps=next_checkpoint)
|
232 |
+
current_step = next_checkpoint
|
233 |
+
|
234 |
+
if params["predict_steps"] > 0:
|
235 |
+
logger.info("Running prediction...")
|
236 |
+
predictions = estimator.predict(input_fn=pred_input_fn)
|
237 |
+
enc = fetch_encoder(params)
|
238 |
+
handle_pred_output_fn(predictions, logger, enc, params, out_name=f"predictions_{args.sacred_id}_{current_step}")
|
239 |
+
|
240 |
+
if params["eval_steps"] > 0:
|
241 |
+
run_eval()
|
242 |
+
|
243 |
+
if eval_tasks:
|
244 |
+
run_eval_tasks()
|
245 |
+
|
246 |
+
return
|
247 |
+
else:
|
248 |
+
# Else, just train
|
249 |
+
while current_step < params["train_steps"]:
|
250 |
+
# Else, don't stop and restart
|
251 |
+
estimator.train(input_fn=partial(input_fn, global_step=current_step, eval=False), max_steps=params["train_steps"])
|
252 |
+
|
253 |
+
|
254 |
+
if __name__ == "__main__":
|
255 |
+
tf.disable_v2_behavior()
|
256 |
+
args = parse_args()
|
257 |
+
main(args)
|
model_fns.py
ADDED
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import mesh_tensorflow as mtf
|
2 |
+
import tensorflow.compat.v1 as tf
|
3 |
+
from tensorflow.python.tpu import tpu_estimator
|
4 |
+
import mesh_tensorflow.transformer as mtf_transformer
|
5 |
+
from optimizers import get_optimizer
|
6 |
+
from utils import (create_host_call, get_graph_info, remove_batch_from_layout, simd_mesh_setup, add_mode_to_params,
|
7 |
+
get_batch_size, auto_layout, auto_layout_and_mesh_shape)
|
8 |
+
from models.utils import biasmask_attn_weights
|
9 |
+
from tensorflow.python.ops import resources
|
10 |
+
from sample import sample_autoregressive
|
11 |
+
from models.gpt2 import gpt2
|
12 |
+
import math
|
13 |
+
|
14 |
+
|
15 |
+
def model_fn(features, labels, mode, params):
|
16 |
+
# Get global step
|
17 |
+
global_step = tf.train.get_global_step()
|
18 |
+
|
19 |
+
# Construct mtf graph + mesh from params
|
20 |
+
graph = mtf.Graph()
|
21 |
+
mesh_shape = mtf.convert_to_shape(params["mesh_shape"])
|
22 |
+
layout_rules = mtf.convert_to_layout_rules(params["layout"])
|
23 |
+
|
24 |
+
# Mesh setup
|
25 |
+
if params["use_tpu"]:
|
26 |
+
var_placer, mesh_impl = simd_mesh_setup(params, mesh_shape, layout_rules)
|
27 |
+
else:
|
28 |
+
var_placer = None
|
29 |
+
gpu_ids = params["gpu_ids"]
|
30 |
+
mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
|
31 |
+
mesh_shape, layout_rules, gpu_ids)
|
32 |
+
|
33 |
+
# Trainable variable precision
|
34 |
+
# Store to checkpoints in master type, train in slice type, compute in activation type
|
35 |
+
if params["precision"] == "bfloat16":
|
36 |
+
variable_dtype = mtf.VariableDType(master_dtype=tf.bfloat16, slice_dtype=tf.float32,
|
37 |
+
activation_dtype=tf.bfloat16)
|
38 |
+
else:
|
39 |
+
variable_dtype = mtf.VariableDType(master_dtype=tf.float32, slice_dtype=tf.float32, activation_dtype=tf.float32)
|
40 |
+
|
41 |
+
# Build mtf mesh object
|
42 |
+
mesh = mtf.Mesh(graph, "my_mesh", var_placer)
|
43 |
+
|
44 |
+
# Build mtf_features & seq length dict for getting number of microbatches
|
45 |
+
# We need to pack inputs into a dict to pass into serialize_training_step
|
46 |
+
features_dict = {"inputs": features, "labels": labels}
|
47 |
+
sequence_length_dict = {"inputs": params["n_ctx"], "labels": params["n_ctx"]}
|
48 |
+
|
49 |
+
params = add_mode_to_params(params, mode)
|
50 |
+
batch_size = get_batch_size(params)
|
51 |
+
|
52 |
+
batch_dim = mtf.Dimension("batch", batch_size)
|
53 |
+
batch_dims = [batch_dim]
|
54 |
+
feature_length = sequence_length_dict["inputs"]
|
55 |
+
length_dim = mtf.Dimension("sequence", feature_length)
|
56 |
+
|
57 |
+
mtf_features = {}
|
58 |
+
for key, x in features_dict.items():
|
59 |
+
if x is not None:
|
60 |
+
feature_shape = mtf.Shape(batch_dims + [length_dim])
|
61 |
+
if type(features_dict[key]) == dict:
|
62 |
+
features_dict[key] = features_dict[key]["feature"]
|
63 |
+
x = tf.cast(features_dict[key], tf.int32)
|
64 |
+
x = tf.reshape(x, feature_shape.to_integer_list)
|
65 |
+
mtf_features[key] = mtf.import_fully_replicated(
|
66 |
+
mesh, x, feature_shape, name=key)
|
67 |
+
|
68 |
+
# Instantiate dict for dimensions, bias, etc that can be calculated here once then passed into model
|
69 |
+
other_features = {}
|
70 |
+
memory_length_dim = mtf.Dimension("memory_length", length_dim.size)
|
71 |
+
|
72 |
+
attn_bias = biasmask_attn_weights(mesh, length_dim, memory_length_dim, variable_dtype) if params["causal"] else None
|
73 |
+
|
74 |
+
# Add attn_bias into mtf_features
|
75 |
+
other_features["attn_bias"] = attn_bias
|
76 |
+
|
77 |
+
# Define other Dimensions that we'll need inside the model
|
78 |
+
embd_dim = mtf.Dimension("embd", params["n_embd"])
|
79 |
+
vocab_dim = mtf.Dimension("vocab", params["n_vocab"])
|
80 |
+
# We need this because gathering when both the args have the same dimension in them breaks things
|
81 |
+
# This dim is specifically for the weights
|
82 |
+
# This prevents the "Einsum has lhs dimension without corresponding rhs or output dimension." error
|
83 |
+
embed_sequence_dim = mtf.Dimension("embed_sequence", params["n_ctx"])
|
84 |
+
|
85 |
+
other_features["embd_dim"] = embd_dim
|
86 |
+
other_features["vocab_dim"] = vocab_dim
|
87 |
+
other_features["embed_sequence_dim"] = embed_sequence_dim
|
88 |
+
other_features["memory_length_dim"] = memory_length_dim
|
89 |
+
|
90 |
+
if mode == tf.estimator.ModeKeys.PREDICT:
|
91 |
+
# Set up the model for prediction
|
92 |
+
inputs = mtf_features["inputs"]
|
93 |
+
if params["remove_partial_sequences"] is None:
|
94 |
+
params["remove_partial_sequences"] = False
|
95 |
+
|
96 |
+
export = params.get("export", False)
|
97 |
+
|
98 |
+
if not export:
|
99 |
+
mtf_samples = sample_autoregressive(
|
100 |
+
inputs, other_features=other_features, params=params, variable_dtype=variable_dtype,
|
101 |
+
remove_partial_sequences=params["remove_partial_sequences"], stop_at_token=params["eos_id"],
|
102 |
+
sampling_use_entmax=params['sampling_use_entmax'], max_steps=params["predict_max_steps"])
|
103 |
+
|
104 |
+
else:
|
105 |
+
with mtf.utils.outside_all_rewrites():
|
106 |
+
with tf.variable_scope('gpt2'):
|
107 |
+
mtf_samples, loss, loss_batch = gpt2.model(mtf_features, other_features, params, mesh,
|
108 |
+
variable_dtype=variable_dtype, context=None)
|
109 |
+
|
110 |
+
mtf_samples = mtf.anonymize(mtf_samples)
|
111 |
+
inputs = mtf.anonymize(inputs)
|
112 |
+
lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=True)
|
113 |
+
inputs = lowering.export_to_tf_tensor(inputs)
|
114 |
+
outputs = lowering.export_to_tf_tensor(mtf_samples)
|
115 |
+
predictions = {
|
116 |
+
"inputs": inputs,
|
117 |
+
"outputs": outputs}
|
118 |
+
|
119 |
+
def scaffold_fn():
|
120 |
+
return tf.train.Scaffold(
|
121 |
+
local_init_op=tf.group(
|
122 |
+
tf.train.Scaffold.default_local_init_op(),
|
123 |
+
lowering.copy_masters_to_slices(),
|
124 |
+
name="mtf_local_init_op"),
|
125 |
+
ready_op=tf.concat(
|
126 |
+
[tf.report_uninitialized_variables(),
|
127 |
+
resources.report_uninitialized_resources()],
|
128 |
+
axis=0,
|
129 |
+
name="mtf_ready_op"))
|
130 |
+
|
131 |
+
return tpu_estimator.TPUEstimatorSpec(
|
132 |
+
mode=tf.estimator.ModeKeys.PREDICT,
|
133 |
+
predictions=predictions,
|
134 |
+
scaffold_fn=scaffold_fn,
|
135 |
+
prediction_hooks=[mtf.MtfRestoreHook(lowering)])
|
136 |
+
|
137 |
+
# We're not predicting, so we better be training or evaluating
|
138 |
+
assert (mode == tf.estimator.ModeKeys.TRAIN or mode == tf.estimator.ModeKeys.EVAL)
|
139 |
+
|
140 |
+
if mode == tf.estimator.ModeKeys.TRAIN:
|
141 |
+
# Gets number of microbatches per batch for serialized training
|
142 |
+
# if param tokens_per_mb_per_replica = None, this defaults to 1 and no microbatching is performed
|
143 |
+
num_microbatches = int(mtf_transformer.utils.serialize_num_microbatches(batch_dim=batch_dim,
|
144 |
+
sequence_length=sequence_length_dict,
|
145 |
+
mesh_shape=mesh_shape,
|
146 |
+
layout_rules=layout_rules,
|
147 |
+
tokens_per_microbatch_per_replica=
|
148 |
+
params["tokens_per_mb_per_replica"]))
|
149 |
+
else:
|
150 |
+
num_microbatches = 1
|
151 |
+
|
152 |
+
params["num_microbatches"] = num_microbatches # Add num microbatches to params
|
153 |
+
|
154 |
+
if num_microbatches > 1:
|
155 |
+
|
156 |
+
# For serialize_training_step we need to modify the model to output results in a dict
|
157 |
+
def serialized_fn(mtf_features):
|
158 |
+
if params["model"] == "GPT":
|
159 |
+
with tf.variable_scope('gpt2'):
|
160 |
+
logits, loss, loss_batch = gpt2.model(mtf_features, other_features, params, mesh,
|
161 |
+
variable_dtype=variable_dtype)
|
162 |
+
return {"logits": logits, "loss": loss, "loss_batch": loss_batch}
|
163 |
+
else:
|
164 |
+
raise Exception(f"'{params['model']}' is not a valid model - please select from [GPT]")
|
165 |
+
|
166 |
+
# Serialize the training step - Gradients are accumulated locally and reduced once.
|
167 |
+
var_grads, output_dict = mtf.serialize_training_step(mtf_features, serialized_fn, batch_dim, num_microbatches)
|
168 |
+
loss = output_dict["loss"]
|
169 |
+
loss_batch = output_dict["loss_batch"]
|
170 |
+
logits = output_dict["logits"]
|
171 |
+
else:
|
172 |
+
# If we're not splitting into microbatches, return logits & loss as is
|
173 |
+
if params["model"] == "GPT":
|
174 |
+
with mtf.utils.outside_all_rewrites():
|
175 |
+
with tf.variable_scope('gpt2'):
|
176 |
+
logits, loss, loss_batch = gpt2.model(mtf_features, other_features, params, mesh,
|
177 |
+
variable_dtype=variable_dtype, context=None)
|
178 |
+
else:
|
179 |
+
raise Exception(f"'{params['model']}' is not a valid model - please select from [GPT]")
|
180 |
+
|
181 |
+
# Auto layout generation
|
182 |
+
if params["auto_layout"]:
|
183 |
+
auto_layout(graph, mesh_shape, logits, loss)
|
184 |
+
if params["auto_layout_and_mesh_shape"]:
|
185 |
+
auto_layout_and_mesh_shape(graph, params["num_cores"], logits, loss)
|
186 |
+
|
187 |
+
if mode == tf.estimator.ModeKeys.TRAIN:
|
188 |
+
# In TRAIN mode, get optimizer
|
189 |
+
if params["num_microbatches"] > 1:
|
190 |
+
# If we are splitting the batch into microbatches, var grads are created in the serialize_training_step fn
|
191 |
+
# So we pass them in here
|
192 |
+
_, update_ops, var_grads = get_optimizer(mesh, loss, params, variable_dtype=variable_dtype,
|
193 |
+
inp_var_grads=var_grads)
|
194 |
+
else:
|
195 |
+
# Otherwise, they are created in the get_optimizer fn, so we leave inp_var_grads blank
|
196 |
+
_, update_ops, var_grads = get_optimizer(mesh, loss, params, variable_dtype=variable_dtype)
|
197 |
+
# Log summaries to tensorboard
|
198 |
+
mtf.scalar_summary("loss", loss)
|
199 |
+
# Log gradients if in params
|
200 |
+
if params["log_grads"] not in [None, False]:
|
201 |
+
for g in var_grads:
|
202 |
+
grad_norm = mtf.sqrt(mtf.reduce_sum(mtf.square(g)))
|
203 |
+
mtf.scalar_summary("grads/norm" + g.name[:-2], grad_norm)
|
204 |
+
else:
|
205 |
+
# For now, we can only export fully-replicated tensors.
|
206 |
+
# This has to be done before lowering or they will not be included in the graph
|
207 |
+
mean_logits = mtf.reduce_mean(logits, reduced_dim=vocab_dim)
|
208 |
+
max_logits = mtf.argmax(logits, vocab_dim)
|
209 |
+
del logits
|
210 |
+
fully_replicated_mean_logits = mtf.anonymize(mean_logits)
|
211 |
+
fully_replicated_max_logits = mtf.anonymize(max_logits)
|
212 |
+
fully_replicated_loss_batch = mtf.anonymize(loss_batch)
|
213 |
+
|
214 |
+
# Gets & prints info about no. trainable vars in the model & dimension names
|
215 |
+
get_graph_info(graph)
|
216 |
+
|
217 |
+
# 'lowers' mtf tensors into a tf graph - this enables us to export results as tf tensors
|
218 |
+
lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=True)
|
219 |
+
tf_loss = lowering.export_to_tf_tensor(loss)
|
220 |
+
tf_loss = tf.cast(tf_loss, tf.float32)
|
221 |
+
|
222 |
+
if mode == tf.estimator.ModeKeys.TRAIN:
|
223 |
+
# Use our patched version until mtf updates theirs
|
224 |
+
host_call = create_host_call(params['model_path'])
|
225 |
+
mtf.utils.remove_summaries()
|
226 |
+
|
227 |
+
# Creates train_op
|
228 |
+
tf_update_ops = [lowering.lowered_operation(op) for op in update_ops]
|
229 |
+
tf_update_ops.append(tf.assign_add(global_step, 1)) # Need to manually increment global_step
|
230 |
+
tf.logging.info(f"tf_update_ops: {tf_update_ops}")
|
231 |
+
train_op = tf.group(tf_update_ops)
|
232 |
+
else:
|
233 |
+
tf_mean_logits = lowering.export_to_tf_tensor(fully_replicated_mean_logits)
|
234 |
+
tf_max_logits = lowering.export_to_tf_tensor(fully_replicated_max_logits)
|
235 |
+
tf_loss_batch = tf.to_float(lowering.export_to_tf_tensor(fully_replicated_loss_batch))
|
236 |
+
|
237 |
+
with mtf.utils.outside_all_rewrites():
|
238 |
+
# Copy master variables to slices. Must be called first.
|
239 |
+
restore_hook = mtf.MtfRestoreHook(lowering)
|
240 |
+
if mode == tf.estimator.ModeKeys.TRAIN:
|
241 |
+
# Set up the checkpoint server and return the TPUEstimatorSpec
|
242 |
+
saver = tf.train.Saver(
|
243 |
+
tf.global_variables(),
|
244 |
+
sharded=True,
|
245 |
+
max_to_keep=10,
|
246 |
+
keep_checkpoint_every_n_hours=2,
|
247 |
+
defer_build=False,
|
248 |
+
save_relative_paths=True)
|
249 |
+
tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
|
250 |
+
saver_listener = mtf.MtfCheckpointSaverListener(lowering)
|
251 |
+
saver_hook = tf.train.CheckpointSaverHook(
|
252 |
+
params["model_path"],
|
253 |
+
save_steps=params["steps_per_checkpoint"],
|
254 |
+
saver=saver,
|
255 |
+
listeners=[saver_listener])
|
256 |
+
|
257 |
+
return tpu_estimator.TPUEstimatorSpec(
|
258 |
+
tf.estimator.ModeKeys.TRAIN,
|
259 |
+
loss=tf_loss,
|
260 |
+
host_call=host_call,
|
261 |
+
train_op=train_op,
|
262 |
+
training_hooks=[restore_hook, saver_hook])
|
263 |
+
|
264 |
+
elif mode == tf.estimator.ModeKeys.EVAL:
|
265 |
+
# Evaluation metrics
|
266 |
+
def _perplexity(loss):
|
267 |
+
perplexity = tf.exp(loss)
|
268 |
+
return tf.metrics.mean(perplexity)
|
269 |
+
|
270 |
+
def _bits_per_byte(loss):
|
271 |
+
bpb = loss * (0.29335 / math.log(2))
|
272 |
+
return tf.metrics.mean(bpb)
|
273 |
+
|
274 |
+
def _metric_fn(tf_mean_logits, tf_loss_batch):
|
275 |
+
mean_logits = tf.metrics.mean(tf_mean_logits)
|
276 |
+
loss = tf.reduce_mean(tf_loss_batch)
|
277 |
+
perp = _perplexity(loss)
|
278 |
+
bpb = _bits_per_byte(loss)
|
279 |
+
return {"mean_logits": mean_logits, "perplexity": perp, "bits per byte": bpb}
|
280 |
+
|
281 |
+
def _lambada_metric_fn(labels, tf_max_logits, tf_loss_batch):
|
282 |
+
eos_token = params["eos_id"]
|
283 |
+
answer_positions = tf.where(tf.math.not_equal(labels, eos_token))
|
284 |
+
|
285 |
+
correct_answers = tf.gather_nd(tf.math.equal(tf_max_logits, labels), answer_positions)
|
286 |
+
accuracy = tf.metrics.mean(tf.cast(correct_answers, tf.float32))
|
287 |
+
|
288 |
+
# I guess tf_loss_batch has z_loss and maybe other stuff added to it
|
289 |
+
# so maybe this should be calculated separately in the future
|
290 |
+
answer_loss = tf.gather_nd(tf_loss_batch, answer_positions)
|
291 |
+
log_perplexity = tf.metrics.mean(answer_loss)
|
292 |
+
|
293 |
+
return {"lambada_acc": accuracy, "lambada_log_ppl": log_perplexity}
|
294 |
+
|
295 |
+
eval_task = params["eval_task"]
|
296 |
+
if eval_task == "lambada":
|
297 |
+
eval_metrics = (_lambada_metric_fn, [labels, tf_max_logits, tf_loss_batch])
|
298 |
+
else:
|
299 |
+
eval_metrics = (_metric_fn, [tf_mean_logits, tf_loss_batch])
|
300 |
+
|
301 |
+
return tpu_estimator.TPUEstimatorSpec(
|
302 |
+
tf.estimator.ModeKeys.EVAL,
|
303 |
+
evaluation_hooks=[restore_hook],
|
304 |
+
loss=tf_loss,
|
305 |
+
eval_metrics=eval_metrics)
|
models/activations.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import mesh_tensorflow as mtf
|
2 |
+
import tensorflow.compat.v1 as tf
|
3 |
+
import random
|
4 |
+
|
5 |
+
BASE_FNS = {'gelu': mtf.gelu,
|
6 |
+
'relu': mtf.relu,
|
7 |
+
'sigmoid': mtf.sigmoid,
|
8 |
+
'tanh': mtf.tanh,
|
9 |
+
'selu': mtf.selu,
|
10 |
+
'elu': mtf.elu,
|
11 |
+
'abs': mtf.abs,
|
12 |
+
'sin': mtf.sin,
|
13 |
+
'cos': mtf.cos,
|
14 |
+
'sign': mtf.sign,
|
15 |
+
'silu': mtf.swish,
|
16 |
+
'softplus': mtf.softplus
|
17 |
+
}
|
18 |
+
|
19 |
+
|
20 |
+
def _arcsinh(x):
|
21 |
+
return mtf.log(x + mtf.sqrt(1 + x ** 2))
|
22 |
+
|
23 |
+
|
24 |
+
def _var(x, init):
|
25 |
+
return mtf.get_variable(x.mesh, f"activation-{random.randint(0, 2 ** 32):x}", [],
|
26 |
+
initializer=tf.constant_initializer(init), dtype=x.dtype)
|
27 |
+
|
28 |
+
|
29 |
+
def _pos_var(x, val):
|
30 |
+
return mtf.softplus(_var(x, 0)) + val
|
31 |
+
|
32 |
+
|
33 |
+
def _rrelu(x):
|
34 |
+
negative_scale = random.random()
|
35 |
+
return (negative_scale * mtf.abs(x) + x) / (1 + negative_scale)
|
36 |
+
|
37 |
+
|
38 |
+
def _elish(x):
|
39 |
+
cond = mtf.cast(mtf.greater(x, 0), x.dtype)
|
40 |
+
exp = mtf.exp(x)
|
41 |
+
return cond * x / (1 + exp) + (1 - cond) * (exp - 1) / (1 / exp + 1)
|
42 |
+
|
43 |
+
|
44 |
+
CUSTOM_FNS = {'lrelu001': lambda x: mtf.leaky_relu(x, alpha=0.01),
|
45 |
+
'lrelu020': lambda x: mtf.leaky_relu(x, alpha=0.20),
|
46 |
+
'id': lambda x: x,
|
47 |
+
'triangle_relax': lambda x: mtf.sin(x) - mtf.sin(3 * x) / 9 + mtf.sin(5 * x) / 25 - mtf.sin(7 * x) / 49,
|
48 |
+
'square_relax': lambda x: mtf.cos(x) - mtf.cos(3 * x) / 3 + mtf.cos(5 * x) / 5 - mtf.cos(7 * x) / 7,
|
49 |
+
'spike': lambda x: 1 / (1 + x ** 2),
|
50 |
+
'spike2': lambda x: mtf.exp(-x ** 2),
|
51 |
+
'tanhshrink': lambda x: x - tanh(x),
|
52 |
+
'softsign': lambda x: x / (mtf.abs(x) + 1),
|
53 |
+
'softmax': lambda x: mtf.softmax(x, x.shape[-1]),
|
54 |
+
'logsoftmax': lambda x: mtf.log_softmax(x, x.shape[-1]),
|
55 |
+
'bipolarsigmoid': lambda x: mtf.sigmoid(x) * 2 - 1,
|
56 |
+
'rrelu': _rrelu,
|
57 |
+
'elish': _elish,
|
58 |
+
'arcsinh': _arcsinh,
|
59 |
+
'aria': lambda x: x * (_var(x, 0) + _var(x, 1) / (
|
60 |
+
_pos_var(x, 0) + _var(x, 1) * mtf.exp(_var(x, -1) * x) ** (1 / _pos_var(x, 1)))),
|
61 |
+
'prelu': lambda x: mtf.leaky_relu(x, alpha=_var(x, 0.2)),
|
62 |
+
'parcsinh': lambda x: _var(x, 1) * _arcsinh(x * _pos_var(x, 1)),
|
63 |
+
'psoftplus': lambda x: _var(x, 1) * mtf.softplus(x * _var(x, 1)) + _var(x, 0),
|
64 |
+
'proottanh': lambda x: (x ** _pos_var(x, 2) + _pos_var(x, 1)) ** (1 / _pos_var(x, 3)) * mtf.tanh(x),
|
65 |
+
'maxsig': lambda x: mtf.maximum(x, mtf.sigmoid(x)),
|
66 |
+
'cosid': lambda x: mtf.cos(x) - x,
|
67 |
+
'minsin': lambda x: mtf.minimum(x, mtf.sin(x)),
|
68 |
+
'maxtanh': lambda x: mtf.maximum(x, mtf.tanh(x)),
|
69 |
+
'mish': lambda x: x * mtf.tanh(mtf.softplus(x)),
|
70 |
+
'tanhexp': lambda x: x * mtf.tanh(mtf.exp(x)),
|
71 |
+
'lisht': lambda x: x * mtf.tanh(x),
|
72 |
+
'seagull': lambda x: mtf.log(1 + x ** 2),
|
73 |
+
'snake': lambda x: x + mtf.sin(x) ** 2,
|
74 |
+
'roottanh': lambda x: (x ** 2 + 1) ** (1 / 3) * mtf.tanh(x),
|
75 |
+
'softplusmone': lambda x: mtf.softplus(x) - 1
|
76 |
+
}
|
77 |
+
|
78 |
+
|
79 |
+
def get_activation_fn(params):
|
80 |
+
if "activation_fn" in params:
|
81 |
+
activation_fn = params["activation_fn"]
|
82 |
+
else:
|
83 |
+
print("Defaulting to GELU activation (see here: https://arxiv.org/abs/1606.08415)")
|
84 |
+
activation_fn = "gelu"
|
85 |
+
|
86 |
+
if activation_fn in BASE_FNS:
|
87 |
+
return BASE_FNS[activation_fn]
|
88 |
+
|
89 |
+
if activation_fn in CUSTOM_FNS:
|
90 |
+
return CUSTOM_FNS[activation_fn]
|
91 |
+
|
92 |
+
raise ValueError('unknown activation function "activation_fn" in config')
|
93 |
+
|
94 |
+
|
95 |
+
|
models/gpt2/gpt2.py
ADDED
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""GPT-like model in Mesh-Tensorflow"""
|
2 |
+
import tensorflow.compat.v1 as tf
|
3 |
+
import mesh_tensorflow.transformer as mtf_transformer
|
4 |
+
|
5 |
+
from models.utils import parse_inputs, entmax_cross_entropy_with_logits
|
6 |
+
from models.layers import *
|
7 |
+
|
8 |
+
|
9 |
+
# --------------------------------------------------------------------------------
|
10 |
+
# TRANSFORMER BLOCK:
|
11 |
+
|
12 |
+
def block(params, scope, layer_num, bias, sequence_dim, memory_length_dim, pos_emb, variable_dtype, context=None):
|
13 |
+
use_mlp_glu = params["mlp_glu"] == True
|
14 |
+
use_scale_norm = params["scalenorm"] == True
|
15 |
+
use_moe = exists(params["moe_layers"]) and (layer_num in params["moe_layers"])
|
16 |
+
use_rezero = params["rezero"] == True
|
17 |
+
macaron_attention = params["macaron"] == True
|
18 |
+
|
19 |
+
def fn(x):
|
20 |
+
with tf.variable_scope(scope):
|
21 |
+
nx = x.shape[-1] # Grab last dimension from input
|
22 |
+
|
23 |
+
if use_rezero:
|
24 |
+
prenorm = identity
|
25 |
+
elif use_scale_norm:
|
26 |
+
prenorm = scale_norm
|
27 |
+
else:
|
28 |
+
prenorm = layer_norm
|
29 |
+
|
30 |
+
pre_residual_fn = rezero if use_rezero else identity
|
31 |
+
|
32 |
+
attention_type = params["attention_types"][layer_num]
|
33 |
+
|
34 |
+
if macaron_attention:
|
35 |
+
mult = 0.5
|
36 |
+
mlp_fn = mlp_glu if use_mlp_glu else mlp
|
37 |
+
intermediate_size = nx.size * 4 * (1 if not use_mlp_glu else 2)
|
38 |
+
# Define intermediate layer of mlp - to split
|
39 |
+
dim_intermediate_expanded = mtf.Dimension("intermediate_expanded", intermediate_size)
|
40 |
+
m = mlp_fn(x, "mlp_macaron", dim_intermediate_expanded, variable_dtype=variable_dtype, params=params)
|
41 |
+
|
42 |
+
x = x + (m * mult)
|
43 |
+
else:
|
44 |
+
mult = 1
|
45 |
+
|
46 |
+
if attention_type != "none":
|
47 |
+
res_x = prenorm(x, "norm_1", variable_dtype=variable_dtype, params=params)
|
48 |
+
a = attn(res_x, "attn", nx, attention_type=attention_type,
|
49 |
+
params=params, bias=bias, dim_seq=sequence_dim, memory_length_dim=memory_length_dim,
|
50 |
+
variable_dtype=variable_dtype, context=context, pos_emb=pos_emb)
|
51 |
+
else:
|
52 |
+
a = x
|
53 |
+
|
54 |
+
x = x + pre_residual_fn(a, "norm_rezero_1", dtype=variable_dtype)
|
55 |
+
|
56 |
+
res_x = prenorm(x, "norm_2", variable_dtype=variable_dtype, params=params)
|
57 |
+
|
58 |
+
if use_moe:
|
59 |
+
moe_params = mtf.transformer.moe.HParams()
|
60 |
+
mtf.transformer.moe.set_default_moe_hparams(moe_params)
|
61 |
+
moe_params.add_hparam("moe_min_expert_capacity", 1)
|
62 |
+
moe_params.add_hparam("moe_use_experts_attention", False)
|
63 |
+
|
64 |
+
# Override defaults
|
65 |
+
for k, v in params["moe_params"].items():
|
66 |
+
moe_params.add_hparam(k, v)
|
67 |
+
|
68 |
+
moe_train = params["mode"] == "train"
|
69 |
+
|
70 |
+
m, aux_loss = mtf.transformer.moe.transformer_moe_layer_v1(res_x, x.shape[-1], moe_params,
|
71 |
+
train=moe_train,
|
72 |
+
mesh_shape=params["mesh_shape"],
|
73 |
+
layout=params["layout"],
|
74 |
+
activation=params.get("moe_activation",
|
75 |
+
"relu"),
|
76 |
+
variable_dtype=variable_dtype,
|
77 |
+
num_microbatches=params["num_microbatches"])
|
78 |
+
m = mtf.dropout(m, rate=params["res_dropout"], name="moe_dropout")
|
79 |
+
else:
|
80 |
+
|
81 |
+
mlp_fn = mlp_glu if use_mlp_glu else mlp
|
82 |
+
intermediate_size = nx.size * 4 * (1 if not use_mlp_glu else 2)
|
83 |
+
|
84 |
+
# Define intermediate layer of mlp - to split
|
85 |
+
dim_intermediate_expanded = mtf.Dimension("intermediate_expanded", intermediate_size)
|
86 |
+
|
87 |
+
m = mlp_fn(res_x, "mlp", dim_intermediate_expanded, variable_dtype=variable_dtype, params=params)
|
88 |
+
aux_loss = mtf.zeros(x.mesh, mtf.Shape([]), dtype=variable_dtype.slice_dtype)
|
89 |
+
|
90 |
+
x = x + pre_residual_fn((m * mult), "norm_rezero_2", variable_dtype)
|
91 |
+
return x, aux_loss
|
92 |
+
|
93 |
+
return fn
|
94 |
+
|
95 |
+
|
96 |
+
# --------------------------------------------------------------------------------
|
97 |
+
# GPT2 MODEL:
|
98 |
+
|
99 |
+
def model(mtf_features, other_features, params, mesh, variable_dtype, context=None):
|
100 |
+
"""A GPT style model implemented in mesh tensorflow."""
|
101 |
+
|
102 |
+
x, batch_dim, sequence_dim, embd_dim, vocab_dim, embed_sequence_dim = parse_inputs(mtf_features, other_features)
|
103 |
+
|
104 |
+
if is_incremental_inference(context):
|
105 |
+
# reshape inputs if in inference mode
|
106 |
+
x = mtf.gather(x, context.position - 1, sequence_dim)
|
107 |
+
x = mtf.reshape(x, [batch_dim])
|
108 |
+
|
109 |
+
use_axial_pos_emb = exists(params["axial_pos_emb"])
|
110 |
+
use_rotary_emb = exists(params["rotary_emb"])
|
111 |
+
|
112 |
+
# Text encoding
|
113 |
+
wte = mtf.get_variable(mesh, "wte", mtf.Shape([vocab_dim, embd_dim]),
|
114 |
+
initializer=tf.random_normal_initializer(stddev=0.02),
|
115 |
+
master_dtype=variable_dtype.master_dtype,
|
116 |
+
slice_dtype=variable_dtype.slice_dtype,
|
117 |
+
activation_dtype=variable_dtype.activation_dtype)
|
118 |
+
|
119 |
+
with tf.variable_scope("token_embd"):
|
120 |
+
# Text embedding
|
121 |
+
h = mtf.gather(wte, x, vocab_dim)
|
122 |
+
if params["embed_dropout"] > 0 and params["mode"] == "train":
|
123 |
+
h = mtf.dropout(h, rate=params["embed_dropout"], name="wte_dropout")
|
124 |
+
|
125 |
+
# Position encoding
|
126 |
+
|
127 |
+
if use_rotary_emb:
|
128 |
+
wpe = None
|
129 |
+
layer_pos_emb = rotary_positional_emb(mesh, sequence_dim, params, variable_dtype)
|
130 |
+
elif use_axial_pos_emb:
|
131 |
+
wpe = axial_positional_emb(embd_dim, mesh, params, variable_dtype)
|
132 |
+
layer_pos_emb = None
|
133 |
+
else:
|
134 |
+
# Use standard position encoding
|
135 |
+
wpe = mtf.get_variable(mesh, "wpe", mtf.Shape([embed_sequence_dim, embd_dim]),
|
136 |
+
initializer=tf.random_normal_initializer(stddev=0.01),
|
137 |
+
master_dtype=variable_dtype.master_dtype,
|
138 |
+
slice_dtype=variable_dtype.slice_dtype,
|
139 |
+
activation_dtype=variable_dtype.activation_dtype)
|
140 |
+
layer_pos_emb = None
|
141 |
+
|
142 |
+
if exists(wpe):
|
143 |
+
with tf.variable_scope("pos_embd"):
|
144 |
+
# Positional embedding
|
145 |
+
position_indices = mtf.range(mesh, sequence_dim, tf.int64) if not is_incremental_inference(context) else (
|
146 |
+
context.position - 1)
|
147 |
+
pos_emb = mtf.gather(wpe, position_indices, wpe.shape[0])
|
148 |
+
if params["embed_dropout"] > 0 and params["mode"] == "train":
|
149 |
+
pos_emb = mtf.dropout(pos_emb, rate=params["embed_dropout"], name="wte_dropout")
|
150 |
+
h += pos_emb
|
151 |
+
|
152 |
+
aux_losses = 0 # instantiate auxiliary losses (for MOE models)
|
153 |
+
|
154 |
+
for layer in range(params["n_layer"]):
|
155 |
+
# attn blocks
|
156 |
+
share_parameters = exists(params["share_parameters"]) and params["share_parameters"] == True
|
157 |
+
block_scope = f"h{layer}" if not share_parameters else ""
|
158 |
+
|
159 |
+
block_fn = block(params=params, scope=block_scope, layer_num=layer,
|
160 |
+
bias=other_features["attn_bias"],
|
161 |
+
sequence_dim=sequence_dim,
|
162 |
+
memory_length_dim=other_features["memory_length_dim"],
|
163 |
+
pos_emb = layer_pos_emb,
|
164 |
+
variable_dtype=variable_dtype,
|
165 |
+
context=context)
|
166 |
+
|
167 |
+
# If true and in train mode, enable gradient checkpointing
|
168 |
+
recompute_grad = params["recompute_grad"] and (params["mode"] == "train") == True
|
169 |
+
h, loss = block_fn(h) if not recompute_grad else mtf.recompute_grad(block_fn, [h])
|
170 |
+
aux_losses += loss
|
171 |
+
|
172 |
+
no_weight_tie_emb = params["no_weight_tie"] == True
|
173 |
+
if no_weight_tie_emb:
|
174 |
+
with tf.variable_scope("wte_final_linear"):
|
175 |
+
logits = linear(h, "linear_out", vocab_dim, variable_dtype=variable_dtype, params=params)
|
176 |
+
else:
|
177 |
+
# Layer normalize & affine transform
|
178 |
+
h = layer_norm(h, "ln_f", variable_dtype=variable_dtype)
|
179 |
+
seq_dim = sequence_dim if not is_incremental_inference(context) else mtf.Dimension("sequence", 1)
|
180 |
+
with tf.variable_scope("wte_final_einsum"):
|
181 |
+
# Equivalent to tf.matmul
|
182 |
+
logits = mtf.einsum([h, wte], output_shape=[batch_dim, seq_dim, vocab_dim])
|
183 |
+
|
184 |
+
if params["mode"] in ["train", "eval"]:
|
185 |
+
labels = mtf_features["labels"]
|
186 |
+
z_loss = params.get("z_loss", 1e-4) # an auxiliary loss used to stabilize mtf xentropy
|
187 |
+
|
188 |
+
# Go to full precision for the logits
|
189 |
+
logits = mtf.cast(logits, tf.float32)
|
190 |
+
|
191 |
+
use_entmax_loss = params.get("entmax_loss", False)
|
192 |
+
loss_fn = mtf.layers.softmax_cross_entropy_with_logits if not use_entmax_loss else entmax_cross_entropy_with_logits
|
193 |
+
|
194 |
+
with tf.variable_scope("xentropy_final"):
|
195 |
+
loss_batch = loss_fn(logits=logits, targets=labels,
|
196 |
+
vocab_dim=logits.shape[-1], z_loss=z_loss)
|
197 |
+
|
198 |
+
# For non-autoregressive models (masked language modeling training)
|
199 |
+
# Make sure labels with padding tokens are not counted in the loss
|
200 |
+
if not params["causal"]:
|
201 |
+
padding_id = params.get("padding_id", 0)
|
202 |
+
loss_batch = mtf.where(mtf.not_equal(labels, padding_id), loss_batch, mtf.zeros_like(loss_batch))
|
203 |
+
|
204 |
+
with tf.variable_scope("reduce_mean_final"):
|
205 |
+
loss = mtf.reduce_mean(loss_batch)
|
206 |
+
|
207 |
+
loss += aux_losses # Add on auxiliary losses (currently only used for MoE)
|
208 |
+
loss /= params["num_microbatches"]
|
209 |
+
# Convert to train dtype
|
210 |
+
loss = mtf.cast(loss, variable_dtype.slice_dtype)
|
211 |
+
else:
|
212 |
+
loss = None
|
213 |
+
loss_batch = None
|
214 |
+
|
215 |
+
# Cast back to checkpoint dtype
|
216 |
+
logits = mtf.cast(logits, variable_dtype.master_dtype)
|
217 |
+
return logits, loss, loss_batch
|
models/layers.py
ADDED
@@ -0,0 +1,357 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import mesh_tensorflow as mtf
|
2 |
+
import tensorflow.compat.v1 as tf
|
3 |
+
import math
|
4 |
+
import mesh_tensorflow.transformer as mtf_transformer
|
5 |
+
|
6 |
+
from models.activations import get_activation_fn
|
7 |
+
|
8 |
+
|
9 |
+
# --------------------------------------------------------------------------------
|
10 |
+
# LAYERS:
|
11 |
+
|
12 |
+
sentinel = object()
|
13 |
+
|
14 |
+
|
15 |
+
def exists(x):
|
16 |
+
return x is not None
|
17 |
+
|
18 |
+
|
19 |
+
def identity(x, *args, **kwargs):
|
20 |
+
return x
|
21 |
+
|
22 |
+
|
23 |
+
def is_incremental_inference(context):
|
24 |
+
return exists(context) and context.mode == "incremental"
|
25 |
+
|
26 |
+
|
27 |
+
def norm(x, axis, epsilon=1e-8):
|
28 |
+
x -= mtf.reduce_mean(x, reduced_dim=axis, name="norm_reduce_mean_u")
|
29 |
+
s = mtf.reduce_mean(mtf.square(x), reduced_dim=axis, name="norm_reduce_mean_s")
|
30 |
+
return x * mtf.rsqrt(s + epsilon)
|
31 |
+
|
32 |
+
|
33 |
+
def rezero(x, scope, dtype):
|
34 |
+
with tf.variable_scope(scope):
|
35 |
+
g = mtf.get_variable(x.mesh, "g", [], initializer=tf.constant_initializer(0), dtype=dtype)
|
36 |
+
return x * g
|
37 |
+
|
38 |
+
|
39 |
+
def scale_norm(x, scope, *, variable_dtype, axis=sentinel, epsilon=1e-5, params=None):
|
40 |
+
if axis is sentinel:
|
41 |
+
axis = x.shape[-1]
|
42 |
+
|
43 |
+
with tf.variable_scope(scope):
|
44 |
+
g = mtf.get_variable(x.mesh, "g", [], initializer=tf.constant_initializer(1),
|
45 |
+
master_dtype=variable_dtype.master_dtype,
|
46 |
+
slice_dtype=variable_dtype.slice_dtype,
|
47 |
+
activation_dtype=variable_dtype.activation_dtype)
|
48 |
+
|
49 |
+
x = norm(x, axis, epsilon)
|
50 |
+
x = x * g
|
51 |
+
return x
|
52 |
+
|
53 |
+
|
54 |
+
def layer_norm(x, scope, *, variable_dtype, axis=sentinel, epsilon=1e-5, params=None):
|
55 |
+
"""Normalize to mean = 0, std = 1, then do a diagonal affine transform."""
|
56 |
+
if axis is sentinel:
|
57 |
+
axis = x.shape[-1]
|
58 |
+
|
59 |
+
with tf.variable_scope(scope):
|
60 |
+
n_state = x.shape[-1]
|
61 |
+
|
62 |
+
g = mtf.get_variable(x.mesh, "g", [n_state], initializer=tf.constant_initializer(1),
|
63 |
+
master_dtype=variable_dtype.master_dtype,
|
64 |
+
slice_dtype=variable_dtype.slice_dtype,
|
65 |
+
activation_dtype=variable_dtype.activation_dtype)
|
66 |
+
b = mtf.get_variable(x.mesh, "b", [n_state], initializer=tf.constant_initializer(0),
|
67 |
+
master_dtype=variable_dtype.master_dtype,
|
68 |
+
slice_dtype=variable_dtype.slice_dtype,
|
69 |
+
activation_dtype=variable_dtype.activation_dtype)
|
70 |
+
|
71 |
+
x = norm(x, axis, epsilon)
|
72 |
+
x = x * g + b
|
73 |
+
return x
|
74 |
+
|
75 |
+
|
76 |
+
def linear_attention(q, k, v):
|
77 |
+
batch_dim, seq_dim, head_dim, dim_out = (v.shape[0], v.shape[1], v.shape[2], v.shape[3])
|
78 |
+
q = mtf.rename_dimension(q, "features_per_head", "features_per_head_in")
|
79 |
+
k = mtf.rename_dimension(k, "features_per_head", "features_per_head_in")
|
80 |
+
|
81 |
+
dim_in = k.shape[-1]
|
82 |
+
|
83 |
+
q = mtf.softmax(q, dim_in)
|
84 |
+
k = mtf.softmax(k, seq_dim)
|
85 |
+
|
86 |
+
context = mtf.einsum([k, v], output_shape=[batch_dim, head_dim, dim_in, dim_out])
|
87 |
+
attn = mtf.einsum([q, context], output_shape=[batch_dim, seq_dim, head_dim, dim_out])
|
88 |
+
return attn
|
89 |
+
|
90 |
+
|
91 |
+
def causal_linear_attention(q, k, v, eps = 1e-6):
|
92 |
+
batch_dim, seq_dim, head_dim, dim_out = (v.shape[0], v.shape[1], v.shape[2], v.shape[3])
|
93 |
+
q = mtf.rename_dimension(q, "features_per_head", "features_per_head_in")
|
94 |
+
k = mtf.rename_dimension(k, "features_per_head", "features_per_head_in")
|
95 |
+
|
96 |
+
dim_in = k.shape[-1]
|
97 |
+
|
98 |
+
q = mtf.softmax(q, dim_in)
|
99 |
+
k = mtf.exp(k)
|
100 |
+
|
101 |
+
cumulative_k = mtf.cumsum(k, seq_dim) + eps
|
102 |
+
D_inv = 1. / mtf.einsum([q, cumulative_k], output_shape=[batch_dim, seq_dim, head_dim])
|
103 |
+
|
104 |
+
context = mtf.einsum([k, v], output_shape=[batch_dim, seq_dim, head_dim, dim_in, dim_out])
|
105 |
+
cumulative_context = mtf.cumsum(context, seq_dim)
|
106 |
+
|
107 |
+
attn = mtf.einsum([q, cumulative_context, D_inv], output_shape=[batch_dim, seq_dim, head_dim, dim_out])
|
108 |
+
return attn
|
109 |
+
|
110 |
+
|
111 |
+
def linear(x, scope, nf, *, w_init_stdev=0.02, variable_dtype, params=None, scale=False):
|
112 |
+
# nf = number of features
|
113 |
+
if params["scale_by_depth"] and scale:
|
114 |
+
# Scale by sqrt(num_layers), only happens at the final projection before a res block output
|
115 |
+
w_init_stdev = w_init_stdev * (1. / math.sqrt(params["n_layer"]))
|
116 |
+
if params["scale_by_in"]: # Scale by sqrt(num_input_features)
|
117 |
+
w_init_stdev = w_init_stdev * (1. / math.sqrt(x.shape[-1].size)) # Dimension is a namedtuple of (name, size)
|
118 |
+
# Not in the variable_scope because mtf already has a variable_scope in it
|
119 |
+
with tf.variable_scope("conv1d_main"):
|
120 |
+
c = mtf.layers.dense(x, new_dims=[nf], reduced_dims=[x.shape[-1]], name=scope, use_bias=True,
|
121 |
+
kernel_initializer=tf.random_normal_initializer(stddev=w_init_stdev),
|
122 |
+
variable_dtype=variable_dtype,
|
123 |
+
)
|
124 |
+
return c
|
125 |
+
|
126 |
+
|
127 |
+
def memory_key_values(k, v, num_mem_kv, dim_batch, dim_heads, variable_dtype, mesh):
|
128 |
+
"""memory / key values from all attention paper"""
|
129 |
+
|
130 |
+
dim_mem_kv = mtf.Dimension("mem_kv_sequence", num_mem_kv)
|
131 |
+
emb_dim = k.shape[-1]
|
132 |
+
mem_std = 1 / math.sqrt(emb_dim.size)
|
133 |
+
|
134 |
+
mem_k = mtf.get_variable(mesh, "mem_k", mtf.Shape([dim_mem_kv, dim_heads, emb_dim]),
|
135 |
+
initializer=tf.random_normal_initializer(stddev=mem_std),
|
136 |
+
master_dtype=variable_dtype.master_dtype,
|
137 |
+
slice_dtype=variable_dtype.slice_dtype,
|
138 |
+
activation_dtype=variable_dtype.activation_dtype,
|
139 |
+
)
|
140 |
+
mem_v = mtf.get_variable(mesh, "mem_v", mtf.Shape([dim_mem_kv, dim_heads, emb_dim]),
|
141 |
+
initializer=tf.random_normal_initializer(stddev=mem_std),
|
142 |
+
master_dtype=variable_dtype.master_dtype,
|
143 |
+
slice_dtype=variable_dtype.slice_dtype,
|
144 |
+
activation_dtype=variable_dtype.activation_dtype)
|
145 |
+
|
146 |
+
mem_k, mem_v = map(lambda t: mtf.broadcast(t, [dim_batch, dim_mem_kv, dim_heads, emb_dim]),
|
147 |
+
(mem_k, mem_v))
|
148 |
+
mem_k, mem_v = map(lambda t: mtf.rename_dimension(t, "mem_kv_sequence", "sequence"),
|
149 |
+
(mem_k, mem_v))
|
150 |
+
|
151 |
+
k = mtf.concat([mem_k, k], "sequence")
|
152 |
+
v = mtf.concat([mem_v, v], "sequence")
|
153 |
+
return k, v
|
154 |
+
|
155 |
+
|
156 |
+
def attn(x, scope, n_state, *, attention_type, params, bias, dim_seq, memory_length_dim, variable_dtype, context=None, pos_emb=None):
|
157 |
+
# x :: [batch, seq, n_embd]
|
158 |
+
x_shape, dim_batch, *_, dim_embd, mesh = x.shape, *x.shape, x.mesh
|
159 |
+
|
160 |
+
# n_state is the same as config["n_embd"], which is also the same as dim_embd.
|
161 |
+
assert n_state.size % params["n_head"] == 0
|
162 |
+
|
163 |
+
dim_heads = mtf.Dimension("heads", params["n_head"])
|
164 |
+
|
165 |
+
num_mem_kv = params.get("num_mem_kv", 0)
|
166 |
+
use_num_mem_kv = num_mem_kv > 0
|
167 |
+
|
168 |
+
with tf.variable_scope(scope):
|
169 |
+
# Compute attention inputs
|
170 |
+
dim_kv = mtf.Dimension("features_per_head", params["n_embd"] // params["n_head"])
|
171 |
+
mtfparams = mtf.transformer.attention.attention_params_simple(
|
172 |
+
x.mesh,
|
173 |
+
io_dim=dim_embd,
|
174 |
+
kv_dim=dim_kv,
|
175 |
+
heads_dim=dim_heads,
|
176 |
+
variable_dtype=variable_dtype
|
177 |
+
)
|
178 |
+
q = mtfparams.compute_q(x)
|
179 |
+
k = mtfparams.compute_k(x)
|
180 |
+
v = mtfparams.compute_v(x)
|
181 |
+
|
182 |
+
if is_incremental_inference(context):
|
183 |
+
one_hot = mtf.one_hot(context.position - 1, dim_seq, dtype=variable_dtype.master_dtype)
|
184 |
+
inv_one_hot = 1.0 - one_hot
|
185 |
+
old_k, old_v = context.get_states(2)
|
186 |
+
k = old_k * inv_one_hot + k * one_hot
|
187 |
+
v = old_v * inv_one_hot + v * one_hot
|
188 |
+
|
189 |
+
if exists(context):
|
190 |
+
context.record_new_states([k, v])
|
191 |
+
|
192 |
+
if exists(pos_emb):
|
193 |
+
cos, sin = pos_emb
|
194 |
+
k = apply_rotary_emb(k, cos, sin)
|
195 |
+
|
196 |
+
if is_incremental_inference(context):
|
197 |
+
seq_dim = cos.shape.get_dim_by_name('sequence')
|
198 |
+
cos = mtf.gather(cos, context.position - 1, seq_dim)
|
199 |
+
sin = mtf.gather(sin, context.position - 1, seq_dim)
|
200 |
+
|
201 |
+
q = apply_rotary_emb(q, cos, sin)
|
202 |
+
|
203 |
+
with tf.variable_scope("attention"):
|
204 |
+
if attention_type == "local":
|
205 |
+
# `local_attention_1d` has built in autoregressive masking, so we don't need mask_attn_weights.
|
206 |
+
radius = params.get("local_attention_radius", 256)
|
207 |
+
|
208 |
+
if is_incremental_inference(context):
|
209 |
+
q *= one_hot
|
210 |
+
|
211 |
+
a = mtf_transformer.attention.local_attention_1d(
|
212 |
+
q, k, v,
|
213 |
+
length_dim=k.shape[1],
|
214 |
+
key_dim=dim_kv,
|
215 |
+
value_dim=dim_kv,
|
216 |
+
radius=radius,
|
217 |
+
length_dim_num_splits=1,
|
218 |
+
fully_autoregressive=params["causal"],
|
219 |
+
attention_kwargs={},
|
220 |
+
)
|
221 |
+
|
222 |
+
if is_incremental_inference(context):
|
223 |
+
a = mtf.gather(a, context.position - 1, dim_seq)
|
224 |
+
|
225 |
+
elif attention_type == "global":
|
226 |
+
|
227 |
+
# TODO: pass in fake context
|
228 |
+
# Broadcast mask bias across batch and heads
|
229 |
+
if exists(bias):
|
230 |
+
if not is_incremental_inference(context):
|
231 |
+
broadcasted_bias = mtf.broadcast(bias, [dim_batch, dim_heads, bias.shape[-2], bias.shape[-1]])
|
232 |
+
else:
|
233 |
+
# In the incremental case, a custom mask needs to be built that masks out all key/values that are greater than the current position
|
234 |
+
bias = mtf.gather(bias, context.position - 1, dim_seq)
|
235 |
+
broadcasted_bias = mtf.broadcast(bias, [dim_batch, dim_heads, bias.shape[-1]])
|
236 |
+
|
237 |
+
# memory key / values, from all-attention paper
|
238 |
+
if use_num_mem_kv:
|
239 |
+
k, v = memory_key_values(k, v, num_mem_kv, dim_batch, dim_heads, variable_dtype, mesh)
|
240 |
+
|
241 |
+
k = mtf.replace_dimensions(k, k.shape[1], memory_length_dim)
|
242 |
+
v = mtf.replace_dimensions(v, v.shape[1], memory_length_dim)
|
243 |
+
|
244 |
+
attn_dropout_rate = params["attn_dropout"] if params["mode"] == "train" else 0
|
245 |
+
|
246 |
+
a = mtf_transformer.attention.attention(
|
247 |
+
q, k, v,
|
248 |
+
memory_length_dim=memory_length_dim,
|
249 |
+
key_dim=dim_kv,
|
250 |
+
value_dim=dim_kv,
|
251 |
+
bias=broadcasted_bias,
|
252 |
+
dropout_rate=attn_dropout_rate
|
253 |
+
)
|
254 |
+
|
255 |
+
elif attention_type == "linear":
|
256 |
+
linear_attn_fn = causal_linear_attention if params["causal"] else linear_attention
|
257 |
+
a = linear_attn_fn(q, k, v)
|
258 |
+
|
259 |
+
else:
|
260 |
+
raise NotImplementedError("Unknown attention type {}!".format(attention_type))
|
261 |
+
|
262 |
+
with tf.variable_scope("compute_output"):
|
263 |
+
a = mtfparams.compute_output(a, x_shape)
|
264 |
+
|
265 |
+
with tf.variable_scope("compute_output_bias"):
|
266 |
+
b = mtf.get_variable(x.mesh, "o_b", [dim_embd], initializer=tf.constant_initializer(0),
|
267 |
+
master_dtype=variable_dtype.master_dtype,
|
268 |
+
slice_dtype=variable_dtype.slice_dtype,
|
269 |
+
activation_dtype=variable_dtype.activation_dtype)
|
270 |
+
a += b
|
271 |
+
|
272 |
+
if params["mode"] == "train" and params["res_dropout"] > 0:
|
273 |
+
a = mtf.dropout(a, rate=params["res_dropout"], name="res_dropout")
|
274 |
+
return a
|
275 |
+
|
276 |
+
|
277 |
+
def mlp(x, scope, n_state, *, variable_dtype, params):
|
278 |
+
activation_fn = get_activation_fn(params)
|
279 |
+
with tf.variable_scope(scope):
|
280 |
+
nx = x.shape[-1]
|
281 |
+
h = activation_fn(linear(x, "c_fc", n_state, variable_dtype=variable_dtype, params=params))
|
282 |
+
h2 = linear(h, "c_proj", nx, variable_dtype=variable_dtype, params=params, scale=True)
|
283 |
+
if params["mode"] == "train" and params["res_dropout"] > 0:
|
284 |
+
h2 = mtf.dropout(h2, rate=params["res_dropout"], name="mlp_dropout")
|
285 |
+
return h2
|
286 |
+
|
287 |
+
|
288 |
+
def mlp_glu(x, scope, n_state, *, variable_dtype, params):
|
289 |
+
activation_fn = get_activation_fn(params)
|
290 |
+
with tf.variable_scope(scope):
|
291 |
+
nx = x.shape[-1]
|
292 |
+
h = linear(x, "c_fc", n_state, params=params)
|
293 |
+
|
294 |
+
h, gate = mtf.split(h, h.shape[-1], 2)
|
295 |
+
h *= activation_fn(gate)
|
296 |
+
|
297 |
+
h2 = linear(h, "c_proj", nx, variable_dtype=variable_dtype, params=params, scale=True)
|
298 |
+
if params["mode"] == "train" and params["res_dropout"] > 0:
|
299 |
+
h2 = mtf.dropout(h2, rate=params["res_dropout"], name="mlp_dropout")
|
300 |
+
return h2
|
301 |
+
|
302 |
+
|
303 |
+
def axial_positional_emb(embd_dim, mesh, params, variable_dtype):
|
304 |
+
# Use axial position encoding
|
305 |
+
axial_dim_1, axial_dim_2 = params["axial_pos_emb"]
|
306 |
+
|
307 |
+
axial_dim = mtf.Dimension("axial_dim", axial_dim_1 * axial_dim_2)
|
308 |
+
dim_axials = [mtf.Dimension(f"axial_dim_{i}", t) for i, t in enumerate((axial_dim_1, axial_dim_2))]
|
309 |
+
|
310 |
+
axial_wpe_1 = mtf.get_variable(mesh, "axial_wpe_1", mtf.Shape([dim_axials[0], embd_dim]),
|
311 |
+
initializer=tf.random_normal_initializer(stddev=0.01),
|
312 |
+
master_dtype=variable_dtype.master_dtype,
|
313 |
+
slice_dtype=variable_dtype.slice_dtype,
|
314 |
+
activation_dtype=variable_dtype.activation_dtype)
|
315 |
+
|
316 |
+
axial_wpe_2 = mtf.get_variable(mesh, "axial_wpe_2", mtf.Shape([dim_axials[1], embd_dim]),
|
317 |
+
initializer=tf.random_normal_initializer(stddev=0.01),
|
318 |
+
master_dtype=variable_dtype.master_dtype,
|
319 |
+
slice_dtype=variable_dtype.slice_dtype,
|
320 |
+
activation_dtype=variable_dtype.activation_dtype)
|
321 |
+
|
322 |
+
axial_wpe_1, axial_wpe_2 = map(lambda t: mtf.broadcast(t, [dim_axials[0], dim_axials[1], embd_dim]),
|
323 |
+
(axial_wpe_1, axial_wpe_2))
|
324 |
+
wpe = (axial_wpe_1 + axial_wpe_2) / 2
|
325 |
+
|
326 |
+
wpe = mtf.reshape(wpe, [axial_dim, embd_dim])
|
327 |
+
|
328 |
+
return wpe
|
329 |
+
|
330 |
+
def rotary_positional_emb(mesh, sequence_dim, params, variable_dtype):
|
331 |
+
dtype = variable_dtype.master_dtype
|
332 |
+
dim_head = params["n_embd"] // params["n_head"]
|
333 |
+
|
334 |
+
dim_head = mtf.Dimension("features_per_head", dim_head)
|
335 |
+
half_dim_head = mtf.Dimension("half_features_per_head", dim_head.size // 2)
|
336 |
+
|
337 |
+
dim_range = mtf.range(mesh, half_dim_head, dtype) * 2 / dim_head.size
|
338 |
+
half_freqs = 1. / mtf.pow(mtf.constant(mesh, 10000, dtype = dtype), dim_range)
|
339 |
+
|
340 |
+
seq = mtf.range(mesh, sequence_dim, dtype)
|
341 |
+
half_freqs = mtf.einsum([half_freqs, seq], [sequence_dim, half_dim_head])
|
342 |
+
|
343 |
+
freqs = mtf.concat((half_freqs, half_freqs), half_dim_head.name)
|
344 |
+
freqs = mtf.rename_dimension(freqs, half_dim_head.name, dim_head.name)
|
345 |
+
return mtf.cos(freqs), mtf.sin(freqs)
|
346 |
+
|
347 |
+
def rotate_half(x):
|
348 |
+
dim_head_name = "features_per_head"
|
349 |
+
dim_head = x.shape.get_dim_by_name(dim_head_name)
|
350 |
+
half_dim_head_size = dim_head.size // 2
|
351 |
+
x1 = mtf.slice(x, 0, half_dim_head_size, dim_head_name)
|
352 |
+
x2 = mtf.slice(x, half_dim_head_size, half_dim_head_size, dim_head_name)
|
353 |
+
return mtf.concat((-x2, x1), dim_head.name)
|
354 |
+
|
355 |
+
def apply_rotary_emb(x, cos, sin):
|
356 |
+
rotated_x = rotate_half(x)
|
357 |
+
return x * cos + rotated_x * sin
|
models/utils.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
import mesh_tensorflow as mtf
|
3 |
+
from functools import partial
|
4 |
+
|
5 |
+
|
6 |
+
def entmax_backward(explicit_inputs, all_inputs, forward_operations, outputs, output_grads, alpha=1.3, dim=None,
|
7 |
+
n_iter=50):
|
8 |
+
x, = explicit_inputs
|
9 |
+
y, = outputs
|
10 |
+
dY, = output_grads
|
11 |
+
|
12 |
+
gppr = mtf.where(mtf.greater(y, 0), mtf.pow(y, (2 - alpha)), mtf.zeros_like(y))
|
13 |
+
dX = dY * gppr
|
14 |
+
|
15 |
+
q = mtf.reduce_sum(dX, reduced_dim=dim) / mtf.reduce_sum(gppr, reduced_dim=dim)
|
16 |
+
dX = dX - q * gppr
|
17 |
+
|
18 |
+
return dX,
|
19 |
+
|
20 |
+
|
21 |
+
def entmax_forward(x, alpha=1.3, dim=None, n_iter=50):
|
22 |
+
assert alpha > 1 and alpha < 2, 'alpha must be between 1 and 2'
|
23 |
+
|
24 |
+
_gp = lambda x, alpha: x ** (alpha - 1)
|
25 |
+
_gp_inv = lambda x, alpha: mtf.pow(x, (1 / (alpha - 1)))
|
26 |
+
_p = lambda x, alpha: _gp_inv(mtf.relu(x), alpha)
|
27 |
+
|
28 |
+
dim = x.shape[-1] if dim is None else dim
|
29 |
+
d = dim.size
|
30 |
+
|
31 |
+
x = x * (alpha - 1)
|
32 |
+
|
33 |
+
max_val = mtf.reduce_max(x, reduced_dim=dim)
|
34 |
+
|
35 |
+
tau_lo = max_val - _gp(1, alpha)
|
36 |
+
tau_hi = max_val - _gp(1 / d, alpha)
|
37 |
+
|
38 |
+
f_lo = mtf.reduce_sum(_p(x - tau_lo, alpha), reduced_dim=dim) - 1
|
39 |
+
|
40 |
+
dm = tau_hi - tau_lo
|
41 |
+
|
42 |
+
for _ in range(n_iter):
|
43 |
+
dm = dm / 2
|
44 |
+
tau_m = tau_lo + dm
|
45 |
+
p_m = _p(x - tau_m, alpha)
|
46 |
+
f_m = mtf.reduce_sum(p_m, reduced_dim=dim) - 1
|
47 |
+
|
48 |
+
mask = mtf.greater_equal((f_m * f_lo), 0)
|
49 |
+
tau_lo = mtf.where(mask, tau_m, tau_lo)
|
50 |
+
|
51 |
+
p_m = p_m / mtf.reduce_sum(p_m, reduced_dim=dim)
|
52 |
+
return p_m
|
53 |
+
|
54 |
+
|
55 |
+
def entmax(x, alpha=1.3, dim=None, n_iter=50):
|
56 |
+
kwargs = dict(alpha=alpha, dim=dim, n_iter=n_iter)
|
57 |
+
|
58 |
+
return mtf.custom_gradient(
|
59 |
+
partial(entmax_forward, **kwargs),
|
60 |
+
partial(entmax_backward, **kwargs),
|
61 |
+
[x]
|
62 |
+
)
|
63 |
+
|
64 |
+
|
65 |
+
def entmax_cross_entropy_with_logits(logits, targets, vocab_dim, z_loss=0.0):
|
66 |
+
if targets.dtype.is_integer:
|
67 |
+
# hard targets
|
68 |
+
if (set(targets.shape.dims) != set(logits.shape.dims).difference([vocab_dim])):
|
69 |
+
raise ValueError(
|
70 |
+
"softmax_cross_entropy_with_logits with hard targets "
|
71 |
+
"dims in targets=%s should be dims in logits=%s other than "
|
72 |
+
"vocab_dim=%s" % (targets, logits, vocab_dim))
|
73 |
+
targets = mtf.one_hot(targets, vocab_dim, dtype=logits.dtype)
|
74 |
+
elif set(targets.shape.dims) != set(logits.shape.dims):
|
75 |
+
raise ValueError(
|
76 |
+
"softmax_cross_entropy_with_logits with soft targets "
|
77 |
+
"dims in targets=%s should be dims in logits=%s" % (targets, logits))
|
78 |
+
|
79 |
+
if vocab_dim not in logits.shape.dims:
|
80 |
+
raise ValueError("vocab_dim must be in logits.shape.dims")
|
81 |
+
|
82 |
+
log_entmax = mtf.log(entmax(logits, dim=vocab_dim))
|
83 |
+
|
84 |
+
loss = mtf.negative(
|
85 |
+
mtf.reduce_sum(log_entmax * targets, reduced_dim=vocab_dim))
|
86 |
+
|
87 |
+
return loss
|
88 |
+
|
89 |
+
|
90 |
+
def sample_categorical(x, dim=None):
|
91 |
+
dim = x.shape[-1] if dim is None else dim
|
92 |
+
|
93 |
+
cdf = mtf.cumsum(x, dim)
|
94 |
+
rand_uniform = mtf.random_uniform(x.mesh, x.shape - dim, minval=0, maxval=1)
|
95 |
+
mask = mtf.cast(mtf.greater(cdf, rand_uniform), tf.int32)
|
96 |
+
return mtf.argmax(mask, dim)
|
97 |
+
|
98 |
+
|
99 |
+
def biasmask_attn_weights(mesh, nd, ns, variable_dtype):
|
100 |
+
# The old mask_attn_weights applied directly to the QK;
|
101 |
+
# this returns a bias that the attention code from mtf adds to the attention matrix.
|
102 |
+
# w has shape [batch, heads, dst_sequence, src_sequence], where information flows from src to dst.
|
103 |
+
# n_src and n_dest are both the same, i.e equal to sequence length
|
104 |
+
# We rename ns because we want bias to have shape [batch, heads, memory_length, sequence] to match up with QK^T
|
105 |
+
# Information flows from k and v (memory_length) to q (sequence)
|
106 |
+
i = mtf.range(mesh, nd, tf.int32) + ns.size - nd.size
|
107 |
+
j = mtf.range(mesh, ns, tf.int32)
|
108 |
+
i, j = map(lambda t: mtf.broadcast(t, [nd, ns]), (i, j))
|
109 |
+
dtype = variable_dtype.activation_dtype
|
110 |
+
return mtf.cast(mtf.less(i, j), dtype) * -1e10
|
111 |
+
|
112 |
+
|
113 |
+
def parse_inputs(mtf_features, other_features):
|
114 |
+
# Parse inputs and labels from the mtf_features / other_features input dicts
|
115 |
+
# All dimensions are defined inside model_fn for efficiency
|
116 |
+
x = mtf_features["inputs"]
|
117 |
+
|
118 |
+
batch_dim = x.shape[0]
|
119 |
+
sequence_dim = x.shape[1]
|
120 |
+
embd_dim = other_features["embd_dim"]
|
121 |
+
vocab_dim = other_features["vocab_dim"]
|
122 |
+
embed_sequence_dim = other_features["embed_sequence_dim"]
|
123 |
+
|
124 |
+
return x, batch_dim, sequence_dim, embd_dim, vocab_dim, embed_sequence_dim
|
optimizers.py
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import
|
2 |
+
from __future__ import division
|
3 |
+
from __future__ import print_function
|
4 |
+
|
5 |
+
import re
|
6 |
+
import mesh_tensorflow as mtf
|
7 |
+
import tensorflow.compat.v1 as tf
|
8 |
+
|
9 |
+
def clip_by_global_norm(grads, clip_norm):
|
10 |
+
"""Clip the grads by global norm."""
|
11 |
+
global_norm = mtf.sqrt(mtf.add_n([mtf.reduce_sum(mtf.square(t)) for t in grads if t is not None]))
|
12 |
+
multiplier = clip_norm / mtf.maximum(global_norm, clip_norm)
|
13 |
+
clipped_grads = [None if t is None else t * multiplier for t in grads]
|
14 |
+
return clipped_grads, global_norm
|
15 |
+
|
16 |
+
def get_optimizer(mesh, loss, params, variable_dtype, inp_var_grads=None):
|
17 |
+
"""Creates and returns an optimizer training op."""
|
18 |
+
global_step = tf.train.get_or_create_global_step()
|
19 |
+
|
20 |
+
learning_rate = tf.constant(value=params["lr"], shape=[], dtype=variable_dtype.slice_dtype)
|
21 |
+
clip_value = mtf.constant(mesh, params["gradient_clipping"], dtype=variable_dtype.slice_dtype)
|
22 |
+
|
23 |
+
if inp_var_grads is None:
|
24 |
+
var_grads = mtf.gradients([loss], [v.outputs[0] for v in mesh.graph.trainable_variables])
|
25 |
+
else:
|
26 |
+
var_grads = inp_var_grads
|
27 |
+
|
28 |
+
# Cast to full precision
|
29 |
+
var_grads_fp = [mtf.cast(v, variable_dtype.slice_dtype) for v in var_grads]
|
30 |
+
|
31 |
+
# decrease LR to final lr (lr*0.1) by this step - defaults to train_steps
|
32 |
+
end_step = params.get("lr_decay_end", params["train_steps"])
|
33 |
+
|
34 |
+
if params["lr_decay"] == "linear":
|
35 |
+
learning_rate = tf.train.polynomial_decay(
|
36 |
+
learning_rate,
|
37 |
+
global_step,
|
38 |
+
end_step,
|
39 |
+
end_learning_rate=params["lr"]*0.1, # Decrease to 10% of initial LR according to GPT-3 paper
|
40 |
+
power=1.0,
|
41 |
+
cycle=False)
|
42 |
+
elif params["lr_decay"] == "cosine":
|
43 |
+
learning_rate = tf.train.cosine_decay(
|
44 |
+
learning_rate,
|
45 |
+
global_step,
|
46 |
+
end_step,
|
47 |
+
alpha=0.1 # Alpha is min lr value as a fraction of init lr.
|
48 |
+
)
|
49 |
+
|
50 |
+
if params["warmup_steps"] > 0:
|
51 |
+
global_steps_int = tf.cast(global_step, tf.int32)
|
52 |
+
warmup_steps_int = tf.constant(params["warmup_steps"], dtype=tf.int32)
|
53 |
+
|
54 |
+
dtype = variable_dtype.slice_dtype
|
55 |
+
|
56 |
+
global_steps_float = tf.cast(global_steps_int, dtype)
|
57 |
+
warmup_steps_float = tf.cast(warmup_steps_int, dtype)
|
58 |
+
|
59 |
+
warmup_percent_done = global_steps_float / warmup_steps_float
|
60 |
+
warmup_learning_rate = learning_rate * warmup_percent_done
|
61 |
+
|
62 |
+
is_warmup = tf.cast(global_steps_int < warmup_steps_int, dtype)
|
63 |
+
learning_rate = ((1.0 - is_warmup) * learning_rate +
|
64 |
+
is_warmup * warmup_learning_rate)
|
65 |
+
|
66 |
+
learning_rate = mtf.import_fully_replicated(mesh, learning_rate, mtf.Shape([]), name="learning_rate")
|
67 |
+
mtf.scalar_summary("lr", learning_rate)
|
68 |
+
|
69 |
+
if params["opt_name"].lower() == "adam":
|
70 |
+
optimizer = AdamWeightDecayOptimizer(
|
71 |
+
learning_rate=learning_rate,
|
72 |
+
weight_decay_rate=params["weight_decay"],
|
73 |
+
beta_1=params["beta1"],
|
74 |
+
beta_2=params["beta2"],
|
75 |
+
epsilon=params["epsilon"],
|
76 |
+
exclude_from_weight_decay=["norm", "bias"],
|
77 |
+
variable_dtype=variable_dtype
|
78 |
+
)
|
79 |
+
else:
|
80 |
+
optimizer = mtf.optimize.AdafactorOptimizer(
|
81 |
+
learning_rate=params["lr"],
|
82 |
+
decay_rate=params["weight_decay"],
|
83 |
+
beta1=params["beta1"],
|
84 |
+
epsilon1=params["ada_epsilon1"],
|
85 |
+
epsilon2=params["ada_epsilon2"]
|
86 |
+
)
|
87 |
+
|
88 |
+
if params["gradient_clipping"] is not None:
|
89 |
+
(var_grads_fp, _) = clip_by_global_norm(var_grads_fp, clip_norm=clip_value)
|
90 |
+
|
91 |
+
update_ops = optimizer.apply_grads(var_grads_fp, mesh.graph.trainable_variables)
|
92 |
+
return learning_rate, update_ops, var_grads_fp
|
93 |
+
|
94 |
+
|
95 |
+
class AdamWeightDecayOptimizer(mtf.optimize.Optimizer):
|
96 |
+
"""A basic Adam optimizer that includes "correct" L2 weight decay."""
|
97 |
+
|
98 |
+
def __init__(self,
|
99 |
+
learning_rate,
|
100 |
+
weight_decay_rate=0.0,
|
101 |
+
beta_1=0.9,
|
102 |
+
beta_2=0.999,
|
103 |
+
epsilon=1e-6,
|
104 |
+
exclude_from_weight_decay=None,
|
105 |
+
variable_dtype=None):
|
106 |
+
"""Constructs a AdamWeightDecayOptimizer."""
|
107 |
+
|
108 |
+
self.learning_rate = learning_rate
|
109 |
+
self.weight_decay_rate = weight_decay_rate
|
110 |
+
self.beta_1 = beta_1
|
111 |
+
self.beta_2 = beta_2
|
112 |
+
self.epsilon = epsilon
|
113 |
+
self.exclude_from_weight_decay = exclude_from_weight_decay
|
114 |
+
self.variable_dtype = variable_dtype
|
115 |
+
|
116 |
+
def apply_grad(self, grad, var):
|
117 |
+
"""See base class."""
|
118 |
+
if grad is None:
|
119 |
+
tf.logging.warning("Gradient is None for variable %s" % var.name)
|
120 |
+
return []
|
121 |
+
|
122 |
+
grad = mtf.to_float(grad)
|
123 |
+
|
124 |
+
assignments = []
|
125 |
+
|
126 |
+
m = mtf.get_variable(
|
127 |
+
var.mesh, var.name + "/adam_m", var.shape,
|
128 |
+
initializer=tf.zeros_initializer(),
|
129 |
+
# master_dtype=self.variable_dtype.master_dtype,
|
130 |
+
# slice_dtype=self.variable_dtype.slice_dtype,
|
131 |
+
# activation_dtype=self.variable_dtype.activation_dtype,
|
132 |
+
trainable=False)
|
133 |
+
|
134 |
+
v = mtf.get_variable(
|
135 |
+
var.mesh, var.name + "/adam_v", var.shape,
|
136 |
+
initializer=tf.zeros_initializer(),
|
137 |
+
# master_dtype=self.variable_dtype.master_dtype,
|
138 |
+
# slice_dtype=self.variable_dtype.slice_dtype,
|
139 |
+
# activation_dtype=self.variable_dtype.activation_dtype,
|
140 |
+
trainable=False)
|
141 |
+
|
142 |
+
# Standard Adam update.
|
143 |
+
next_m = self.beta_1 * m + (1.0 - self.beta_1) * grad
|
144 |
+
next_v = self.beta_2 * v + (1.0 - self.beta_2) * mtf.square(grad)
|
145 |
+
|
146 |
+
update = next_m / (mtf.sqrt(next_v) + self.epsilon)
|
147 |
+
|
148 |
+
# Just adding the square of the weights to the loss function is *not*
|
149 |
+
# the correct way of using L2 regularization/weight decay with Adam,
|
150 |
+
# since that will interact with the m and v parameters in strange ways.
|
151 |
+
#
|
152 |
+
# Instead we want to decay the weights in a manner that doesn't interact
|
153 |
+
# with the m/v parameters. This is equivalent to adding the square
|
154 |
+
# of the weights to the loss with plain (non-momentum) SGD.
|
155 |
+
if self._do_use_weight_decay(var.name):
|
156 |
+
update += mtf.to_float(var.value) * self.weight_decay_rate
|
157 |
+
|
158 |
+
update_with_lr = self.learning_rate * update
|
159 |
+
|
160 |
+
var_update = mtf.assign_sub(var, update_with_lr)
|
161 |
+
|
162 |
+
assignments.extend(
|
163 |
+
[var_update,
|
164 |
+
mtf.assign(m, next_m),
|
165 |
+
mtf.assign(v, next_v)])
|
166 |
+
return assignments
|
167 |
+
|
168 |
+
def _do_use_weight_decay(self, param_name):
|
169 |
+
"""Whether to use L2 weight decay for `param_name`."""
|
170 |
+
if not self.weight_decay_rate:
|
171 |
+
return False
|
172 |
+
if self.exclude_from_weight_decay:
|
173 |
+
for r in self.exclude_from_weight_decay:
|
174 |
+
if re.search(r, param_name) is not None:
|
175 |
+
return False
|
176 |
+
return True
|
requirements.txt
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
google-api-python-client
|
2 |
+
jsonlines
|
3 |
+
lm_dataformat
|
4 |
+
mesh-tensorflow==0.1.18
|
5 |
+
numpy
|
6 |
+
oauth2client
|
7 |
+
ortools
|
8 |
+
pytest
|
9 |
+
sacred
|
10 |
+
tensorflow==2.5.0
|
11 |
+
tensorflow-datasets==3.2.1
|
12 |
+
tokenizers==0.9.4
|
13 |
+
transformers==4.1.1
|
14 |
+
tpunicorn
|
15 |
+
absl-py
|
16 |
+
ftfy
|
17 |
+
sacred
|
18 |
+
pymongo
|
run_experiment.py
ADDED
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import atexit
|
2 |
+
import sacred
|
3 |
+
import argparse
|
4 |
+
import time
|
5 |
+
import math
|
6 |
+
import subprocess
|
7 |
+
import shutil
|
8 |
+
import os
|
9 |
+
import json
|
10 |
+
import threading
|
11 |
+
import requests
|
12 |
+
import glob
|
13 |
+
from configs import fetch_model_params
|
14 |
+
import socket
|
15 |
+
import subprocess
|
16 |
+
import queue
|
17 |
+
import sys
|
18 |
+
import signal
|
19 |
+
|
20 |
+
|
21 |
+
parser = argparse.ArgumentParser()
|
22 |
+
parser.add_argument('--tpu', type=str, required=True) # Name of TPU to train on, if any
|
23 |
+
parser.add_argument('--model', type=str, required=True) # JSON file that contains model parameters
|
24 |
+
parser.add_argument('--experiment_name', type=str, required=True) # name of experiment (will show up in omniboard)
|
25 |
+
parser.add_argument('--steps_per_checkpoint', type=int, default=5000)
|
26 |
+
parser.add_argument('--autostack', action="store_false")
|
27 |
+
parser.add_argument('--auto_layout', action="store_true")
|
28 |
+
parser.add_argument('--auto_layout_and_mesh_shape', action="store_true")
|
29 |
+
parser.add_argument('--new', action='store_true')
|
30 |
+
parser.add_argument('--test', action='store_true')
|
31 |
+
parser.add_argument('--eval', action='store_true')
|
32 |
+
parser.add_argument('--predict', action='store_true')
|
33 |
+
parser.add_argument('--no_delete_tpu', action='store_true')
|
34 |
+
parser.add_argument('--initial_heartbeat_timeout', type=int, default=7200)
|
35 |
+
parser.add_argument('--heartbeat_timeout', type=int, default=1800) # kill and restart if nothing logged to tensorboard in this many seconds
|
36 |
+
args = parser.parse_args()
|
37 |
+
|
38 |
+
params = fetch_model_params(args.model)
|
39 |
+
|
40 |
+
ex = sacred.Experiment(args.experiment_name)
|
41 |
+
ex.observers.append(sacred.observers.QueuedMongoObserver(url='127.0.0.1:27017', db_name='db', username='user', password='password'))
|
42 |
+
|
43 |
+
|
44 |
+
def get_open_port(lo=8000, hi=8100):
|
45 |
+
for i in range(lo, hi):
|
46 |
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
47 |
+
if s.connect_ex(('localhost', i)) != 0:
|
48 |
+
return i
|
49 |
+
|
50 |
+
|
51 |
+
def train_thread(args, tpu, id, q):
|
52 |
+
print('starting training on', tpu)
|
53 |
+
|
54 |
+
# pass binary flags through
|
55 |
+
opts = ''
|
56 |
+
for flag in ['auto_layout', 'auto_layout_and_mesh_shape', 'new', 'test', 'predict', 'eval', ]:
|
57 |
+
if args.__getattribute__(flag):
|
58 |
+
opts += ' --' + flag
|
59 |
+
|
60 |
+
for flag in ['autostack', ]:
|
61 |
+
if not args.__getattribute__(flag):
|
62 |
+
opts += ' --' + flag
|
63 |
+
|
64 |
+
cmd = "python3 main.py --tpu {tpu} --model run_configs/config_{id}.json --steps_per_checkpoint {steps_per_checkpoint} {opts} --sacred_id {run_id}".format(tpu=tpu, id=id, steps_per_checkpoint=args.steps_per_checkpoint, opts=opts, run_id=id)
|
65 |
+
print('Running:', cmd)
|
66 |
+
proc = subprocess.Popen(cmd, shell=True)
|
67 |
+
|
68 |
+
# poll until it's exited
|
69 |
+
while proc.poll() is None:
|
70 |
+
time.sleep(60)
|
71 |
+
try:
|
72 |
+
nq, *nargs = q.get_nowait()
|
73 |
+
if nq == 'kill':
|
74 |
+
print('train thread recieved kill signal from logging thread')
|
75 |
+
# first send SIGTERM
|
76 |
+
proc.terminate()
|
77 |
+
|
78 |
+
time.sleep(60)
|
79 |
+
|
80 |
+
# if it still hasn't exited, we send SIGKILL
|
81 |
+
if proc.poll() is None:
|
82 |
+
print('SIGTERM not successful, sending SIGKILL')
|
83 |
+
proc.kill()
|
84 |
+
|
85 |
+
except queue.Empty:
|
86 |
+
pass
|
87 |
+
|
88 |
+
print('exited training!')
|
89 |
+
if proc.returncode == 0:
|
90 |
+
print('exited gracefully')
|
91 |
+
os.kill(os.getpid(), signal.SIGINT)
|
92 |
+
return
|
93 |
+
|
94 |
+
if args.no_delete_tpu:
|
95 |
+
print('recreate done, exiting train_thread - not killing tpu!')
|
96 |
+
return
|
97 |
+
print("Recreating {} in 60sec...".format(tpu))
|
98 |
+
time.sleep(60)
|
99 |
+
os.system("pu recreate {} --yes --retry 3600 --retry-randomness 1.5".format(tpu))
|
100 |
+
print('recreate done, exiting train_thread')
|
101 |
+
|
102 |
+
# clear out queue
|
103 |
+
while True:
|
104 |
+
try:
|
105 |
+
q.get_nowait()
|
106 |
+
print('dropped request in queue after pu recreate')
|
107 |
+
except queue.Empty:
|
108 |
+
break
|
109 |
+
|
110 |
+
|
111 |
+
def get_json(uri, params=None, timeout=15):
|
112 |
+
resp = requests.get(uri, params=params, timeout=timeout)
|
113 |
+
resp.raise_for_status()
|
114 |
+
return resp.json()
|
115 |
+
|
116 |
+
|
117 |
+
def get_tag_sets(base_uri):
|
118 |
+
j = get_json(f'{base_uri}/data/plugin/scalars/tags', {'experiment': ''})
|
119 |
+
assert isinstance(j, dict)
|
120 |
+
return {
|
121 |
+
run: j[run].keys()
|
122 |
+
for run in j.keys()
|
123 |
+
}
|
124 |
+
|
125 |
+
|
126 |
+
def get_scalar_data(base_uri, run, tag):
|
127 |
+
j = get_json(f'{base_uri}/data/plugin/scalars/scalars', {'experiment': '', 'run': run, 'tag': tag})
|
128 |
+
assert isinstance(j, list)
|
129 |
+
return j
|
130 |
+
|
131 |
+
|
132 |
+
def get_run_data(port):
|
133 |
+
base_uri = f'http://localhost:{port}/'
|
134 |
+
r = {}
|
135 |
+
try:
|
136 |
+
tag_sets = get_tag_sets(base_uri)
|
137 |
+
runs = tag_sets.keys()
|
138 |
+
if '.' in runs:
|
139 |
+
if 'loss' in tag_sets['.']:
|
140 |
+
r['loss'] = get_scalar_data(base_uri, '.', 'loss')
|
141 |
+
if 'eval' in runs:
|
142 |
+
if 'loss' in tag_sets['eval']:
|
143 |
+
r['val_loss'] = get_scalar_data(base_uri, 'eval', 'loss')
|
144 |
+
if 'eval_lambada' in runs:
|
145 |
+
if 'lambada_acc' in tag_sets['eval_lambada']:
|
146 |
+
r['lambada_acc'] = get_scalar_data(base_uri, 'eval_lambada', 'lambada_acc')
|
147 |
+
if 'lambada_log_ppl' in tag_sets['eval_lambada']:
|
148 |
+
r['lambada_ppl'] = [
|
149 |
+
[t, s, math.exp(lp)]
|
150 |
+
for [t, s, lp] in get_scalar_data(base_uri, 'eval_lambada', 'lambada_log_ppl')
|
151 |
+
]
|
152 |
+
except:
|
153 |
+
import traceback
|
154 |
+
traceback.print_exc()
|
155 |
+
return r
|
156 |
+
|
157 |
+
|
158 |
+
@ex.main
|
159 |
+
def main(_run):
|
160 |
+
print('Starting run', _run._id)
|
161 |
+
print('experiment main invoked with argv:', " ".join(sys.argv))
|
162 |
+
print('WARNING: please remember to remove old metric log files from the model directory.')
|
163 |
+
|
164 |
+
os.makedirs('run_configs', exist_ok=True)
|
165 |
+
shutil.copy(args.model if args.model.endswith('.json') else 'configs/{}.json'.format(args.model), 'run_configs/config_{}.json'.format(_run._id))
|
166 |
+
|
167 |
+
tensorboard_port = get_open_port()
|
168 |
+
print('Tensorboard at port:', tensorboard_port)
|
169 |
+
print('Tensorboard url: ', 'http://eleutherai.bmk.sh:'+ str(tensorboard_port))
|
170 |
+
os.system("screen -S tensorboard_{} -d -m bash -c 'tensorboard --logdir {} --port {} --bind_all --reload_multifile=true || tensorboard --logdir {} --port {} --reload_multifile=true'".format(_run._id, params["model_path"], tensorboard_port,params["model_path"], tensorboard_port,))
|
171 |
+
atexit.register(goodbye, _run._id)
|
172 |
+
|
173 |
+
curr_step = {}
|
174 |
+
seen_predictions = set()
|
175 |
+
|
176 |
+
heartbeat_timeout = args.initial_heartbeat_timeout * 2
|
177 |
+
while True:
|
178 |
+
last_tb_log_time = time.time()
|
179 |
+
start_time = time.time()
|
180 |
+
q = queue.Queue()
|
181 |
+
trainthd = threading.Thread(target=train_thread, args=(args, args.tpu, _run._id, q))
|
182 |
+
trainthd.start()
|
183 |
+
|
184 |
+
while trainthd.is_alive():
|
185 |
+
time.sleep(60)
|
186 |
+
|
187 |
+
if start_time + args.initial_heartbeat_timeout < time.time():
|
188 |
+
# after initial args.initial_heartbeat_timeout grace period, now we want to set the timeout threshold much lower
|
189 |
+
heartbeat_timeout = args.heartbeat_timeout
|
190 |
+
|
191 |
+
print('Polling tensorboard for metrics...')
|
192 |
+
data = get_run_data(tensorboard_port)
|
193 |
+
for k in data.keys():
|
194 |
+
for ts, step, val in data[k]:
|
195 |
+
if step <= curr_step.get(k, -1):
|
196 |
+
continue
|
197 |
+
_run.log_scalar(k, val, step)
|
198 |
+
if k == 'loss':
|
199 |
+
_run.log_scalar('tb_ts', ts, step)
|
200 |
+
print('Logged to sacred: step={},loss={},tb_ts={}'.format(step, val, ts))
|
201 |
+
|
202 |
+
# found something new, so logging!
|
203 |
+
last_tb_log_time = time.time()
|
204 |
+
|
205 |
+
curr_step[k] = step
|
206 |
+
|
207 |
+
for f in glob.glob('predictions_{}_*'.format(_run._id)):
|
208 |
+
if f in seen_predictions:
|
209 |
+
continue
|
210 |
+
print('collecting prediction file', f)
|
211 |
+
ex.add_artifact(f)
|
212 |
+
|
213 |
+
seen_predictions.add(f)
|
214 |
+
|
215 |
+
# collect eval metrics from jsonl
|
216 |
+
if os.path.exists(f'eval_{_run._id}.jsonl'):
|
217 |
+
with open(f'eval_{_run._id}.jsonl') as fh:
|
218 |
+
for line in fh:
|
219 |
+
ob = json.loads(line)
|
220 |
+
val_step = ob['global_step']
|
221 |
+
val_task = ob['task']
|
222 |
+
for metr in ob.keys():
|
223 |
+
k = 'fs.' + val_task + '.' + metr
|
224 |
+
if metr in ['task', 'global_step']: continue
|
225 |
+
if val_step <= curr_step.get(k, -1): continue
|
226 |
+
_run.log_scalar(k, ob[metr], val_step)
|
227 |
+
curr_step[k] = val_step
|
228 |
+
|
229 |
+
if time.time() - last_tb_log_time > heartbeat_timeout:
|
230 |
+
# the run hasn't logged in a while, so we restart it
|
231 |
+
q.put(('kill',))
|
232 |
+
|
233 |
+
# give training thread some time to do its thing and recreate tpu
|
234 |
+
while trainthd.is_alive():
|
235 |
+
print('logging thread waiting for killing stalled run and for tpu recreate to finish')
|
236 |
+
time.sleep(60)
|
237 |
+
|
238 |
+
# reset heartbeat timeout to initial
|
239 |
+
heartbeat_timeout = args.initial_heartbeat_timeout
|
240 |
+
last_tb_log_time = time.time()
|
241 |
+
|
242 |
+
|
243 |
+
if args.no_delete_tpu:
|
244 |
+
break
|
245 |
+
|
246 |
+
|
247 |
+
def goodbye(id):
|
248 |
+
print("You are now leaving the Python sector.")
|
249 |
+
print("Sie verlassen den pythonischen Sektor.")
|
250 |
+
|
251 |
+
os.system("screen -S tensorboard_{} -X quit".format(id))
|
252 |
+
|
253 |
+
|
254 |
+
if __name__ == '__main__':
|
255 |
+
for file in glob.glob("**/*", recursive=True):
|
256 |
+
if file.split('.')[-1] in ['py']:
|
257 |
+
print('Adding', file, 'to sacred')
|
258 |
+
ex.add_source_file(file)
|
259 |
+
|
260 |
+
ex.add_config({
|
261 |
+
'tpu_name': args.tpu,
|
262 |
+
**params
|
263 |
+
})
|
264 |
+
|
265 |
+
ex.run()
|
sample.py
ADDED
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import mesh_tensorflow as mtf
|
2 |
+
import tensorflow.compat.v1 as tf
|
3 |
+
import mesh_tensorflow.transformer as mtf_transformer
|
4 |
+
|
5 |
+
from models.utils import entmax, sample_categorical
|
6 |
+
from models.gpt2 import gpt2
|
7 |
+
|
8 |
+
def sample_autoregressive(partial_sequences,
|
9 |
+
other_features,
|
10 |
+
params,
|
11 |
+
stop_at_token=50256,
|
12 |
+
max_steps=None,
|
13 |
+
temperature=0.9,
|
14 |
+
variable_dtype=mtf.VariableDType(tf.float32),
|
15 |
+
encoder_output=None,
|
16 |
+
encoder_sequence_id=None,
|
17 |
+
encoder_inputs=None,
|
18 |
+
shared_params=None,
|
19 |
+
has_partial_sequences=True,
|
20 |
+
encoder_layer_outputs=None,
|
21 |
+
never_end=False,
|
22 |
+
remove_partial_sequences=False,
|
23 |
+
sampling_keep_top_k=-1,
|
24 |
+
sampling_use_entmax = False,
|
25 |
+
bos_id=50256,
|
26 |
+
):
|
27 |
+
"""Sample randomly one token at a time.
|
28 |
+
|
29 |
+
The partial_sequences represent partial sequences to be continued. The
|
30 |
+
first tokens of each sequence are nonzero representing the given partial
|
31 |
+
sequences and the last tokens of each sequence are zeros, representing what
|
32 |
+
needs to be filled in.
|
33 |
+
|
34 |
+
If there are no partial sequences (you want to sample from the beginning),
|
35 |
+
then pass partial_sequences=mtf.zeros(mesh, shape, dtype=tf.int32) and
|
36 |
+
has_partial_sequences=False (so we can skip computation).
|
37 |
+
|
38 |
+
Args:
|
39 |
+
partial_sequences: an int32 Tensor with shape [<batch_dims>, length_dim]
|
40 |
+
stop_at_token: an optional integer eos id. Stop when we produce it.
|
41 |
+
max_steps: an optional integer, the max number of steps to decode.
|
42 |
+
temperature: an optional floating point value between 0.0 and 1.0 0.0
|
43 |
+
means argmax, 1.0 means sample according to predicted distribution.
|
44 |
+
variable_dtype: a mtf.VariableDType
|
45 |
+
encoder_output: an optional Tensor
|
46 |
+
encoder_sequence_id: an optional Tensor
|
47 |
+
encoder_inputs: an optional Tensor
|
48 |
+
shared_params: an optional dictionary
|
49 |
+
has_partial_sequences: a boolean
|
50 |
+
encoder_layer_outputs: optional - readonly list of tensor activations when
|
51 |
+
decoding, one per each input layer + the embedding layer
|
52 |
+
never_end: a boolean - if set, then avoid generating stop_at_token
|
53 |
+
remove_partial_sequences: a boolean - whether to remove the partial
|
54 |
+
sequences from the output
|
55 |
+
sampling_keep_top_k: an integer - if not -1, only sample from the top k
|
56 |
+
logits.
|
57 |
+
bos_id: beginning of sequence id
|
58 |
+
|
59 |
+
Returns:
|
60 |
+
a Tensor with shape [<batch_dims>, length_dim]
|
61 |
+
"""
|
62 |
+
|
63 |
+
inputs = partial_sequences # Partial sequences to fill in
|
64 |
+
batch_dims = inputs.shape.dims[:-1]
|
65 |
+
length_dim = inputs.shape.dims[-1]
|
66 |
+
padding_id = params.get("padding_id", 0)
|
67 |
+
slow_sampling = params.get("slow_sampling", False)
|
68 |
+
|
69 |
+
|
70 |
+
initial_position = mtf.reduce_sum(
|
71 |
+
mtf.to_int32(mtf.not_equal(inputs, padding_id)), reduced_dim=length_dim) # Gets position where zero padding starts
|
72 |
+
|
73 |
+
length_range = mtf.range(inputs.mesh, length_dim, tf.int32)
|
74 |
+
input_full_attention = True # for now hardcode this to true bc lazy
|
75 |
+
if input_full_attention:
|
76 |
+
# Vanilla autoregressive model - each position can see previous positions.
|
77 |
+
# Think this feeds in to the loop fn and tells each position where it can attend to?
|
78 |
+
read_priority = write_priority = length_range * mtf.to_int32(
|
79 |
+
mtf.greater(length_range, initial_position))
|
80 |
+
else:
|
81 |
+
read_priority = write_priority = length_range
|
82 |
+
|
83 |
+
# Builds context to pass around internally
|
84 |
+
# The 'first part' context records initial states of k / v / x
|
85 |
+
|
86 |
+
if not slow_sampling:
|
87 |
+
context_first_part = mtf_transformer.transformer.Context(
|
88 |
+
model=None,
|
89 |
+
mesh=inputs.mesh,
|
90 |
+
batch_dims=batch_dims,
|
91 |
+
length_dim=length_dim,
|
92 |
+
variable_dtype=variable_dtype,
|
93 |
+
mode="first_part",
|
94 |
+
position=length_range,
|
95 |
+
position_is_default=True,
|
96 |
+
new_states=[],
|
97 |
+
initial_position=initial_position,
|
98 |
+
sequence_id=None,
|
99 |
+
encoder_output=encoder_output,
|
100 |
+
encoder_sequence_id=encoder_sequence_id,
|
101 |
+
constant_states=[],
|
102 |
+
shared_params=shared_params,
|
103 |
+
encoder_layer_outputs=encoder_layer_outputs,
|
104 |
+
write_priority=write_priority,
|
105 |
+
read_priority=read_priority,
|
106 |
+
inputs=inputs,
|
107 |
+
encoder_inputs=encoder_inputs)
|
108 |
+
|
109 |
+
with tf.variable_scope("gpt2"):
|
110 |
+
logits, _, _ = gpt2.model({"inputs": inputs}, other_features, params, inputs.mesh, variable_dtype=variable_dtype, context=context_first_part)
|
111 |
+
|
112 |
+
if not has_partial_sequences:
|
113 |
+
initial_states = [mtf.zeros_like(t) for t in context_first_part.new_states]
|
114 |
+
else:
|
115 |
+
initial_states = context_first_part.new_states
|
116 |
+
else:
|
117 |
+
initial_states = []
|
118 |
+
|
119 |
+
if not has_partial_sequences:
|
120 |
+
partial_sequences_eos_count = 0
|
121 |
+
|
122 |
+
if stop_at_token is not None:
|
123 |
+
partial_sequences_eos_count = mtf.reduce_sum(
|
124 |
+
mtf.to_int32(mtf.equal(partial_sequences, stop_at_token)),
|
125 |
+
reduced_dim=length_dim)
|
126 |
+
|
127 |
+
def cond_fn(position, ids, *unused_states):
|
128 |
+
"""Should we run another loop iteration?"""
|
129 |
+
past_end = mtf.greater_equal(position, length_dim.size)
|
130 |
+
if max_steps:
|
131 |
+
past_end = mtf.logical_or(
|
132 |
+
past_end, mtf.greater_equal(position - initial_position, max_steps))
|
133 |
+
|
134 |
+
is_done = past_end
|
135 |
+
if stop_at_token is not None:
|
136 |
+
eos_count = mtf.reduce_sum(
|
137 |
+
mtf.to_int32(mtf.equal(ids, stop_at_token)),
|
138 |
+
reduced_dim=length_dim)
|
139 |
+
has_additional_eos = mtf.greater(eos_count, partial_sequences_eos_count)
|
140 |
+
is_done = mtf.logical_or(is_done, has_additional_eos)
|
141 |
+
all_done = mtf.reduce_all(is_done)
|
142 |
+
return mtf.logical_not(all_done)
|
143 |
+
|
144 |
+
def body_fn(position, ids, *states):
|
145 |
+
"""One step in the decode loop."""
|
146 |
+
nonlocal sampling_keep_top_k
|
147 |
+
|
148 |
+
context = mtf_transformer.transformer.Context(
|
149 |
+
model=None,
|
150 |
+
mesh=inputs.mesh,
|
151 |
+
batch_dims=batch_dims,
|
152 |
+
length_dim=length_dim,
|
153 |
+
variable_dtype=variable_dtype,
|
154 |
+
mode="incremental",
|
155 |
+
position=position,
|
156 |
+
position_is_default=True,
|
157 |
+
states=states,
|
158 |
+
new_states=[],
|
159 |
+
initial_position=position,
|
160 |
+
sequence_id=None,
|
161 |
+
encoder_output=encoder_output,
|
162 |
+
encoder_sequence_id=encoder_sequence_id,
|
163 |
+
shared_params=shared_params,
|
164 |
+
encoder_layer_outputs=encoder_layer_outputs,
|
165 |
+
write_priority=write_priority,
|
166 |
+
read_priority=read_priority,
|
167 |
+
inputs=ids,
|
168 |
+
encoder_inputs=encoder_inputs) if not slow_sampling else None
|
169 |
+
|
170 |
+
with tf.variable_scope("gpt2", reuse=tf.AUTO_REUSE):
|
171 |
+
logits, _, _ = gpt2.model({"inputs": ids}, other_features, params, inputs.mesh, variable_dtype=variable_dtype, context = context)
|
172 |
+
|
173 |
+
if not sampling_use_entmax:
|
174 |
+
# By default, do top_k sampling of 0.9
|
175 |
+
if sampling_keep_top_k == -2:
|
176 |
+
sampling_keep_top_k = int(logits.shape[-1].size * 0.1)
|
177 |
+
|
178 |
+
if sampling_keep_top_k != -1:
|
179 |
+
if sampling_keep_top_k <= 0:
|
180 |
+
raise ValueError("sampling_keep_top_k must either be -1 or positive.")
|
181 |
+
k_largest = mtf.nth_largest_element(
|
182 |
+
logits, n=sampling_keep_top_k,
|
183 |
+
reduced_dim=other_features["vocab_dim"])
|
184 |
+
logits = mtf.where(mtf.less_equal(logits, k_largest),
|
185 |
+
mtf.ones_like(logits) * -1e6, logits)
|
186 |
+
|
187 |
+
ids_this_step = mtf.sample_with_temperature(
|
188 |
+
logits, other_features["vocab_dim"], temperature)
|
189 |
+
else:
|
190 |
+
ids_this_step = sample_categorical(entmax(logits))
|
191 |
+
|
192 |
+
if slow_sampling:
|
193 |
+
ids_this_step = mtf.shift(ids_this_step, offset=1, dim=length_dim, wrap=False)
|
194 |
+
else:
|
195 |
+
ids_this_step = mtf.reshape(ids_this_step, (batch_dims))
|
196 |
+
|
197 |
+
one_hot = mtf.one_hot(position, length_dim, dtype=tf.int32)
|
198 |
+
one_new_id = ids_this_step * one_hot
|
199 |
+
new_ids = (1 - one_hot) * ids + one_new_id
|
200 |
+
new_position = position + 1
|
201 |
+
|
202 |
+
ret = [new_position, new_ids]
|
203 |
+
if context is not None:
|
204 |
+
ret += context.new_states
|
205 |
+
return ret
|
206 |
+
|
207 |
+
while_loop_inputs = [initial_position, inputs] + initial_states
|
208 |
+
final_position, outputs = mtf.while_loop(
|
209 |
+
cond_fn, body_fn, while_loop_inputs)[:2]
|
210 |
+
del final_position
|
211 |
+
if has_partial_sequences and remove_partial_sequences:
|
212 |
+
# Remove partial sequences from outputs
|
213 |
+
partial_length = mtf.reduce_sum(
|
214 |
+
mtf.to_int32(mtf.not_equal(partial_sequences, padding_id)),
|
215 |
+
reduced_dim=length_dim)
|
216 |
+
outputs = mtf.dynamic_shift(
|
217 |
+
outputs, -partial_length, length_dim, wrap=False)
|
218 |
+
return outputs
|
tasks.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os.path
|
2 |
+
import json
|
3 |
+
import requests
|
4 |
+
import numpy as np
|
5 |
+
import ftfy
|
6 |
+
from data.encoders import fetch_encoder, encode
|
7 |
+
import tensorflow as tf
|
8 |
+
import re
|
9 |
+
from functools import partial
|
10 |
+
|
11 |
+
lambada_src_uri = 'http://eaidata.bmk.sh/data/lambada_test.jsonl'
|
12 |
+
normalization = 'NFKC'
|
13 |
+
|
14 |
+
|
15 |
+
# Note: this task is called "lambada" but it really refers to OpenAI's version
|
16 |
+
# of the task, which actually differs in some ways from the task described in
|
17 |
+
# the original paper. So, strictly speaking, accuracy values from this task
|
18 |
+
# should not be compared to accuracy values from the original lambada task.
|
19 |
+
# For more information, see
|
20 |
+
# https://github.com/openai/gpt-2/issues/131
|
21 |
+
|
22 |
+
def lambada_create_tokens_data(params, path):
|
23 |
+
with open(path, 'w') as f:
|
24 |
+
req = requests.get(lambada_src_uri)
|
25 |
+
req.raise_for_status()
|
26 |
+
jsons = [json.loads(l) for l in req.iter_lines()]
|
27 |
+
texts = [ftfy.fix_text(j['text'], normalization=normalization) for j in jsons]
|
28 |
+
enc = fetch_encoder(params)
|
29 |
+
arrays = [encode(enc, t) for t in texts]
|
30 |
+
json.dump(arrays, f)
|
31 |
+
return arrays
|
32 |
+
|
33 |
+
|
34 |
+
def lambada_read_or_create_tokens_data(params, path):
|
35 |
+
# if you tell me where the file should go, i will helpfully create it for you
|
36 |
+
if not os.path.exists(path):
|
37 |
+
return lambada_create_tokens_data(params, path)
|
38 |
+
with open(path) as f:
|
39 |
+
return json.load(f)
|
40 |
+
|
41 |
+
|
42 |
+
def bin_pack(params, tokens_data):
|
43 |
+
eos_token = params['eos_id']
|
44 |
+
n_ctx = params['n_ctx']
|
45 |
+
dummy_token = 1
|
46 |
+
pad_batch_size = params['eval_batch_size']
|
47 |
+
bins = []
|
48 |
+
for a in tokens_data:
|
49 |
+
if len(bins) == 0 or len(bins[-1]) + len(a) + 1 > n_ctx:
|
50 |
+
bins.append([])
|
51 |
+
bins[-1] += a
|
52 |
+
bins[-1].append(eos_token)
|
53 |
+
while len(bins) % pad_batch_size != 0:
|
54 |
+
bins.append([])
|
55 |
+
bins_array = np.full((len(bins), n_ctx), dummy_token, dtype=np.uint16)
|
56 |
+
for i, b in enumerate(bins):
|
57 |
+
bins_array[i, 0:len(b)] = b
|
58 |
+
return bins_array
|
59 |
+
|
60 |
+
|
61 |
+
def lambada_init(params):
|
62 |
+
ds_configs = params['dataset_configs']
|
63 |
+
l = [
|
64 |
+
ds_configs[ds_id].get('lambada_tokens_path', "./lambada.json")
|
65 |
+
for ds_id, _, _, _ in params['datasets']
|
66 |
+
]
|
67 |
+
assert len(l) > 0, 'lambada_tokens_path not found in the dataset config'
|
68 |
+
lt_path = l[0]
|
69 |
+
assert lt_path.endswith('.json'), 'lambada_tokens_path must have extension json'
|
70 |
+
|
71 |
+
tokens_data = lambada_read_or_create_tokens_data(params, lt_path)
|
72 |
+
bins_array = bin_pack(params, tokens_data)
|
73 |
+
params['lambada_tokens_path'] = lt_path
|
74 |
+
params['lambada_n_steps'] = len(bins_array) // params['eval_batch_size']
|
75 |
+
|
76 |
+
|
77 |
+
def lambada_get_task_info(params):
|
78 |
+
return {
|
79 |
+
'n_steps': params['lambada_n_steps'],
|
80 |
+
}
|
81 |
+
|
82 |
+
|
83 |
+
# The LAMBADA evaluation code looks at the logits of each position just before an eos_token
|
84 |
+
def lambada_input(params):
|
85 |
+
eos_token = 50256 if params['n_vocab'] >= 50257 else 0
|
86 |
+
n_ctx = params['n_ctx']
|
87 |
+
lt_path = params['lambada_tokens_path']
|
88 |
+
tokens_data = lambada_read_or_create_tokens_data(params, lt_path)
|
89 |
+
bins_array = bin_pack(params, tokens_data)
|
90 |
+
dataset = tf.data.Dataset.from_tensor_slices(bins_array)
|
91 |
+
|
92 |
+
def _get_output(bin):
|
93 |
+
bin = tf.cast(bin, dtype=tf.int32)
|
94 |
+
indexes = tf.range(n_ctx)
|
95 |
+
results = tf.gather(bin, (indexes + 1) % n_ctx)
|
96 |
+
eos_next_positions = tf.math.equal(tf.gather(bin, (indexes + 2) % n_ctx), eos_token)
|
97 |
+
output = tf.where(eos_next_positions, results, tf.constant(eos_token, shape=[n_ctx]))
|
98 |
+
bin = tf.reshape(bin, [n_ctx])
|
99 |
+
bin = tf.cast(bin, dtype=tf.int32)
|
100 |
+
output = tf.reshape(output, [n_ctx])
|
101 |
+
output = tf.cast(output, dtype=tf.int32)
|
102 |
+
return bin, output
|
103 |
+
|
104 |
+
dataset = dataset.map(_get_output)
|
105 |
+
dataset = dataset.batch(params['eval_batch_size'], drop_remainder=True)
|
106 |
+
dataset = dataset.repeat()
|
107 |
+
return dataset
|
108 |
+
|
109 |
+
|
110 |
+
task_descriptors = {
|
111 |
+
'lambada': {
|
112 |
+
'init_fn': lambada_init,
|
113 |
+
'get_task_info_fn': lambada_get_task_info,
|
114 |
+
'input_fn': lambada_input,
|
115 |
+
}
|
116 |
+
}
|
test_models.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pytest
|
2 |
+
import traceback
|
3 |
+
import logging
|
4 |
+
from collections import defaultdict
|
5 |
+
from contextlib import contextmanager
|
6 |
+
|
7 |
+
import tensorflow as tf
|
8 |
+
tf.compat.v1.enable_eager_execution()
|
9 |
+
import mesh_tensorflow as mtf
|
10 |
+
from mesh_tensorflow import placement_mesh_impl
|
11 |
+
|
12 |
+
from inputs import mlm_sample_text
|
13 |
+
from models.gpt2 import gpt2
|
14 |
+
from models.utils import biasmask_attn_weights, entmax, sample_categorical
|
15 |
+
|
16 |
+
from sample import sample_autoregressive
|
17 |
+
|
18 |
+
# helper functions
|
19 |
+
|
20 |
+
@contextmanager
|
21 |
+
def not_raises(exception):
|
22 |
+
try:
|
23 |
+
yield
|
24 |
+
except exception:
|
25 |
+
logging.error(traceback.format_exc())
|
26 |
+
raise pytest.fail("DID RAISE {0}".format(exception))
|
27 |
+
|
28 |
+
# fixtures
|
29 |
+
|
30 |
+
params = defaultdict(lambda: None, {
|
31 |
+
"n_head": 1,
|
32 |
+
"n_ctx": 4,
|
33 |
+
"n_embd": 2,
|
34 |
+
"n_vocab": 256,
|
35 |
+
"embed_dropout": 0.,
|
36 |
+
"n_layer": 2,
|
37 |
+
"num_microbatches": 1,
|
38 |
+
"train_batch_size": 1,
|
39 |
+
"causal": True,
|
40 |
+
"attention_types": ['global', 'local'],
|
41 |
+
"res_dropout": 0.1,
|
42 |
+
"rotary_emb": True,
|
43 |
+
"activation_function": "gelu",
|
44 |
+
"moe_layers": (1,),
|
45 |
+
"num_mem_kv": 16,
|
46 |
+
"no_weight_tie": True,
|
47 |
+
"moe_params": {
|
48 |
+
'moe_dropout_rate': 0.0
|
49 |
+
},
|
50 |
+
"mesh_shape": [],
|
51 |
+
"layout": {},
|
52 |
+
"local_attention_radius": 128,
|
53 |
+
"share_parameters": True,
|
54 |
+
"rezero": True
|
55 |
+
})
|
56 |
+
|
57 |
+
# tests
|
58 |
+
|
59 |
+
def test_model():
|
60 |
+
graph = mtf.Graph()
|
61 |
+
mesh = mtf.Mesh(graph, "my_mesh")
|
62 |
+
|
63 |
+
seq_len = params["n_ctx"]
|
64 |
+
|
65 |
+
batch_dim = mtf.Dimension("batch", 1)
|
66 |
+
sequence_dim = mtf.Dimension("sequence", seq_len)
|
67 |
+
|
68 |
+
features = {
|
69 |
+
'inputs': mtf.ones(mesh, mtf.Shape((batch_dim, sequence_dim)), tf.int32),
|
70 |
+
'labels': mtf.ones(mesh, mtf.Shape((batch_dim, sequence_dim)), tf.int32)
|
71 |
+
}
|
72 |
+
|
73 |
+
# create mask
|
74 |
+
|
75 |
+
num_mem_kv = params.get('num_mem_kv', 0)
|
76 |
+
length_dim = mtf.Dimension('sequence', seq_len)
|
77 |
+
memory_length_dim = mtf.Dimension('memory_length', seq_len + num_mem_kv)
|
78 |
+
embed_sequence_dim = mtf.Dimension('embed_sequence', seq_len)
|
79 |
+
embd_dim = mtf.Dimension("embd", params["n_embd"])
|
80 |
+
vocab_dim = mtf.Dimension("vocab", params["n_vocab"])
|
81 |
+
|
82 |
+
other_features = {}
|
83 |
+
variable_dtype = mtf.VariableDType(tf.float32, tf.float32, tf.float32)
|
84 |
+
|
85 |
+
other_features["attn_bias"] = biasmask_attn_weights(mesh, length_dim, memory_length_dim, variable_dtype)
|
86 |
+
other_features["embd_dim"] = embd_dim
|
87 |
+
other_features["vocab_dim"] = vocab_dim
|
88 |
+
other_features["embed_sequence_dim"] = embed_sequence_dim
|
89 |
+
other_features["memory_length_dim"] = memory_length_dim
|
90 |
+
|
91 |
+
with not_raises(Exception):
|
92 |
+
logits, _, _ = gpt2.model(features, other_features, params, mesh, variable_dtype=variable_dtype)
|
93 |
+
|
94 |
+
mesh_impl = placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""])
|
95 |
+
lowering = mtf.Lowering(graph, {mesh: mesh_impl})
|
96 |
+
logits = lowering.export_to_tf_tensor(logits)
|
97 |
+
|
98 |
+
|
99 |
+
def test_sampling():
|
100 |
+
graph = mtf.Graph()
|
101 |
+
mesh = mtf.Mesh(graph, "my_mesh")
|
102 |
+
|
103 |
+
batch_dim = mtf.Dimension("batch", 1)
|
104 |
+
sequence_dim = mtf.Dimension("sequence", 1)
|
105 |
+
|
106 |
+
inputs = mtf.ones(mesh, mtf.Shape((batch_dim, sequence_dim)), tf.int32)
|
107 |
+
inputs = mtf.pad(inputs, [0, 3], sequence_dim.name)
|
108 |
+
|
109 |
+
# create mask
|
110 |
+
|
111 |
+
seq_len = params["n_ctx"]
|
112 |
+
num_mem_kv = params.get('num_mem_kv', 0)
|
113 |
+
length_dim = mtf.Dimension('sequence', seq_len)
|
114 |
+
memory_length_dim = mtf.Dimension('memory_length', seq_len + num_mem_kv)
|
115 |
+
embed_sequence_dim = mtf.Dimension('embed_sequence', seq_len)
|
116 |
+
embd_dim = mtf.Dimension("embd", params["n_embd"])
|
117 |
+
vocab_dim = mtf.Dimension("vocab", params["n_vocab"])
|
118 |
+
|
119 |
+
other_features = {}
|
120 |
+
|
121 |
+
other_features["attn_bias"] = biasmask_attn_weights(mesh, length_dim, memory_length_dim, mtf.VariableDType(tf.float32))
|
122 |
+
other_features["embd_dim"] = embd_dim
|
123 |
+
other_features["vocab_dim"] = vocab_dim
|
124 |
+
other_features["embed_sequence_dim"] = embed_sequence_dim
|
125 |
+
other_features["memory_length_dim"] = memory_length_dim
|
126 |
+
|
127 |
+
params["mode"] = "predict"
|
128 |
+
|
129 |
+
with not_raises(Exception):
|
130 |
+
samples = sample_autoregressive(
|
131 |
+
inputs, other_features=other_features, params=params, variable_dtype=mtf.VariableDType(),
|
132 |
+
remove_partial_sequences=params["remove_partial_sequences"], stop_at_token=params["eos_id"], sampling_use_entmax=True)
|
133 |
+
|
134 |
+
mesh_impl = placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""])
|
135 |
+
lowering = mtf.Lowering(graph, {mesh: mesh_impl})
|
136 |
+
samples = lowering.export_to_tf_tensor(samples)
|
137 |
+
|
138 |
+
# mlm
|
139 |
+
|
140 |
+
mlm_params = defaultdict(lambda: None, {
|
141 |
+
"n_head": 1,
|
142 |
+
"n_ctx": 4,
|
143 |
+
"n_embd": 1,
|
144 |
+
"n_vocab": 256,
|
145 |
+
"embed_dropout": 0.,
|
146 |
+
"n_layer": 2,
|
147 |
+
"num_microbatches": 1,
|
148 |
+
"train_batch_size": 1,
|
149 |
+
"attention_types": ['global', 'local'],
|
150 |
+
"res_dropout": 0.1,
|
151 |
+
"mesh_shape": [],
|
152 |
+
"layout": {},
|
153 |
+
"share_parameters": True,
|
154 |
+
"mlm_training": True,
|
155 |
+
"mlm_mask_id": 3,
|
156 |
+
"mlm_cls_token_id": 4,
|
157 |
+
"mlm_random_token_prob": 0.1
|
158 |
+
})
|
159 |
+
|
160 |
+
def test_mlm_sample_text():
|
161 |
+
document = tf.random.normal((16,))
|
162 |
+
with not_raises(Exception):
|
163 |
+
features, labels = mlm_sample_text(mlm_params, document, random_documents = True)
|
164 |
+
assert features.shape == (mlm_params['n_ctx'],)
|
165 |
+
|
166 |
+
# entmax
|
167 |
+
|
168 |
+
def test_entmax():
|
169 |
+
graph = mtf.Graph()
|
170 |
+
mesh = mtf.Mesh(graph, "my_mesh")
|
171 |
+
length = mtf.Dimension("tensor_length", 8)
|
172 |
+
tensor = mtf.range(mesh, length, tf.float32)
|
173 |
+
output = entmax(tensor)
|
174 |
+
grad = mtf.gradients([output], [tensor])[0]
|
175 |
+
sample = sample_categorical(output, length)
|
176 |
+
|
177 |
+
mesh_impl = placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""])
|
178 |
+
lowering = mtf.Lowering(graph, {mesh: mesh_impl})
|
179 |
+
sample = lowering.export_to_tf_tensor(sample)
|
180 |
+
grad = lowering.export_to_tf_tensor(grad)
|
utils.py
ADDED
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
from urllib.parse import urlparse
|
3 |
+
from shutil import rmtree
|
4 |
+
import logging
|
5 |
+
import os
|
6 |
+
from pathlib import Path
|
7 |
+
import sys
|
8 |
+
import tensorflow.compat.v1 as tf
|
9 |
+
import tensorflow.compat.v2 as tf2
|
10 |
+
import mesh_tensorflow as mtf
|
11 |
+
from data.encoders import fetch_encoder
|
12 |
+
import re
|
13 |
+
|
14 |
+
def setup_logging(args):
|
15 |
+
Path("logs").mkdir(exist_ok=True)
|
16 |
+
tf.logging.set_verbosity(logging.INFO)
|
17 |
+
tf.get_logger().propagate = False # Remove double log on console
|
18 |
+
name = os.path.splitext(os.path.basename(args.model))[0]
|
19 |
+
handlers = [
|
20 |
+
logging.FileHandler(f"logs/{name}.log"),
|
21 |
+
logging.StreamHandler(sys.stdout)
|
22 |
+
]
|
23 |
+
logger = logging.getLogger("tensorflow")
|
24 |
+
logger.handlers = handlers
|
25 |
+
return logger
|
26 |
+
|
27 |
+
|
28 |
+
def get_batch_size(params):
|
29 |
+
return params[f"{params['mode']}_batch_size"]
|
30 |
+
|
31 |
+
|
32 |
+
def add_mode_to_params(params, mode):
|
33 |
+
if mode == tf.estimator.ModeKeys.PREDICT:
|
34 |
+
params["mode"] = "predict"
|
35 |
+
elif mode == tf.estimator.ModeKeys.EVAL:
|
36 |
+
params["mode"] = "eval"
|
37 |
+
elif mode == tf.estimator.ModeKeys.TRAIN:
|
38 |
+
params["mode"] = "train"
|
39 |
+
else:
|
40 |
+
raise ValueError(f"Invalid mode {mode}")
|
41 |
+
return params
|
42 |
+
|
43 |
+
|
44 |
+
def simd_mesh_setup(params, mesh_shape, layout_rules):
|
45 |
+
"""Constructs SimdMesh function - instructions on how to evenly split tensors across all TPU cores"""
|
46 |
+
|
47 |
+
num_hosts = params["context"].num_hosts
|
48 |
+
host_placement_fn = params["context"].tpu_host_placement_function
|
49 |
+
device_list = [host_placement_fn(host_id=i) for i in range(num_hosts)]
|
50 |
+
tf.logging.info(f"device_list = {device_list}")
|
51 |
+
|
52 |
+
# TODO: Better estimation of replica cache size?
|
53 |
+
replica_cache_size = 300 * 1000000 # 300M per replica
|
54 |
+
|
55 |
+
# Worker 0 caches all the TPU binaries
|
56 |
+
worker0_mem = replica_cache_size * params["context"].num_replicas
|
57 |
+
devices_memory_usage = [worker0_mem] + [0] * (num_hosts - 1)
|
58 |
+
var_placer = mtf.utils.BalancedVariablePlacer(device_list, devices_memory_usage)
|
59 |
+
mesh_devices = [""] * mesh_shape.size
|
60 |
+
mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
|
61 |
+
mesh_shape, layout_rules, mesh_devices, params["context"].device_assignment)
|
62 |
+
|
63 |
+
return var_placer, mesh_impl
|
64 |
+
|
65 |
+
|
66 |
+
def remove_batch_from_layout(layout):
|
67 |
+
"""
|
68 |
+
The tf-mesh layout splits across batch size, remove it.
|
69 |
+
Useful for prediction steps, when you no longer want large batches.
|
70 |
+
|
71 |
+
:param layout: string describing tf-mesh layout
|
72 |
+
:return: layout minus batch dimension
|
73 |
+
"""
|
74 |
+
layout = layout.split(',')
|
75 |
+
ret_layout = ""
|
76 |
+
for i in layout:
|
77 |
+
if "batch" in i:
|
78 |
+
pass
|
79 |
+
else:
|
80 |
+
ret_layout += f"{i},"
|
81 |
+
return ret_layout[:-1]
|
82 |
+
|
83 |
+
|
84 |
+
def yes_or_no(question):
|
85 |
+
while True:
|
86 |
+
reply = str(input(question+' (y/n): ')).lower().strip()
|
87 |
+
if reply[:1] == 'y':
|
88 |
+
return True
|
89 |
+
if reply[:1] == 'n':
|
90 |
+
return False
|
91 |
+
|
92 |
+
|
93 |
+
def remove_gs_or_filepath(path):
|
94 |
+
parsed_url = urlparse(path)
|
95 |
+
if parsed_url.scheme == "gs":
|
96 |
+
os.system(f"gsutil rm -rf {path}")
|
97 |
+
return
|
98 |
+
rmtree(path)
|
99 |
+
|
100 |
+
|
101 |
+
def save_config(params_dict, logdir):
|
102 |
+
print(f"Saving config to {logdir}")
|
103 |
+
text = "{\n\n"
|
104 |
+
total_params = len(params_dict)
|
105 |
+
for count, key in enumerate(params_dict):
|
106 |
+
config_value = str(params_dict[key])
|
107 |
+
if re.search('[a-zA-Z]', config_value):
|
108 |
+
if config_value.lower() != 'true':
|
109 |
+
if config_value.lower() != 'false':
|
110 |
+
if config_value[0] != '[':
|
111 |
+
# TODO: Making a manual exception for parsing epsilon right now since it's the only number in
|
112 |
+
# scientific notation. Should fix this.
|
113 |
+
if key != "epsilon":
|
114 |
+
config_value = f'"{config_value}"'
|
115 |
+
if count == total_params - 1:
|
116 |
+
text += f'"{str(key)}"' + ' : ' + config_value + '\n\n'
|
117 |
+
else:
|
118 |
+
text += f'"{str(key)}"' + ' : ' + config_value + ',\n\n'
|
119 |
+
text += '\n\n}'
|
120 |
+
sess = tf.InteractiveSession()
|
121 |
+
summary_op = tf.summary.text("run_config", tf.convert_to_tensor(text))
|
122 |
+
summary_writer = tf.summary.FileWriter(f"{logdir}/config", sess.graph)
|
123 |
+
text = sess.run(summary_op)
|
124 |
+
summary_writer.add_summary(text, 0)
|
125 |
+
summary_writer.flush()
|
126 |
+
summary_writer.close()
|
127 |
+
tf.reset_default_graph()
|
128 |
+
print('Done!')
|
129 |
+
|
130 |
+
|
131 |
+
def expand_attention_types_params(params_list):
|
132 |
+
newlist = []
|
133 |
+
for item in params_list:
|
134 |
+
for _ in range(item[1]):
|
135 |
+
newlist.extend(item[0])
|
136 |
+
return newlist
|
137 |
+
|
138 |
+
|
139 |
+
def get_n_trainable_vars(graph):
|
140 |
+
"""
|
141 |
+
Gets number of trainable vars in a MTF model.
|
142 |
+
|
143 |
+
:param graph: Mesh-Tensorflow graph
|
144 |
+
:return: None
|
145 |
+
"""
|
146 |
+
total_parameters = 0
|
147 |
+
for variable in graph.trainable_variables:
|
148 |
+
shape = variable.shape.dims
|
149 |
+
variable_parameters = 1
|
150 |
+
for dim in shape:
|
151 |
+
variable_parameters *= dim.size
|
152 |
+
total_parameters += variable_parameters
|
153 |
+
print(f"\n\nN TRAINABLE VARS:\n{total_parameters:,}\n\n")
|
154 |
+
|
155 |
+
|
156 |
+
def print_dim_names(graph):
|
157 |
+
"""
|
158 |
+
Print names of all Dimensions
|
159 |
+
:param graph: Mesh-Tensorflow graph
|
160 |
+
:return: None
|
161 |
+
"""
|
162 |
+
all_dim_names = []
|
163 |
+
for variable in graph.all_variables:
|
164 |
+
names = variable.shape.dimension_names
|
165 |
+
all_dim_names.append(names)
|
166 |
+
|
167 |
+
# Print all dim names in graph & write to file
|
168 |
+
all_dim_names = [item for sublist in all_dim_names for item in sublist] # Flatten all dims
|
169 |
+
unique_dims = list(set(all_dim_names))
|
170 |
+
print("ALL DIM NAMES:")
|
171 |
+
for dim_name in unique_dims:
|
172 |
+
print(dim_name)
|
173 |
+
print('\n')
|
174 |
+
|
175 |
+
|
176 |
+
def get_graph_info(graph):
|
177 |
+
"""
|
178 |
+
Wrapper fn that calculates number of trainable vars in an MTF graph & prints all dim_names to file
|
179 |
+
TODO: how to get un-trainable dim-names too, batch etc.
|
180 |
+
|
181 |
+
:param graph: Mesh-Tensorflow graph
|
182 |
+
:return: None
|
183 |
+
"""
|
184 |
+
get_n_trainable_vars(graph)
|
185 |
+
print_dim_names(graph)
|
186 |
+
|
187 |
+
|
188 |
+
def loss_denominator(targets, num_microbatches):
|
189 |
+
"""Denominator applied to losses.
|
190 |
+
|
191 |
+
This is usually the size of the targets tensor (omitting ensemble
|
192 |
+
dimensions). Alternatively, it is an override value passed to the
|
193 |
+
class constructor.
|
194 |
+
|
195 |
+
Args:
|
196 |
+
targets: a mtf.Tensor
|
197 |
+
num_microbatches: an integer - greater than one if the step has been
|
198 |
+
serialized into multiple microbatches to save memory.
|
199 |
+
Returns:
|
200 |
+
a float
|
201 |
+
"""
|
202 |
+
ret = float(targets.shape.size) * num_microbatches
|
203 |
+
return float(ret)
|
204 |
+
|
205 |
+
def check_dataset(input_fn, params, global_step=None):
|
206 |
+
tf.enable_eager_execution()
|
207 |
+
if global_step is not None:
|
208 |
+
dataset = input_fn(params, global_step=global_step)
|
209 |
+
else:
|
210 |
+
dataset = input_fn(params)
|
211 |
+
dataset_iter = dataset.make_one_shot_iterator()
|
212 |
+
tensor, _ = next(dataset_iter)
|
213 |
+
enc = fetch_encoder(params)
|
214 |
+
|
215 |
+
for p in tensor[:1]:
|
216 |
+
txt = enc.decode(p)
|
217 |
+
|
218 |
+
print('-' * 50)
|
219 |
+
print(txt[:500], '\n\n...\n\n', txt[-500:])
|
220 |
+
print('-' * 50)
|
221 |
+
exit()
|
222 |
+
|
223 |
+
def auto_layout(graph, mesh_shape, logits, loss):
|
224 |
+
layout_rules = mtf.auto_mtf.layout(graph, mesh_shape, [logits, loss])
|
225 |
+
print(f"Auto-selected layout:\n{layout_rules}\nRe-initialize graph with selected layout")
|
226 |
+
quit()
|
227 |
+
|
228 |
+
def auto_layout_and_mesh_shape(graph, num_cores, logits, loss):
|
229 |
+
layout_rules, mesh_shape = mtf.auto_mtf.layout_and_mesh_shape(graph, num_cores,
|
230 |
+
[logits, loss], max_mesh_shape_dimensions=4)
|
231 |
+
print(f"Num cores:\n{num_cores}\nAuto-selected layout:\n{layout_rules}\nAuto-selected mesh shape:\n{mesh_shape}" \
|
232 |
+
f"\nRe-initialize graph with selected layout & mesh shape")
|
233 |
+
quit()
|
234 |
+
|
235 |
+
def create_host_call(model_dir):
|
236 |
+
"""Construct a host_call writing scalar summaries.
|
237 |
+
|
238 |
+
Borrowed from t2t.
|
239 |
+
|
240 |
+
Args:
|
241 |
+
model_dir: String containing path to train
|
242 |
+
Returns:
|
243 |
+
(fn, args) Pair to be called by TPUEstimator as the host_call.
|
244 |
+
"""
|
245 |
+
|
246 |
+
graph = tf.get_default_graph()
|
247 |
+
# A list of (name, lowered tensor) tuples
|
248 |
+
summaries = graph.get_collection(mtf.utils.SCALAR_SUMMARIES_COLLECTION_KEY)
|
249 |
+
|
250 |
+
def maybe_cast(tensor):
|
251 |
+
assert tensor.shape.is_compatible_with([]), tensor.name
|
252 |
+
if tensor.dtype == tf.int64:
|
253 |
+
return tf.to_int32(tensor)
|
254 |
+
if tensor.dtype == tf.bfloat16:
|
255 |
+
return tf.cast(tensor, tf.float32)
|
256 |
+
return tensor
|
257 |
+
|
258 |
+
reshaped_tensors = [tf.reshape(maybe_cast(t), [1]) for _, t in summaries]
|
259 |
+
|
260 |
+
# When no supported summaries are found, don't create host_call. Otherwise,
|
261 |
+
# TPU outfeed queue would enqueue global_step while host_call doesn't dequeue
|
262 |
+
# it, eventually causing hang.
|
263 |
+
if not reshaped_tensors:
|
264 |
+
return None
|
265 |
+
|
266 |
+
def host_call_fn(global_step, *args):
|
267 |
+
"""Training host call. Creates scalar summaries for training metrics."""
|
268 |
+
# This function is executed on the CPU and should not directly reference
|
269 |
+
# any Tensors in the rest of the `model_fn`. To pass Tensors from the
|
270 |
+
# model to the `model_fn`, provide as part of the `host_call`.
|
271 |
+
global_step = tf.cast(global_step[0], tf.int64)
|
272 |
+
with tf2.summary.create_file_writer(model_dir).as_default():
|
273 |
+
# We cannot directly use any tensor from summaries, because each
|
274 |
+
# tensor here must be a concat of multiple tensors from all shards.
|
275 |
+
# Therefore, we rely on the assumption that args wil have the same
|
276 |
+
# length as summaries, and all tensors in args will have the same
|
277 |
+
# order of self._tup_summaries.
|
278 |
+
assert len(args) == len(summaries)
|
279 |
+
for i, tensor in enumerate(args):
|
280 |
+
name = summaries[i][0]
|
281 |
+
tf2.summary.scalar(name, tf.reduce_mean(tensor), step=global_step)
|
282 |
+
return tf.summary.all_v2_summary_ops()
|
283 |
+
|
284 |
+
global_step_t = tf.reshape(tf.to_int32(tf.train.get_global_step()), [1])
|
285 |
+
return host_call_fn, [global_step_t] + reshaped_tensors
|
286 |
+
|
287 |
+
|
288 |
+
def natural_sort(l):
|
289 |
+
convert = lambda text: int(text) if text.isdigit() else text.lower()
|
290 |
+
alphanum_key = lambda key: [ convert(c) for c in re.split('([0-9]+)', key) ]
|
291 |
+
return sorted(l, key = alphanum_key)
|