Andrew DalPino commited on
Commit
3325763
·
1 Parent(s): 0cc4ecd

Use Fineweb instead of Openwebtext

Browse files
Files changed (4) hide show
  1. README.md +27 -12
  2. data.py +48 -32
  3. model_sizing.ipynb +27 -22
  4. pre-train.py +23 -9
README.md CHANGED
@@ -14,19 +14,32 @@ tags:
14
  ---
15
  # LightGPT
16
 
17
- LightGPT is a lightweight generative pre-trained Transformer (GPT) model for the people! Built using pure PyTorch, LightGPT can generate text, answer questions, summarize documents, and more. A unique feature of LightGPT is that it allows you to train larger models on smaller hardware by taking advantage of memory optimizations wherever possible.
18
 
19
  ## Features
20
 
21
  - **Parameter-efficiency**: LightGPT aims to be a more parsimonious model by only training parameters that are absolutely necessary. As such, biases and positional embeddings have been completely removed from the architecture. In addition, the token embeddings and output layer share weight matrices resulting in a buy-one-get-one-free deal on trainable parameters.
22
 
23
- - **Low Memory Utilization**: LightGPT employs a number of training-time optimizations that conserve precious VRAM. With zero-redundancy distributed pre-training using fully-sharded data-parallel (FSDP), activation checkpointing, and automatic mixed precision, you'll be able to train larger models by accepting a relatively small amount of communication and computational overhead.
24
 
25
  - **Fully Open-source**: Unlike closed-source LLMs, LightGPT provides both the model weights *and* the source code to train, fine-tune, and generate text from the model using your own hardware. With the help of the open-source software community, we aim to democratize AI and continually improve the models.
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  ## Install Project Dependencies
28
 
