dq158 commited on
Commit
84d6b29
1 Parent(s): 88bccbd

Training in progress, epoch 1, checkpoint

Browse files
last-checkpoint/README.md ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: peft
3
+ ---
4
+ ## Training procedure
5
+
6
+ ### Framework versions
7
+
8
+
9
+ - PEFT 0.5.0
last-checkpoint/adapter_config.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_mapping": null,
3
+ "base_model_name_or_path": "google/flan-t5-xxl",
4
+ "bias": "none",
5
+ "fan_in_fan_out": false,
6
+ "inference_mode": true,
7
+ "init_lora_weights": true,
8
+ "layers_pattern": null,
9
+ "layers_to_transform": null,
10
+ "lora_alpha": 32,
11
+ "lora_dropout": 0.1,
12
+ "modules_to_save": null,
13
+ "peft_type": "LORA",
14
+ "r": 8,
15
+ "revision": null,
16
+ "target_modules": [
17
+ "q",
18
+ "v"
19
+ ],
20
+ "task_type": "SEQ_2_SEQ_LM"
21
+ }
last-checkpoint/adapter_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7eaa51169fb5f5f33b328f26090dcd19a47bde0a9efe64e78856dbbe04a07e7a
3
+ size 888
last-checkpoint/global_step1581/bf16_zero_pp_rank_0_mp_rank_00_optim_states.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ccc367dbb9740751e85fd2958b4867e548bfad92bc52db0c411c1d02943344cf
3
+ size 56626640
last-checkpoint/global_step1581/bf16_zero_pp_rank_1_mp_rank_00_optim_states.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f9f583059aebf5c348baa213b71780e3beb0ed5bea73b3eaaae551a3ee1fa3b6
3
+ size 56626640
last-checkpoint/global_step1581/zero_pp_rank_0_mp_rank_00_model_states.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a5d9a7306e11004f5b07be5adcf241a253323781fe1b13c03e09f6e9c51b1a4f
3
+ size 11136132566
last-checkpoint/global_step1581/zero_pp_rank_1_mp_rank_00_model_states.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a71dd45f808f455f2743152cc5c2c5389f3f01dc02239d02a8986dbc29229ccf
3
+ size 11136132374
last-checkpoint/latest ADDED
@@ -0,0 +1 @@
 
 
1
+ global_step1581
last-checkpoint/rng_state_0.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f232d8d672d6ff9e3cc5ac4ba64234b9f002d5e144a43ee183dc01871e2f10e5
3
+ size 14512
last-checkpoint/rng_state_1.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e25f573f6bcb03f0bbfe93da192a1ab49ba692dcab33095de44fc91e5e4216a5
3
+ size 14512
last-checkpoint/trainer_state.json CHANGED
@@ -1,1250 +1,56 @@
1
  {
2
- "best_metric": 1.5654487609863281,
3
- "best_model_checkpoint": "dq158/pingusPongus/checkpoint-6323",
4
- "epoch": 13.0,
5
  "eval_steps": 500,
6
- "global_step": 82199,
7
  "is_hyper_param_search": false,
8
  "is_local_process_zero": true,
9
  "is_world_process_zero": true,
10
  "log_history": [
11
- {
12
- "epoch": 0.08,
13
- "learning_rate": 5e-06,
14
- "loss": 1.8585,
15
- "step": 500
16
- },
17
- {
18
- "epoch": 0.16,
19
- "learning_rate": 4.999805607800008e-06,
20
- "loss": 1.823,
21
- "step": 1000
22
- },
23
- {
24
- "epoch": 0.24,
25
- "learning_rate": 4.999222461430692e-06,
26
- "loss": 1.8388,
27
- "step": 1500
28
- },
29
  {
30
  "epoch": 0.32,
31
- "learning_rate": 4.998250651579336e-06,
32
- "loss": 1.8372,
33
- "step": 2000
34
- },
35
- {
36
- "epoch": 0.4,
37
- "learning_rate": 4.996890329375747e-06,
38
- "loss": 1.8066,
39
- "step": 2500
40
- },
41
- {
42
- "epoch": 0.47,
43
- "learning_rate": 4.995141706368742e-06,
44
- "loss": 1.8485,
45
- "step": 3000
46
- },
47
- {
48
- "epoch": 0.55,
49
- "learning_rate": 4.993005054493262e-06,
50
- "loss": 1.8243,
51
- "step": 3500
52
  },
53
  {
54
  "epoch": 0.63,
55
- "learning_rate": 4.990480706028073e-06,
56
- "loss": 1.8278,
57
- "step": 4000
58
- },
59
- {
60
- "epoch": 0.71,
61
- "learning_rate": 4.987569053544098e-06,
62
- "loss": 1.8126,
63
- "step": 4500
64
- },
65
- {
66
- "epoch": 0.79,
67
- "learning_rate": 4.98427054984336e-06,
68
- "loss": 1.8277,
69
- "step": 5000
70
- },
71
- {
72
- "epoch": 0.87,
73
- "learning_rate": 4.980585707888573e-06,
74
- "loss": 1.8475,
75
- "step": 5500
76
  },
77
  {
78
  "epoch": 0.95,
79
- "learning_rate": 4.976515100723365e-06,
80
- "loss": 1.8441,
81
- "step": 6000
82
  },
83
  {
84
  "epoch": 1.0,
85
  "eval_bleu": 1.0,
86
  "eval_brevity_penalty": 1.0,
87
  "eval_length_ratio": 1.0,
88
- "eval_loss": 1.5654487609863281,
89
- "eval_precisions": [
90
- 1.0,
91
- 1.0,
92
- 1.0,
93
- 1.0
94
- ],
95
- "eval_reference_length": 52412,
96
- "eval_runtime": 683.1649,
97
- "eval_samples_per_second": 4.115,
98
- "eval_steps_per_second": 1.029,
99
- "eval_translation_length": 52412,
100
- "step": 6323
101
- },
102
- {
103
- "epoch": 1.03,
104
- "learning_rate": 4.972059361383162e-06,
105
- "loss": 1.8281,
106
- "step": 6500
107
- },
108
- {
109
- "epoch": 1.11,
110
- "learning_rate": 4.9672191827967395e-06,
111
- "loss": 1.8431,
112
- "step": 7000
113
- },
114
- {
115
- "epoch": 1.19,
116
- "learning_rate": 4.961995317678472e-06,
117
- "loss": 1.8261,
118
- "step": 7500
119
- },
120
- {
121
- "epoch": 1.27,
122
- "learning_rate": 4.9563885784112645e-06,
123
- "loss": 1.8253,
124
- "step": 8000
125
- },
126
- {
127
- "epoch": 1.34,
128
- "learning_rate": 4.950399836920221e-06,
129
- "loss": 1.847,
130
- "step": 8500
131
- },
132
- {
133
- "epoch": 1.42,
134
- "learning_rate": 4.944030024537049e-06,
135
- "loss": 1.8209,
136
- "step": 9000
137
- },
138
- {
139
- "epoch": 1.5,
140
- "learning_rate": 4.937280131855223e-06,
141
- "loss": 1.8153,
142
- "step": 9500
143
- },
144
- {
145
- "epoch": 1.58,
146
- "learning_rate": 4.930151208575933e-06,
147
- "loss": 1.8591,
148
- "step": 10000
149
- },
150
- {
151
- "epoch": 1.66,
152
- "learning_rate": 4.9226443633448426e-06,
153
- "loss": 1.7892,
154
- "step": 10500
155
- },
156
- {
157
- "epoch": 1.74,
158
- "learning_rate": 4.91476076357968e-06,
159
- "loss": 1.8601,
160
- "step": 11000
161
- },
162
- {
163
- "epoch": 1.82,
164
- "learning_rate": 4.906501635288687e-06,
165
- "loss": 1.8231,
166
- "step": 11500
167
- },
168
- {
169
- "epoch": 1.9,
170
- "learning_rate": 4.8978682628799575e-06,
171
- "loss": 1.7805,
172
- "step": 12000
173
- },
174
- {
175
- "epoch": 1.98,
176
- "learning_rate": 4.888861988961698e-06,
177
- "loss": 1.8429,
178
- "step": 12500
179
- },
180
- {
181
- "epoch": 2.0,
182
- "eval_bleu": 1.0,
183
- "eval_brevity_penalty": 1.0,
184
- "eval_length_ratio": 1.0,
185
- "eval_loss": 1.5659914016723633,
186
- "eval_precisions": [
187
- 1.0,
188
- 1.0,
189
- 1.0,
190
- 1.0
191
- ],
192
- "eval_reference_length": 52542,
193
- "eval_runtime": 676.0746,
194
- "eval_samples_per_second": 4.158,
195
- "eval_steps_per_second": 1.04,
196
- "eval_translation_length": 52542,
197
- "step": 12646
198
- },
199
- {
200
- "epoch": 2.06,
201
- "learning_rate": 4.879484214133427e-06,
202
- "loss": 1.838,
203
- "step": 13000
204
- },
205
- {
206
- "epoch": 2.14,
207
- "learning_rate": 4.8697363967681696e-06,
208
- "loss": 1.8539,
209
- "step": 13500
210
- },
211
- {
212
- "epoch": 2.21,
213
- "learning_rate": 4.8596200527856564e-06,
214
- "loss": 1.8081,
215
- "step": 14000
216
- },
217
- {
218
- "epoch": 2.29,
219
- "learning_rate": 4.849136755416576e-06,
220
- "loss": 1.8135,
221
- "step": 14500
222
- },
223
- {
224
- "epoch": 2.37,
225
- "learning_rate": 4.838288134957921e-06,
226
- "loss": 1.8273,
227
- "step": 15000
228
- },
229
- {
230
- "epoch": 2.45,
231
- "learning_rate": 4.827075878519448e-06,
232
- "loss": 1.8316,
233
- "step": 15500
234
- },
235
- {
236
- "epoch": 2.53,
237
- "learning_rate": 4.815501729761316e-06,
238
- "loss": 1.8101,
239
- "step": 16000
240
- },
241
- {
242
- "epoch": 2.61,
243
- "learning_rate": 4.803567488622915e-06,
244
- "loss": 1.8257,
245
- "step": 16500
246
- },
247
- {
248
- "epoch": 2.69,
249
- "learning_rate": 4.791275011042958e-06,
250
- "loss": 1.8036,
251
- "step": 17000
252
- },
253
- {
254
- "epoch": 2.77,
255
- "learning_rate": 4.778626208670853e-06,
256
- "loss": 1.8235,
257
- "step": 17500
258
- },
259
- {
260
- "epoch": 2.85,
261
- "learning_rate": 4.765623048569417e-06,
262
- "loss": 1.8133,
263
- "step": 18000
264
- },
265
- {
266
- "epoch": 2.93,
267
- "learning_rate": 4.752267552908968e-06,
268
- "loss": 1.8316,
269
- "step": 18500
270
- },
271
- {
272
- "epoch": 3.0,
273
- "eval_bleu": 1.0,
274
- "eval_brevity_penalty": 1.0,
275
- "eval_length_ratio": 1.0,
276
- "eval_loss": 1.5685368776321411,
277
- "eval_precisions": [
278
- 1.0,
279
- 1.0,
280
- 1.0,
281
- 1.0
282
- ],
283
- "eval_reference_length": 52485,
284
- "eval_runtime": 677.2148,
285
- "eval_samples_per_second": 4.151,
286
- "eval_steps_per_second": 1.038,
287
- "eval_translation_length": 52485,
288
- "step": 18969
289
- },
290
- {
291
- "epoch": 3.0,
292
- "learning_rate": 4.738561798652854e-06,
293
- "loss": 1.7842,
294
- "step": 19000
295
- },
296
- {
297
- "epoch": 3.08,
298
- "learning_rate": 4.724507917234451e-06,
299
- "loss": 1.8069,
300
- "step": 19500
301
- },
302
- {
303
- "epoch": 3.16,
304
- "learning_rate": 4.710108094225704e-06,
305
- "loss": 1.7776,
306
- "step": 20000
307
- },
308
- {
309
- "epoch": 3.24,
310
- "learning_rate": 4.695364568997228e-06,
311
- "loss": 1.8232,
312
- "step": 20500
313
- },
314
- {
315
- "epoch": 3.32,
316
- "learning_rate": 4.680279634370071e-06,
317
- "loss": 1.8125,
318
- "step": 21000
319
- },
320
- {
321
- "epoch": 3.4,
322
- "learning_rate": 4.664855636259134e-06,
323
- "loss": 1.841,
324
- "step": 21500
325
- },
326
- {
327
- "epoch": 3.48,
328
- "learning_rate": 4.649094973308358e-06,
329
- "loss": 1.8519,
330
- "step": 22000
331
- },
332
- {
333
- "epoch": 3.56,
334
- "learning_rate": 4.633000096517698e-06,
335
- "loss": 1.8293,
336
- "step": 22500
337
- },
338
- {
339
- "epoch": 3.64,
340
- "learning_rate": 4.61657350886196e-06,
341
- "loss": 1.8121,
342
- "step": 23000
343
- },
344
- {
345
- "epoch": 3.72,
346
- "learning_rate": 4.5998177649015565e-06,
347
- "loss": 1.7916,
348
- "step": 23500
349
- },
350
- {
351
- "epoch": 3.8,
352
- "learning_rate": 4.582735470385229e-06,
353
- "loss": 1.8352,
354
- "step": 24000
355
- },
356
- {
357
- "epoch": 3.87,
358
- "learning_rate": 4.56532928184483e-06,
359
- "loss": 1.7778,
360
- "step": 24500
361
- },
362
- {
363
- "epoch": 3.95,
364
- "learning_rate": 4.547601906182184e-06,
365
- "loss": 1.815,
366
- "step": 25000
367
- },
368
- {
369
- "epoch": 4.0,
370
- "eval_bleu": 1.0,
371
- "eval_brevity_penalty": 1.0,
372
- "eval_length_ratio": 1.0,
373
- "eval_loss": 1.5691736936569214,
374
- "eval_precisions": [
375
- 1.0,
376
- 1.0,
377
- 1.0,
378
- 1.0
379
- ],
380
- "eval_reference_length": 52469,
381
- "eval_runtime": 676.8094,
382
- "eval_samples_per_second": 4.153,
383
- "eval_steps_per_second": 1.039,
384
- "eval_translation_length": 52469,
385
- "step": 25292
386
- },
387
- {
388
- "epoch": 4.03,
389
- "learning_rate": 4.529556100248137e-06,
390
- "loss": 1.8259,
391
- "step": 25500
392
- },
393
- {
394
- "epoch": 4.11,
395
- "learning_rate": 4.511194670413822e-06,
396
- "loss": 1.8127,
397
- "step": 26000
398
- },
399
- {
400
- "epoch": 4.19,
401
- "learning_rate": 4.49252047213423e-06,
402
- "loss": 1.7847,
403
- "step": 26500
404
- },
405
- {
406
- "epoch": 4.27,
407
- "learning_rate": 4.473536409504151e-06,
408
- "loss": 1.8137,
409
- "step": 27000
410
- },
411
- {
412
- "epoch": 4.35,
413
- "learning_rate": 4.454245434806545e-06,
414
- "loss": 1.8443,
415
- "step": 27500
416
- },
417
- {
418
- "epoch": 4.43,
419
- "learning_rate": 4.4346505480534205e-06,
420
- "loss": 1.7972,
421
- "step": 28000
422
- },
423
- {
424
- "epoch": 4.51,
425
- "learning_rate": 4.4147547965192934e-06,
426
- "loss": 1.8035,
427
- "step": 28500
428
- },
429
- {
430
- "epoch": 4.59,
431
- "learning_rate": 4.394561274267293e-06,
432
- "loss": 1.7983,
433
- "step": 29000
434
- },
435
- {
436
- "epoch": 4.67,
437
- "learning_rate": 4.374073121667992e-06,
438
- "loss": 1.8227,
439
- "step": 29500
440
- },
441
- {
442
- "epoch": 4.74,
443
- "learning_rate": 4.3532935249110366e-06,
444
- "loss": 1.7775,
445
- "step": 30000
446
- },
447
- {
448
- "epoch": 4.82,
449
- "learning_rate": 4.3322257155096496e-06,
450
- "loss": 1.8181,
451
- "step": 30500
452
- },
453
- {
454
- "epoch": 4.9,
455
- "learning_rate": 4.310872969798085e-06,
456
- "loss": 1.8333,
457
- "step": 31000
458
- },
459
- {
460
- "epoch": 4.98,
461
- "learning_rate": 4.289238608422115e-06,
462
- "loss": 1.8452,
463
- "step": 31500
464
- },
465
- {
466
- "epoch": 5.0,
467
- "eval_bleu": 1.0,
468
- "eval_brevity_penalty": 1.0,
469
- "eval_length_ratio": 1.0,
470
- "eval_loss": 1.56978178024292,
471
- "eval_precisions": [
472
- 1.0,
473
- 1.0,
474
- 1.0,
475
- 1.0
476
- ],
477
- "eval_reference_length": 52449,
478
- "eval_runtime": 676.5149,
479
- "eval_samples_per_second": 4.155,
480
- "eval_steps_per_second": 1.039,
481
- "eval_translation_length": 52449,
482
- "step": 31615
483
- },
484
- {
485
- "epoch": 5.06,
486
- "learning_rate": 4.267325995822624e-06,
487
- "loss": 1.8142,
488
- "step": 32000
489
- },
490
- {
491
- "epoch": 5.14,
492
- "learning_rate": 4.2451385397123864e-06,
493
- "loss": 1.8047,
494
- "step": 32500
495
- },
496
- {
497
- "epoch": 5.22,
498
- "learning_rate": 4.222679690546128e-06,
499
- "loss": 1.8006,
500
- "step": 33000
501
- },
502
- {
503
- "epoch": 5.3,
504
- "learning_rate": 4.199952940983926e-06,
505
- "loss": 1.7971,
506
- "step": 33500
507
- },
508
- {
509
- "epoch": 5.38,
510
- "learning_rate": 4.176961825348059e-06,
511
- "loss": 1.825,
512
- "step": 34000
513
- },
514
- {
515
- "epoch": 5.46,
516
- "learning_rate": 4.1537099190733656e-06,
517
- "loss": 1.8121,
518
- "step": 34500
519
- },
520
- {
521
- "epoch": 5.54,
522
- "learning_rate": 4.130200838151217e-06,
523
- "loss": 1.8179,
524
- "step": 35000
525
- },
526
- {
527
- "epoch": 5.61,
528
- "learning_rate": 4.106438238567183e-06,
529
- "loss": 1.8005,
530
- "step": 35500
531
- },
532
- {
533
- "epoch": 5.69,
534
- "learning_rate": 4.08242581573247e-06,
535
- "loss": 1.8367,
536
- "step": 36000
537
- },
538
- {
539
- "epoch": 5.77,
540
- "learning_rate": 4.058167303909241e-06,
541
- "loss": 1.8062,
542
- "step": 36500
543
- },
544
- {
545
- "epoch": 5.85,
546
- "learning_rate": 4.033666475629881e-06,
547
- "loss": 1.8092,
548
- "step": 37000
549
- },
550
- {
551
- "epoch": 5.93,
552
- "learning_rate": 4.008927141110319e-06,
553
- "loss": 1.7638,
554
- "step": 37500
555
- },
556
- {
557
- "epoch": 6.0,
558
- "eval_bleu": 1.0,
559
- "eval_brevity_penalty": 1.0,
560
- "eval_length_ratio": 1.0,
561
- "eval_loss": 1.5709949731826782,
562
- "eval_precisions": [
563
- 1.0,
564
- 1.0,
565
- 1.0,
566
- 1.0
567
- ],
568
- "eval_reference_length": 52568,
569
- "eval_runtime": 677.5535,
570
- "eval_samples_per_second": 4.149,
571
- "eval_steps_per_second": 1.038,
572
- "eval_translation_length": 52568,
573
- "step": 37938
574
- },
575
- {
576
- "epoch": 6.01,
577
- "learning_rate": 3.9839531476574855e-06,
578
- "loss": 1.801,
579
- "step": 38000
580
- },
581
- {
582
- "epoch": 6.09,
583
- "learning_rate": 3.958748379071004e-06,
584
- "loss": 1.7813,
585
- "step": 38500
586
- },
587
- {
588
- "epoch": 6.17,
589
- "learning_rate": 3.933316755039209e-06,
590
- "loss": 1.7742,
591
- "step": 39000
592
- },
593
- {
594
- "epoch": 6.25,
595
- "learning_rate": 3.9076622305295755e-06,
596
- "loss": 1.7852,
597
- "step": 39500
598
- },
599
- {
600
- "epoch": 6.33,
601
- "learning_rate": 3.88178879517367e-06,
602
- "loss": 1.8312,
603
- "step": 40000
604
- },
605
- {
606
- "epoch": 6.41,
607
- "learning_rate": 3.855700472646708e-06,
608
- "loss": 1.8144,
609
- "step": 40500
610
- },
611
- {
612
- "epoch": 6.48,
613
- "learning_rate": 3.82940132004182e-06,
614
- "loss": 1.7856,
615
- "step": 41000
616
- },
617
- {
618
- "epoch": 6.56,
619
- "learning_rate": 3.8028954272391116e-06,
620
- "loss": 1.8139,
621
- "step": 41500
622
- },
623
- {
624
- "epoch": 6.64,
625
- "learning_rate": 3.7761869162696334e-06,
626
- "loss": 1.8018,
627
- "step": 42000
628
- },
629
- {
630
- "epoch": 6.72,
631
- "learning_rate": 3.7492799406743512e-06,
632
- "loss": 1.7771,
633
- "step": 42500
634
- },
635
- {
636
- "epoch": 6.8,
637
- "learning_rate": 3.722178684858209e-06,
638
- "loss": 1.8217,
639
- "step": 43000
640
- },
641
- {
642
- "epoch": 6.88,
643
- "learning_rate": 3.6948873634394e-06,
644
- "loss": 1.8276,
645
- "step": 43500
646
- },
647
- {
648
- "epoch": 6.96,
649
- "learning_rate": 3.667410220593933e-06,
650
- "loss": 1.8267,
651
- "step": 44000
652
- },
653
- {
654
- "epoch": 7.0,
655
- "eval_bleu": 1.0,
656
- "eval_brevity_penalty": 1.0,
657
- "eval_length_ratio": 1.0,
658
- "eval_loss": 1.571208119392395,
659
- "eval_precisions": [
660
- 1.0,
661
- 1.0,
662
- 1.0,
663
- 1.0
664
- ],
665
- "eval_reference_length": 52554,
666
- "eval_runtime": 677.3503,
667
- "eval_samples_per_second": 4.15,
668
- "eval_steps_per_second": 1.038,
669
- "eval_translation_length": 52554,
670
- "step": 44261
671
- },
672
- {
673
- "epoch": 7.04,
674
- "learning_rate": 3.639751529395606e-06,
675
- "loss": 1.7645,
676
- "step": 44500
677
- },
678
- {
679
- "epoch": 7.12,
680
- "learning_rate": 3.611915591151483e-06,
681
- "loss": 1.8167,
682
- "step": 45000
683
- },
684
- {
685
- "epoch": 7.2,
686
- "learning_rate": 3.5839067347329844e-06,
687
- "loss": 1.808,
688
- "step": 45500
689
- },
690
- {
691
- "epoch": 7.28,
692
- "learning_rate": 3.5557293159026845e-06,
693
- "loss": 1.7742,
694
- "step": 46000
695
- },
696
- {
697
- "epoch": 7.35,
698
- "learning_rate": 3.5273877166369326e-06,
699
- "loss": 1.784,
700
- "step": 46500
701
- },
702
- {
703
- "epoch": 7.43,
704
- "learning_rate": 3.4988863444443942e-06,
705
- "loss": 1.7732,
706
- "step": 47000
707
- },
708
- {
709
- "epoch": 7.51,
710
- "learning_rate": 3.4702296316806243e-06,
711
- "loss": 1.8029,
712
- "step": 47500
713
- },
714
- {
715
- "epoch": 7.59,
716
- "learning_rate": 3.4414220348587744e-06,
717
- "loss": 1.8167,
718
- "step": 48000
719
- },
720
- {
721
- "epoch": 7.67,
722
- "learning_rate": 3.412468033956543e-06,
723
- "loss": 1.8037,
724
- "step": 48500
725
- },
726
- {
727
- "epoch": 7.75,
728
- "learning_rate": 3.3833721317194756e-06,
729
- "loss": 1.7689,
730
- "step": 49000
731
- },
732
- {
733
- "epoch": 7.83,
734
- "learning_rate": 3.3541388529607303e-06,
735
- "loss": 1.8414,
736
- "step": 49500
737
- },
738
- {
739
- "epoch": 7.91,
740
- "learning_rate": 3.324772743857404e-06,
741
- "loss": 1.795,
742
- "step": 50000
743
- },
744
- {
745
- "epoch": 7.99,
746
- "learning_rate": 3.2952783712435406e-06,
747
- "loss": 1.8108,
748
- "step": 50500
749
- },
750
- {
751
- "epoch": 8.0,
752
- "eval_bleu": 1.0,
753
- "eval_brevity_penalty": 1.0,
754
- "eval_length_ratio": 1.0,
755
- "eval_loss": 1.5723803043365479,
756
- "eval_precisions": [
757
- 1.0,
758
- 1.0,
759
- 1.0,
760
- 1.0
761
- ],
762
- "eval_reference_length": 52356,
763
- "eval_runtime": 677.354,
764
- "eval_samples_per_second": 4.15,
765
- "eval_steps_per_second": 1.038,
766
- "eval_translation_length": 52356,
767
- "step": 50584
768
- },
769
- {
770
- "epoch": 8.07,
771
- "learning_rate": 3.265660321899923e-06,
772
- "loss": 1.8002,
773
- "step": 51000
774
- },
775
- {
776
- "epoch": 8.14,
777
- "learning_rate": 3.235923201840768e-06,
778
- "loss": 1.7785,
779
- "step": 51500
780
- },
781
- {
782
- "epoch": 8.22,
783
- "learning_rate": 3.2060716355974274e-06,
784
- "loss": 1.7734,
785
- "step": 52000
786
- },
787
- {
788
- "epoch": 8.3,
789
- "learning_rate": 3.1761102654992106e-06,
790
- "loss": 1.8028,
791
- "step": 52500
792
- },
793
- {
794
- "epoch": 8.38,
795
- "learning_rate": 3.1460437509514345e-06,
796
- "loss": 1.7929,
797
- "step": 53000
798
- },
799
- {
800
- "epoch": 8.46,
801
- "learning_rate": 3.115876767710828e-06,
802
- "loss": 1.8039,
803
- "step": 53500
804
- },
805
- {
806
- "epoch": 8.54,
807
- "learning_rate": 3.0856140071583806e-06,
808
- "loss": 1.8066,
809
- "step": 54000
810
- },
811
- {
812
- "epoch": 8.62,
813
- "learning_rate": 3.0552601755697765e-06,
814
- "loss": 1.7612,
815
- "step": 54500
816
- },
817
- {
818
- "epoch": 8.7,
819
- "learning_rate": 3.024819993383493e-06,
820
- "loss": 1.8281,
821
- "step": 55000
822
- },
823
- {
824
- "epoch": 8.78,
825
- "learning_rate": 2.9942981944667193e-06,
826
- "loss": 1.7766,
827
- "step": 55500
828
- },
829
- {
830
- "epoch": 8.86,
831
- "learning_rate": 2.963699525379166e-06,
832
- "loss": 1.8176,
833
- "step": 56000
834
- },
835
- {
836
- "epoch": 8.94,
837
- "learning_rate": 2.933028744634912e-06,
838
- "loss": 1.79,
839
- "step": 56500
840
- },
841
- {
842
- "epoch": 9.0,
843
- "eval_bleu": 1.0,
844
- "eval_brevity_penalty": 1.0,
845
- "eval_length_ratio": 1.0,
846
- "eval_loss": 1.5721025466918945,
847
- "eval_precisions": [
848
- 1.0,
849
- 1.0,
850
- 1.0,
851
- 1.0
852
- ],
853
- "eval_reference_length": 52640,
854
- "eval_runtime": 678.5345,
855
- "eval_samples_per_second": 4.143,
856
- "eval_steps_per_second": 1.036,
857
- "eval_translation_length": 52640,
858
- "step": 56907
859
- },
860
- {
861
- "epoch": 9.01,
862
- "learning_rate": 2.9022906219623958e-06,
863
- "loss": 1.8089,
864
- "step": 57000
865
- },
866
- {
867
- "epoch": 9.09,
868
- "learning_rate": 2.871489937562647e-06,
869
- "loss": 1.7695,
870
- "step": 57500
871
- },
872
- {
873
- "epoch": 9.17,
874
- "learning_rate": 2.8406314813659073e-06,
875
- "loss": 1.7845,
876
- "step": 58000
877
- },
878
- {
879
- "epoch": 9.25,
880
- "learning_rate": 2.8097200522867294e-06,
881
- "loss": 1.7954,
882
- "step": 58500
883
- },
884
- {
885
- "epoch": 9.33,
886
- "learning_rate": 2.7787604574776745e-06,
887
- "loss": 1.8204,
888
- "step": 59000
889
- },
890
- {
891
- "epoch": 9.41,
892
- "learning_rate": 2.747757511581739e-06,
893
- "loss": 1.7786,
894
- "step": 59500
895
- },
896
- {
897
- "epoch": 9.49,
898
- "learning_rate": 2.716716035983611e-06,
899
- "loss": 1.7995,
900
- "step": 60000
901
- },
902
- {
903
- "epoch": 9.57,
904
- "learning_rate": 2.685640858059876e-06,
905
- "loss": 1.8058,
906
- "step": 60500
907
- },
908
- {
909
- "epoch": 9.65,
910
- "learning_rate": 2.6545368104282955e-06,
911
- "loss": 1.7961,
912
- "step": 61000
913
- },
914
- {
915
- "epoch": 9.73,
916
- "learning_rate": 2.623408730196268e-06,
917
- "loss": 1.7866,
918
- "step": 61500
919
- },
920
- {
921
- "epoch": 9.81,
922
- "learning_rate": 2.592261458208591e-06,
923
- "loss": 1.7643,
924
- "step": 62000
925
- },
926
- {
927
- "epoch": 9.88,
928
- "learning_rate": 2.5610998382946463e-06,
929
- "loss": 1.7679,
930
- "step": 62500
931
- },
932
- {
933
- "epoch": 9.96,
934
- "learning_rate": 2.529928716515112e-06,
935
- "loss": 1.8195,
936
- "step": 63000
937
- },
938
- {
939
- "epoch": 10.0,
940
- "eval_bleu": 1.0,
941
- "eval_brevity_penalty": 1.0,
942
- "eval_length_ratio": 1.0,
943
- "eval_loss": 1.5732102394104004,
944
- "eval_precisions": [
945
- 1.0,
946
- 1.0,
947
- 1.0,
948
- 1.0
949
- ],
950
- "eval_reference_length": 52485,
951
- "eval_runtime": 677.793,
952
- "eval_samples_per_second": 4.147,
953
- "eval_steps_per_second": 1.037,
954
- "eval_translation_length": 52485,
955
- "step": 63230
956
- },
957
- {
958
- "epoch": 10.04,
959
- "learning_rate": 2.498752940408342e-06,
960
- "loss": 1.7938,
961
- "step": 63500
962
- },
963
- {
964
- "epoch": 10.12,
965
- "learning_rate": 2.4675773582364977e-06,
966
- "loss": 1.7688,
967
- "step": 64000
968
- },
969
- {
970
- "epoch": 10.2,
971
- "learning_rate": 2.436406818231583e-06,
972
- "loss": 1.7701,
973
- "step": 64500
974
- },
975
- {
976
- "epoch": 10.28,
977
- "learning_rate": 2.4052461678414753e-06,
978
- "loss": 1.7821,
979
- "step": 65000
980
- },
981
- {
982
- "epoch": 10.36,
983
- "learning_rate": 2.37410025297608e-06,
984
- "loss": 1.8251,
985
- "step": 65500
986
- },
987
- {
988
- "epoch": 10.44,
989
- "learning_rate": 2.342973917253726e-06,
990
- "loss": 1.7384,
991
- "step": 66000
992
- },
993
- {
994
- "epoch": 10.52,
995
- "learning_rate": 2.3118720012479183e-06,
996
- "loss": 1.8001,
997
- "step": 66500
998
- },
999
- {
1000
- "epoch": 10.6,
1001
- "learning_rate": 2.280799341734556e-06,
1002
- "loss": 1.8386,
1003
- "step": 67000
1004
- },
1005
- {
1006
- "epoch": 10.68,
1007
- "learning_rate": 2.249760770939754e-06,
1008
- "loss": 1.8098,
1009
- "step": 67500
1010
- },
1011
- {
1012
- "epoch": 10.75,
1013
- "learning_rate": 2.218761115788362e-06,
1014
- "loss": 1.8059,
1015
- "step": 68000
1016
- },
1017
- {
1018
- "epoch": 10.83,
1019
- "learning_rate": 2.1878051971533093e-06,
1020
- "loss": 1.757,
1021
- "step": 68500
1022
- },
1023
- {
1024
- "epoch": 10.91,
1025
- "learning_rate": 2.156897829105898e-06,
1026
- "loss": 1.8037,
1027
- "step": 69000
1028
- },
1029
- {
1030
- "epoch": 10.99,
1031
- "learning_rate": 2.1260438181671446e-06,
1032
- "loss": 1.7714,
1033
- "step": 69500
1034
- },
1035
- {
1036
- "epoch": 11.0,
1037
- "eval_bleu": 1.0,
1038
- "eval_brevity_penalty": 1.0,
1039
- "eval_length_ratio": 1.0,
1040
- "eval_loss": 1.5735211372375488,
1041
- "eval_precisions": [
1042
- 1.0,
1043
- 1.0,
1044
- 1.0,
1045
- 1.0
1046
- ],
1047
- "eval_reference_length": 52469,
1048
- "eval_runtime": 678.026,
1049
- "eval_samples_per_second": 4.146,
1050
- "eval_steps_per_second": 1.037,
1051
- "eval_translation_length": 52469,
1052
- "step": 69553
1053
- },
1054
- {
1055
- "epoch": 11.07,
1056
- "learning_rate": 2.0952479625603017e-06,
1057
- "loss": 1.7783,
1058
- "step": 70000
1059
- },
1060
- {
1061
- "epoch": 11.15,
1062
- "learning_rate": 2.0645150514646657e-06,
1063
- "loss": 1.7443,
1064
- "step": 70500
1065
- },
1066
- {
1067
- "epoch": 11.23,
1068
- "learning_rate": 2.0338498642707977e-06,
1069
- "loss": 1.7678,
1070
- "step": 71000
1071
- },
1072
- {
1073
- "epoch": 11.31,
1074
- "learning_rate": 2.0032571698372577e-06,
1075
- "loss": 1.7786,
1076
- "step": 71500
1077
- },
1078
- {
1079
- "epoch": 11.39,
1080
- "learning_rate": 1.9727417257489874e-06,
1081
- "loss": 1.7768,
1082
- "step": 72000
1083
- },
1084
- {
1085
- "epoch": 11.47,
1086
- "learning_rate": 1.9423082775774337e-06,
1087
- "loss": 1.7953,
1088
- "step": 72500
1089
- },
1090
- {
1091
- "epoch": 11.55,
1092
- "learning_rate": 1.9119615581425524e-06,
1093
- "loss": 1.7715,
1094
- "step": 73000
1095
- },
1096
- {
1097
- "epoch": 11.62,
1098
- "learning_rate": 1.881706286776785e-06,
1099
- "loss": 1.8047,
1100
- "step": 73500
1101
- },
1102
- {
1103
- "epoch": 11.7,
1104
- "learning_rate": 1.8515471685911402e-06,
1105
- "loss": 1.7781,
1106
- "step": 74000
1107
- },
1108
- {
1109
- "epoch": 11.78,
1110
- "learning_rate": 1.821488893743488e-06,
1111
- "loss": 1.8197,
1112
- "step": 74500
1113
- },
1114
- {
1115
- "epoch": 11.86,
1116
- "learning_rate": 1.7915361367091677e-06,
1117
- "loss": 1.8159,
1118
- "step": 75000
1119
- },
1120
- {
1121
- "epoch": 11.94,
1122
- "learning_rate": 1.7616935555540475e-06,
1123
- "loss": 1.8004,
1124
- "step": 75500
1125
- },
1126
- {
1127
- "epoch": 12.0,
1128
- "eval_bleu": 1.0,
1129
- "eval_brevity_penalty": 1.0,
1130
- "eval_length_ratio": 1.0,
1131
- "eval_loss": 1.5739296674728394,
1132
- "eval_precisions": [
1133
- 1.0,
1134
- 1.0,
1135
- 1.0,
1136
- 1.0
1137
- ],
1138
- "eval_reference_length": 52457,
1139
- "eval_runtime": 676.9983,
1140
- "eval_samples_per_second": 4.152,
1141
- "eval_steps_per_second": 1.038,
1142
- "eval_translation_length": 52457,
1143
- "step": 75876
1144
- },
1145
- {
1146
- "epoch": 12.02,
1147
- "learning_rate": 1.7319657912101309e-06,
1148
- "loss": 1.7871,
1149
- "step": 76000
1150
- },
1151
- {
1152
- "epoch": 12.1,
1153
- "learning_rate": 1.7023574667538268e-06,
1154
- "loss": 1.7728,
1155
- "step": 76500
1156
- },
1157
- {
1158
- "epoch": 12.18,
1159
- "learning_rate": 1.6728731866869999e-06,
1160
- "loss": 1.792,
1161
- "step": 77000
1162
- },
1163
- {
1164
- "epoch": 12.26,
1165
- "learning_rate": 1.6435175362209033e-06,
1166
- "loss": 1.8009,
1167
- "step": 77500
1168
- },
1169
- {
1170
- "epoch": 12.34,
1171
- "learning_rate": 1.6142950805631178e-06,
1172
- "loss": 1.751,
1173
- "step": 78000
1174
- },
1175
- {
1176
- "epoch": 12.41,
1177
- "learning_rate": 1.5852103642075995e-06,
1178
- "loss": 1.7877,
1179
- "step": 78500
1180
- },
1181
- {
1182
- "epoch": 12.49,
1183
- "learning_rate": 1.5562679102279453e-06,
1184
- "loss": 1.7936,
1185
- "step": 79000
1186
- },
1187
- {
1188
- "epoch": 12.57,
1189
- "learning_rate": 1.5274722195740005e-06,
1190
- "loss": 1.7884,
1191
- "step": 79500
1192
- },
1193
- {
1194
- "epoch": 12.65,
1195
- "learning_rate": 1.4988277703718882e-06,
1196
- "loss": 1.7617,
1197
- "step": 80000
1198
- },
1199
- {
1200
- "epoch": 12.73,
1201
- "learning_rate": 1.4703390172276072e-06,
1202
- "loss": 1.7916,
1203
- "step": 80500
1204
- },
1205
- {
1206
- "epoch": 12.81,
1207
- "learning_rate": 1.4420103905342767e-06,
1208
- "loss": 1.7773,
1209
- "step": 81000
1210
- },
1211
- {
1212
- "epoch": 12.89,
1213
- "learning_rate": 1.4138462957831472e-06,
1214
- "loss": 1.7798,
1215
- "step": 81500
1216
- },
1217
- {
1218
- "epoch": 12.97,
1219
- "learning_rate": 1.3858511128784937e-06,
1220
- "loss": 1.7658,
1221
- "step": 82000
1222
- },
1223
- {
1224
- "epoch": 13.0,
1225
- "eval_bleu": 1.0,
1226
- "eval_brevity_penalty": 1.0,
1227
- "eval_length_ratio": 1.0,
1228
- "eval_loss": 1.5740926265716553,
1229
  "eval_precisions": [
1230
  1.0,
1231
  1.0,
1232
  1.0,
1233
  1.0
1234
  ],
1235
- "eval_reference_length": 52468,
1236
- "eval_runtime": 677.6429,
1237
- "eval_samples_per_second": 4.148,
1238
- "eval_steps_per_second": 1.037,
1239
- "eval_translation_length": 52468,
1240
- "step": 82199
1241
  }
1242
  ],
1243
  "logging_steps": 500,
1244
- "max_steps": 126460,
1245
- "num_train_epochs": 20,
1246
  "save_steps": 500,
1247
- "total_flos": 2.251365766099108e+17,
1248
  "trial_name": null,
1249
  "trial_params": null
1250
  }
 
