w11wo commited on
Commit
2436252
·
1 Parent(s): 60d5285

Saving weights and logs of epoch 1

Browse files
.gitattributes CHANGED
@@ -16,3 +16,4 @@
16
  *.pth filter=lfs diff=lfs merge=lfs -text
17
  *tfevents* filter=lfs diff=lfs merge=lfs -text
18
  nohup.out filter=lfs diff=lfs merge=lfs -text
 
 
16
  *.pth filter=lfs diff=lfs merge=lfs -text
17
  *tfevents* filter=lfs diff=lfs merge=lfs -text
18
  nohup.out filter=lfs diff=lfs merge=lfs -text
19
+ flax_model.msgpack filter=lfs diff=lfs merge=lfs -text
events.out.tfevents.1626286603.t1v-n-b95d739e-w-0.590614.3.v2 → events.out.tfevents.1626318482.t1v-n-b95d739e-w-0.622701.3.v2 RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:69d0b0c11510581415e3ad84919fcc5857dd72e276dfa98d90a601a31995e9d7
3
- size 4897718
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c0c8833b4f4649f58ab0d01d47c772f8c05080f371d5f9d57e7134d997e944a1
3
+ size 157187
flax_model.msgpack ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:91922c046bb159618da797c8c5076e9684aa45c6ded263ff8c60dab3cb008059
3
+ size 498796983
flax_to_torch.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from transformers import RobertaForMaskedLM, AutoTokenizer
2
+
3
+ model = RobertaForMaskedLM.from_pretrained("./", from_flax=True)
4
+ model.save_pretrained("./")
5
+
6
+ tokenizer = AutoTokenizer.from_pretrained("./")
7
+ tokenizer.save_pretrained("./")
nohup.out CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:f2645f6739234c77a54e6320ab13a6dcdd86e2decd09fe90e309506975ad0b0b
3
- size 4470375
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:99c46650710b372548e97ab2d4a123983e2b495c3ceb094847f500b4ac3a64f7
3
+ size 193918
run.sh CHANGED
@@ -1,5 +1,6 @@
1
  #!/usr/bin/env bash
2
  python3 run_mlm_flax.py \
 
3
  --output_dir="./" \
4
  --model_type="roberta" \
5
  --config_name="./" \
 
1
  #!/usr/bin/env bash
2
  python3 run_mlm_flax.py \
3
+ --model_name_or_path="flax_model.msgpack" \
4
  --output_dir="./" \
5
  --model_type="roberta" \
6
  --config_name="./" \
run_mlm_flax.py CHANGED
@@ -56,6 +56,24 @@ from transformers import (
56
  )
57
 
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys())
60
  MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
61
 
@@ -156,7 +174,7 @@ class DataTrainingArguments:
156
  metadata={"help": "Overwrite the cached training and evaluation sets"},
157
  )
158
  validation_split_percentage: Optional[int] = field(
159
- default=10,
160
  metadata={
161
  "help": "The percentage of the train set used as validation set in case there's no validation split"
162
  },
@@ -314,7 +332,7 @@ def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndar
314
  return batch_idx
315
 
316
 
317
- def write_train_metric(summary_writer, train_metrics, train_time, step):
318
  summary_writer.scalar("train_time", train_time, step)
319
 
320
  train_metrics = get_metrics(train_metrics)
@@ -323,8 +341,6 @@ def write_train_metric(summary_writer, train_metrics, train_time, step):
323
  for i, val in enumerate(vals):
324
  summary_writer.scalar(tag, val, step - len(vals) + i + 1)
325
 
326
-
327
- def write_eval_metric(summary_writer, eval_metrics, step):
328
  for metric_name, value in eval_metrics.items():
329
  summary_writer.scalar(f"eval_{metric_name}", value, step)
330
 
@@ -366,6 +382,10 @@ if __name__ == "__main__":
366
 
367
  # Log on each process the small summary:
368
  logger = logging.getLogger(__name__)
 
 
 
 
369
 
370
  # Set the verbosity to info of the Transformers logger (on main process only):
371
  logger.info(f"Training/evaluation parameters {training_args}")
@@ -557,22 +577,8 @@ if __name__ == "__main__":
557
  )
558
 
559
  # Enable tensorboard only on the master node
560
- has_tensorboard = is_tensorboard_available()
561
  if has_tensorboard and jax.process_index() == 0:
562
- try:
563
- from flax.metrics.tensorboard import SummaryWriter
564
-
565
- summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
566
- except ImportError as ie:
567
- has_tensorboard = False
568
- logger.warning(
569
- f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
570
- )
571
- else:
572
- logger.warning(
573
- "Unable to display metrics through TensorBoard because the package is not installed: "
574
- "Please run pip install tensorboard to enable."
575
- )
576
 
577
  # Data collator
578
  # This one will take care of randomly masking the tokens.
@@ -584,17 +590,9 @@ if __name__ == "__main__":
584
  rng = jax.random.PRNGKey(training_args.seed)
