Andrew DalPino commited on
Commit
fc4824e
·
1 Parent(s): 624da87

A little nicer

Browse files
Files changed (7) hide show
  1. README.md +2 -2
  2. beam_search.py +1 -1
  3. data.py +1 -1
  4. export_model.ipynb +25 -15
  5. instruction-tune.py +4 -1
  6. model.py +24 -21
  7. pre-train.py +4 -1
README.md CHANGED
@@ -17,7 +17,7 @@ LightGPT is a lightweight generative pre-trained Transformer (GPT) model for the
17
 
18
  ## Features
19
 
20
- - **Parameter-efficiency**: LightGPT aims to be a more parsimonious model by only training parameters that are absolutely necessary. As such, biases and positional embeddings have been completely removed from the architecture. Despite having no positional embeddings (NoPE), LightGPT performs better in terms of context-length generalization than relative embeddings offering good performance up to 2X the trained block size.
21
 
22
  - **Low Memory Utilization**: LightGPT lets you progressively employ training-time memory optimizations such as fully-sharded data-parallel (FSDP), activation checkpointing, mixed precision, and low-memory optimizer updates that allow you to train larger models on smaller hardware.
23
 
@@ -83,7 +83,7 @@ torchrun --standalone --nnodes=1 --nproc-per-node=8 pre-train.py --batch_size=16
83
  | Argument | Default | Type | Description |
84
  |---|---|---|---|
85
  | --dataset_subset | "sample-10BT" | str | The subset of the Fineweb dataset to train on. Options are `sample-10BT`, `sample-100BT`, and `sample-350BT`. Set to `None` to train on the full 15T token dataset. |
86
- | --token_encoding | "r50k_base" | str | The encoding scheme to use when tokenizing the dataset. Options include `r50k_base`, `cl100k_base`, and `o200k_base`. |
87
  | --dataset_path | "./datasets" | str | The path to the preprocessed dataset files on disk. |
88
  | --num_dataset_processes | 8 | int | The number of processes (CPUs) to use to process the dataset. |
89
  | --batch_size | 1 | int | The number of samples to pass through the network at a time. |
 
17
 
18
  ## Features
19
 
20
+ - **Parameter-efficiency**: LightGPT aims to be a more parsimonious model by only training parameters that are absolutely necessary. As such, biases and positional embeddings have been completely removed from the architecture. Despite having no positional embeddings (NoPE), LightGPT performs better at context-length generalization than relative embeddings offering good performance even at 2X of the trained context window.
21
 
22
  - **Low Memory Utilization**: LightGPT lets you progressively employ training-time memory optimizations such as fully-sharded data-parallel (FSDP), activation checkpointing, mixed precision, and low-memory optimizer updates that allow you to train larger models on smaller hardware.
23
 
 
83
  | Argument | Default | Type | Description |
84
  |---|---|---|---|
85
  | --dataset_subset | "sample-10BT" | str | The subset of the Fineweb dataset to train on. Options are `sample-10BT`, `sample-100BT`, and `sample-350BT`. Set to `None` to train on the full 15T token dataset. |
86
+ | --token_encoding | "r50k_base" | str | The Tiktoken encoding scheme to use when tokenizing the dataset. Options include `r50k_base`, `p50k_base`, `cl100k_base`, and `o200k_base`. |
87
  | --dataset_path | "./datasets" | str | The path to the preprocessed dataset files on disk. |
88
  | --num_dataset_processes | 8 | int | The number of processes (CPUs) to use to process the dataset. |
89
  | --batch_size | 1 | int | The number of samples to pass through the network at a time. |
beam_search.py CHANGED
@@ -15,7 +15,7 @@ import tiktoken
15
 
16
  def main():
17
  parser = ArgumentParser(
18
- description="Generate text from the model given a prompt.",
19
  )
20
 
