boris commited on
Commit
07a6f9a
1 Parent(s): 0199604

feat: scan layers + gradient checkpointing (#161)

Browse files

* scan layers for faster compilation
* support gradient checkpointing

src/dalle_mini/model/configuration.py CHANGED
@@ -51,7 +51,8 @@ class DalleBartConfig(PretrainedFromWandbMixin, PretrainedConfig):
51
  activation_dropout=0.0,
52
  init_std=0.02,
53
  scale_embedding=False,
54
- gradient_checkpointing=False,
 
55
  use_cache=True,
56
  is_encoder_decoder=True,
57
  forced_eos_token_id=None,
@@ -59,7 +60,7 @@ class DalleBartConfig(PretrainedFromWandbMixin, PretrainedConfig):
59
  do_sample=True,
60
  # transformer variants
61
  use_bias=False, # use bias in attention and dense layers (except for lm_head)
62
- ln_type="layernorm", # layer normalization type, "rmsnorm", "layernorm"
63
  ln_positions="normformer", # layer normalization positions, "normformer", "swinv2", "cogview", "postln", "preln", "deepnet" (same as postln)
64
  use_head_scale=False, # used in NormFormer
65
  use_cosine_attention=False, # used in Swin v2
@@ -67,7 +68,7 @@ class DalleBartConfig(PretrainedFromWandbMixin, PretrainedConfig):
67
  use_absolute_position_embeddings=True, # default
68
  use_swin_position_embeddings=False, # used in Swin v1/v2
69
  use_deepnet_scaling=False, # used in Deepnet
70
- use_glu=False, # "GLU Variants Improve Transformer"
71
  use_alibi=False, # Not implemented yet - from "Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation"
72
  sinkhorn_iters=1, # used in SinkFormers
73
  use_final_ln_encoder=True, # final layer normalization in encoder
@@ -136,6 +137,11 @@ class DalleBartConfig(PretrainedFromWandbMixin, PretrainedConfig):
136
  self.init_std = init_std
137
  self.use_cache = use_cache
138
  self.gradient_checkpointing = gradient_checkpointing
 
 
 
 
 
139
  self.scale_embedding = (
140
  scale_embedding # scale factor will be sqrt(d_model) if True
141
  )
 
51
  activation_dropout=0.0,
52
  init_std=0.02,
53
  scale_embedding=False,
54
+ gradient_checkpointing=True,
55
+ use_scan=None,
56
  use_cache=True,
57
  is_encoder_decoder=True,
58
  forced_eos_token_id=None,
 
60
  do_sample=True,
61
  # transformer variants
62
  use_bias=False, # use bias in attention and dense layers (except for lm_head)
63
+ ln_type="rmsnorm", # layer normalization type, "rmsnorm", "layernorm"
64
  ln_positions="normformer", # layer normalization positions, "normformer", "swinv2", "cogview", "postln", "preln", "deepnet" (same as postln)
65
  use_head_scale=False, # used in NormFormer
66
  use_cosine_attention=False, # used in Swin v2
 
68
  use_absolute_position_embeddings=True, # default
69
  use_swin_position_embeddings=False, # used in Swin v1/v2
70
  use_deepnet_scaling=False, # used in Deepnet
71
+ use_glu=True, # "GLU Variants Improve Transformer"
72
  use_alibi=False, # Not implemented yet - from "Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation"
73
  sinkhorn_iters=1, # used in SinkFormers
74
  use_final_ln_encoder=True, # final layer normalization in encoder
 
137
  self.init_std = init_std
138
  self.use_cache = use_cache
139
  self.gradient_checkpointing = gradient_checkpointing
140
+ # all layers are the same in most configurations
141
+ self.use_scan = use_scan if use_scan is not None else ln_positions != "swinv2"
142
+ assert not (
143
+ self.use_scan and ln_positions == "swinv2"
144
+ ), "scan cannot be used with 'swinv2'"
145
  self.scale_embedding = (
146
  scale_embedding # scale factor will be sqrt(d_model) if True
147
  )
src/dalle_mini/model/modeling.py CHANGED
@@ -619,6 +619,9 @@ class FlaxBartEncoderLayer(nn.Module):
619
  deterministic: bool = True,
620
  ) -> Tuple[jnp.ndarray]:
621
 
 
 
 
622
  res_gain = (
623
  deepnet_gain["encoder"]["alpha"](self.config)
624
  if self.config.use_deepnet_scaling
@@ -679,12 +682,8 @@ class FlaxBartEncoderLayer(nn.Module):
679
  )
680
  hidden_states = ff_block(hidden_states, deterministic=deterministic)
681
  hidden_states = residual * res_gain + hidden_states
682
- if self.add_norm or self.config.ln_positions in ["postln"]:
683
- use_scale = (
684
- self.use_scale
685
- or self.config.ln_positions == "postln"
686
- or self.config.force_ln_scale
687
- )
688
  hidden_states = norm(
689
  self.config.ln_type,
690
  dtype=self.dtype,
@@ -697,6 +696,9 @@ class FlaxBartEncoderLayer(nn.Module):
697
  if output_attentions:
698
  outputs += (attn_weights,)
699
 
 
 
 
700
  return outputs
701
 
702
 
@@ -710,7 +712,7 @@ class FlaxBartDecoderLayer(nn.Module):
710
  config: DalleBartConfig
711
  dtype: jnp.dtype = jnp.float32
712
  add_norm: bool = False
713
- use_scale: bool = False
714
 
715
  @nn.compact
716
  def __call__(
@@ -724,6 +726,9 @@ class FlaxBartDecoderLayer(nn.Module):
724
  deterministic: bool = True,
725
  ) -> Tuple[jnp.ndarray]:
726
 
 
 
 
727
  res_gain = (
728
  deepnet_gain["decoder"]["alpha"](self.config)
729
  if self.config.use_deepnet_scaling
@@ -831,12 +836,8 @@ class FlaxBartDecoderLayer(nn.Module):
831
  )
832
  hidden_states = ff_block(hidden_states, deterministic=deterministic)
833
  hidden_states = residual * res_gain + hidden_states
834
- if self.add_norm or self.config.ln_positions in ["postln"]:
835
- use_scale = (
836
- self.use_scale
837
- or self.config.ln_positions == "postln"
838
- or self.config.force_ln_scale
839
- )
840
  hidden_states = norm(
841
  self.config.ln_type,
842
  dtype=self.dtype,
@@ -849,6 +850,9 @@ class FlaxBartDecoderLayer(nn.Module):
849
  if output_attentions:
850
  outputs += (attn_weights, cross_attn_weights)
851
 
 
 
 
852
  return outputs
853
 
854
 
@@ -876,35 +880,80 @@ class FlaxBartEncoderLayerCollection(nn.Module):
876
 
877
  n_layers = self.config.encoder_layers
878
  layer = (
879
- remat(FlaxBartEncoderLayer, static_argnums=(2, 3))
 
 
 
 
880
  if self.config.gradient_checkpointing
881
  else FlaxBartEncoderLayer
882
  )
883
- for i in range(n_layers):
884
- if output_hidden_states:
885
- all_hidden_states += (hidden_states,)
886
- # final layernorm on the output of the last layer
887
- # or every 6 layers for Swin v2
888
- add_norm = (
889
- self.config.ln_positions == "swinv2" and ((i + 1) % 6 == 0)
890
- ) or (self.config.use_final_ln_encoder and (i == n_layers - 1))
891
- # we don't need to scale the norm for the last layer
892
- use_scale = i != n_layers - 1
893
- layer_outputs = layer(
894
- self.config, dtype=self.dtype, add_norm=add_norm, use_scale=use_scale
 
 
 
 
 
 
895
  )(
896
  hidden_states,
897
  attention_mask,
898
  output_attentions,
899
  deterministic,
900
  )
901
- hidden_states = layer_outputs[0]
902
- if output_attentions:
903
- all_self_attns += (layer_outputs[1],)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
904
 
905
- # add hidden states from the last layer
906
- if output_hidden_states:
907
- all_hidden_states += (hidden_states,)
 
 
 
 
 
 
 
 
 
908
 
909
  outputs = [
910
  hidden_states,
@@ -953,22 +1002,39 @@ class FlaxBartDecoderLayerCollection(nn.Module):
953
 
954
  n_layers = self.config.decoder_layers
955
  layer = (
956
- remat(FlaxBartDecoderLayer, static_argnums=(4, 5, 6))
 
 
 
 
957
  if self.config.gradient_checkpointing
958
  else FlaxBartDecoderLayer
959
  )
960
- for i in range(n_layers):
961
- if output_hidden_states:
962
- all_hidden_states += (hidden_states,)
963
- # final layernorm on the output of the last layer
964
- # or every 6 layers for Swin v2
965
- add_norm = (
966
- self.config.ln_positions == "swinv2" and ((i + 1) % 6 == 0)
967
- ) or (self.config.use_final_ln_decoder and (i == n_layers - 1))
968
- # we don't need to scale the norm for the last layer
969
- use_scale = i != n_layers - 1
970
- layer_outputs = layer(
971
- self.config, dtype=self.dtype, add_norm=add_norm, use_scale=use_scale
 
 
 
 
 
 
 
 
 
 
 
 
 
972
  )(
973
  hidden_states,
974
  attention_mask,
@@ -978,17 +1044,56 @@ class FlaxBartDecoderLayerCollection(nn.Module):
978
  output_attentions,
979
  deterministic,
980
  )
 
981
 
982
- hidden_states = layer_outputs[0]
983
- if output_attentions:
984
- all_self_attns += (layer_outputs[1],)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
985
 
986
- if encoder_hidden_states is not None:
987
- all_cross_attentions += (layer_outputs[2],)
 
988
 
989
- # add hidden states from the last decoder layer
990
- if output_hidden_states:
991
- all_hidden_states += (hidden_states,)
 
 
 
 
 
992
 
993
  outputs = [
994
  hidden_states,
 
619
  deterministic: bool = True,
620
  ) -> Tuple[jnp.ndarray]:
621
 
622
+ if self.config.use_scan:
623
+ hidden_states = hidden_states[0]
624
+
625
  res_gain = (
626
  deepnet_gain["encoder"]["alpha"](self.config)
627
  if self.config.use_deepnet_scaling
 
682
  )
683
  hidden_states = ff_block(hidden_states, deterministic=deterministic)
684
  hidden_states = residual * res_gain + hidden_states
685
+ if self.add_norm:
686
+ use_scale = self.use_scale or self.config.force_ln_scale
 
 
 
 
687
  hidden_states = norm(
688
  self.config.ln_type,
689
  dtype=self.dtype,
 
696
  if output_attentions:
697
  outputs += (attn_weights,)
698
 
699
+ if self.config.use_scan:
700
+ outputs = (outputs, None)
701
+
702
  return outputs
703
 
704
 
 
712
  config: DalleBartConfig
713
  dtype: jnp.dtype = jnp.float32
714
  add_norm: bool = False
715
+ use_scale: bool = True
716
 
717
  @nn.compact
718
  def __call__(
 
726
  deterministic: bool = True,
727
  ) -> Tuple[jnp.ndarray]:
728
 
729
+ if self.config.use_scan:
730
+ hidden_states = hidden_states[0]
731
+
732
  res_gain = (
733
  deepnet_gain["decoder"]["alpha"](self.config)
734
  if self.config.use_deepnet_scaling
 
836
  )
837
  hidden_states = ff_block(hidden_states, deterministic=deterministic)
838
  hidden_states = residual * res_gain + hidden_states
839
+ if self.add_norm:
840
+ use_scale = self.use_scale or self.config.force_ln_scale
 
 
 
 
841
  hidden_states = norm(
842
  self.config.ln_type,
843
  dtype=self.dtype,
 
850
  if output_attentions:
851
  outputs += (attn_weights, cross_attn_weights)
852
 
853
+ if self.config.use_scan:
854
+ outputs = (outputs, None)
855
+
856
  return outputs
857
 
858
 
 
880
 
881
  n_layers = self.config.encoder_layers
882
  layer = (
883
+ remat(
884
+ FlaxBartEncoderLayer,
885
+ static_argnums=(2, 3),
886
+ prevent_cse=not self.config.use_scan,
887
+ )
888
  if self.config.gradient_checkpointing
889
  else FlaxBartEncoderLayer
890
  )
891
+
892
+ if self.config.use_scan:
893
+ # all blocks are the same so we use nn.scan
894
+ assert not output_attentions, "cannot scan with output_attentions"
895
+ assert not output_hidden_states, "cannot scan with output_hidden_states"
896
+ hidden_states = (hidden_states,)
897
+ # we use a scale on all norms (even last layer) to allow scanning
898
+ hidden_states, _ = nn.scan(
899
+ layer,
900
+ variable_axes={"params": 0},
901
+ split_rngs={"params": True, "dropout": True},
902
+ in_axes=(nn.broadcast, nn.broadcast, nn.broadcast),
903
+ length=n_layers,
904
+ )(
905
+ self.config,
906
+ dtype=self.dtype,
907
+ add_norm=self.config.ln_positions == "postln",
908
+ name="FlaxBartEncoderLayers",
909
  )(
910
  hidden_states,
911
  attention_mask,
912
  output_attentions,
913
  deterministic,
914
  )
915
+ hidden_states = hidden_states[0]
916
+ else:
917
+ for i in range(n_layers):
918
+ if output_hidden_states:
919
+ all_hidden_states += (hidden_states,)
920
+ # final layernorm on the output of the last layer
921
+ # or every 6 layers for Swin v2
922
+ add_norm = self.config.ln_positions == "postln" or (
923
+ self.config.ln_positions == "swinv2"
924
+ and ((i + 1) % 6 == 0)
925
+ and (i != n_layers - 1)
926
+ )
927
+ # we don't need to scale the norm for the last layer
928
+ use_scale = i != n_layers - 1
929
+ layer_outputs = layer(
930
+ self.config,
931
+ dtype=self.dtype,
932
+ add_norm=add_norm,
933
+ use_scale=use_scale,
934
+ name=f"FlaxBartEncoderLayer_{i}",
935
+ )(
936
+ hidden_states,
937
+ attention_mask,
938
+ output_attentions,
939
+ deterministic,
940
+ )
941
+ hidden_states = layer_outputs[0]
942
+ if output_attentions:
943
+ all_self_attns += (layer_outputs[1],)
944
 
945
+ # add hidden states from the last layer
946
+ if output_hidden_states:
947
+ all_hidden_states += (hidden_states,)
948
+
949
+ # postln is already applied in every layer
950
+ if self.config.use_final_ln_encoder and self.config.ln_positions != "postln":
951
+ hidden_states = norm(
952
+ self.config.ln_type,
953
+ dtype=self.dtype,
954
+ epsilon=1e-05,
955
+ use_scale=self.config.force_ln_scale,
956
+ )(hidden_states)
957
 
958
  outputs = [
959
  hidden_states,
 
1002
 
1003
  n_layers = self.config.decoder_layers
1004
  layer = (
1005
+ remat(
1006
+ FlaxBartDecoderLayer,
1007
+ static_argnums=(4, 5, 6),
1008
+ prevent_cse=not self.config.use_scan,
1009
+ )
1010
  if self.config.gradient_checkpointing
1011
  else FlaxBartDecoderLayer
1012
  )
1013
+
1014
+ if self.config.use_scan:
1015
+ # all blocks are the same so we use nn.scan
1016
+ assert not output_attentions, "cannot scan with output_attentions"
1017
+ assert not output_hidden_states, "cannot scan with output_hidden_states"
1018
+ hidden_states = (hidden_states,)
1019
+ # we use a scale on all norms (even last layer) to allow scanning
1020
+ hidden_states, _ = nn.scan(
1021
+ layer,
1022
+ variable_axes={"params": 0},
1023
+ split_rngs={"params": True, "dropout": True},
1024
+ in_axes=(
1025
+ nn.broadcast,
1026
+ nn.broadcast,
1027
+ nn.broadcast,
1028
+ nn.broadcast,
1029
+ nn.broadcast,
1030
+ nn.broadcast,
1031
+ ),
1032
+ length=n_layers,
1033
+ )(
1034
+ self.config,
1035
+ dtype=self.dtype,
1036
+ add_norm=self.config.ln_positions == "postln",
1037
+ name="FlaxBartEncoderLayers",
1038
  )(
1039
  hidden_states,
1040
  attention_mask,
 
1044
  output_attentions,
1045
  deterministic,
1046
  )
1047
+ hidden_states = hidden_states[0]
1048
 
1049
+ else:
1050
+ for i in range(n_layers):
1051
+ if output_hidden_states:
1052
+ all_hidden_states += (hidden_states,)
1053
+ # final layernorm on the output of the last layer
1054
+ # or every 6 layers for Swin v2
1055
+ add_norm = self.config.ln_positions == "postln" or (
1056
+ self.config.ln_positions == "swinv2"
1057
+ and ((i + 1) % 6 == 0)
1058
+ and (i != n_layers - 1)
1059
+ )
1060
+ # we don't need to scale the norm for the last layer
1061
+ use_scale = i != n_layers - 1
1062
+ layer_outputs = layer(
1063
+ self.config,
1064
+ dtype=self.dtype,
1065
+ add_norm=add_norm,
1066
+ use_scale=use_scale,
1067
+ name=f"FlaxBartDecoderLayer_{i}",
1068
+ )(
1069
+ hidden_states,
1070
+ attention_mask,
1071
+ encoder_hidden_states,
1072
+ encoder_attention_mask,
1073
+ init_cache,
1074
+ output_attentions,
1075
+ deterministic,
1076
+ )
1077
+
1078
+ hidden_states = layer_outputs[0]
1079
+ if output_attentions:
1080
+ all_self_attns += (layer_outputs[1],)
1081
+
1082
+ if encoder_hidden_states is not None:
1083
+ all_cross_attentions += (layer_outputs[2],)
1084
 
1085
+ # add hidden states from the last decoder layer
1086
+ if output_hidden_states:
1087
+ all_hidden_states += (hidden_states,)
1088
 
1089
+ # postln is already applied in every layer
1090
+ if self.config.use_final_ln_decoder and self.config.ln_positions != "postln":
1091
+ hidden_states = norm(
1092
+ self.config.ln_type,
1093
+ dtype=self.dtype,
1094
+ epsilon=1e-05,
1095
+ use_scale=self.config.force_ln_scale,
1096
+ )(hidden_states)
1097
 
1098
  outputs = [
1099
  hidden_states,
src/dalle_mini/model/partitions.py CHANGED
@@ -55,7 +55,7 @@ def _get_partition_rules():
55
  ]
56
 
57
 
58
- def set_partitions(in_dict):
59
  rules = _get_partition_rules()
60
  replace = _replacement_rules(rules)
61
  initd = {k: _unmatched for k in flatten_dict(in_dict)}
@@ -63,5 +63,14 @@ def set_partitions(in_dict):
63
  for k, v in result.items():
64
  if v == _unmatched:
65
  print(f"Unmatched -> {k}")
 
 
 
 
 
 
 
 
 
66
  assert _unmatched not in result.values(), "Incomplete partition spec."
67
  return freeze(unflatten_dict(result))
 
55
  ]
56
 
57
 
58
+ def set_partitions(in_dict, use_scan):
59
  rules = _get_partition_rules()
60
  replace = _replacement_rules(rules)
61
  initd = {k: _unmatched for k in flatten_dict(in_dict)}
 
63
  for k, v in result.items():
64
  if v == _unmatched:
65
  print(f"Unmatched -> {k}")
66
+ l = list(result.keys())
67
+ if use_scan:
68
+ # add None dimension to scanned layers
69
+ result = {
70
+ k: (P(*(None,) + v) if v is not None else None)
71
+ if any(x in k for x in ["FlaxBartEncoderLayers", "FlaxBartDecoderLayers"])
72
+ else v
73
+ for k, v in result.items()
74
+ }
75
  assert _unmatched not in result.values(), "Incomplete partition spec."
76
  return freeze(unflatten_dict(result))
tools/train/config/mega/config.json CHANGED
@@ -7,14 +7,14 @@
7
  "decoder_attention_heads": 32,
8
  "decoder_ffn_dim": 4096,
9
  "decoder_layerdrop": 0.0,
10
- "decoder_layers": 25,
11
  "decoder_start_token_id": 16384,
12
  "do_sample": true,
13
  "dropout": 0.0,
14
  "encoder_attention_heads": 32,
15
  "encoder_ffn_dim": 4096,
16
  "encoder_layerdrop": 0.0,
17
- "encoder_layers": 25,
18
  "encoder_vocab_size": 50272,
19
  "eos_token_id": 16385,
20
  "force_ln_scale": false,
 
7
  "decoder_attention_heads": 32,
8
  "decoder_ffn_dim": 4096,
9
  "decoder_layerdrop": 0.0,
10
+ "decoder_layers": 26,
11
  "decoder_start_token_id": 16384,
12
  "do_sample": true,
13
  "dropout": 0.0,
14
  "encoder_attention_heads": 32,
15
  "encoder_ffn_dim": 4096,
16
  "encoder_layerdrop": 0.0,
17
+ "encoder_layers": 26,
18
  "encoder_vocab_size": 50272,
19
  "eos_token_id": 16385,
20
  "force_ln_scale": false,
tools/train/train.py CHANGED
@@ -42,6 +42,7 @@ from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
42
  from flax.serialization import from_bytes, to_bytes
43
  from flax.training import train_state
44
  from flax.training.common_utils import onehot
 
45
  from jax.experimental import PartitionSpec, maps
46
  from jax.experimental.compilation_cache import compilation_cache as cc
47
  from jax.experimental.pjit import pjit, with_sharding_constraint
@@ -531,6 +532,54 @@ class TrainState(train_state.TrainState):
531
  train_time: float = 0.0 # total time the model trained
532
  train_samples: int = 0 # number of samples seen
533
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
534
 
535
  def main():
536
  # See all possible arguments by passing the --help flag to this script.
@@ -618,7 +667,7 @@ def main():
618
  model_metadata = model_args.get_metadata()
619
 
620
  # get PartitionSpec for model params (required to be a dict)
621
- param_spec = set_partitions(model.params)
622
 
623
  # convert params to frozen dict
624
  model._params = freeze(model.params)
@@ -743,6 +792,23 @@ def main():
743
 
744
  learning_rate_fn = create_learning_rate_fn()
745
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
746
  # create adam optimizer
747
  if training_args.optim == "distributed_shampoo":
748
  # parameters from https://github.com/tensorflow/lingvo/blob/03ee9d7cd50764b0424c7c863733c91fc0b053ec/lingvo/jax/optimizers.py#L729
@@ -795,10 +861,12 @@ def main():
795
  )
796
  # get the real optimizer and helper functions
797
  update_fn = optimizer.update
798
- optimizer = optimizer.init(model.params)
 
799
  opt_fn = NamedTuple("opt_fn", pspec_fn=Any, shape_and_dtype_fn=Any)(
800
  optimizer.pspec_fn, optimizer.shape_and_dtype_fn
801
  )
 
802
  optimizer = optax.GradientTransformation(optimizer.init_fn, update_fn)
803
 
804
  elif training_args.optim == "adam":
@@ -819,7 +887,7 @@ def main():
819
  # get PartitionSpec for optimizer state
820
  def get_opt_state_spec_and_shape(param_spec):
821
  # get opt_state shape without actual init
822
- opt_state_shape = jax.eval_shape(optimizer.init, model.params)
823
 
824
  if training_args.optim == "adam":
825
 
@@ -844,7 +912,7 @@ def main():
844
 
845
  elif training_args.optim == "distributed_shampoo":
846
  opt_state_spec = opt_fn.pspec_fn(
847
- params=model.params,
848
  params_partition_spec=param_spec,
849
  partition_spec_for_statistics=PartitionSpec(None, "dp", None),
850
  )
@@ -852,7 +920,7 @@ def main():
852
  raise NotImplementedError
853
  return opt_state_spec, opt_state_shape
854
 
855
- opt_state_spec, opt_state_shape = get_opt_state_spec_and_shape(param_spec)
856
 
857
  # create a mesh
858
  mesh_shape = (training_args.dp_devices, training_args.mp_devices)
 
42
  from flax.serialization import from_bytes, to_bytes
43
  from flax.training import train_state
44
  from flax.training.common_utils import onehot
45
+ from jax import ShapeDtypeStruct
46
  from jax.experimental import PartitionSpec, maps
47
  from jax.experimental.compilation_cache import compilation_cache as cc
48
  from jax.experimental.pjit import pjit, with_sharding_constraint
 
532
  train_time: float = 0.0 # total time the model trained
533
  train_samples: int = 0 # number of samples seen
534
 
535
+ def apply_gradients(self, *, grads, **kwargs):
536
+ params = self.unscan(self.params)
537
+ updates, new_opt_state = self.tx.update(
538
+ self.unscan(grads), self.opt_state, params
539
+ )
540
+ params = optax.apply_updates(params, updates)
541
+ return self.replace(
542
+ step=self.step + 1,
543
+ params=self.rescan(params),
544
+ opt_state=new_opt_state,
545
+ **kwargs,
546
+ )
547
+
548
+ @classmethod
549
+ def create(cls, *, apply_fn, params, tx, **kwargs):
550
+ opt_state = tx.init(cls.unscan(params))
551
+ return cls(
552
+ step=0,
553
+ apply_fn=apply_fn,
554
+ params=params,
555
+ tx=tx,
556
+ opt_state=opt_state,
557
+ **kwargs,
558
+ )
559
+
560
+ @staticmethod
561
+ def unscan(params):
562
+ params = unfreeze(params)
563
+ for l in ["encoder", "decoder"]:
564
+ params["model"][l]["layers"] = jax.tree_map(
565
+ lambda x: {f"{i}": x[i] for i in range(len(x))},
566
+ params["model"][l]["layers"],
567
+ )
568
+ params = freeze(params)
569
+ return params
570
+
571
+ @staticmethod
572
+ def rescan(params):
573
+ params = unfreeze(params)
574
+ for l in ["encoder", "decoder"]:
575
+ params["model"][l]["layers"] = jax.tree_map(
576
+ lambda x: jnp.stack([x[f"{i}"] for i in range(len(x))]),
577
+ params["model"][l]["layers"],
578
+ is_leaf=lambda x: "0" in x,
579
+ )
580
+ params = freeze(params)
581
+ return params
582
+
583
 
584
  def main():
585
  # See all possible arguments by passing the --help flag to this script.
 
667
  model_metadata = model_args.get_metadata()
668
 
669
  # get PartitionSpec for model params (required to be a dict)
670
+ param_spec = set_partitions(model.params, model.config.use_scan)
671
 
672
  # convert params to frozen dict
673
  model._params = freeze(model.params)
 
792
 
793
  learning_rate_fn = create_learning_rate_fn()
794
 
795
+ # reshape params to split scanned layers for optimizers
796
+ if model.config.use_scan:
797
+ params_struct = unfreeze(model.params)
798
+ for l in ["encoder", "decoder"]:
799
+ params_struct["model"][l]["layers"] = jax.tree_map(
800
+ lambda x: {
801
+ f"{i}": ShapeDtypeStruct(shape=x.shape[1:], dtype=x.dtype)
802
+ for i in range(len(x))
803
+ },
804
+ params_struct["model"][l]["layers"],
805
+ )
806
+ params_struct = freeze(params_struct)
807
+
808
+ else:
809
+ params_struct = model.params
810
+ opt_param_spec = set_partitions(params_struct, False)
811
+
812
  # create adam optimizer
813
  if training_args.optim == "distributed_shampoo":
814
  # parameters from https://github.com/tensorflow/lingvo/blob/03ee9d7cd50764b0424c7c863733c91fc0b053ec/lingvo/jax/optimizers.py#L729
 
861
  )
862
  # get the real optimizer and helper functions
863
  update_fn = optimizer.update
864
+
865
+ optimizer = optimizer.init(params_struct)
866
  opt_fn = NamedTuple("opt_fn", pspec_fn=Any, shape_and_dtype_fn=Any)(
867
  optimizer.pspec_fn, optimizer.shape_and_dtype_fn
868
  )
869
+
870
  optimizer = optax.GradientTransformation(optimizer.init_fn, update_fn)
871
 
872
  elif training_args.optim == "adam":
 
887
  # get PartitionSpec for optimizer state
888
  def get_opt_state_spec_and_shape(param_spec):
889
  # get opt_state shape without actual init
890
+ opt_state_shape = jax.eval_shape(optimizer.init, params_struct)
891
 
892
  if training_args.optim == "adam":
893
 
 
912
 
913
  elif training_args.optim == "distributed_shampoo":
914
  opt_state_spec = opt_fn.pspec_fn(
915
+ params=params_struct,
916
  params_partition_spec=param_spec,
917
  partition_spec_for_statistics=PartitionSpec(None, "dp", None),
918
  )
 
920
  raise NotImplementedError
921
  return opt_state_spec, opt_state_shape
922
 
923
+ opt_state_spec, opt_state_shape = get_opt_state_spec_and_shape(opt_param_spec)
924
 
925
  # create a mesh
926
  mesh_shape = (training_args.dp_devices, training_args.mp_devices)