585
  dropout_rngs = jax.random.split(rng, jax.local_device_count())
586
 
587
- if model_args.model_name_or_path:
588
- model = FlaxAutoModelForMaskedLM.from_pretrained(
589
- model_args.model_name_or_path,
590
- config=config,
591
- seed=training_args.seed,
592
- dtype=getattr(jnp, model_args.dtype),
593
- )
594
- else:
595
- model = FlaxAutoModelForMaskedLM.from_config(
596
- config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
597
- )
598
 
599
  # Store some constant
600
  num_epochs = int(training_args.num_train_epochs)
@@ -636,23 +634,18 @@ if __name__ == "__main__":
636
  return traverse_util.unflatten_dict(flat_mask)
637
 
638
  # create adam optimizer
639
- if training_args.adafactor:
640
- # We use the default parameters here to initialize adafactor,
641
- # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
642
- optimizer = optax.adafactor(learning_rate=linear_decay_lr_schedule_fn,)
643
- else:
644
- optimizer = optax.adamw(
645
- learning_rate=linear_decay_lr_schedule_fn,
646
- b1=training_args.adam_beta1,
647
- b2=training_args.adam_beta2,
648
- eps=training_args.adam_epsilon,
649
- weight_decay=training_args.weight_decay,
650
- mask=decay_mask_fn,
651
- )
652
 
653
  # Setup train state
654
  state = train_state.TrainState.create(
655
- apply_fn=model.__call__, params=model.params, tx=optimizer
656
  )
657
 
658
  # Define gradient update step fn
@@ -742,7 +735,7 @@ if __name__ == "__main__":
742
  train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size)
743
 
744
  # Gather the indexes for creating the batch and do a training step
745
- for step, batch_idx in enumerate(
746
  tqdm(train_batch_idx, desc="Training...", position=1)
747
  ):
748
  samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx]
@@ -755,69 +748,52 @@ if __name__ == "__main__":
755
  )
756
  train_metrics.append(train_metric)
757
 
758
- cur_step = epoch * (num_train_samples // train_batch_size) + step
759
-
760
- if cur_step % training_args.logging_steps == 0 and cur_step > 0:
761
- # Save metrics
762
- train_metric = jax_utils.unreplicate(train_metric)
763
- train_time += time.time() - train_start
764
- if has_tensorboard and jax.process_index() == 0:
765
- write_train_metric(
766
- summary_writer, train_metrics, train_time, cur_step
767
- )
768
-
769
- epochs.write(
770
- f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
771
- )
772
-
773
- train_metrics = []
774
-
775
- if cur_step % training_args.eval_steps == 0 and cur_step > 0:
776
- # ======================== Evaluating ==============================
777
- num_eval_samples = len(tokenized_datasets["test"])
778
- eval_samples_idx = jnp.arange(num_eval_samples)
779
- eval_batch_idx = generate_batch_splits(
780
- eval_samples_idx, eval_batch_size
781
- )
782
-
783
- eval_metrics = []
784
- for i, batch_idx in enumerate(
785
- tqdm(eval_batch_idx, desc="Evaluating ...", position=2)
786
- ):
787
- samples = [
788
- tokenized_datasets["test"][int(idx)] for idx in batch_idx
789
- ]
790
- model_inputs = data_collator(samples, pad_to_multiple_of=16)
791
-
792
- # Model forward
793
- model_inputs = shard(model_inputs.data)
794
- metrics = p_eval_step(state.params, model_inputs)
795
- eval_metrics.append(metrics)
796
-
797
- # normalize eval metrics
798
- eval_metrics = get_metrics(eval_metrics)
799
- eval_metrics = jax.tree_map(jnp.sum, eval_metrics)
800
- eval_normalizer = eval_metrics.pop("normalizer")
801
- eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
802
-
803
- # Update progress bar
804
- epochs.desc = f"Step... ({cur_step} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})"
805
-
806
- # Save metrics
807
- if has_tensorboard and jax.process_index() == 0:
808
- cur_step = epoch * (
809
- len(tokenized_datasets["train"]) // train_batch_size
810
- )
811
- write_eval_metric(summary_writer, eval_metrics, cur_step)
812
-
813
- if cur_step % training_args.save_steps == 0 and cur_step > 0:
814
- # save checkpoint after each epoch and push checkpoint to the hub
815
- if jax.process_index() == 0:
816
- params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
817
- model.save_pretrained(
818
- training_args.output_dir,
819
- params=params,
820
- push_to_hub=training_args.push_to_hub,
821
- commit_message=f"Saving weights and logs of step {cur_step}",
822
- )
823
 
 
56
  )
57
 
58
 
59
+ # Cache the result
60
+ has_tensorboard = is_tensorboard_available()
61
+ if has_tensorboard:
62
+ try:
63
+ from flax.metrics.tensorboard import SummaryWriter
64
+ except ImportError as ie:
65
+ has_tensorboard = False
66
+ print(
67
+ f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
68
+ )
69
+
70
+ else:
71
+ print(
72
+ "Unable to display metrics through TensorBoard because the package is not installed: "
73
+ "Please run pip install tensorboard to enable."
74
+ )
75
+
76
+
77
  MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys())