21
  parser.add_argument(
 
15
 
16
  def main():
17
  parser = ArgumentParser(
18
+ description="Use a greedy search strategy to generate candidate sequences.",
19
  )
20
 
21
  parser.add_argument(
data.py CHANGED
@@ -92,7 +92,7 @@ class Fineweb(IterableDataset):
92
 
93
  index = 0
94
 
95
- for i in tqdm(range(self.NUM_SHARDS), desc="Writing"):
96
  batch = dataset.shard(
97
  num_shards=self.NUM_SHARDS, index=i, contiguous=True
98
  ).with_format("numpy")
 
92
 
93
  index = 0
94
 
95
+ for i in tqdm(range(self.NUM_SHARDS), desc="Saving"):
96
  batch = dataset.shard(
97
  num_shards=self.NUM_SHARDS, index=i, contiguous=True
98
  ).with_format("numpy")
export_model.ipynb CHANGED
@@ -9,11 +9,11 @@
9
  },
10
  {
11
  "cell_type": "code",
12
- "execution_count": 56,
13
  "metadata": {},
14
  "outputs": [],
15
  "source": [
16
- "model_name = \"lightgpt-small-turbo\"\n",
17
  "checkpoint_path = \"./checkpoints/checkpoint.pt\"\n",
18
  "lora_path = None # \"./checkpoints/lora_instruction.pt\"\n",
19
  "exports_path = \"./exports\""
@@ -28,14 +28,18 @@
28
  },
29
  {
30
  "cell_type": "code",
31
- "execution_count": 57,
32
  "metadata": {},
33
  "outputs": [
34
  {
35
- "name": "stdout",
36
- "output_type": "stream",
37
- "text": [
38
- "Base checkpoint loaded successfully\n"
 
 
 
 
39
  ]
40
  }
41
  ],
@@ -179,8 +183,8 @@
179
  " model,\n",
180
  " example_input,\n",
181
  " onnx_path,\n",
182
- " input_names=[\"input_tokens\"],\n",
183
- " output_names=[\"output\"],\n",
184
  " dynamo=True,\n",
185
  ")\n",
186
  "\n",
@@ -226,14 +230,18 @@
226
  },
227
  {
228
  "cell_type": "code",
229
- "execution_count": 88,
230
  "metadata": {},
231
  "outputs": [
232
  {
233
- "name": "stdout",
234
- "output_type": "stream",
235
- "text": [
236
- "Looking good\n"
 
 
 
 
237
  ]
238
  }
239
  ],
@@ -242,6 +250,8 @@
242
  "\n",
243
  "import numpy as np\n",
244
  "\n",
 
 
245
  "session = onnxruntime.InferenceSession(onnx_path, providers=[\"CPUExecutionProvider\"])\n",
246
  "\n",
247
  "onnx_input = {\"input_tokens\": example_input.numpy()}\n",
@@ -251,7 +261,7 @@
251
  "onnx_output = output[0]\n",
252
  "pytorch_output = np.array(example_output.detach())\n",
253
  "\n",
254
- "np.testing.assert_allclose(pytorch_output, onnx_output, rtol=1e-2, atol=1e-03)\n",
255
  "\n",
256
  "print(\"Looking good\")"
257
  ]
 
9
  },
10
  {
11
  "cell_type": "code",
12
+ "execution_count": 2,
13
  "metadata": {},
14
  "outputs": [],
15
  "source": [
16
+ "model_name = \"lightgpt-small\"\n",
17
  "checkpoint_path = \"./checkpoints/checkpoint.pt\"\n",
18
  "lora_path = None # \"./checkpoints/lora_instruction.pt\"\n",
19
  "exports_path = \"./exports\""
 
28
  },
29
  {
30
  "cell_type": "code",
31
+ "execution_count": 3,
32
  "metadata": {},
33
  "outputs": [
34
  {
35
+ "ename": "TypeError",
36
+ "evalue": "GPT.__init__() missing 1 required positional argument: 'feed_forward_ratio'",
37
+ "output_type": "error",
38
+ "traceback": [
39
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
40
+ "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
41
+ "Cell \u001b[0;32mIn[3], line 7\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mmodel\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m GPT, GPTWithLoRA\n\u001b[1;32m 5\u001b[0m checkpoint \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mload(checkpoint_path, map_location\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcpu\u001b[39m\u001b[38;5;124m\"\u001b[39m, weights_only\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[0;32m----> 7\u001b[0m model \u001b[38;5;241m=\u001b[39m \u001b[43mGPT\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mcheckpoint\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mmodel_args\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 9\u001b[0m model \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mcompile(model)\n\u001b[1;32m 11\u001b[0m model\u001b[38;5;241m.\u001b[39mload_state_dict(checkpoint[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmodel\u001b[39m\u001b[38;5;124m\"\u001b[39m])\n",
42
+ "\u001b[0;31mTypeError\u001b[0m: GPT.__init__() missing 1 required positional argument: 'feed_forward_ratio'"
43
  ]
44
  }
45
  ],
 
183
  " model,\n",
184
  " example_input,\n",
185
  " onnx_path,\n",
186
+ " input_names=[\"input_tokens\", \"labels\"],\n",
187
+ " output_names=[\"logits\"],\n",
188
  " dynamo=True,\n",
189
  ")\n",
190
  "\n",
 
230
  },
231
  {
232
  "cell_type": "code",
233
+ "execution_count": null,
234
  "metadata": {},
235
  "outputs": [
236
  {
237
+ "ename": "NameError",
238
+ "evalue": "name 'onnx_path' is not defined",
239
+ "output_type": "error",
240
+ "traceback": [
241
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
242
+ "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
243
+ "Cell \u001b[0;32mIn[1], line 7\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mnumpy\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mnp\u001b[39;00m\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mnumpy\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mtesting\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m assert_allclose\n\u001b[0;32m----> 7\u001b[0m session \u001b[38;5;241m=\u001b[39m onnxruntime\u001b[38;5;241m.\u001b[39mInferenceSession(\u001b[43monnx_path\u001b[49m, providers\u001b[38;5;241m=\u001b[39m[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCPUExecutionProvider\u001b[39m\u001b[38;5;124m\"\u001b[39m])\n\u001b[1;32m 9\u001b[0m onnx_input \u001b[38;5;241m=\u001b[39m {\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124minput_tokens\u001b[39m\u001b[38;5;124m\"\u001b[39m: example_input\u001b[38;5;241m.\u001b[39mnumpy()}\n\u001b[1;32m 11\u001b[0m output \u001b[38;5;241m=\u001b[39m session\u001b[38;5;241m.\u001b[39mrun(\u001b[38;5;28;01mNone\u001b[39;00m, onnx_input)\n",
244
+ "\u001b[0;31mNameError\u001b[0m: name 'onnx_path' is not defined"
245
  ]
246
  }
247
  ],
 
250
  "\n",
251
  "import numpy as np\n",
252
  "\n",
253
+ "from numpy.testing import assert_allclose\n",
254
+ "\n",
255
  "session = onnxruntime.InferenceSession(onnx_path, providers=[\"CPUExecutionProvider\"])\n",
256
  "\n",
257
  "onnx_input = {\"input_tokens\": example_input.numpy()}\n",
 
261
  "onnx_output = output[0]\n",
262
  "pytorch_output = np.array(example_output.detach())\n",
263
  "\n",
264
+ "assert_allclose(pytorch_output, onnx_output, rtol=1e-2, atol=1e-03)\n",
265
  "\n",
266
  "print(\"Looking good\")"
267
  ]
instruction-tune.py CHANGED
@@ -96,7 +96,10 @@ def main():
96
  shuffle=False,
97
  )
98
 
99
- model = GPT(**model_args, activation_checkpointing=args.activation_checkpointing)
 
 
 
100
 
101
  model = torch.compile(model)
102
 
 
96
  shuffle=False,
97
  )
