Andrew DalPino
commited on
Commit
·
fc4824e
1
Parent(s):
624da87
A little nicer
Browse files- README.md +2 -2
- beam_search.py +1 -1
- data.py +1 -1
- export_model.ipynb +25 -15
- instruction-tune.py +4 -1
- model.py +24 -21
- pre-train.py +4 -1
README.md
CHANGED
@@ -17,7 +17,7 @@ LightGPT is a lightweight generative pre-trained Transformer (GPT) model for the
|
|
17 |
|
18 |
## Features
|
19 |
|
20 |
-
- **Parameter-efficiency**: LightGPT aims to be a more parsimonious model by only training parameters that are absolutely necessary. As such, biases and positional embeddings have been completely removed from the architecture. Despite having no positional embeddings (NoPE), LightGPT performs better
|
21 |
|
22 |
- **Low Memory Utilization**: LightGPT lets you progressively employ training-time memory optimizations such as fully-sharded data-parallel (FSDP), activation checkpointing, mixed precision, and low-memory optimizer updates that allow you to train larger models on smaller hardware.
|
23 |
|
@@ -83,7 +83,7 @@ torchrun --standalone --nnodes=1 --nproc-per-node=8 pre-train.py --batch_size=16
|
|
83 |
| Argument | Default | Type | Description |
|
84 |
|---|---|---|---|
|
85 |
| --dataset_subset | "sample-10BT" | str | The subset of the Fineweb dataset to train on. Options are `sample-10BT`, `sample-100BT`, and `sample-350BT`. Set to `None` to train on the full 15T token dataset. |
|
86 |
-
| --token_encoding | "r50k_base" | str | The encoding scheme to use when tokenizing the dataset. Options include `r50k_base`, `cl100k_base`, and `o200k_base`. |
|
87 |
| --dataset_path | "./datasets" | str | The path to the preprocessed dataset files on disk. |
|
88 |
| --num_dataset_processes | 8 | int | The number of processes (CPUs) to use to process the dataset. |
|
89 |
| --batch_size | 1 | int | The number of samples to pass through the network at a time. |
|
|
|
17 |
|
18 |
## Features
|
19 |
|
20 |
+
- **Parameter-efficiency**: LightGPT aims to be a more parsimonious model by only training parameters that are absolutely necessary. As such, biases and positional embeddings have been completely removed from the architecture. Despite having no positional embeddings (NoPE), LightGPT performs better at context-length generalization than relative embeddings offering good performance even at 2X of the trained context window.
|
21 |
|
22 |
- **Low Memory Utilization**: LightGPT lets you progressively employ training-time memory optimizations such as fully-sharded data-parallel (FSDP), activation checkpointing, mixed precision, and low-memory optimizer updates that allow you to train larger models on smaller hardware.
|
23 |
|
|
|
83 |
| Argument | Default | Type | Description |
|
84 |
|---|---|---|---|
|
85 |
| --dataset_subset | "sample-10BT" | str | The subset of the Fineweb dataset to train on. Options are `sample-10BT`, `sample-100BT`, and `sample-350BT`. Set to `None` to train on the full 15T token dataset. |
|
86 |
+
| --token_encoding | "r50k_base" | str | The Tiktoken encoding scheme to use when tokenizing the dataset. Options include `r50k_base`, `p50k_base`, `cl100k_base`, and `o200k_base`. |
|
87 |
| --dataset_path | "./datasets" | str | The path to the preprocessed dataset files on disk. |
|
88 |
| --num_dataset_processes | 8 | int | The number of processes (CPUs) to use to process the dataset. |
|
89 |
| --batch_size | 1 | int | The number of samples to pass through the network at a time. |
|
beam_search.py
CHANGED
@@ -15,7 +15,7 @@ import tiktoken
|
|
15 |
|
16 |
def main():
|
17 |
parser = ArgumentParser(
|
18 |
-
description="
|
19 |
)
|
20 |
|
21 |
parser.add_argument(
|
|
|
15 |
|
16 |
def main():
|
17 |
parser = ArgumentParser(
|
18 |
+
description="Use a greedy search strategy to generate candidate sequences.",
|
19 |
)
|
20 |
|
21 |
parser.add_argument(
|
data.py
CHANGED
@@ -92,7 +92,7 @@ class Fineweb(IterableDataset):
|
|
92 |
|
93 |
index = 0
|
94 |
|
95 |
-
for i in tqdm(range(self.NUM_SHARDS), desc="
|
96 |
batch = dataset.shard(
|
97 |
num_shards=self.NUM_SHARDS, index=i, contiguous=True
|
98 |
).with_format("numpy")
|
|
|
92 |
|
93 |
index = 0
|
94 |
|
95 |
+
for i in tqdm(range(self.NUM_SHARDS), desc="Saving"):
|
96 |
batch = dataset.shard(
|
97 |
num_shards=self.NUM_SHARDS, index=i, contiguous=True
|
98 |
).with_format("numpy")
|
export_model.ipynb
CHANGED
@@ -9,11 +9,11 @@
|
|
9 |
},
|
10 |
{
|
11 |
"cell_type": "code",
|
12 |
-
"execution_count":
|
13 |
"metadata": {},
|
14 |
"outputs": [],
|
15 |
"source": [
|
16 |
-
"model_name = \"lightgpt-small
|
17 |
"checkpoint_path = \"./checkpoints/checkpoint.pt\"\n",
|
18 |
"lora_path = None # \"./checkpoints/lora_instruction.pt\"\n",
|
19 |
"exports_path = \"./exports\""
|
@@ -28,14 +28,18 @@
|
|
28 |
},
|
29 |
{
|
30 |
"cell_type": "code",
|
31 |
-
"execution_count":
|
32 |
"metadata": {},
|
33 |
"outputs": [
|
34 |
{
|
35 |
-
"
|
36 |
-
"
|
37 |
-
"
|
38 |
-
|
|
|
|
|
|
|
|
|
39 |
]
|
40 |
}
|
41 |
],
|
@@ -179,8 +183,8 @@
|
|
179 |
" model,\n",
|
180 |
" example_input,\n",
|
181 |
" onnx_path,\n",
|
182 |
-
" input_names=[\"input_tokens\"],\n",
|
183 |
-
" output_names=[\"
|
184 |
" dynamo=True,\n",
|
185 |
")\n",
|
186 |
"\n",
|
@@ -226,14 +230,18 @@
|
|
226 |
},
|
227 |
{
|
228 |
"cell_type": "code",
|
229 |
-
"execution_count":
|
230 |
"metadata": {},
|
231 |
"outputs": [
|
232 |
{
|
233 |
-
"
|
234 |
-
"
|
235 |
-
"
|
236 |
-
|
|
|
|
|
|
|
|
|
237 |
]
|
238 |
}
|
239 |
],
|
@@ -242,6 +250,8 @@
|
|
242 |
"\n",
|
243 |
"import numpy as np\n",
|
244 |
"\n",
|
|
|
|
|
245 |
"session = onnxruntime.InferenceSession(onnx_path, providers=[\"CPUExecutionProvider\"])\n",
|
246 |
"\n",
|
247 |
"onnx_input = {\"input_tokens\": example_input.numpy()}\n",
|
@@ -251,7 +261,7 @@
|
|
251 |
"onnx_output = output[0]\n",
|
252 |
"pytorch_output = np.array(example_output.detach())\n",
|
253 |
"\n",
|
254 |
-
"
|
255 |
"\n",
|
256 |
"print(\"Looking good\")"
|
257 |
]
|
|
|
9 |
},
|
10 |
{
|
11 |
"cell_type": "code",
|
12 |
+
"execution_count": 2,
|
13 |
"metadata": {},
|
14 |
"outputs": [],
|
15 |
"source": [
|
16 |
+
"model_name = \"lightgpt-small\"\n",
|
17 |
"checkpoint_path = \"./checkpoints/checkpoint.pt\"\n",
|
18 |
"lora_path = None # \"./checkpoints/lora_instruction.pt\"\n",
|
19 |
"exports_path = \"./exports\""
|
|
|
28 |
},
|
29 |
{
|
30 |
"cell_type": "code",
|
31 |
+
"execution_count": 3,
|
32 |
"metadata": {},
|
33 |
"outputs": [
|
34 |
{
|
35 |
+
"ename": "TypeError",
|
36 |
+
"evalue": "GPT.__init__() missing 1 required positional argument: 'feed_forward_ratio'",
|
37 |
+
"output_type": "error",
|
38 |
+
"traceback": [
|
39 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
40 |
+
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
|
41 |
+
"Cell \u001b[0;32mIn[3], line 7\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mmodel\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m GPT, GPTWithLoRA\n\u001b[1;32m 5\u001b[0m checkpoint \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mload(checkpoint_path, map_location\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcpu\u001b[39m\u001b[38;5;124m\"\u001b[39m, weights_only\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[0;32m----> 7\u001b[0m model \u001b[38;5;241m=\u001b[39m \u001b[43mGPT\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mcheckpoint\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mmodel_args\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 9\u001b[0m model \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mcompile(model)\n\u001b[1;32m 11\u001b[0m model\u001b[38;5;241m.\u001b[39mload_state_dict(checkpoint[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmodel\u001b[39m\u001b[38;5;124m\"\u001b[39m])\n",
|
42 |
+
"\u001b[0;31mTypeError\u001b[0m: GPT.__init__() missing 1 required positional argument: 'feed_forward_ratio'"
|
43 |
]
|
44 |
}
|
45 |
],
|
|
|
183 |
" model,\n",
|
184 |
" example_input,\n",
|
185 |
" onnx_path,\n",
|
186 |
+
" input_names=[\"input_tokens\", \"labels\"],\n",
|
187 |
+
" output_names=[\"logits\"],\n",
|
188 |
" dynamo=True,\n",
|
189 |
")\n",
|
190 |
"\n",
|
|
|
230 |
},
|
231 |
{
|
232 |
"cell_type": "code",
|
233 |
+
"execution_count": null,
|
234 |
"metadata": {},
|
235 |
"outputs": [
|
236 |
{
|
237 |
+
"ename": "NameError",
|
238 |
+
"evalue": "name 'onnx_path' is not defined",
|
239 |
+
"output_type": "error",
|
240 |
+
"traceback": [
|
241 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
242 |
+
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
|
243 |
+
"Cell \u001b[0;32mIn[1], line 7\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mnumpy\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mnp\u001b[39;00m\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mnumpy\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mtesting\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m assert_allclose\n\u001b[0;32m----> 7\u001b[0m session \u001b[38;5;241m=\u001b[39m onnxruntime\u001b[38;5;241m.\u001b[39mInferenceSession(\u001b[43monnx_path\u001b[49m, providers\u001b[38;5;241m=\u001b[39m[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCPUExecutionProvider\u001b[39m\u001b[38;5;124m\"\u001b[39m])\n\u001b[1;32m 9\u001b[0m onnx_input \u001b[38;5;241m=\u001b[39m {\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124minput_tokens\u001b[39m\u001b[38;5;124m\"\u001b[39m: example_input\u001b[38;5;241m.\u001b[39mnumpy()}\n\u001b[1;32m 11\u001b[0m output \u001b[38;5;241m=\u001b[39m session\u001b[38;5;241m.\u001b[39mrun(\u001b[38;5;28;01mNone\u001b[39;00m, onnx_input)\n",
|
244 |
+
"\u001b[0;31mNameError\u001b[0m: name 'onnx_path' is not defined"
|
245 |
]
|
246 |
}
|
247 |
],
|
|
|
250 |
"\n",
|
251 |
"import numpy as np\n",
|
252 |
"\n",
|
253 |
+
"from numpy.testing import assert_allclose\n",
|
254 |
+
"\n",
|
255 |
"session = onnxruntime.InferenceSession(onnx_path, providers=[\"CPUExecutionProvider\"])\n",
|
256 |
"\n",
|
257 |
"onnx_input = {\"input_tokens\": example_input.numpy()}\n",
|
|
|
261 |
"onnx_output = output[0]\n",
|
262 |
"pytorch_output = np.array(example_output.detach())\n",
|
263 |
"\n",
|
264 |
+
"assert_allclose(pytorch_output, onnx_output, rtol=1e-2, atol=1e-03)\n",
|
265 |
"\n",
|
266 |
"print(\"Looking good\")"
|
267 |
]
|
instruction-tune.py
CHANGED
@@ -96,7 +96,10 @@ def main():
|
|
96 |
shuffle=False,
|
97 |
)
|
98 |
|
99 |
-
model = GPT(**model_args
|
|
|
|
|
|
|
100 |
|
101 |
model = torch.compile(model)
|
102 |
|
|
|
96 |
shuffle=False,
|
97 |
)
|
98 |
|
99 |
+
model = GPT(**model_args)
|
100 |
+
|
101 |
+
if args.activation_checkpointing:
|
102 |
+
model.enable_activation_checkpointing()
|
103 |
|
104 |
model = torch.compile(model)
|
105 |
|
model.py
CHANGED
@@ -23,7 +23,7 @@ from torch.nn import (
|
|
23 |
|
24 |
from torch.nn.functional import softmax, log_softmax
|
25 |
from torch.nn.utils.parametrize import register_parametrization, remove_parametrizations
|
26 |
-
from torch.utils.checkpoint import checkpoint
|
27 |
|
28 |
|
29 |
class GPT(Module):
|
@@ -31,27 +31,32 @@ class GPT(Module):
|
|
31 |
|
32 |
def __init__(
|
33 |
self,
|
34 |
-
block_size: int
|
35 |
-
embedding_dimensions: int
|
36 |
-
num_heads: int
|
37 |
-
num_layers: int
|
38 |
-
feed_forward_ratio: int
|
39 |
-
dropout: float
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
eos_index: int = 50256,
|
44 |
):
|
45 |
super().__init__()
|
46 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
if vocabulary_size <= 0:
|
48 |
raise ValueError(
|
49 |
f"Vocabulary size must be greater than 0, {vocabulary_size} given."
|
50 |
)
|
51 |
|
52 |
-
if num_layers <= 0:
|
53 |
-
raise ValueError(f"Num layers must be greater than 0, {num_layers} given.")
|
54 |
-
|
55 |
token_embeddings = Embedding(
|
56 |
vocabulary_size, embedding_dimensions, padding_idx=padding_index
|
57 |
)
|
@@ -80,10 +85,7 @@ class GPT(Module):
|
|
80 |
]
|
81 |
)
|
82 |
|
83 |
-
|
84 |
-
self.checkpoint = partial(checkpoint, use_reentrant=False)
|
85 |
-
else:
|
86 |
-
self.checkpoint = lambda layer, x, attention_mask: layer(x, attention_mask)
|
87 |
|
88 |
self.output_norm = RMSNorm(embedding_dimensions)
|
89 |
self.output_layer = output_layer
|
@@ -98,6 +100,9 @@ class GPT(Module):
|
|
98 |
def num_trainable_params(self) -> int:
|
99 |
return sum(param.numel() for param in self.parameters() if param.requires_grad)
|
100 |
|
|
|
|
|
|
|
101 |
def forward(
|
102 |
self, x: Tensor, y: Tensor | None = None
|
103 |
) -> tuple[Tensor, Tensor | None]:
|
@@ -292,9 +297,7 @@ class GPTWithLoRA(Module):
|
|
292 |
to the intermediate layers of the network.
|
293 |
"""
|
294 |
|
295 |
-
def __init__(
|
296 |
-
self, model: GPT, rank: int = 8, alpha: float = 1.0, dropout: float = 0.05
|
297 |
-
):
|
298 |
super().__init__()
|
299 |
|
300 |
if rank <= 0:
|
|
|
23 |
|
24 |
from torch.nn.functional import softmax, log_softmax
|
25 |
from torch.nn.utils.parametrize import register_parametrization, remove_parametrizations
|
26 |
+
from torch.utils.checkpoint import checkpoint as torch_checkpoint
|
27 |
|
28 |
|
29 |
class GPT(Module):
|
|
|
31 |
|
32 |
def __init__(
|
33 |
self,
|
34 |
+
block_size: int,
|
35 |
+
embedding_dimensions: int,
|
36 |
+
num_heads: int,
|
37 |
+
num_layers: int,
|
38 |
+
feed_forward_ratio: int,
|
39 |
+
dropout: float,
|
40 |
+
vocabulary_size: int,
|
41 |
+
padding_index: int,
|
42 |
+
eos_index: int,
|
|
|
43 |
):
|
44 |
super().__init__()
|
45 |
|
46 |
+
if block_size < 1:
|
47 |
+
raise ValueError(f"Block size must be greater than 0, {block_size} given.")
|
48 |
+
|
49 |
+
if num_layers <= 0:
|
50 |
+
raise ValueError(f"Num layers must be greater than 0, {num_layers} given.")
|
51 |
+
|
52 |
+
if feed_forward_ratio not in (1, 2, 4):
|
53 |
+
raise ValueError("Feed-forward ratio must be either 1, 2, or 4.")
|
54 |
+
|
55 |
if vocabulary_size <= 0:
|
56 |
raise ValueError(
|
57 |
f"Vocabulary size must be greater than 0, {vocabulary_size} given."
|
58 |
)
|
59 |
|
|
|
|
|
|
|
60 |
token_embeddings = Embedding(
|
61 |
vocabulary_size, embedding_dimensions, padding_idx=padding_index
|
62 |
)
|
|
|
85 |
]
|
86 |
)
|
87 |
|
88 |
+
self.checkpoint = lambda layer, x, attention_mask: layer(x, attention_mask)
|
|
|
|
|
|
|
89 |
|
90 |
self.output_norm = RMSNorm(embedding_dimensions)
|
91 |
self.output_layer = output_layer
|
|
|
100 |
def num_trainable_params(self) -> int:
|
101 |
return sum(param.numel() for param in self.parameters() if param.requires_grad)
|
102 |
|
103 |
+
def enable_activation_checkpointing(self) -> None:
|
104 |
+
self.checkpoint = partial(torch_checkpoint, use_reentrant=False)
|
105 |
+
|
106 |
def forward(
|
107 |
self, x: Tensor, y: Tensor | None = None
|
108 |
) -> tuple[Tensor, Tensor | None]:
|
|
|
297 |
to the intermediate layers of the network.
|
298 |
"""
|
299 |
|
300 |
+
def __init__(self, model: GPT, rank: int, alpha: float, dropout: float):
|
|
|
|
|
301 |
super().__init__()
|
302 |
|
303 |
if rank <= 0:
|
pre-train.py
CHANGED
@@ -196,7 +196,10 @@ def main():
|
|
196 |
"eos_index": tokenizer.eot_token,
|
197 |
}
|
198 |
|
199 |
-
model = GPT(**model_args
|
|
|
|
|
|
|
200 |
|
201 |
print("Compiling model")
|
202 |
model = torch.compile(model)
|
|
|
196 |
"eos_index": tokenizer.eot_token,
|
197 |
}
|
198 |
|
199 |
+
model = GPT(**model_args)
|
200 |
+
|
201 |
+
if args.activation_checkpointing:
|
202 |
+
model.enable_activation_checkpointing()
|
203 |
|
204 |
print("Compiling model")
|
205 |
model = torch.compile(model)
|