Andrew DalPino
commited on
Commit
·
3325763
1
Parent(s):
0cc4ecd
Use Fineweb instead of Openwebtext
Browse files- README.md +27 -12
- data.py +48 -32
- model_sizing.ipynb +27 -22
- pre-train.py +23 -9
README.md
CHANGED
@@ -14,19 +14,32 @@ tags:
|
|
14 |
---
|
15 |
# LightGPT
|
16 |
|
17 |
-
LightGPT is a lightweight generative pre-trained Transformer (GPT) model for the people! Built using pure PyTorch, LightGPT can
|
18 |
|
19 |
## Features
|
20 |
|
21 |
- **Parameter-efficiency**: LightGPT aims to be a more parsimonious model by only training parameters that are absolutely necessary. As such, biases and positional embeddings have been completely removed from the architecture. In addition, the token embeddings and output layer share weight matrices resulting in a buy-one-get-one-free deal on trainable parameters.
|
22 |
|
23 |
-
- **Low Memory Utilization**: LightGPT employs a number of training-time optimizations that conserve precious
|
24 |
|
25 |
- **Fully Open-source**: Unlike closed-source LLMs, LightGPT provides both the model weights *and* the source code to train, fine-tune, and generate text from the model using your own hardware. With the help of the open-source software community, we aim to democratize AI and continually improve the models.
|
26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
## Install Project Dependencies
|
28 |
|
29 |
-
Project dependencies are specified in the `requirements.txt` file. You can install them with [pip](https://pip.pypa.io/en/stable/) using the following command from the project root.
|
30 |
|
31 |
```
|
32 |
python -m venv ./.venv
|
@@ -38,7 +51,7 @@ pip install -r requirements.txt
|
|
38 |
|
39 |
## Pre-training
|
40 |
|
41 |
-
For the pre-training corpus we use the
|
42 |
|
43 |
```
|
44 |
python pre-train.py
|
@@ -88,26 +101,28 @@ Soon ...
|
|
88 |
|
89 |
| Argument | Default | Type | Description |
|
90 |
|---|---|---|---|
|
|
|
|
|
|
|
|
|
91 |
| --batch_size | 1 | int | The number of samples to pass through the network at a time. |
|
92 |
| --gradient_accumulation_steps | 128 | int | The number of batches to pass through the network before updating the weights. |
|
93 |
| --samples_per_epoch | 4096 | int | The number of training samples to pass through the network every epoch. |
|
94 |
| --learning_rate | 5e-4 | float | The global step size taken after every gradient accumulation step. |
|
95 |
| --max_gradient_norm | 1.0 | float | Clip gradients above this threshold before stepping. |
|
96 |
-
| --num_epochs |
|
97 |
| --eval_interval | 10 | int | Evaluate the model after this many epochs on the testing set. |
|
98 |
| --block_size | 1024 | int | The number of tokens within the context window for every sample. |
|
99 |
| --embedding_dimensions | 1024 | int | The dimensionality of the token embeddings. |
|
100 |
| --num_attention_heads | 16 | int | The number of attention heads within every block. |
|
101 |
-
| --num_hidden_layers |
|
102 |
| --dropout | 0.1 | float | The proportion of signals to send to zero during training as regularization. |
|
103 |
-
| --activation_checkpointing | False | bool | Should we use activation checkpointing? |
|
104 |
-
| --ddp_sharding_level | 2 |
|
105 |
| --checkpoint_interval | 20 | int | Save the model parameters to disk every this many epochs. |
|
106 |
-
| --checkpoint_path | "./out/checkpoint.pt" |
|
107 |
-
| --dataset_path | "./dataset" | string | The path to the dataset files on disk. |
|
108 |
-
| --num_dataset_processes | 8 | int | The number of processes (CPUs) to use to process the dataset. |
|
109 |
| --resume | False | bool | Should we resume training from the last checkpoint? |
|
110 |
-
| --device | "cuda" |
|
111 |
| --seed | None | int | The seed for the random number generator. |
|
112 |
|
113 |
### Instruction-tuning Arguments
|
|
|
14 |
---
|
15 |
# LightGPT
|
16 |
|
17 |
+
LightGPT is a lightweight generative pre-trained Transformer (GPT) model for the people! Built using pure PyTorch, LightGPT can answer questions, summarize documents, chat, and more. A unique feature of LightGPT is that you can train larger models on smaller hardware by progressively enabling memory-saving features at train time such as activation checkpointing, mixed-precision, and ZeRO redundancy distributed pre-training using fully-sharded data parallel (FSDP).
|
18 |
|
19 |
## Features
|
20 |
|
21 |
- **Parameter-efficiency**: LightGPT aims to be a more parsimonious model by only training parameters that are absolutely necessary. As such, biases and positional embeddings have been completely removed from the architecture. In addition, the token embeddings and output layer share weight matrices resulting in a buy-one-get-one-free deal on trainable parameters.
|
22 |
|
23 |
+
- **Low Memory Utilization**: LightGPT employs a number of training-time optimizations that conserve precious GPU memory. With zero-redundancy distributed pre-training using fully-sharded data-parallel (FSDP), activation checkpointing, and automatic mixed precision, you'll be able to train larger models by accepting a relatively small amount of overhead.
|
24 |
|
25 |
- **Fully Open-source**: Unlike closed-source LLMs, LightGPT provides both the model weights *and* the source code to train, fine-tune, and generate text from the model using your own hardware. With the help of the open-source software community, we aim to democratize AI and continually improve the models.
|
26 |
|
27 |
+
## Default Configurations
|
28 |
+
|
29 |
+
Below is a table of recommended default model training configurations but feel free to experiment with settings on your own. See the `model_sizing.ipynb` notebook to estimate the memory and compute requirements for your model configuration.
|
30 |
+
|
31 |
+
| Name | Vocab. Size | Block Size | Embedding Dim. | Attn. Heads | Layers | Params | Train Tokens |
|
32 |
+
|---|---|---|---|---|---|---|---|
|
33 |
+
| Small | 50,257 | 1024 | 1024 | 16 | 32 | 454M | 10B |
|
34 |
+
| Medium | 50,257 | 1024 | 2048 | 32 | 32 | 1.7B | 20B |
|
35 |
+
| Large | 100,275 | 2048 | 4096 | 64 | 32 | 6.8B | 100B |
|
36 |
+
| X-large | 100,275 | 2048 | 4096 | 64 | 64 | 13B | 350B |
|
37 |
+
| XX-large | 200,017 | 4096 | 8192 | 128 | 64 | 53B | 1T |
|
38 |
+
| XXX-large | 200,017 | 4096 | 8192 | 128 | 128 | 105B | 3T |
|
39 |
+
|
40 |
## Install Project Dependencies
|
41 |
|
42 |
+
Project dependencies are specified in the `requirements.txt` file. You can install them with [pip](https://pip.pypa.io/en/stable/) using the following command from the project root. We recommend using a virtual environment such as `venv` to keep package dependencies on your system tidy.
|
43 |
|
44 |
```
|
45 |
python -m venv ./.venv
|
|
|
51 |
|
52 |
## Pre-training
|
53 |
|
54 |
+
For the pre-training corpus we use the Fineweb dataset which consists of about 15T high-quality tokens gathered from the worldwide web. The dataset has been split into 3 subsets (10BT, 100BT, and 350BT versions) for training smaller models. If you'd like to start training right away, the default settings should work on most single-GPU systems with 12G of VRAM or more.
|
55 |
|
56 |
```
|
57 |
python pre-train.py
|
|
|
101 |
|
102 |
| Argument | Default | Type | Description |
|
103 |
|---|---|---|---|
|
104 |
+
| --dataset_subset | "sample-10BT" | str | The subset of the Fineweb dataset to train on. Options are `sample-10BT`, `sample-100BT`, and `sample-350BT`. Set to `None` to train on the full 15T token dataset. |
|
105 |
+
| --token_encoding | "r50k_base" | str | The encoding scheme to use when tokenizing the dataset. Options include `r50k_base`, `cl100k_base`, and `o200k_base`. |
|
106 |
+
| --dataset_path | "./dataset" | str | The path to the preprocessed dataset files on disk. |
|
107 |
+
| --num_dataset_processes | 8 | int | The number of processes (CPUs) to use to process the dataset. |
|
108 |
| --batch_size | 1 | int | The number of samples to pass through the network at a time. |
|
109 |
| --gradient_accumulation_steps | 128 | int | The number of batches to pass through the network before updating the weights. |
|
110 |
| --samples_per_epoch | 4096 | int | The number of training samples to pass through the network every epoch. |
|
111 |
| --learning_rate | 5e-4 | float | The global step size taken after every gradient accumulation step. |
|
112 |
| --max_gradient_norm | 1.0 | float | Clip gradients above this threshold before stepping. |
|
113 |
+
| --num_epochs | 2384 | int | The number of epochs to train for. |
|
114 |
| --eval_interval | 10 | int | Evaluate the model after this many epochs on the testing set. |
|
115 |
| --block_size | 1024 | int | The number of tokens within the context window for every sample. |
|
116 |
| --embedding_dimensions | 1024 | int | The dimensionality of the token embeddings. |
|
117 |
| --num_attention_heads | 16 | int | The number of attention heads within every block. |
|
118 |
+
| --num_hidden_layers | 32 | int | The number of attention/MLP blocks within the hidden layer of the network. |
|
119 |
| --dropout | 0.1 | float | The proportion of signals to send to zero during training as regularization. |
|
120 |
+
| --activation_checkpointing | False | bool | Should we use activation checkpointing? This will drastically reduce memory utilization at the cost of about 30% more runtime per epoch. |
|
121 |
+
| --ddp_sharding_level | 2 | int | The level of sharding to use for DDP training. Options are 2 or 3 for partial and full sharding respectively, or 0 for no sharding. |
|
122 |
| --checkpoint_interval | 20 | int | Save the model parameters to disk every this many epochs. |
|
123 |
+
| --checkpoint_path | "./out/checkpoint.pt" | str | The path to the checkpoint file on disk. |
|
|
|
|
|
124 |
| --resume | False | bool | Should we resume training from the last checkpoint? |
|
125 |
+
| --device | "cuda" | str | The device to run the computation on. |
|
126 |
| --seed | None | int | The seed for the random number generator. |
|
127 |
|
128 |
### Instruction-tuning Arguments
|
data.py
CHANGED
@@ -18,56 +18,68 @@ from torch.nn.utils.rnn import pad_sequence
|
|
18 |
from tqdm import tqdm
|
19 |
|
20 |
|
21 |
-
class
|
22 |
-
DATASET_NAME = "
|
23 |
-
|
24 |
-
FILE_PREFIX = DATASET_NAME
|
25 |
-
|
26 |
-
TRAIN_FILENAME = f"{FILE_PREFIX}-train.bin"
|
27 |
-
TEST_FILENAME = f"{FILE_PREFIX}-test.bin"
|
28 |
|
29 |
TEST_SPLIT_PROPORTION = 0.005
|
30 |
NUM_SHARDS = 1024
|
31 |
|
32 |
-
ENCODING = "r50k_base"
|
33 |
-
|
34 |
PADDING_INDEX = -100
|
35 |
|
36 |
def __init__(
|
37 |
self,
|
38 |
-
root_path: str,
|
39 |
-
|
|
|
40 |
tokens_per_sample: int = 1024,
|
41 |
samples_per_epoch: int = 4096,
|
|
|
42 |
num_processes: int = 8,
|
43 |
):
|
44 |
super().__init__()
|
45 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
if tokens_per_sample < 1:
|
47 |
raise ValueError(f"Tokens per sample must be greater than 0.")
|
48 |
|
49 |
if samples_per_epoch < 1:
|
50 |
raise ValueError(f"Samples per epoch must be greater than 0.")
|
51 |
|
52 |
-
|
53 |
-
|
54 |
|
55 |
-
self.tokenizer = tiktoken.get_encoding(
|
|
|
|
|
|
|
|
|
|
|
56 |
|
57 |
if not path.exists(train_path) or not path.exists(test_path):
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
|
|
|
|
|
|
|
|
|
|
67 |
)
|
68 |
|
69 |
for split, dataset in tokenized_splits.items():
|
70 |
-
bin_path =
|
71 |
|
72 |
total_length = np.sum(dataset["length"], dtype=np.uint64)
|
73 |
|
@@ -92,9 +104,7 @@ class Openwebtext(IterableDataset):
|
|
92 |
|
93 |
bin_out.flush()
|
94 |
|
95 |
-
bin_file_path =
|
96 |
-
root_path, self.TRAIN_FILENAME if train else self.TEST_FILENAME
|
97 |
-
)
|
98 |
|
99 |
memmap = np.memmap(bin_file_path, dtype=np.uint16, mode="r")
|
100 |
|
@@ -140,8 +150,6 @@ class Openwebtext(IterableDataset):
|
|
140 |
class Alpaca(Dataset):
|
141 |
DATASET_NAME = "tatsu-lab/alpaca"
|
142 |
|
143 |
-
ENCODING = "r50k_base"
|
144 |
-
|
145 |
PADDING_INDEX = -100
|
146 |
|
147 |
PROMPT_TEMPLATE = (
|
@@ -162,7 +170,12 @@ class Alpaca(Dataset):
|
|
162 |
|
163 |
RESPONSE_TEMPLATE = "{output}"
|
164 |
|
165 |
-
def __init__(
|
|
|
|
|
|
|
|
|
|
|
166 |
super().__init__()
|
167 |
|
168 |
if max_tokens_per_sample < 1:
|
@@ -170,9 +183,12 @@ class Alpaca(Dataset):
|
|
170 |
f"Max tokens per sample must be greater than 0, {max_tokens_per_sample} given."
|
171 |
)
|
172 |
|
173 |
-
|
|
|
174 |
|
175 |
-
self.tokenizer = tiktoken.get_encoding(
|
|
|
|
|
176 |
|
177 |
self.max_tokens_per_sample = max_tokens_per_sample
|
178 |
self.mask_input = mask_input
|
|
|
18 |
from tqdm import tqdm
|
19 |
|
20 |
|
21 |
+
class Fineweb(IterableDataset):
|
22 |
+
DATASET_NAME = "HuggingFaceFW/fineweb"
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
TEST_SPLIT_PROPORTION = 0.005
|
25 |
NUM_SHARDS = 1024
|
26 |
|
|
|
|
|
27 |
PADDING_INDEX = -100
|
28 |
|
29 |
def __init__(
|
30 |
self,
|
31 |
+
root_path: str = "./dataset",
|
32 |
+
subset: str | None = "sample-10BT",
|
33 |
+
split: str = "train",
|
34 |
tokens_per_sample: int = 1024,
|
35 |
samples_per_epoch: int = 4096,
|
36 |
+
token_encoding: str = "r50k_base",
|
37 |
num_processes: int = 8,
|
38 |
):
|
39 |
super().__init__()
|
40 |
|
41 |
+
if subset != None:
|
42 |
+
if subset not in ("sample-10BT", "sample-100BT", "sample-350BT"):
|
43 |
+
raise ValueError(f"Invalid subset, {subset} given.")
|
44 |
+
|
45 |
+
if split not in ("train", "test"):
|
46 |
+
raise ValueError(f"Split must be either train or test, {split} given.")
|
47 |
+
|
48 |
if tokens_per_sample < 1:
|
49 |
raise ValueError(f"Tokens per sample must be greater than 0.")
|
50 |
|
51 |
if samples_per_epoch < 1:
|
52 |
raise ValueError(f"Samples per epoch must be greater than 0.")
|
53 |
|
54 |
+
if token_encoding not in ("r50k_base", "cl100k_base", "o200k_base"):
|
55 |
+
raise ValueError(f"Invalid token encoding, {token_encoding} given.")
|
56 |
|
57 |
+
self.tokenizer = tiktoken.get_encoding(token_encoding)
|
58 |
+
|
59 |
+
dataset_name = f"fineweb-{subset}" if subset != None else "fineweb"
|
60 |
+
|
61 |
+
train_path = path.join(root_path, f"{dataset_name}-train-{token_encoding}.bin")
|
62 |
+
test_path = path.join(root_path, f"{dataset_name}-test-{token_encoding}.bin")
|
63 |
|
64 |
if not path.exists(train_path) or not path.exists(test_path):
|
65 |
+
dataset = load_dataset(
|
66 |
+
self.DATASET_NAME,
|
67 |
+
name=subset,
|
68 |
+
num_proc=num_processes,
|
69 |
+
split="train",
|
70 |
+
).map(
|
71 |
+
self.tokenize,
|
72 |
+
desc="Tokenizing",
|
73 |
+
remove_columns=["text"],
|
74 |
+
num_proc=num_processes,
|
75 |
+
)
|
76 |
+
|
77 |
+
tokenized_splits = dataset.train_test_split(
|
78 |
+
test_size=self.TEST_SPLIT_PROPORTION
|
79 |
)
|
80 |
|
81 |
for split, dataset in tokenized_splits.items():
|
82 |
+
bin_path = train_path if split == "train" else test_path
|
83 |
|
84 |
total_length = np.sum(dataset["length"], dtype=np.uint64)
|
85 |
|
|
|
104 |
|
105 |
bin_out.flush()
|
106 |
|
107 |
+
bin_file_path = train_path if split == "train" else test_path
|
|
|
|
|
108 |
|
109 |
memmap = np.memmap(bin_file_path, dtype=np.uint16, mode="r")
|
110 |
|
|
|
150 |
class Alpaca(Dataset):
|
151 |
DATASET_NAME = "tatsu-lab/alpaca"
|
152 |
|
|
|
|
|
153 |
PADDING_INDEX = -100
|
154 |
|
155 |
PROMPT_TEMPLATE = (
|
|
|
170 |
|
171 |
RESPONSE_TEMPLATE = "{output}"
|
172 |
|
173 |
+
def __init__(
|
174 |
+
self,
|
175 |
+
max_tokens_per_sample: int = 1024,
|
176 |
+
token_encoding: str = "r50k_base",
|
177 |
+
mask_input: bool = True,
|
178 |
+
):
|
179 |
super().__init__()
|
180 |
|
181 |
if max_tokens_per_sample < 1:
|
|
|
183 |
f"Max tokens per sample must be greater than 0, {max_tokens_per_sample} given."
|
184 |
)
|
185 |
|
186 |
+
if token_encoding not in ("r50k_base", "cl100k_base", "o200k_base"):
|
187 |
+
raise ValueError(f"Invalid token encoding, {token_encoding} given.")
|
188 |
|
189 |
+
self.tokenizer = tiktoken.get_encoding(token_encoding)
|
190 |
+
|
191 |
+
self.dataset = load_dataset(self.DATASET_NAME, split="train")
|
192 |
|
193 |
self.max_tokens_per_sample = max_tokens_per_sample
|
194 |
self.mask_input = mask_input
|
model_sizing.ipynb
CHANGED
@@ -9,7 +9,7 @@
|
|
9 |
},
|
10 |
{
|
11 |
"cell_type": "code",
|
12 |
-
"execution_count":
|
13 |
"metadata": {},
|
14 |
"outputs": [],
|
15 |
"source": [
|
@@ -17,7 +17,9 @@
|
|
17 |
"vocabulary_size = 50257\n",
|
18 |
"embedding_dimensions = 1024\n",
|
19 |
"num_attention_heads = 16\n",
|
20 |
-
"num_hidden_layers = 32"
|
|
|
|
|
21 |
]
|
22 |
},
|
23 |
{
|
@@ -29,7 +31,7 @@
|
|
29 |
},
|
30 |
{
|
31 |
"cell_type": "code",
|
32 |
-
"execution_count":
|
33 |
"metadata": {},
|
34 |
"outputs": [
|
35 |
{
|
@@ -95,14 +97,14 @@
|
|
95 |
},
|
96 |
{
|
97 |
"cell_type": "code",
|
98 |
-
"execution_count":
|
99 |
"metadata": {},
|
100 |
"outputs": [
|
101 |
{
|
102 |
"name": "stdout",
|
103 |
"output_type": "stream",
|
104 |
"text": [
|
105 |
-
"Total gigabytes: 1.
|
106 |
]
|
107 |
}
|
108 |
],
|
@@ -113,7 +115,7 @@
|
|
113 |
"\n",
|
114 |
"total_gigabytes = total_bytes / 1e9\n",
|
115 |
"\n",
|
116 |
-
"print(f\"Total gigabytes: {total_gigabytes:,.2f}\")"
|
117 |
]
|
118 |
},
|
119 |
{
|
@@ -125,7 +127,7 @@
|
|
125 |
},
|
126 |
{
|
127 |
"cell_type": "code",
|
128 |
-
"execution_count":
|
129 |
"metadata": {},
|
130 |
"outputs": [
|
131 |
{
|
@@ -220,7 +222,7 @@
|
|
220 |
},
|
221 |
{
|
222 |
"cell_type": "code",
|
223 |
-
"execution_count":
|
224 |
"metadata": {},
|
225 |
"outputs": [
|
226 |
{
|
@@ -246,7 +248,7 @@
|
|
246 |
},
|
247 |
{
|
248 |
"cell_type": "code",
|
249 |
-
"execution_count":
|
250 |
"metadata": {},
|
251 |
"outputs": [
|
252 |
{
|
@@ -272,7 +274,7 @@
|
|
272 |
},
|
273 |
{
|
274 |
"cell_type": "code",
|
275 |
-
"execution_count":
|
276 |
"metadata": {},
|
277 |
"outputs": [
|
278 |
{
|
@@ -300,7 +302,7 @@
|
|
300 |
},
|
301 |
{
|
302 |
"cell_type": "code",
|
303 |
-
"execution_count":
|
304 |
"metadata": {},
|
305 |
"outputs": [
|
306 |
{
|
@@ -308,8 +310,10 @@
|
|
308 |
"output_type": "stream",
|
309 |
"text": [
|
310 |
"RTX A2000 MFU: 17.29%\n",
|
|
|
311 |
"RTX 3090 MFU: 22.99%\n",
|
312 |
-
"A100 SXM MFU: 37.16%\n"
|
|
|
313 |
]
|
314 |
}
|
315 |
],
|
@@ -332,8 +336,10 @@
|
|
332 |
"\n",
|
333 |
"devices = [\n",
|
334 |
" Device(\"RTX A2000\", 63.9e12, 3.45 * total_roundtrip_flops),\n",
|
|
|
335 |
" Device(\"RTX 3090\", 285.5e12, 20.5 * total_roundtrip_flops),\n",
|
336 |
" Device(\"A100 SXM\", 624.0e12, 72.4 * total_roundtrip_flops),\n",
|
|
|
337 |
"]\n",
|
338 |
"\n",
|
339 |
"for device in devices:\n",
|
@@ -344,31 +350,30 @@
|
|
344 |
"cell_type": "markdown",
|
345 |
"metadata": {},
|
346 |
"source": [
|
347 |
-
"
|
348 |
]
|
349 |
},
|
350 |
{
|
351 |
"cell_type": "code",
|
352 |
-
"execution_count":
|
353 |
"metadata": {},
|
354 |
"outputs": [
|
355 |
{
|
356 |
"name": "stdout",
|
357 |
"output_type": "stream",
|
358 |
"text": [
|
359 |
-
"Total tokens:
|
360 |
-
"Epochs required: 2,
|
361 |
"\n",
|
362 |
-
"RTX A2000: 1187.25 seconds/epoch,
|
363 |
-
"RTX
|
364 |
-
"
|
|
|
|
|
365 |
]
|
366 |
}
|
367 |
],
|
368 |
"source": [
|
369 |
-
"num_training_tokens = 8994885755\n",
|
370 |
-
"samples_per_epoch = 4096\n",
|
371 |
-
"\n",
|
372 |
"num_epochs_required = round(num_training_tokens / (samples_per_epoch * block_size))\n",
|
373 |
"\n",
|
374 |
"print(f\"Total tokens: {num_training_tokens:,}\")\n",
|
|
|
9 |
},
|
10 |
{
|
11 |
"cell_type": "code",
|
12 |
+
"execution_count": 252,
|
13 |
"metadata": {},
|
14 |
"outputs": [],
|
15 |
"source": [
|
|
|
17 |
"vocabulary_size = 50257\n",
|
18 |
"embedding_dimensions = 1024\n",
|
19 |
"num_attention_heads = 16\n",
|
20 |
+
"num_hidden_layers = 32\n",
|
21 |
+
"num_training_tokens = 10e9\n",
|
22 |
+
"samples_per_epoch = 4096"
|
23 |
]
|
24 |
},
|
25 |
{
|
|
|
31 |
},
|
32 |
{
|
33 |
"cell_type": "code",
|
34 |
+
"execution_count": 253,
|
35 |
"metadata": {},
|
36 |
"outputs": [
|
37 |
{
|
|
|
97 |
},
|
98 |
{
|
99 |
"cell_type": "code",
|
100 |
+
"execution_count": 254,
|
101 |
"metadata": {},
|
102 |
"outputs": [
|
103 |
{
|
104 |
"name": "stdout",
|
105 |
"output_type": "stream",
|
106 |
"text": [
|
107 |
+
"Total gigabytes: 1.82G\n"
|
108 |
]
|
109 |
}
|
110 |
],
|
|
|
115 |
"\n",
|
116 |
"total_gigabytes = total_bytes / 1e9\n",
|
117 |
"\n",
|
118 |
+
"print(f\"Total gigabytes: {total_gigabytes:,.2f}G\")"
|
119 |
]
|
120 |
},
|
121 |
{
|
|
|
127 |
},
|
128 |
{
|
129 |
"cell_type": "code",
|
130 |
+
"execution_count": 255,
|
131 |
"metadata": {},
|
132 |
"outputs": [
|
133 |
{
|
|
|
222 |
},
|
223 |
{
|
224 |
"cell_type": "code",
|
225 |
+
"execution_count": 256,
|
226 |
"metadata": {},
|
227 |
"outputs": [
|
228 |
{
|
|
|
248 |
},
|
249 |
{
|
250 |
"cell_type": "code",
|
251 |
+
"execution_count": 257,
|
252 |
"metadata": {},
|
253 |
"outputs": [
|
254 |
{
|
|
|
274 |
},
|
275 |
{
|
276 |
"cell_type": "code",
|
277 |
+
"execution_count": 258,
|
278 |
"metadata": {},
|
279 |
"outputs": [
|
280 |
{
|
|
|
302 |
},
|
303 |
{
|
304 |
"cell_type": "code",
|
305 |
+
"execution_count": 259,
|
306 |
"metadata": {},
|
307 |
"outputs": [
|
308 |
{
|
|
|
310 |
"output_type": "stream",
|
311 |
"text": [
|
312 |
"RTX A2000 MFU: 17.29%\n",
|
313 |
+
"RTX A4000 MFU: 19.00%\n",
|
314 |
"RTX 3090 MFU: 22.99%\n",
|
315 |
+
"A100 SXM MFU: 37.16%\n",
|
316 |
+
"HGX A100 MFU: 37.16%\n"
|
317 |
]
|
318 |
}
|
319 |
],
|
|
|
336 |
"\n",
|
337 |
"devices = [\n",
|
338 |
" Device(\"RTX A2000\", 63.9e12, 3.45 * total_roundtrip_flops),\n",
|
339 |
+
" Device(\"RTX A4000\", 153.4e12, 9.1 * total_roundtrip_flops),\n",
|
340 |
" Device(\"RTX 3090\", 285.5e12, 20.5 * total_roundtrip_flops),\n",
|
341 |
" Device(\"A100 SXM\", 624.0e12, 72.4 * total_roundtrip_flops),\n",
|
342 |
+
" Device(\"HGX A100\", 4992e12, 579.2 * total_roundtrip_flops),\n",
|
343 |
"]\n",
|
344 |
"\n",
|
345 |
"for device in devices:\n",
|
|
|
350 |
"cell_type": "markdown",
|
351 |
"metadata": {},
|
352 |
"source": [
|
353 |
+
"Finally, let's estimate how long it would take to train over every sample in the Openwebtext training set at least once in expectation. Note that these results shown here are a theoretical scenario and do not factor in additional overhead such as activation checkpointing or network latency."
|
354 |
]
|
355 |
},
|
356 |
{
|
357 |
"cell_type": "code",
|
358 |
+
"execution_count": 260,
|
359 |
"metadata": {},
|
360 |
"outputs": [
|
361 |
{
|
362 |
"name": "stdout",
|
363 |
"output_type": "stream",
|
364 |
"text": [
|
365 |
+
"Total tokens: 10,000,000,000.0\n",
|
366 |
+
"Epochs required: 2,384\n",
|
367 |
"\n",
|
368 |
+
"RTX A2000: 1187.25 seconds/epoch, 32.76 days required\n",
|
369 |
+
"RTX A4000: 450.11 seconds/epoch, 12.42 days required\n",
|
370 |
+
"RTX 3090: 199.80 seconds/epoch, 5.51 days required\n",
|
371 |
+
"A100 SXM: 56.57 seconds/epoch, 1.56 days required\n",
|
372 |
+
"HGX A100: 7.07 seconds/epoch, 0.20 days required\n"
|
373 |
]
|
374 |
}
|
375 |
],
|
376 |
"source": [
|
|
|
|
|
|
|
377 |
"num_epochs_required = round(num_training_tokens / (samples_per_epoch * block_size))\n",
|
378 |
"\n",
|
379 |
"print(f\"Total tokens: {num_training_tokens:,}\")\n",
|
pre-train.py
CHANGED
@@ -20,7 +20,7 @@ from torch.distributed.fsdp import FullyShardedDataParallel, ShardingStrategy
|
|
20 |
from torchmetrics.text import Perplexity
|
21 |
|
22 |
from model import GPT
|
23 |
-
from data import
|
24 |
|
25 |
from tqdm import tqdm
|
26 |
|
@@ -38,25 +38,35 @@ DDP_BACKEND = "nccl"
|
|
38 |
def main():
|
39 |
parser = ArgumentParser(description="Pre-train the GPT.")
|
40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
parser.add_argument("--batch_size", default=1, type=int)
|
42 |
parser.add_argument("--gradient_accumulation_steps", default=128, type=int)
|
43 |
parser.add_argument("--samples_per_epoch", default=4096, type=int)
|
44 |
parser.add_argument("--learning_rate", default=1e-2, type=float)
|
45 |
parser.add_argument("--max_gradient_norm", default=1.0, type=float)
|
46 |
parser.add_argument("--dropout", default=0.1, type=float)
|
47 |
-
parser.add_argument("--num_epochs", default=
|
48 |
parser.add_argument("--block_size", default=1024, type=int)
|
49 |
parser.add_argument("--embedding_dimensions", default=1024, type=int)
|
50 |
parser.add_argument("--num_attention_heads", default=16, type=int)
|
51 |
parser.add_argument("--num_hidden_layers", default=32, type=int)
|
52 |
parser.add_argument("--activation_checkpointing", action="store_true")
|
53 |
-
parser.add_argument("--ddp_sharding_level", default=2, choices=
|
54 |
parser.add_argument("--eval_interval", default=10, type=int)
|
55 |
parser.add_argument("--checkpoint_interval", default=20, type=int)
|
56 |
parser.add_argument("--checkpoint_path", default="./out/checkpoint.pt", type=str)
|
57 |
parser.add_argument("--resume", action="store_true")
|
58 |
-
parser.add_argument("--dataset_path", default="./dataset", type=str)
|
59 |
-
parser.add_argument("--num_dataset_processes", default=8, type=int)
|
60 |
parser.add_argument("--device", default="cuda", type=str)
|
61 |
parser.add_argument("--seed", default=None, type=int)
|
62 |
|
@@ -139,18 +149,22 @@ def main():
|
|
139 |
torch.manual_seed(args.seed)
|
140 |
random.seed(args.seed)
|
141 |
|
142 |
-
training =
|
143 |
root_path=args.dataset_path,
|
144 |
-
|
|
|
145 |
tokens_per_sample=args.block_size,
|
146 |
samples_per_epoch=args.samples_per_epoch,
|
|
|
147 |
num_processes=args.num_dataset_processes,
|
148 |
)
|
149 |
-
testing =
|
150 |
root_path=args.dataset_path,
|
151 |
-
|
|
|
152 |
tokens_per_sample=args.block_size,
|
153 |
samples_per_epoch=args.samples_per_epoch,
|
|
|
154 |
num_processes=args.num_dataset_processes,
|
155 |
)
|
156 |
|
|
|
20 |
from torchmetrics.text import Perplexity
|
21 |
|
22 |
from model import GPT
|
23 |
+
from data import Fineweb
|
24 |
|
25 |
from tqdm import tqdm
|
26 |
|
|
|
38 |
def main():
|
39 |
parser = ArgumentParser(description="Pre-train the GPT.")
|
40 |
|
41 |
+
parser.add_argument(
|
42 |
+
"--dataset_subset",
|
43 |
+
default="sample-10BT",
|
44 |
+
choices=("sample-10BT", "sample-100BT", "sample-350BT", None),
|
45 |
+
)
|
46 |
+
parser.add_argument(
|
47 |
+
"--token_encoding",
|
48 |
+
default="r50k_base",
|
49 |
+
choices=("r50k_base", "cl100k_base", "o200k_base"),
|
50 |
+
)
|
51 |
+
parser.add_argument("--dataset_path", default="./dataset", type=str)
|
52 |
+
parser.add_argument("--num_dataset_processes", default=8, type=int)
|
53 |
parser.add_argument("--batch_size", default=1, type=int)
|
54 |
parser.add_argument("--gradient_accumulation_steps", default=128, type=int)
|
55 |
parser.add_argument("--samples_per_epoch", default=4096, type=int)
|
56 |
parser.add_argument("--learning_rate", default=1e-2, type=float)
|
57 |
parser.add_argument("--max_gradient_norm", default=1.0, type=float)
|
58 |
parser.add_argument("--dropout", default=0.1, type=float)
|
59 |
+
parser.add_argument("--num_epochs", default=2384, type=int)
|
60 |
parser.add_argument("--block_size", default=1024, type=int)
|
61 |
parser.add_argument("--embedding_dimensions", default=1024, type=int)
|
62 |
parser.add_argument("--num_attention_heads", default=16, type=int)
|
63 |
parser.add_argument("--num_hidden_layers", default=32, type=int)
|
64 |
parser.add_argument("--activation_checkpointing", action="store_true")
|
65 |
+
parser.add_argument("--ddp_sharding_level", default=2, choices=(0, 2, 3))
|
66 |
parser.add_argument("--eval_interval", default=10, type=int)
|
67 |
parser.add_argument("--checkpoint_interval", default=20, type=int)
|
68 |
parser.add_argument("--checkpoint_path", default="./out/checkpoint.pt", type=str)
|
69 |
parser.add_argument("--resume", action="store_true")
|
|
|
|
|
70 |
parser.add_argument("--device", default="cuda", type=str)
|
71 |
parser.add_argument("--seed", default=None, type=int)
|
72 |
|
|
|
149 |
torch.manual_seed(args.seed)
|
150 |
random.seed(args.seed)
|
151 |
|
152 |
+
training = Fineweb(
|
153 |
root_path=args.dataset_path,
|
154 |
+
subset=args.dataset_subset,
|
155 |
+
split="train",
|
156 |
tokens_per_sample=args.block_size,
|
157 |
samples_per_epoch=args.samples_per_epoch,
|
158 |
+
token_encoding=args.token_encoding,
|
159 |
num_processes=args.num_dataset_processes,
|
160 |
)
|
161 |
+
testing = Fineweb(
|
162 |
root_path=args.dataset_path,
|
163 |
+
subset=args.dataset_subset,
|
164 |
+
split="test",
|
165 |
tokens_per_sample=args.block_size,
|
166 |
samples_per_epoch=args.samples_per_epoch,
|
167 |
+
token_encoding=args.token_encoding,
|
168 |
num_processes=args.num_dataset_processes,
|
169 |
)
|
170 |
|