98
 
99
+ model = GPT(**model_args)
100
+
101
+ if args.activation_checkpointing:
102
+ model.enable_activation_checkpointing()
103
 
104
  model = torch.compile(model)
105
 
model.py CHANGED
@@ -23,7 +23,7 @@ from torch.nn import (
23
 
24
  from torch.nn.functional import softmax, log_softmax
25
  from torch.nn.utils.parametrize import register_parametrization, remove_parametrizations
26
- from torch.utils.checkpoint import checkpoint
27
 
28
 
29
  class GPT(Module):
@@ -31,27 +31,32 @@ class GPT(Module):
31
 
32
  def __init__(
33
  self,
34
- block_size: int = 1024,
35
- embedding_dimensions: int = 1024,
36
- num_heads: int = 16,
37
- num_layers: int = 24,
38
- feed_forward_ratio: int = 4,
39
- dropout: float = 0.1,
40
- activation_checkpointing: bool = False,
41
- vocabulary_size: int = 50257,
42
- padding_index: int = -100,
43
- eos_index: int = 50256,
44
  ):
45
  super().__init__()
46
 
 
 
 
 
 
 
 
 
 
47
  if vocabulary_size <= 0:
48
  raise ValueError(
49
  f"Vocabulary size must be greater than 0, {vocabulary_size} given."
50
  )
51
 
52
- if num_layers <= 0:
53
- raise ValueError(f"Num layers must be greater than 0, {num_layers} given.")
54
-
55
  token_embeddings = Embedding(
56
  vocabulary_size, embedding_dimensions, padding_idx=padding_index
57
  )
@@ -80,10 +85,7 @@ class GPT(Module):
80
  ]