78
  MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
79
 
 
174
  metadata={"help": "Overwrite the cached training and evaluation sets"},
175
  )
176
  validation_split_percentage: Optional[int] = field(
177
+ default=5,
178
  metadata={
179
  "help": "The percentage of the train set used as validation set in case there's no validation split"
180
  },
 
332
  return batch_idx
333
 
334
 
335
+ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
336
  summary_writer.scalar("train_time", train_time, step)
337
 
338
  train_metrics = get_metrics(train_metrics)
 
341
  for i, val in enumerate(vals):
342
  summary_writer.scalar(tag, val, step - len(vals) + i + 1)
343
 
 
 
344
  for metric_name, value in eval_metrics.items():
345
  summary_writer.scalar(f"eval_{metric_name}", value, step)
346
 
 
382
 
383
  # Log on each process the small summary:
384
  logger = logging.getLogger(__name__)
385
+ logger.warning(
386
+ f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
387
+ + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
388
+ )
389
 
390
  # Set the verbosity to info of the Transformers logger (on main process only):
391
  logger.info(f"Training/evaluation parameters {training_args}")
 
577
  )
578
 
579
  # Enable tensorboard only on the master node
 
580
  if has_tensorboard and jax.process_index() == 0:
581
+ summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
 
 
 
 
 
 
 
 
 
 
 
 
 
582
 
583
  # Data collator
584
  # This one will take care of randomly masking the tokens.
 
590
  rng = jax.random.PRNGKey(training_args.seed)
591
  dropout_rngs = jax.random.split(rng, jax.local_device_count())
592
 
593
+ model = FlaxAutoModelForMaskedLM.from_config(
594
+ config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
595
+ )
 
 
 
 
 
 
 
 
596
 
597
  # Store some constant
598
  num_epochs = int(training_args.num_train_epochs)
 
634
  return traverse_util.unflatten_dict(flat_mask)
635
 
636
  # create adam optimizer
637
+ adamw = optax.adamw(
638
+ learning_rate=linear_decay_lr_schedule_fn,
639
+ b1=training_args.adam_beta1,
640
+ b2=training_args.adam_beta2,
641
+ eps=1e-8,
642
+ weight_decay=training_args.weight_decay,
643
+ mask=decay_mask_fn,
644
+ )
 
 
 
 
 
645
 
646
  # Setup train state
647
  state = train_state.TrainState.create(
648
+ apply_fn=model.__call__, params=model.params, tx=adamw
649
  )
650
 
651
  # Define gradient update step fn
 
735
  train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size)
736
 
737
  # Gather the indexes for creating the batch and do a training step
738
+ for i, batch_idx in enumerate(
739
  tqdm(train_batch_idx, desc="Training...", position=1)
740
  ):
741
  samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx]
 
748
  )
749
  train_metrics.append(train_metric)
750
 
751
+ train_time += time.time() - train_start
752
+
753
+ epochs.write(
754
+ f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
755
+ )
756
+
757
+ # ======================== Evaluating ==============================
758
+ num_eval_samples = len(tokenized_datasets["test"])
759
+ eval_samples_idx = jnp.arange(num_eval_samples)
760
+ eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
761
+
762
+ eval_metrics = []
763
+ for i, batch_idx in enumerate(
764
+ tqdm(eval_batch_idx, desc="Evaluating ...", position=2)
765
+ ):
766
+ samples = [tokenized_datasets["test"][int(idx)] for idx in batch_idx]
767
+ model_inputs = data_collator(samples, pad_to_multiple_of=16)
768
+
769
+ # Model forward
770
+ model_inputs = shard(model_inputs.data)
771
+ metrics = p_eval_step(state.params, model_inputs)
772
+ eval_metrics.append(metrics)
773
+
774
+ # normalize eval metrics
775
+ eval_metrics = get_metrics(eval_metrics)
776
+ eval_metrics = jax.tree_map(jnp.sum, eval_metrics)
777
+ eval_normalizer = eval_metrics.pop("normalizer")
778
+ eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
779
+
780
+ # Update progress bar
781
+ epochs.desc = f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})"
782
+
783
+ # Save metrics
784
+ if has_tensorboard and jax.process_index() == 0:
785
+ cur_step = epoch * (len(tokenized_datasets["train"]) // train_batch_size)
786
+ write_metric(
787
+ summary_writer, train_metrics, eval_metrics, train_time, cur_step
788
+ )
789
+
790
+ # save checkpoint after each epoch and push checkpoint to the hub
791
+ if jax.process_index() == 0:
792
+ params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
793
+ model.save_pretrained(
794
+ training_args.output_dir,
795
+ params=params,
796
+ push_to_hub=training_args.push_to_hub,
797
+ commit_message=f"Saving weights and logs of epoch {epoch+1}",
798
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
799