Andrew DalPino commited on
Commit
0cc4ecd
·
1 Parent(s): 19b8dfb

Add MFU estimation for Ampere GPUs

Browse files
Files changed (3) hide show
  1. instruction-tune.py +1 -1
  2. model.py +2 -2
  3. model_sizing.ipynb +60 -25
instruction-tune.py CHANGED
@@ -26,7 +26,7 @@ def main():
26
  parser.add_argument("--base_model_path", default="./out/checkpoint.pt", type=str)
27
  parser.add_argument("--batch_size", default=1, type=int)
28
  parser.add_argument("--gradient_accumulation_steps", default=128, type=int)
29
- parser.add_argument("--learning_rate", default=1e-2, type=float)
30
  parser.add_argument("--mask_input", default=True, type=bool)
31
  parser.add_argument("--rank", default=8, type=int)
32
  parser.add_argument("--alpha", default=1.0, type=float)
 
26
  parser.add_argument("--base_model_path", default="./out/checkpoint.pt", type=str)
27
  parser.add_argument("--batch_size", default=1, type=int)
28
  parser.add_argument("--gradient_accumulation_steps", default=128, type=int)
29
+ parser.add_argument("--learning_rate", default=5e-4, type=float)
30
  parser.add_argument("--mask_input", default=True, type=bool)
31
  parser.add_argument("--rank", default=8, type=int)
32
  parser.add_argument("--alpha", default=1.0, type=float)
model.py CHANGED
@@ -1,5 +1,5 @@
1
  from math import sqrt, exp
2
- from dataclasses import dataclass, field
3
  from functools import partial
4
  from typing import Iterator, Self
5
 
@@ -210,7 +210,7 @@ class GPT(Module):
210
  if beam_width <= 0:
211
  raise ValueError(f"Beam width must be greater than 0, {beam_width} given.")
212
 
213
- @dataclass(order=True)
214
  class Candidate:
215
  log_probability: float
216
  tokens: Tensor
 
1
  from math import sqrt, exp
2
+ from dataclasses import dataclass
3
  from functools import partial
4
  from typing import Iterator, Self
5
 
 
210
  if beam_width <= 0:
211
  raise ValueError(f"Beam width must be greater than 0, {beam_width} given.")
212
 
213
+ @dataclass
214
  class Candidate:
215
  log_probability: float
216
  tokens: Tensor
model_sizing.ipynb CHANGED
@@ -1,5 +1,12 @@
1
  {
2
  "cells": [
 
 
 
 
 
 
 
3
  {
4
  "cell_type": "code",
5
  "execution_count": 35,
@@ -17,7 +24,7 @@
17
  "cell_type": "markdown",
18
  "metadata": {},
19
  "source": [
20
- "First, we'll estimate the total number of parameters in the network."
21
  ]
22
  },
23
  {
@@ -260,7 +267,7 @@
260
  "cell_type": "markdown",
261
  "metadata": {},
262
  "source": [
263
- "Now, let's estimate the model FLOPs utilization using the method in the PaLM paper by Chowdhery, et al. Then, we'll compare the PaLM estimation with our own as a sanity check."
264
  ]
265
  },
266
  {
@@ -288,19 +295,61 @@
288
  "cell_type": "markdown",
289
  "metadata": {},
290
  "source": [
291
- "The estimates seem pretty similar so let's move on!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
292
  ]
293
  },
294
  {
295
  "cell_type": "markdown",
296
  "metadata": {},
297
  "source": [
298
- "Now, let's estimate how long it would take to train over every sample in the Openwebtext training set at least once in expectation using a few well-known Nvidia GPUs as benchmarks. Note that these results shown here are a best-case scenario and neglect to factor in overhead such as moving data to and from VRAM."
299
  ]
300
  },
301
  {
302
  "cell_type": "code",
303
- "execution_count": 67,
304
  "metadata": {},
305
  "outputs": [
306
  {
@@ -310,19 +359,13 @@
310
  "Total tokens: 8,994,885,755\n",
311
  "Epochs required: 2,145\n",
312
  "\n",
313
- "RTX A2000: 684.25 seconds/epoch, 16.99 days required\n",
314
- "A100 SXM: 70.07 seconds/epoch, 1.74 days required\n",
315
- "HGX B100: 1.56 seconds/epoch, 0.04 days required\n"
316
  ]
317
  }
318
  ],
319
  "source": [
320
- "RTX_A2000_BF16_FLOPS_PER_SECOND = 63.9e12\n",
321
- "A100_SXM_BF16_FLOPS_PER_SECOND = 624.0e12\n",
322
- "HGX_B100_BF16_FLOPS_PER_SECOND = 28000e12\n",
323
- "\n",
324
- "MODEL_FLOPS_UTILIZATION = 0.3\n",
325
- "\n",
326
  "num_training_tokens = 8994885755\n",
327
  "samples_per_epoch = 4096\n",
328
  "\n",
@@ -331,20 +374,12 @@
331
  "print(f\"Total tokens: {num_training_tokens:,}\")\n",
332
  "print(f\"Epochs required: {num_epochs_required:,}\", end=\"\\n\\n\")\n",
333
  "\n",
334
- "gpus = {\n",
335
- " \"RTX A2000\": RTX_A2000_BF16_FLOPS_PER_SECOND,\n",
336
- " \"A100 SXM\": A100_SXM_BF16_FLOPS_PER_SECOND,\n",
337
- " \"HGX B100\": HGX_B100_BF16_FLOPS_PER_SECOND,\n",
338
- "}\n",
339
- "\n",
340
- "for name, flops_per_second in gpus.items():\n",
341
- " flops_per_second *= MODEL_FLOPS_UTILIZATION\n",
342
- "\n",
343
- " seconds_per_epoch = samples_per_epoch * total_roundtrip_flops / flops_per_second\n",
344
  "\n",
345
  " days_required = num_epochs_required * seconds_per_epoch / 60 / 60 / 24\n",
346
  "\n",
347
- " print(f\"{name}: {seconds_per_epoch:.2f} seconds/epoch, {days_required:,.2f} days required\")"
348
  ]
349
  }
350
  ],
 