81
  )
82
 
83
- if activation_checkpointing:
84
- self.checkpoint = partial(checkpoint, use_reentrant=False)
85
- else:
86
- self.checkpoint = lambda layer, x, attention_mask: layer(x, attention_mask)
87
 
88
  self.output_norm = RMSNorm(embedding_dimensions)
89
  self.output_layer = output_layer
@@ -98,6 +100,9 @@ class GPT(Module):
98
  def num_trainable_params(self) -> int:
99
  return sum(param.numel() for param in self.parameters() if param.requires_grad)
100
 
 
 
 
101
  def forward(
102
  self, x: Tensor, y: Tensor | None = None
103
  ) -> tuple[Tensor, Tensor | None]:
@@ -292,9 +297,7 @@ class GPTWithLoRA(Module):
292
  to the intermediate layers of the network.
293
  """
294
 
295
- def __init__(
296
- self, model: GPT, rank: int = 8, alpha: float = 1.0, dropout: float = 0.05
297
- ):
298
  super().__init__()
299
 
300
  if rank <= 0:
 
23
 
24
  from torch.nn.functional import softmax, log_softmax
25
  from torch.nn.utils.parametrize import register_parametrization, remove_parametrizations
26
+ from torch.utils.checkpoint import checkpoint as torch_checkpoint
27
 
28
 
29
  class GPT(Module):
 
31
 
32
  def __init__(
33
  self,
34
+ block_size: int,
35
+ embedding_dimensions: int,
36
+ num_heads: int,
37
+ num_layers: int,
38
+ feed_forward_ratio: int,
39
+ dropout: float,
40
+ vocabulary_size: int,
41
+ padding_index: int,
42
+ eos_index: int,
 
43
  ):
44
  super().__init__()
45
 
46
+ if block_size < 1:
47
+ raise ValueError(f"Block size must be greater than 0, {block_size} given.")
48
+
49
+ if num_layers <= 0:
50
+ raise ValueError(f"Num layers must be greater than 0, {num_layers} given.")
51
+
52
+ if feed_forward_ratio not in (1, 2, 4):
53
+ raise ValueError("Feed-forward ratio must be either 1, 2, or 4.")
54
+
55
  if vocabulary_size <= 0:
56
  raise ValueError(
57
  f"Vocabulary size must be greater than 0, {vocabulary_size} given."
58
  )
59
 
 
 
 
60
  token_embeddings = Embedding(
61
  vocabulary_size, embedding_dimensions, padding_idx=padding_index
62
  )
 
85
  ]
86
  )
87
 
88
+ self.checkpoint = lambda layer, x, attention_mask: layer(x, attention_mask)
 
 
 
89
 
90
  self.output_norm = RMSNorm(embedding_dimensions)
91
  self.output_layer = output_layer
 
100
  def num_trainable_params(self) -> int:
101
  return sum(param.numel() for param in self.parameters() if param.requires_grad)
102
 
103
+ def enable_activation_checkpointing(self) -> None:
104
+ self.checkpoint = partial(torch_checkpoint, use_reentrant=False)
105
+
106
  def forward(
107
  self, x: Tensor, y: Tensor | None = None
108
  ) -> tuple[Tensor, Tensor | None]:
 
297
  to the intermediate layers of the network.
298
  """
299
 
300
+ def __init__(self, model: GPT, rank: int, alpha: float, dropout: float):
 
 
301
  super().__init__()
302
 
303
  if rank <= 0:
pre-train.py CHANGED
@@ -196,7 +196,10 @@ def main():
196
  "eos_index": tokenizer.eot_token,
197
  }
198
 
199
- model = GPT(**model_args, activation_checkpointing=args.activation_checkpointing)
 
 
 
200
 
201
  print("Compiling model")
202
  model = torch.compile(model)
 
196
  "eos_index": tokenizer.eot_token,
197
  }
198
 
199
+ model = GPT(**model_args)
200
+
201
+ if args.activation_checkpointing:
202
+ model.enable_activation_checkpointing()
203
 
204
  print("Compiling model")
205
  model = torch.compile(model)