Andrew DalPino
commited on
Commit
·
0cc4ecd
1
Parent(s):
19b8dfb
Add MFU estimation for Ampere GPUs
Browse files- instruction-tune.py +1 -1
- model.py +2 -2
- model_sizing.ipynb +60 -25
instruction-tune.py
CHANGED
@@ -26,7 +26,7 @@ def main():
|
|
26 |
parser.add_argument("--base_model_path", default="./out/checkpoint.pt", type=str)
|
27 |
parser.add_argument("--batch_size", default=1, type=int)
|
28 |
parser.add_argument("--gradient_accumulation_steps", default=128, type=int)
|
29 |
-
parser.add_argument("--learning_rate", default=
|
30 |
parser.add_argument("--mask_input", default=True, type=bool)
|
31 |
parser.add_argument("--rank", default=8, type=int)
|
32 |
parser.add_argument("--alpha", default=1.0, type=float)
|
|
|
26 |
parser.add_argument("--base_model_path", default="./out/checkpoint.pt", type=str)
|
27 |
parser.add_argument("--batch_size", default=1, type=int)
|
28 |
parser.add_argument("--gradient_accumulation_steps", default=128, type=int)
|
29 |
+
parser.add_argument("--learning_rate", default=5e-4, type=float)
|
30 |
parser.add_argument("--mask_input", default=True, type=bool)
|
31 |
parser.add_argument("--rank", default=8, type=int)
|
32 |
parser.add_argument("--alpha", default=1.0, type=float)
|
model.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
from math import sqrt, exp
|
2 |
-
from dataclasses import dataclass
|
3 |
from functools import partial
|
4 |
from typing import Iterator, Self
|
5 |
|
@@ -210,7 +210,7 @@ class GPT(Module):
|
|
210 |
if beam_width <= 0:
|
211 |
raise ValueError(f"Beam width must be greater than 0, {beam_width} given.")
|
212 |
|
213 |
-
@dataclass
|
214 |
class Candidate:
|
215 |
log_probability: float
|
216 |
tokens: Tensor
|
|
|
1 |
from math import sqrt, exp
|
2 |
+
from dataclasses import dataclass
|
3 |
from functools import partial
|
4 |
from typing import Iterator, Self
|
5 |
|
|
|
210 |
if beam_width <= 0:
|
211 |
raise ValueError(f"Beam width must be greater than 0, {beam_width} given.")
|
212 |
|
213 |
+
@dataclass
|
214 |
class Candidate:
|
215 |
log_probability: float
|
216 |
tokens: Tensor
|
model_sizing.ipynb
CHANGED
@@ -1,5 +1,12 @@
|
|
1 |
{
|
2 |
"cells": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
{
|
4 |
"cell_type": "code",
|
5 |
"execution_count": 35,
|
@@ -17,7 +24,7 @@
|
|
17 |
"cell_type": "markdown",
|
18 |
"metadata": {},
|
19 |
"source": [
|
20 |
-
"
|
21 |
]
|
22 |
},
|
23 |
{
|
@@ -260,7 +267,7 @@
|
|
260 |
"cell_type": "markdown",
|
261 |
"metadata": {},
|
262 |
"source": [
|
263 |
-
"Now, let's estimate the
|
264 |
]
|
265 |
},
|
266 |
{
|
@@ -288,19 +295,61 @@
|
|
288 |
"cell_type": "markdown",
|
289 |
"metadata": {},
|
290 |
"source": [
|
291 |
-
"The estimates seem pretty similar so let's move on
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
292 |
]
|
293 |
},
|
294 |
{
|
295 |
"cell_type": "markdown",
|
296 |
"metadata": {},
|
297 |
"source": [
|
298 |
-
"Now, let's estimate how long it would take to train over every sample in the Openwebtext training set at least once in expectation
|
299 |
]
|
300 |
},
|
301 |
{
|
302 |
"cell_type": "code",
|
303 |
-
"execution_count":
|
304 |
"metadata": {},
|
305 |
"outputs": [
|
306 |
{
|
@@ -310,19 +359,13 @@
|
|
310 |
"Total tokens: 8,994,885,755\n",
|
311 |
"Epochs required: 2,145\n",
|
312 |
"\n",
|
313 |
-
"RTX A2000:
|
314 |
-
"
|
315 |
-
"
|
316 |
]
|
317 |
}
|
318 |
],
|
319 |
"source": [
|
320 |
-
"RTX_A2000_BF16_FLOPS_PER_SECOND = 63.9e12\n",
|
321 |
-
"A100_SXM_BF16_FLOPS_PER_SECOND = 624.0e12\n",
|
322 |
-
"HGX_B100_BF16_FLOPS_PER_SECOND = 28000e12\n",
|
323 |
-
"\n",
|
324 |
-
"MODEL_FLOPS_UTILIZATION = 0.3\n",
|
325 |
-
"\n",
|
326 |
"num_training_tokens = 8994885755\n",
|
327 |
"samples_per_epoch = 4096\n",
|
328 |
"\n",
|
@@ -331,20 +374,12 @@
|
|
331 |
"print(f\"Total tokens: {num_training_tokens:,}\")\n",
|
332 |
"print(f\"Epochs required: {num_epochs_required:,}\", end=\"\\n\\n\")\n",
|
333 |
"\n",
|
334 |
-
"
|
335 |
-
"
|
336 |
-
" \"A100 SXM\": A100_SXM_BF16_FLOPS_PER_SECOND,\n",
|
337 |
-
" \"HGX B100\": HGX_B100_BF16_FLOPS_PER_SECOND,\n",
|
338 |
-
"}\n",
|
339 |
-
"\n",
|
340 |
-
"for name, flops_per_second in gpus.items():\n",
|
341 |
-
" flops_per_second *= MODEL_FLOPS_UTILIZATION\n",
|
342 |
-
"\n",
|
343 |
-
" seconds_per_epoch = samples_per_epoch * total_roundtrip_flops / flops_per_second\n",
|
344 |
"\n",
|
345 |
" days_required = num_epochs_required * seconds_per_epoch / 60 / 60 / 24\n",
|
346 |
"\n",
|
347 |
-
" print(f\"{name}: {seconds_per_epoch:.2f} seconds/epoch, {days_required:,.2f} days required\")"
|
348 |
]
|
349 |
}
|
350 |
],
|
|
|
1 |
{
|
2 |
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {},
|
6 |
+
"source": [
|
7 |
+
"Welcome! In this notebook we aim to estimate the compute and memory requirements needed to train a theoretical model architecture using LightGPT. We'll start by first defining the parameters of the architecture."
|
8 |
+
]
|
9 |
+
},
|
10 |
{
|
11 |
"cell_type": "code",
|
12 |
"execution_count": 35,
|
|
|
24 |
"cell_type": "markdown",
|
25 |
"metadata": {},
|
26 |
"source": [
|
27 |
+
"Next, we'll estimate the total number of trainable parameters in the network."
|
28 |
]
|
29 |
},
|
30 |
{
|
|
|
267 |
"cell_type": "markdown",
|
268 |
"metadata": {},
|
269 |
"source": [
|
270 |
+
"Now, let's estimate the number of FLOPs using the method in the PaLM paper by Chowdhery, et al. Then, we'll compare the PaLM estimation with our own as a sanity check."
|
271 |
]
|
272 |
},
|
273 |
{
|
|
|
295 |
"cell_type": "markdown",
|
296 |
"metadata": {},
|
297 |
"source": [
|
298 |
+
"The two estimates seem pretty similar so let's move on to estimating the model FLOPs utilization (MFU) by comparing some observed throughput data for various GPUs to their advertised theoretical maximum throughput."
|
299 |
+
]
|
300 |
+
},
|
301 |
+
{
|
302 |
+
"cell_type": "code",
|
303 |
+
"execution_count": 96,
|
304 |
+
"metadata": {},
|
305 |
+
"outputs": [
|
306 |
+
{
|
307 |
+
"name": "stdout",
|
308 |
+
"output_type": "stream",
|
309 |
+
"text": [
|
310 |
+
"RTX A2000 MFU: 17.29%\n",
|
311 |
+
"RTX 3090 MFU: 22.99%\n",
|
312 |
+
"A100 SXM MFU: 37.16%\n"
|
313 |
+
]
|
314 |
+
}
|
315 |
+
],
|
316 |
+
"source": [
|
317 |
+
"from dataclasses import dataclass\n",
|
318 |
+
"\n",
|
319 |
+
"@dataclass\n",
|
320 |
+
"class Device:\n",
|
321 |
+
" name: str\n",
|
322 |
+
" advertised_flops: float\n",
|
323 |
+
" actual_flops: float\n",
|
324 |
+
"\n",
|
325 |
+
" @property\n",
|
326 |
+
" def mfu(self) -> float:\n",
|
327 |
+
" return self.actual_flops / self.advertised_flops\n",
|
328 |
+
"\n",
|
329 |
+
" @property\n",
|
330 |
+
" def percentage_utilization(self) -> float:\n",
|
331 |
+
" return self.mfu * 100\n",
|
332 |
+
"\n",
|
333 |
+
"devices = [\n",
|
334 |
+
" Device(\"RTX A2000\", 63.9e12, 3.45 * total_roundtrip_flops),\n",
|
335 |
+
" Device(\"RTX 3090\", 285.5e12, 20.5 * total_roundtrip_flops),\n",
|
336 |
+
" Device(\"A100 SXM\", 624.0e12, 72.4 * total_roundtrip_flops),\n",
|
337 |
+
"]\n",
|
338 |
+
"\n",
|
339 |
+
"for device in devices:\n",
|
340 |
+
" print(f\"{device.name} MFU: {device.percentage_utilization:.2f}%\")\n"
|
341 |
]
|
342 |
},
|
343 |
{
|
344 |
"cell_type": "markdown",
|
345 |
"metadata": {},
|
346 |
"source": [
|
347 |
+
"Now, let's estimate how long it would take to train over every sample in the Openwebtext training set at least once in expectation. Note that these results shown here are a best-case scenario and neglect to factor in other overhead."
|
348 |
]
|
349 |
},
|
350 |
{
|
351 |
"cell_type": "code",
|
352 |
+
"execution_count": 97,
|
353 |
"metadata": {},
|
354 |
"outputs": [
|
355 |
{
|
|
|
359 |
"Total tokens: 8,994,885,755\n",
|
360 |
"Epochs required: 2,145\n",
|
361 |
"\n",
|
362 |
+
"RTX A2000: 1187.25 seconds/epoch, 29.48 days required\n",
|
363 |
+
"RTX 3090: 199.80 seconds/epoch, 4.96 days required\n",
|
364 |
+
"A100 SXM: 56.57 seconds/epoch, 1.40 days required\n"
|
365 |
]
|
366 |
}
|
367 |
],
|
368 |
"source": [
|
|
|
|
|
|
|
|
|
|
|
|
|
369 |
"num_training_tokens = 8994885755\n",
|
370 |
"samples_per_epoch = 4096\n",
|
371 |
"\n",
|
|
|
374 |
"print(f\"Total tokens: {num_training_tokens:,}\")\n",
|
375 |
"print(f\"Epochs required: {num_epochs_required:,}\", end=\"\\n\\n\")\n",
|
376 |
"\n",
|
377 |
+
"for device in devices:\n",
|
378 |
+
" seconds_per_epoch = samples_per_epoch * total_roundtrip_flops / device.actual_flops\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
379 |
"\n",
|
380 |
" days_required = num_epochs_required * seconds_per_epoch / 60 / 60 / 24\n",
|
381 |
"\n",
|
382 |
+
" print(f\"{device.name}: {seconds_per_epoch:.2f} seconds/epoch, {days_required:,.2f} days required\")"
|
383 |
]
|
384 |
}
|
385 |
],
|