1
  {
2
  "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "Welcome! In this notebook we aim to estimate the compute and memory requirements needed to train a theoretical model architecture using LightGPT. We'll start by first defining the parameters of the architecture."
8
+ ]
9
+ },
10
  {
11
  "cell_type": "code",
12
  "execution_count": 35,
 
24
  "cell_type": "markdown",
25
  "metadata": {},
26
  "source": [
27
+ "Next, we'll estimate the total number of trainable parameters in the network."
28
  ]
29
  },
30
  {
 
267
  "cell_type": "markdown",
268
  "metadata": {},
269
  "source": [
270
+ "Now, let's estimate the number of FLOPs using the method in the PaLM paper by Chowdhery, et al. Then, we'll compare the PaLM estimation with our own as a sanity check."
271
  ]
272
  },
273
  {
 
295
  "cell_type": "markdown",
296
  "metadata": {},
297
  "source": [
298
+ "The two estimates seem pretty similar so let's move on to estimating the model FLOPs utilization (MFU) by comparing some observed throughput data for various GPUs to their advertised theoretical maximum throughput."
299
+ ]
300
+ },
301
+ {
302
+ "cell_type": "code",
303
+ "execution_count": 96,
304
+ "metadata": {},
305
+ "outputs": [
306
+ {
307
+ "name": "stdout",
308
+ "output_type": "stream",
309
+ "text": [
310
+ "RTX A2000 MFU: 17.29%\n",
311
+ "RTX 3090 MFU: 22.99%\n",
312
+ "A100 SXM MFU: 37.16%\n"
313
+ ]
314
+ }
315
+ ],
316
+ "source": [
317
+ "from dataclasses import dataclass\n",
318
+ "\n",
319
+ "@dataclass\n",
320
+ "class Device:\n",
321
+ " name: str\n",
322
+ " advertised_flops: float\n",
323
+ " actual_flops: float\n",
324
+ "\n",
325
+ " @property\n",
326
+ " def mfu(self) -> float:\n",
327
+ " return self.actual_flops / self.advertised_flops\n",
328
+ "\n",
329
+ " @property\n",
330
+ " def percentage_utilization(self) -> float:\n",
331
+ " return self.mfu * 100\n",
332
+ "\n",
333
+ "devices = [\n",
334
+ " Device(\"RTX A2000\", 63.9e12, 3.45 * total_roundtrip_flops),\n",
335
+ " Device(\"RTX 3090\", 285.5e12, 20.5 * total_roundtrip_flops),\n",
336
+ " Device(\"A100 SXM\", 624.0e12, 72.4 * total_roundtrip_flops),\n",
337
+ "]\n",
338
+ "\n",
339
+ "for device in devices:\n",
340
+ " print(f\"{device.name} MFU: {device.percentage_utilization:.2f}%\")\n"
341
  ]
342
  },
343
  {
344
  "cell_type": "markdown",
345
  "metadata": {},
346
  "source": [
347
+ "Now, let's estimate how long it would take to train over every sample in the Openwebtext training set at least once in expectation. Note that these results shown here are a best-case scenario and neglect to factor in other overhead."
348
  ]
349
  },
350
  {
351
  "cell_type": "code",
352
+ "execution_count": 97,
353
  "metadata": {},
354
  "outputs": [
355
  {
 
359
  "Total tokens: 8,994,885,755\n",
360
  "Epochs required: 2,145\n",
361
  "\n",
362
+ "RTX A2000: 1187.25 seconds/epoch, 29.48 days required\n",
363
+ "RTX 3090: 199.80 seconds/epoch, 4.96 days required\n",
364
+ "A100 SXM: 56.57 seconds/epoch, 1.40 days required\n"
365
  ]
366
  }
367
  ],
368
  "source": [
 
 
 
 
 
 
369
  "num_training_tokens = 8994885755\n",
370
  "samples_per_epoch = 4096\n",
371
  "\n",
 
374
  "print(f\"Total tokens: {num_training_tokens:,}\")\n",
375
  "print(f\"Epochs required: {num_epochs_required:,}\", end=\"\\n\\n\")\n",
376
  "\n",
377
+ "for device in devices:\n",
378
+ " seconds_per_epoch = samples_per_epoch * total_roundtrip_flops / device.actual_flops\n",
 
 
 
 
 
 
 
 
379
  "\n",
380
  " days_required = num_epochs_required * seconds_per_epoch / 60 / 60 / 24\n",
381
  "\n",
382
+ " print(f\"{device.name}: {seconds_per_epoch:.2f} seconds/epoch, {days_required:,.2f} days required\")"
383
  ]
384
  }
385
  ],