1
  {
2
+ "best_metric": 2.2135026454925537,
3
+ "best_model_checkpoint": "dq158/pingusPongus/checkpoint-1581",
4
+ "epoch": 1.0,
5
  "eval_steps": 500,
6
+ "global_step": 1581,
7
  "is_hyper_param_search": false,
8
  "is_local_process_zero": true,
9
  "is_world_process_zero": true,
10
  "log_history": [
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  {
12
  "epoch": 0.32,
13
+ "learning_rate": 0.0008497772211805019,
14
+ "loss": 2.4966,
15
+ "step": 500
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  },
17
  {
18
  "epoch": 0.63,
19
+ "learning_rate": 0.0009445572420019074,
20
+ "loss": 2.3574,
21
+ "step": 1000
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  },
23
  {
24
  "epoch": 0.95,
25
+ "learning_rate": 0.0009999999999999998,
26
+ "loss": 2.3094,
27
+ "step": 1500
28
  },
29
  {
30
  "epoch": 1.0,
31
  "eval_bleu": 1.0,
32
  "eval_brevity_penalty": 1.0,
33
  "eval_length_ratio": 1.0,
34
+ "eval_loss": 2.2135026454925537,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  "eval_precisions": [
36
  1.0,
37
  1.0,
38
  1.0,
39
  1.0
40
  ],
41
+ "eval_reference_length": 53394,
42
+ "eval_runtime": 2937.3471,
43
+ "eval_samples_per_second": 0.957,
44
+ "eval_steps_per_second": 0.12,
45
+ "eval_translation_length": 53394,
46
+ "step": 1581
47
  }
48
  ],
49
  "logging_steps": 500,
50
+ "max_steps": 15810,
51
+ "num_train_epochs": 10,
52
  "save_steps": 500,
53
+ "total_flos": 771945142419456.0,
54
  "trial_name": null,
55
  "trial_params": null
56
  }