29
- Project dependencies are specified in the `requirements.txt` file. You can install them with [pip](https://pip.pypa.io/en/stable/) using the following command from the project root. I recommend using a virtual environment such as venv to keep package dependencies on your system tidy.
30
 
31
  ```
32
  python -m venv ./.venv
@@ -38,7 +51,7 @@ pip install -r requirements.txt
38
 
39
  ## Pre-training
40
 
41
- For the pre-training corpus we use the Openwebtext dataset which consists of about 9B high-quality token sequences gathered from the worldwide web. In addition, you can add as much pre-training data as you like with a custom dataloader. If you'd just like to start training right away, the default settings should work on most single-GPU systems with 12G of VRAM or more.
42
 
43
  ```
44
  python pre-train.py
@@ -88,26 +101,28 @@ Soon ...
88
 
89
  | Argument | Default | Type | Description |
90
  |---|---|---|---|
 
 
 
 
91
  | --batch_size | 1 | int | The number of samples to pass through the network at a time. |
92
  | --gradient_accumulation_steps | 128 | int | The number of batches to pass through the network before updating the weights. |
93
  | --samples_per_epoch | 4096 | int | The number of training samples to pass through the network every epoch. |
94
  | --learning_rate | 5e-4 | float | The global step size taken after every gradient accumulation step. |
95
  | --max_gradient_norm | 1.0 | float | Clip gradients above this threshold before stepping. |
96
- | --num_epochs | 2145 | int | The number of epochs to train for. |
97
  | --eval_interval | 10 | int | Evaluate the model after this many epochs on the testing set. |
98
  | --block_size | 1024 | int | The number of tokens within the context window for every sample. |
99
  | --embedding_dimensions | 1024 | int | The dimensionality of the token embeddings. |
100
  | --num_attention_heads | 16 | int | The number of attention heads within every block. |
101
- | --num_hidden_layers | 24 | int | The number of attention/MLP blocks within the hidden layer of the network. |
102
  | --dropout | 0.1 | float | The proportion of signals to send to zero during training as regularization. |
103
- | --activation_checkpointing | False | bool | Should we use activation checkpointing? |
104
- | --ddp_sharding_level | 2 | (0, 2, 3) | int | The level of sharding to use for DDP training. |
105
  | --checkpoint_interval | 20 | int | Save the model parameters to disk every this many epochs. |
106
- | --checkpoint_path | "./out/checkpoint.pt" | string | The path to the checkpoint file on disk. |
107
- | --dataset_path | "./dataset" | string | The path to the dataset files on disk. |
108
- | --num_dataset_processes | 8 | int | The number of processes (CPUs) to use to process the dataset. |
109
  | --resume | False | bool | Should we resume training from the last checkpoint? |
110
- | --device | "cuda" | string | The device to run the computation on. |
111
  | --seed | None | int | The seed for the random number generator. |
112
 
113
  ### Instruction-tuning Arguments
 
14
  ---
15
  # LightGPT
16
 
17
+ LightGPT is a lightweight generative pre-trained Transformer (GPT) model for the people! Built using pure PyTorch, LightGPT can answer questions, summarize documents, chat, and more. A unique feature of LightGPT is that you can train larger models on smaller hardware by progressively enabling memory-saving features at train time such as activation checkpointing, mixed-precision, and ZeRO redundancy distributed pre-training using fully-sharded data parallel (FSDP).
18
 
19
  ## Features
20
 
21
  - **Parameter-efficiency**: LightGPT aims to be a more parsimonious model by only training parameters that are absolutely necessary. As such, biases and positional embeddings have been completely removed from the architecture. In addition, the token embeddings and output layer share weight matrices resulting in a buy-one-get-one-free deal on trainable parameters.
22
 
23
+ - **Low Memory Utilization**: LightGPT employs a number of training-time optimizations that conserve precious GPU memory. With zero-redundancy distributed pre-training using fully-sharded data-parallel (FSDP), activation checkpointing, and automatic mixed precision, you'll be able to train larger models by accepting a relatively small amount of overhead.
24
 
25
  - **Fully Open-source**: Unlike closed-source LLMs, LightGPT provides both the model weights *and* the source code to train, fine-tune, and generate text from the model using your own hardware. With the help of the open-source software community, we aim to democratize AI and continually improve the models.
26
 
27
+ ## Default Configurations
28
+
29
+ Below is a table of recommended default model training configurations but feel free to experiment with settings on your own. See the `model_sizing.ipynb` notebook to estimate the memory and compute requirements for your model configuration.
30
+
31
+ | Name | Vocab. Size | Block Size | Embedding Dim. | Attn. Heads | Layers | Params | Train Tokens |
32
+ |---|---|---|---|---|---|---|---|
33
+ | Small | 50,257 | 1024 | 1024 | 16 | 32 | 454M | 10B |
34
+ | Medium | 50,257 | 1024 | 2048 | 32 | 32 | 1.7B | 20B |
35
+ | Large | 100,275 | 2048 | 4096 | 64 | 32 | 6.8B | 100B |
36
+ | X-large | 100,275 | 2048 | 4096 | 64 | 64 | 13B | 350B |
37
+ | XX-large | 200,017 | 4096 | 8192 | 128 | 64 | 53B | 1T |
38
+ | XXX-large | 200,017 | 4096 | 8192 | 128 | 128 | 105B | 3T |
39
+
40
  ## Install Project Dependencies
41
 
42
+ Project dependencies are specified in the `requirements.txt` file. You can install them with [pip](https://pip.pypa.io/en/stable/) using the following command from the project root. We recommend using a virtual environment such as `venv` to keep package dependencies on your system tidy.
43
 
44
  ```
45
  python -m venv ./.venv
 
51
 
52
  ## Pre-training
53
 
54
+ For the pre-training corpus we use the Fineweb dataset which consists of about 15T high-quality tokens gathered from the worldwide web. The dataset has been split into 3 subsets (10BT, 100BT, and 350BT versions) for training smaller models. If you'd like to start training right away, the default settings should work on most single-GPU systems with 12G of VRAM or more.
55
 
56
  ```
57
  python pre-train.py
 
101
 
102
  | Argument | Default | Type | Description |
103
  |---|---|---|---|
104
+ | --dataset_subset | "sample-10BT" | str | The subset of the Fineweb dataset to train on. Options are `sample-10BT`, `sample-100BT`, and `sample-350BT`. Set to `None` to train on the full 15T token dataset. |
105
+ | --token_encoding | "r50k_base" | str | The encoding scheme to use when tokenizing the dataset. Options include `r50k_base`, `cl100k_base`, and `o200k_base`. |
106
+ | --dataset_path | "./dataset" | str | The path to the preprocessed dataset files on disk. |
107
+ | --num_dataset_processes | 8 | int | The number of processes (CPUs) to use to process the dataset. |
108
  | --batch_size | 1 | int | The number of samples to pass through the network at a time. |
109
  | --gradient_accumulation_steps | 128 | int | The number of batches to pass through the network before updating the weights. |
110
  | --samples_per_epoch | 4096 | int | The number of training samples to pass through the network every epoch. |
111
  | --learning_rate | 5e-4 | float | The global step size taken after every gradient accumulation step. |
112
  | --max_gradient_norm | 1.0 | float | Clip gradients above this threshold before stepping. |
113
+ | --num_epochs | 2384 | int | The number of epochs to train for. |
114
  | --eval_interval | 10 | int | Evaluate the model after this many epochs on the testing set. |
115
  | --block_size | 1024 | int | The number of tokens within the context window for every sample. |
116
  | --embedding_dimensions | 1024 | int | The dimensionality of the token embeddings. |
117
  | --num_attention_heads | 16 | int | The number of attention heads within every block. |
118
+ | --num_hidden_layers | 32 | int | The number of attention/MLP blocks within the hidden layer of the network. |
119
  | --dropout | 0.1 | float | The proportion of signals to send to zero during training as regularization. |
120
+ | --activation_checkpointing | False | bool | Should we use activation checkpointing? This will drastically reduce memory utilization at the cost of about 30% more runtime per epoch. |
121
+ | --ddp_sharding_level | 2 | int | The level of sharding to use for DDP training. Options are 2 or 3 for partial and full sharding respectively, or 0 for no sharding. |
122
  | --checkpoint_interval | 20 | int | Save the model parameters to disk every this many epochs. |
123
+ | --checkpoint_path | "./out/checkpoint.pt" | str | The path to the checkpoint file on disk. |
 
 
124
  | --resume | False | bool | Should we resume training from the last checkpoint? |
125
+ | --device | "cuda" | str | The device to run the computation on. |
126
  | --seed | None | int | The seed for the random number generator. |
127
 
128
  ### Instruction-tuning Arguments
data.py CHANGED
@@ -18,56 +18,68 @@ from torch.nn.utils.rnn import pad_sequence
18
  from tqdm import tqdm
19
 
20
 
21
- class Openwebtext(IterableDataset):
22
- DATASET_NAME = "openwebtext"
23
-
24
- FILE_PREFIX = DATASET_NAME
25
-
26
- TRAIN_FILENAME = f"{FILE_PREFIX}-train.bin"
27
- TEST_FILENAME = f"{FILE_PREFIX}-test.bin"
28
 
29
  TEST_SPLIT_PROPORTION = 0.005
30
  NUM_SHARDS = 1024
31
 
32
- ENCODING = "r50k_base"
33
-
34
  PADDING_INDEX = -100
35
 
36
  def __init__(
37
  self,
38
- root_path: str,
39
- train: bool = True,
 
40
  tokens_per_sample: int = 1024,
41
  samples_per_epoch: int = 4096,
 
42
  num_processes: int = 8,
43
  ):
44
  super().__init__()
45
 
 
 
 
 
 
 
 
46
  if tokens_per_sample < 1:
47
  raise ValueError(f"Tokens per sample must be greater than 0.")
48
 
49
  if samples_per_epoch < 1:
50
  raise ValueError(f"Samples per epoch must be greater than 0.")
51
 
52
- train_path = path.join(root_path, self.TRAIN_FILENAME)
53
- test_path = path.join(root_path, self.TEST_FILENAME)
54
 
55
- self.tokenizer = tiktoken.get_encoding(self.ENCODING)
 
 
 
 
 
56
 
57
  if not path.exists(train_path) or not path.exists(test_path):
58
- tokenized_splits = (
59
- load_dataset(self.DATASET_NAME, num_proc=num_processes, split="train")
60
- .train_test_split(test_size=self.TEST_SPLIT_PROPORTION, shuffle=True)
61
- .map(
62
- self.tokenize,
63
- desc="Tokenizing",
64
- remove_columns=["text"],
65
- num_proc=num_processes,
66
- )
 
 
 
 
 
67
  )
68
 
69
  for split, dataset in tokenized_splits.items():
70
- bin_path = path.join(root_path, f"{self.FILE_PREFIX}-{split}.bin")
71
 
72
  total_length = np.sum(dataset["length"], dtype=np.uint64)
73
 
@@ -92,9 +104,7 @@ class Openwebtext(IterableDataset):
92
 
93
  bin_out.flush()
94
 
95
- bin_file_path = path.join(
96
- root_path, self.TRAIN_FILENAME if train else self.TEST_FILENAME
97
- )
98
 
99
  memmap = np.memmap(bin_file_path, dtype=np.uint16, mode="r")
100
 
@@ -140,8 +150,6 @@ class Openwebtext(IterableDataset):
140
  class Alpaca(Dataset):
141
  DATASET_NAME = "tatsu-lab/alpaca"
142
 
143
- ENCODING = "r50k_base"
144
-
145
  PADDING_INDEX = -100
146
 
147
  PROMPT_TEMPLATE = (
@@ -162,7 +170,12 @@ class Alpaca(Dataset):
162
 
163
  RESPONSE_TEMPLATE = "{output}"
164
 
165
- def __init__(self, max_tokens_per_sample: int = 1024, mask_input: bool = True):
 
 
 
 
 
166
  super().__init__()
167
 
168
  if max_tokens_per_sample < 1:
@@ -170,9 +183,12 @@ class Alpaca(Dataset):
170
  f"Max tokens per sample must be greater than 0, {max_tokens_per_sample} given."
171
  )
172
 
173
- self.dataset = load_dataset(self.DATASET_NAME, split="train")
 
174
 
175
- self.tokenizer = tiktoken.get_encoding(self.ENCODING)
 
 
176
 
177
  self.max_tokens_per_sample = max_tokens_per_sample
178
  self.mask_input = mask_input
 
18
  from tqdm import tqdm
19
 
20
 
21
+ class Fineweb(IterableDataset):
22
+ DATASET_NAME = "HuggingFaceFW/fineweb"
 
 
 
 
 
23
 
24
  TEST_SPLIT_PROPORTION = 0.005
25
  NUM_SHARDS = 1024
26
 
 
 
27
  PADDING_INDEX = -100
28
 
29
  def __init__(
30
  self,
31
+ root_path: str = "./dataset",
32
+ subset: str | None = "sample-10BT",
33
+ split: str = "train",
34
  tokens_per_sample: int = 1024,
35
  samples_per_epoch: int = 4096,
36
+ token_encoding: str = "r50k_base",
37
  num_processes: int = 8,
38
  ):
39
  super().__init__()
40
 
41
+ if subset != None:
42
+ if subset not in ("sample-10BT", "sample-100BT", "sample-350BT"):
43
+ raise ValueError(f"Invalid subset, {subset} given.")
44
+
45
+ if split not in ("train", "test"):
46
+ raise ValueError(f"Split must be either train or test, {split} given.")
47
+
48
  if tokens_per_sample < 1:
49
  raise ValueError(f"Tokens per sample must be greater than 0.")
50
 
51
  if samples_per_epoch < 1:
52
  raise ValueError(f"Samples per epoch must be greater than 0.")
53
 
54
+ if token_encoding not in ("r50k_base", "cl100k_base", "o200k_base"):
55
+ raise ValueError(f"Invalid token encoding, {token_encoding} given.")
56
 
57
+ self.tokenizer = tiktoken.get_encoding(token_encoding)
58
+
59
+ dataset_name = f"fineweb-{subset}" if subset != None else "fineweb"
60
+
61
+ train_path = path.join(root_path, f"{dataset_name}-train-{token_encoding}.bin")
62
+ test_path = path.join(root_path, f"{dataset_name}-test-{token_encoding}.bin")
63
 
64
  if not path.exists(train_path) or not path.exists(test_path):
65
+ dataset = load_dataset(
66
+ self.DATASET_NAME,
67
+ name=subset,
68
+ num_proc=num_processes,
69
+ split="train",
70
+ ).map(
71
+ self.tokenize,
72
+ desc="Tokenizing",
73
+ remove_columns=["text"],
74
+ num_proc=num_processes,
75
+ )
76
+
77
+ tokenized_splits = dataset.train_test_split(
78
+ test_size=self.TEST_SPLIT_PROPORTION
79
  )
80
 
81
  for split, dataset in tokenized_splits.items():
82
+ bin_path = train_path if split == "train" else test_path
83
 
84
  total_length = np.sum(dataset["length"], dtype=np.uint64)
85
 
 
104
 
105
  bin_out.flush()
106
 
107
+ bin_file_path = train_path if split == "train" else test_path
 
 
108
 
109
  memmap = np.memmap(bin_file_path, dtype=np.uint16, mode="r")
110
 
 
150
  class Alpaca(Dataset):
151
  DATASET_NAME = "tatsu-lab/alpaca"
152
 
 
 
153
  PADDING_INDEX = -100
154
 
155
  PROMPT_TEMPLATE = (
 
170
 
171
  RESPONSE_TEMPLATE = "{output}"
172
 
173
+ def __init__(
174
+ self,
175
+ max_tokens_per_sample: int = 1024,
176
+ token_encoding: str = "r50k_base",
177
+ mask_input: bool = True,
178
+ ):
179
  super().__init__()
180
 
181
  if max_tokens_per_sample < 1:
 
183
  f"Max tokens per sample must be greater than 0, {max_tokens_per_sample} given."
184
  )
185
 
186
+ if token_encoding not in ("r50k_base", "cl100k_base", "o200k_base"):
187
+ raise ValueError(f"Invalid token encoding, {token_encoding} given.")
188
 
189
+ self.tokenizer = tiktoken.get_encoding(token_encoding)
190
+
191
+ self.dataset = load_dataset(self.DATASET_NAME, split="train")
192
 
193
  self.max_tokens_per_sample = max_tokens_per_sample
194
  self.mask_input = mask_input
model_sizing.ipynb CHANGED
@@ -9,7 +9,7 @@
9
  },
10
  {
11
  "cell_type": "code",
12
- "execution_count": 35,
13
  "metadata": {},
14
  "outputs": [],
15
  "source": [
@@ -17,7 +17,9 @@
17
  "vocabulary_size = 50257\n",
18
  "embedding_dimensions = 1024\n",
19
  "num_attention_heads = 16\n",
20
- "num_hidden_layers = 32"
 
 
21
  ]
22
  },
23
  {
@@ -29,7 +31,7 @@
29
  },
30
  {
31
  "cell_type": "code",
32
- "execution_count": 36,
33
  "metadata": {},
34
  "outputs": [
35
  {
@@ -95,14 +97,14 @@
95
  },
96
  {
97
  "cell_type": "code",
98
- "execution_count": 37,
99
  "metadata": {},
100
  "outputs": [
101
  {
102
  "name": "stdout",
103
  "output_type": "stream",
104
  "text": [
105
- "Total gigabytes: 1.82\n"
106
  ]
107
  }
108
  ],
@@ -113,7 +115,7 @@
113
  "\n",
114
  "total_gigabytes = total_bytes / 1e9\n",
115
  "\n",
116
- "print(f\"Total gigabytes: {total_gigabytes:,.2f}\")"
117
  ]
118
  },
119
  {
@@ -125,7 +127,7 @@
125
  },
126
  {
127
  "cell_type": "code",
128
- "execution_count": 54,
129
  "metadata": {},
130
  "outputs": [
131
  {
@@ -220,7 +222,7 @@
220
  },
221
  {
222
  "cell_type": "code",
223
- "execution_count": 55,
224
  "metadata": {},
225
  "outputs": [
226
  {
@@ -246,7 +248,7 @@
246
  },
247
  {
248
  "cell_type": "code",
249
- "execution_count": 56,
250
  "metadata": {},
251
  "outputs": [
252
  {
@@ -272,7 +274,7 @@
272
  },
273
  {
274
  "cell_type": "code",
275
- "execution_count": 65,
276
  "metadata": {},
277
  "outputs": [
278
  {
@@ -300,7 +302,7 @@
300
  },
301
  {
302
  "cell_type": "code",
303
- "execution_count": 96,
304
  "metadata": {},
305
  "outputs": [
306
  {
@@ -308,8 +310,10 @@
308
  "output_type": "stream",
309
  "text": [
310
  "RTX A2000 MFU: 17.29%\n",
 
311
  "RTX 3090 MFU: 22.99%\n",
312
- "A100 SXM MFU: 37.16%\n"
 
313
  ]
314
  }
315
  ],
@@ -332,8 +336,10 @@
332
  "\n",
333
  "devices = [\n",
334
  " Device(\"RTX A2000\", 63.9e12, 3.45 * total_roundtrip_flops),\n",
 
335
  " Device(\"RTX 3090\", 285.5e12, 20.5 * total_roundtrip_flops),\n",
336
  " Device(\"A100 SXM\", 624.0e12, 72.4 * total_roundtrip_flops),\n",
 
337
  "]\n",
338
  "\n",
339
  "for device in devices:\n",
@@ -344,31 +350,30 @@
344
  "cell_type": "markdown",
345
  "metadata": {},
346
  "source": [
347
- "Now, let's estimate how long it would take to train over every sample in the Openwebtext training set at least once in expectation. Note that these results shown here are a best-case scenario and neglect to factor in other overhead."
348
  ]
349
  },
350
  {
351
  "cell_type": "code",
352
- "execution_count": 97,
353
  "metadata": {},
354
  "outputs": [
355
  {
356
  "name": "stdout",
357
  "output_type": "stream",
358
  "text": [
359
- "Total tokens: 8,994,885,755\n",
360
- "Epochs required: 2,145\n",
361
  "\n",
362
- "RTX A2000: 1187.25 seconds/epoch, 29.48 days required\n",
363
- "RTX 3090: 199.80 seconds/epoch, 4.96 days required\n",
364
- "A100 SXM: 56.57 seconds/epoch, 1.40 days required\n"
 
 
365
  ]
366
  }
367
  ],
368
  "source": [
369
- "num_training_tokens = 8994885755\n",
370
- "samples_per_epoch = 4096\n",
371
- "\n",
372
  "num_epochs_required = round(num_training_tokens / (samples_per_epoch * block_size))\n",
373
  "\n",
374
  "print(f\"Total tokens: {num_training_tokens:,}\")\n",
 
9
  },
10
  {
11
  "cell_type": "code",
12
+ "execution_count": 252,
13
  "metadata": {},
14
  "outputs": [],
15
  "source": [
 
17
  "vocabulary_size = 50257\n",
18
  "embedding_dimensions = 1024\n",
19
  "num_attention_heads = 16\n",
20
+ "num_hidden_layers = 32\n",
21
+ "num_training_tokens = 10e9\n",
22
+ "samples_per_epoch = 4096"
23
  ]
24
  },
25
  {
 
31
  },
32
  {
33
  "cell_type": "code",
34
+ "execution_count": 253,
35
  "metadata": {},
36
  "outputs": [
37
  {
 
97
  },
98
  {
99
  "cell_type": "code",
100
+ "execution_count": 254,
101
  "metadata": {},
102
  "outputs": [
103
  {
104
  "name": "stdout",
105
  "output_type": "stream",
106
  "text": [
107
+ "Total gigabytes: 1.82G\n"
108
  ]
109
  }
110
  ],
 
115
  "\n",
116
  "total_gigabytes = total_bytes / 1e9\n",
117
  "\n",
118
+ "print(f\"Total gigabytes: {total_gigabytes:,.2f}G\")"
119
  ]
120
  },
121
  {
 
127
  },
128
  {
129
  "cell_type": "code",
130
+ "execution_count": 255,
131
  "metadata": {},
132
  "outputs": [
133
  {
 
222
  },
223
  {
224
  "cell_type": "code",
225
+ "execution_count": 256,
226
  "metadata": {},
227
  "outputs": [
228
  {
 
248
  },
249
  {
250
  "cell_type": "code",
251
+ "execution_count": 257,
252
  "metadata": {},
253
  "outputs": [
254
  {
 
274
  },
275
  {
276
  "cell_type": "code",
277
+ "execution_count": 258,
278
  "metadata": {},
279
  "outputs": [
280
  {
 
302
  },
303
  {
304
  "cell_type": "code",
305
+ "execution_count": 259,
306
  "metadata": {},
307
  "outputs": [
308
  {
 
310
  "output_type": "stream",
311
  "text": [
312
  "RTX A2000 MFU: 17.29%\n",
313
+ "RTX A4000 MFU: 19.00%\n",
314
  "RTX 3090 MFU: 22.99%\n",
315
+ "A100 SXM MFU: 37.16%\n",
316
+ "HGX A100 MFU: 37.16%\n"
317
  ]
318
  }
319
  ],
 
336
  "\n",
337
  "devices = [\n",
338
  " Device(\"RTX A2000\", 63.9e12, 3.45 * total_roundtrip_flops),\n",
339
+ " Device(\"RTX A4000\", 153.4e12, 9.1 * total_roundtrip_flops),\n",
340
  " Device(\"RTX 3090\", 285.5e12, 20.5 * total_roundtrip_flops),\n",
341
  " Device(\"A100 SXM\", 624.0e12, 72.4 * total_roundtrip_flops),\n",
342
+ " Device(\"HGX A100\", 4992e12, 579.2 * total_roundtrip_flops),\n",
343
  "]\n",
344
  "\n",
345
  "for device in devices:\n",
 
350
  "cell_type": "markdown",
351
  "metadata": {},
352
  "source": [
353
+ "Finally, let's estimate how long it would take to train over every sample in the Openwebtext training set at least once in expectation. Note that these results shown here are a theoretical scenario and do not factor in additional overhead such as activation checkpointing or network latency."
354
  ]
355
  },
356
  {
357
  "cell_type": "code",
358
+ "execution_count": 260,
359
  "metadata": {},
360
  "outputs": [
361
  {
362
  "name": "stdout",
363
  "output_type": "stream",
364
  "text": [
365
+ "Total tokens: 10,000,000,000.0\n",
366
+ "Epochs required: 2,384\n",
367
  "\n",
368
+ "RTX A2000: 1187.25 seconds/epoch, 32.76 days required\n",
369
+ "RTX A4000: 450.11 seconds/epoch, 12.42 days required\n",
370
+ "RTX 3090: 199.80 seconds/epoch, 5.51 days required\n",
371
+ "A100 SXM: 56.57 seconds/epoch, 1.56 days required\n",
372
+ "HGX A100: 7.07 seconds/epoch, 0.20 days required\n"
373
  ]
374
  }
375
  ],
376
  "source": [
 
 
 
377
  "num_epochs_required = round(num_training_tokens / (samples_per_epoch * block_size))\n",
378
  "\n",
379
  "print(f\"Total tokens: {num_training_tokens:,}\")\n",
pre-train.py CHANGED
@@ -20,7 +20,7 @@ from torch.distributed.fsdp import FullyShardedDataParallel, ShardingStrategy
20
  from torchmetrics.text import Perplexity
21
 
22
  from model import GPT
23
- from data import Openwebtext
24
 
25
  from tqdm import tqdm
26
 
@@ -38,25 +38,35 @@ DDP_BACKEND = "nccl"
38
  def main():
39
  parser = ArgumentParser(description="Pre-train the GPT.")
40
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  parser.add_argument("--batch_size", default=1, type=int)
42
  parser.add_argument("--gradient_accumulation_steps", default=128, type=int)
43
  parser.add_argument("--samples_per_epoch", default=4096, type=int)
44
  parser.add_argument("--learning_rate", default=1e-2, type=float)
45
  parser.add_argument("--max_gradient_norm", default=1.0, type=float)
46
  parser.add_argument("--dropout", default=0.1, type=float)
47
- parser.add_argument("--num_epochs", default=2140, type=int)
48
  parser.add_argument("--block_size", default=1024, type=int)
49
  parser.add_argument("--embedding_dimensions", default=1024, type=int)
50
  parser.add_argument("--num_attention_heads", default=16, type=int)
51
  parser.add_argument("--num_hidden_layers", default=32, type=int)
52
  parser.add_argument("--activation_checkpointing", action="store_true")
53
- parser.add_argument("--ddp_sharding_level", default=2, choices=[0, 2, 3])
54
  parser.add_argument("--eval_interval", default=10, type=int)
55
  parser.add_argument("--checkpoint_interval", default=20, type=int)
56
  parser.add_argument("--checkpoint_path", default="./out/checkpoint.pt", type=str)
57
  parser.add_argument("--resume", action="store_true")
58
- parser.add_argument("--dataset_path", default="./dataset", type=str)
59
- parser.add_argument("--num_dataset_processes", default=8, type=int)
60
  parser.add_argument("--device", default="cuda", type=str)
61
  parser.add_argument("--seed", default=None, type=int)
62
 
@@ -139,18 +149,22 @@ def main():
139
  torch.manual_seed(args.seed)
140
  random.seed(args.seed)
141
 
142
- training = Openwebtext(
143
  root_path=args.dataset_path,
144
- train=True,
 
145
  tokens_per_sample=args.block_size,
146
  samples_per_epoch=args.samples_per_epoch,
 
147
  num_processes=args.num_dataset_processes,
148
  )
149
- testing = Openwebtext(
150
  root_path=args.dataset_path,
151
- train=False,
 
152
  tokens_per_sample=args.block_size,
153
  samples_per_epoch=args.samples_per_epoch,
 
154
  num_processes=args.num_dataset_processes,
155
  )
156
 
 
20
  from torchmetrics.text import Perplexity
21
 
22
  from model import GPT
23
+ from data import Fineweb
24
 
25
  from tqdm import tqdm
26
 
 
38
  def main():
39
  parser = ArgumentParser(description="Pre-train the GPT.")
40
 
41
+ parser.add_argument(
42
+ "--dataset_subset",
43
+ default="sample-10BT",
44
+ choices=("sample-10BT", "sample-100BT", "sample-350BT", None),
45
+ )
46
+ parser.add_argument(
47
+ "--token_encoding",
48
+ default="r50k_base",
49
+ choices=("r50k_base", "cl100k_base", "o200k_base"),
50
+ )
51
+ parser.add_argument("--dataset_path", default="./dataset", type=str)
52
+ parser.add_argument("--num_dataset_processes", default=8, type=int)
53
  parser.add_argument("--batch_size", default=1, type=int)
54
  parser.add_argument("--gradient_accumulation_steps", default=128, type=int)
55
  parser.add_argument("--samples_per_epoch", default=4096, type=int)
56
  parser.add_argument("--learning_rate", default=1e-2, type=float)
57
  parser.add_argument("--max_gradient_norm", default=1.0, type=float)
58
  parser.add_argument("--dropout", default=0.1, type=float)
59
+ parser.add_argument("--num_epochs", default=2384, type=int)
60
  parser.add_argument("--block_size", default=1024, type=int)
61
  parser.add_argument("--embedding_dimensions", default=1024, type=int)
62
  parser.add_argument("--num_attention_heads", default=16, type=int)
63
  parser.add_argument("--num_hidden_layers", default=32, type=int)
64
  parser.add_argument("--activation_checkpointing", action="store_true")
65
+ parser.add_argument("--ddp_sharding_level", default=2, choices=(0, 2, 3))
66
  parser.add_argument("--eval_interval", default=10, type=int)
67
  parser.add_argument("--checkpoint_interval", default=20, type=int)
68
  parser.add_argument("--checkpoint_path", default="./out/checkpoint.pt", type=str)
69
  parser.add_argument("--resume", action="store_true")
 
 
70
  parser.add_argument("--device", default="cuda", type=str)
71
  parser.add_argument("--seed", default=None, type=int)
72
 
 
149
  torch.manual_seed(args.seed)
150
  random.seed(args.seed)
151
 
152
+ training = Fineweb(
153
  root_path=args.dataset_path,
154
+ subset=args.dataset_subset,
155
+ split="train",
156
  tokens_per_sample=args.block_size,
157
  samples_per_epoch=args.samples_per_epoch,
158
+ token_encoding=args.token_encoding,
159
  num_processes=args.num_dataset_processes,
160
  )
161
+ testing = Fineweb(
162
  root_path=args.dataset_path,
163
+ subset=args.dataset_subset,
164
+ split="test",
165
  tokens_per_sample=args.block_size,
166
  samples_per_epoch=args.samples_per_epoch,
167
+ token_encoding=args.token_encoding,
168
  num_processes=args.num_dataset_processes,
169
  )
170