last-checkpoint/training_args.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:4ce6df3970ff39f84beab7b635dd3c941539643a1d32dedf54263683eef40519
3
- size 4664
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:da68e390908bbf17cdb4f041431915c87c2f3f6b8699cebf8dcb70aae5296d20
3
+ size 6648
last-checkpoint/zero_to_fp32.py ADDED
@@ -0,0 +1,587 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright (c) Microsoft Corporation.
4
+ # SPDX-License-Identifier: Apache-2.0
5
+
6
+ # DeepSpeed Team
7
+
8
+ # This script extracts fp32 consolidated weights from a zero 1, 2 and 3 DeepSpeed checkpoints. It gets
9
+ # copied into the top level checkpoint dir, so the user can easily do the conversion at any point in
10
+ # the future. Once extracted, the weights don't require DeepSpeed and can be used in any
11
+ # application.
12
+ #
13
+ # example: python zero_to_fp32.py . pytorch_model.bin
14
+
15
+ import argparse
16
+ import torch
17
+ import glob
18
+ import math
19
+ import os
20
+ import re
21
+ from collections import OrderedDict
22
+ from dataclasses import dataclass
23
+
24
+ # while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with
25
+ # DeepSpeed data structures it has to be available in the current python environment.
26
+ from deepspeed.utils import logger
27
+ from deepspeed.checkpoint.constants import (DS_VERSION, OPTIMIZER_STATE_DICT, SINGLE_PARTITION_OF_FP32_GROUPS,
28
+ FP32_FLAT_GROUPS, ZERO_STAGE, PARTITION_COUNT, PARAM_SHAPES, BUFFER_NAMES,
29
+ FROZEN_PARAM_SHAPES, FROZEN_PARAM_FRAGMENTS)
30
+
31
+
32
+ @dataclass
33
+ class zero_model_state:
34
+ buffers: dict()
35
+ param_shapes: dict()
36
+ shared_params: list
37
+ ds_version: int
38
+ frozen_param_shapes: dict()
39
+ frozen_param_fragments: dict()
40
+
41
+
42
+ debug = 0
43
+
44
+ # load to cpu
45
+ device = torch.device('cpu')
46
+
47
+
48
+ def atoi(text):
49
+ return int(text) if text.isdigit() else text
50
+
51
+
52
+ def natural_keys(text):
53
+ '''
54
+ alist.sort(key=natural_keys) sorts in human order
55
+ http://nedbatchelder.com/blog/200712/human_sorting.html
56
+ (See Toothy's implementation in the comments)
57
+ '''
58
+ return [atoi(c) for c in re.split(r'(\d+)', text)]
59
+
60
+
61
+ def get_model_state_file(checkpoint_dir, zero_stage):
62
+ if not os.path.isdir(checkpoint_dir):
63
+ raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist")
64
+
65
+ # there should be only one file
66
+ if zero_stage <= 2:
67
+ file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt")
68
+ elif zero_stage == 3:
69
+ file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt")
70
+
71
+ if not os.path.exists(file):
72
+ raise FileNotFoundError(f"can't find model states file at '{file}'")
73
+
74
+ return file
75
+
76
+
77
+ def get_checkpoint_files(checkpoint_dir, glob_pattern):
78
+ # XXX: need to test that this simple glob rule works for multi-node setup too
79
+ ckpt_files = sorted(glob.glob(os.path.join(checkpoint_dir, glob_pattern)), key=natural_keys)
80
+
81
+ if len(ckpt_files) == 0:
82
+ raise FileNotFoundError(f"can't find {glob_pattern} files in directory '{checkpoint_dir}'")
83
+
84
+ return ckpt_files
85
+
86
+
87
+ def get_optim_files(checkpoint_dir):
88
+ return get_checkpoint_files(checkpoint_dir, "*_optim_states.pt")
89
+
90
+
91
+ def get_model_state_files(checkpoint_dir):
92
+ return get_checkpoint_files(checkpoint_dir, "*_model_states.pt")
93
+
94
+
95
+ def parse_model_states(files):
96
+ zero_model_states = []
97
+ for file in files:
98
+ state_dict = torch.load(file, map_location=device)
99
+
100
+ if BUFFER_NAMES not in state_dict:
101
+ raise ValueError(f"{file} is not a model state checkpoint")
102
+ buffer_names = state_dict[BUFFER_NAMES]
103
+ if debug:
104
+ print("Found buffers:", buffer_names)
105
+
106
+ # recover just the buffers while restoring them to fp32 if they were saved in fp16
107
+ buffers = {k: v.float() for k, v in state_dict["module"].items() if k in buffer_names}
108
+ param_shapes = state_dict[PARAM_SHAPES]
109
+
110
+ # collect parameters that are included in param_shapes
111
+ param_names = []
112
+ for s in param_shapes:
113
+ for name in s.keys():
114
+ param_names.append(name)
115
+
116
+ # update with frozen parameters
117
+ frozen_param_shapes = state_dict.get(FROZEN_PARAM_SHAPES, None)
118
+ if frozen_param_shapes is not None:
119
+ if debug:
120
+ print(f"Found frozen_param_shapes: {frozen_param_shapes}")
121
+ param_names += list(frozen_param_shapes.keys())
122
+
123
+ # handle shared params
124
+ shared_params = [[k, v] for k, v in state_dict["shared_params"].items()]
125
+
126
+ ds_version = state_dict.get(DS_VERSION, None)
127
+
128
+ frozen_param_fragments = state_dict.get(FROZEN_PARAM_FRAGMENTS, None)
129
+
130
+ z_model_state = zero_model_state(buffers=buffers,
131
+ param_shapes=param_shapes,
132
+ shared_params=shared_params,
133
+ ds_version=ds_version,
134
+ frozen_param_shapes=frozen_param_shapes,
135
+ frozen_param_fragments=frozen_param_fragments)
136
+ zero_model_states.append(z_model_state)
137
+
138
+ return zero_model_states
139
+
140
+
141
+ def parse_optim_states(files, ds_checkpoint_dir):
142
+
143
+ total_files = len(files)
144
+ state_dicts = []
145
+ for f in files:
146
+ state_dict = torch.load(f, map_location=device)
147
+ # immediately discard the potentially huge 2 optimizer states as we only care for fp32 master weights
148
+ # and also handle the case where it was already removed by another helper script
149
+ state_dict["optimizer_state_dict"].pop("optimizer_state_dict", None)
150
+ state_dicts.append(state_dict)
151
+
152
+ if not ZERO_STAGE in state_dicts[0][OPTIMIZER_STATE_DICT]:
153
+ raise ValueError(f"{files[0]} is not a zero checkpoint")
154
+ zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE]
155
+ world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT]
156
+
157
+ # For ZeRO-2 each param group can have different partition_count as data parallelism for expert
158
+ # parameters can be different from data parallelism for non-expert parameters. So we can just
159
+ # use the max of the partition_count to get the dp world_size.
160
+
161
+ if type(world_size) is list:
162
+ world_size = max(world_size)
163
+
164
+ if world_size != total_files:
165
+ raise ValueError(
166
+ f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. "
167
+ "Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes."
168
+ )
169
+
170
+ # the groups are named differently in each stage
171
+ if zero_stage <= 2:
172
+ fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS
173
+ elif zero_stage == 3:
174
+ fp32_groups_key = FP32_FLAT_GROUPS
175
+ else:
176
+ raise ValueError(f"unknown zero stage {zero_stage}")
177
+
178
+ if zero_stage <= 2:
179
+ fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))]
180
+ elif zero_stage == 3:
181
+ # if there is more than one param group, there will be multiple flattened tensors - one
182
+ # flattened tensor per group - for simplicity merge them into a single tensor
183
+ #
184
+ # XXX: could make the script more memory efficient for when there are multiple groups - it
185
+ # will require matching the sub-lists of param_shapes for each param group flattened tensor
186
+
187
+ fp32_flat_groups = [
188
+ torch.cat(state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key], 0) for i in range(len(state_dicts))
189
+ ]
190
+
191
+ return zero_stage, world_size, fp32_flat_groups
192
+
193
+
194
+ def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir):
195
+ """
196
+ Returns fp32 state_dict reconstructed from ds checkpoint
197
+
198
+ Args:
199
+ - ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are)
200
+
201
+ """
202
+ print(f"Processing zero checkpoint '{ds_checkpoint_dir}'")
203
+
204
+ optim_files = get_optim_files(ds_checkpoint_dir)
205
+ zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir)
206
+ print(f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}")
207
+
208
+ model_files = get_model_state_files(ds_checkpoint_dir)
209
+
210
+ zero_model_states = parse_model_states(model_files)
211
+ print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}')
212
+
213
+ if zero_stage <= 2:
214
+ return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states)
215
+ elif zero_stage == 3:
216
+ return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states)
217
+
218
+
219
+ def _zero2_merge_frozen_params(state_dict, zero_model_states):
220
+ if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
221
+ return
222
+
223
+ frozen_param_shapes = zero_model_states[0].frozen_param_shapes
224
+ frozen_param_fragments = zero_model_states[0].frozen_param_fragments
225
+
226
+ if debug:
227
+ num_elem = sum(s.numel() for s in frozen_param_shapes.values())
228
+ print(f'rank 0: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
229
+
230
+ wanted_params = len(frozen_param_shapes)
231
+ wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
232
+ avail_numel = sum([p.numel() for p in frozen_param_fragments.values()])
233
+ print(f'Frozen params: Have {avail_numel} numels to process.')
234
+ print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
235
+
236
+ total_params = 0
237
+ total_numel = 0
238
+ for name, shape in frozen_param_shapes.items():
239
+ total_params += 1
240
+ unpartitioned_numel = shape.numel()
241
+ total_numel += unpartitioned_numel
242
+
243
+ state_dict[name] = frozen_param_fragments[name]
244
+
245
+ if debug:
246
+ print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
247
+
248
+ print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
249
+
250
+
251
+ def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
252
+ param_shapes = zero_model_states[0].param_shapes
253
+
254
+ # Reconstruction protocol:
255
+ #
256
+ # XXX: document this
257
+
258
+ if debug:
259
+ for i in range(world_size):
260
+ for j in range(len(fp32_flat_groups[0])):
261
+ print(f"{FP32_FLAT_GROUPS}[{i}][{j}].shape={fp32_flat_groups[i][j].shape}")
262
+
263
+ # XXX: memory usage doubles here (zero2)
264
+ num_param_groups = len(fp32_flat_groups[0])
265
+ merged_single_partition_of_fp32_groups = []
266
+ for i in range(num_param_groups):
267
+ merged_partitions = [sd[i] for sd in fp32_flat_groups]
268
+ full_single_fp32_vector = torch.cat(merged_partitions, 0)
269
+ merged_single_partition_of_fp32_groups.append(full_single_fp32_vector)
270
+ avail_numel = sum(
271
+ [full_single_fp32_vector.numel() for full_single_fp32_vector in merged_single_partition_of_fp32_groups])
272
+
273
+ if debug:
274
+ wanted_params = sum([len(shapes) for shapes in param_shapes])
275
+ wanted_numel = sum([sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes])
276
+ # not asserting if there is a mismatch due to possible padding
277
+ print(f"Have {avail_numel} numels to process.")
278
+ print(f"Need {wanted_numel} numels in {wanted_params} params.")
279
+
280
+ # params
281
+ # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
282
+ # out-of-core computing solution
283
+ total_numel = 0
284
+ total_params = 0
285
+ for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups):
286
+ offset = 0
287
+ avail_numel = full_single_fp32_vector.numel()
288
+ for name, shape in shapes.items():
289
+
290
+ unpartitioned_numel = shape.numel()
291
+ total_numel += unpartitioned_numel
292
+ total_params += 1
293
+
294
+ if debug:
295
+ print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
296
+ state_dict[name] = full_single_fp32_vector.narrow(0, offset, unpartitioned_numel).view(shape)
297
+ offset += unpartitioned_numel
298
+
299
+ # Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and
300
+ # avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex
301
+ # paddings performed in the code it's almost impossible to predict the exact numbers w/o the
302
+ # live optimizer object, so we are checking that the numbers are within the right range
303
+ align_to = 2 * world_size
304
+
305
+ def zero2_align(x):
306
+ return align_to * math.ceil(x / align_to)
307
+
308
+ if debug:
309
+ print(f"original offset={offset}, avail_numel={avail_numel}")
310
+
311
+ offset = zero2_align(offset)
312
+ avail_numel = zero2_align(avail_numel)
313
+
314
+ if debug:
315
+ print(f"aligned offset={offset}, avail_numel={avail_numel}")
316
+
317
+ # Sanity check
318
+ if offset != avail_numel:
319
+ raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
320
+
321
+ print(f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements")
322
+
323
+
324
+ def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states):
325
+ state_dict = OrderedDict()
326
+
327
+ # buffers
328
+ buffers = zero_model_states[0].buffers
329
+ state_dict.update(buffers)
330
+ if debug:
331
+ print(f"added {len(buffers)} buffers")
332
+
333
+ _zero2_merge_frozen_params(state_dict, zero_model_states)
334
+
335
+ _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
336
+
337
+ # recover shared parameters
338
+ for pair in zero_model_states[0].shared_params:
339
+ if pair[1] in state_dict:
340
+ state_dict[pair[0]] = state_dict[pair[1]]
341
+
342
+ return state_dict
343
+
344
+
345
+ def zero3_partitioned_param_info(unpartitioned_numel, world_size):
346
+ remainder = unpartitioned_numel % world_size
347
+ padding_numel = (world_size - remainder) if remainder else 0
348
+ partitioned_numel = math.ceil(unpartitioned_numel / world_size)
349
+ return partitioned_numel, padding_numel
350
+
351
+
352
+ def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states):
353
+ if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
354
+ return
355
+
356
+ if debug:
357
+ for i in range(world_size):
358
+ num_elem = sum(s.numel() for s in zero_model_states[i].frozen_param_fragments.values())
359
+ print(f'rank {i}: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
360
+
361
+ frozen_param_shapes = zero_model_states[0].frozen_param_shapes
362
+ wanted_params = len(frozen_param_shapes)
363
+ wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
364
+ avail_numel = sum([p.numel() for p in zero_model_states[0].frozen_param_fragments.values()]) * world_size
365
+ print(f'Frozen params: Have {avail_numel} numels to process.')
366
+ print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
367
+
368
+ total_params = 0
369
+ total_numel = 0
370
+ for name, shape in zero_model_states[0].frozen_param_shapes.items():
371
+ total_params += 1
372
+ unpartitioned_numel = shape.numel()
373
+ total_numel += unpartitioned_numel
374
+
375
+ param_frags = tuple(model_state.frozen_param_fragments[name] for model_state in zero_model_states)
376
+ state_dict[name] = torch.cat(param_frags, 0).narrow(0, 0, unpartitioned_numel).view(shape)
377
+
378
+ partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
379
+
380
+ if debug:
381
+ print(
382
+ f"Frozen params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
383
+ )
384
+
385
+ print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
386
+
387
+
388
+ def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
389
+ param_shapes = zero_model_states[0].param_shapes
390
+ avail_numel = fp32_flat_groups[0].numel() * world_size
391
+ # Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each
392
+ # param, re-consolidating each param, while dealing with padding if any
393
+
394
+ # merge list of dicts, preserving order
395
+ param_shapes = {k: v for d in param_shapes for k, v in d.items()}
396
+
397
+ if debug:
398
+ for i in range(world_size):
399
+ print(f"{FP32_FLAT_GROUPS}[{i}].shape={fp32_flat_groups[i].shape}")
400
+
401
+ wanted_params = len(param_shapes)
402
+ wanted_numel = sum(shape.numel() for shape in param_shapes.values())
403
+ # not asserting if there is a mismatch due to possible padding
404
+ avail_numel = fp32_flat_groups[0].numel() * world_size
405
+ print(f"Trainable params: Have {avail_numel} numels to process.")
406
+ print(f"Trainable params: Need {wanted_numel} numels in {wanted_params} params.")
407
+
408
+ # params
409
+ # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
410
+ # out-of-core computing solution
411
+ offset = 0
412
+ total_numel = 0
413
+ total_params = 0
414
+ for name, shape in param_shapes.items():
415
+
416
+ unpartitioned_numel = shape.numel()
417
+ total_numel += unpartitioned_numel
418
+ total_params += 1
419
+
420
+ partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
421
+
422
+ if debug:
423
+ print(
424
+ f"Trainable params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
425
+ )
426
+
427
+ # XXX: memory usage doubles here
428
+ state_dict[name] = torch.cat(
429
+ tuple(fp32_flat_groups[i].narrow(0, offset, partitioned_numel) for i in range(world_size)),
430
+ 0).narrow(0, 0, unpartitioned_numel).view(shape)
431
+ offset += partitioned_numel
432
+
433
+ offset *= world_size
434
+
435
+ # Sanity check
436
+ if offset != avail_numel:
437
+ raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
438
+
439
+ print(f"Reconstructed Trainable fp32 state dict with {total_params} params {total_numel} elements")
440
+
441
+
442
+ def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states):
443
+ state_dict = OrderedDict()
444
+
445
+ # buffers
446
+ buffers = zero_model_states[0].buffers
447
+ state_dict.update(buffers)
448
+ if debug:
449
+ print(f"added {len(buffers)} buffers")
450
+
451
+ _zero3_merge_frozen_params(state_dict, world_size, zero_model_states)
452
+
453
+ _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
454
+
455
+ # recover shared parameters
456
+ for pair in zero_model_states[0].shared_params:
457
+ if pair[1] in state_dict:
458
+ state_dict[pair[0]] = state_dict[pair[1]]
459
+
460
+ return state_dict
461
+
462
+
463
+ def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None):
464
+ """
465
+ Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with
466
+ ``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example
467
+ via a model hub.
468
+
469
+ Args:
470
+ - ``checkpoint_dir``: path to the desired checkpoint folder
471
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14``
472
+
473
+ Returns:
474
+ - pytorch ``state_dict``
475
+
476
+ Note: this approach may not work if your application doesn't have sufficient free CPU memory and
477
+ you may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with
478
+ the checkpoint.
479
+
480
+ A typical usage might be ::
481
+
482
+ from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
483
+ # do the training and checkpoint saving
484
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu
485
+ model = model.cpu() # move to cpu
486
+ model.load_state_dict(state_dict)
487
+ # submit to model hub or save the model to share with others
488
+
489
+ In this example the ``model`` will no longer be usable in the deepspeed context of the same
490
+ application. i.e. you will need to re-initialize the deepspeed engine, since
491
+ ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
492
+
493
+ If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead.
494
+
495
+ """
496
+ if tag is None:
497
+ latest_path = os.path.join(checkpoint_dir, 'latest')
498
+ if os.path.isfile(latest_path):
499
+ with open(latest_path, 'r') as fd:
500
+ tag = fd.read().strip()
501
+ else:
502
+ raise ValueError(f"Unable to find 'latest' file at {latest_path}")
503
+
504
+ ds_checkpoint_dir = os.path.join(checkpoint_dir, tag)
505
+
506
+ if not os.path.isdir(ds_checkpoint_dir):
507
+ raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist")
508
+
509
+ return _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir)
510
+
511
+
512
+ def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag=None):
513
+ """
514
+ Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be
515
+ loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed.
516
+
517
+ Args:
518
+ - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
519
+ - ``output_file``: path to the pytorch fp32 state_dict output file (e.g. path/pytorch_model.bin)
520
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
521
+ """
522
+
523
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag)
524
+ print(f"Saving fp32 state dict to {output_file}")
525
+ torch.save(state_dict, output_file)
526
+
527
+
528
+ def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None):
529
+ """
530
+ 1. Put the provided model to cpu
531
+ 2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict``
532
+ 3. Load it into the provided model
533
+
534
+ Args:
535
+ - ``model``: the model object to update
536
+ - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
537
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
538
+
539
+ Returns:
540
+ - ``model`: modified model
541
+
542
+ Make sure you have plenty of CPU memory available before you call this function. If you don't
543
+ have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it
544
+ conveniently placed for you in the checkpoint folder.
545
+
546
+ A typical usage might be ::
547
+
548
+ from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint
549
+ model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir)
550
+ # submit to model hub or save the model to share with others
551
+
552
+ Note, that once this was run, the ``model`` will no longer be usable in the deepspeed context
553
+ of the same application. i.e. you will need to re-initialize the deepspeed engine, since
554
+ ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
555
+
556
+ """
557
+ logger.info(f"Extracting fp32 weights")
558
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag)
559
+
560
+ logger.info(f"Overwriting model with fp32 weights")
561
+ model = model.cpu()
562
+ model.load_state_dict(state_dict, strict=False)
563
+
564
+ return model
565
+
566
+
567
+ if __name__ == "__main__":
568
+
569
+ parser = argparse.ArgumentParser()
570
+ parser.add_argument("checkpoint_dir",
571
+ type=str,
572
+ help="path to the desired checkpoint folder, e.g., path/checkpoint-12")
573
+ parser.add_argument(
574
+ "output_file",
575
+ type=str,
576
+ help="path to the pytorch fp32 state_dict output file (e.g. path/checkpoint-12/pytorch_model.bin)")
577
+ parser.add_argument("-t",
578
+ "--tag",
579
+ type=str,
580
+ default=None,
581
+ help="checkpoint tag used as a unique identifier for checkpoint. e.g., global_step1")
582
+ parser.add_argument("-d", "--debug", action='store_true', help="enable debug")
583
+ args = parser.parse_args()
584
+
585
+ debug = args.debug
586
+
587
+ convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir, args.output_file, tag=args.tag)