Faisal AlKhateeb commited on
Commit
2f32550
·
1 Parent(s): 68be314

added support for position interpolation

Browse files
Files changed (3) hide show
  1. README.md +29 -0
  2. configuration_btlm.py +36 -0
  3. modeling_btlm.py +14 -3
README.md CHANGED
@@ -162,6 +162,35 @@ Ensure the following muP parameters are passed in your config, otherwise your mo
162
  - `mup_output_alpha: <float>`
163
  - `mup_scale_qk_dot_by_d: true`
164
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  ## Uses and Limitations
166
 
167
  ### Intended Use
 
162
  - `mup_output_alpha: <float>`
163
  - `mup_scale_qk_dot_by_d: true`
164
 
165
+ ## To extend the context length with Position Interpolation
166
+
167
+ ### During inference (without fine-tuning):
168
+ It's possible to extend the context length to 2x the training context length without degradation in performance using dynamic linear scaling. Dynamic linear scaling adjusts the slopes of ALiBi with a factor of `input_seq_len/train_seq_len` when `input_seq_len` is larger than `train_seq_len`. Check the details in our paper [Position Interpolation Improves ALiBi Extrapolation](https://arxiv.org/abs/2310.13017). To enable dynamic linear scaling, update `config.json` as follows:
169
+ ```json
170
+ # update `n_positions` with the maximum context length will be
171
+ # encountered during inference (e.g. 16384 tokens)
172
+ "n_positions": 16384,
173
+
174
+ # specify `train_seq_len` in `alibi_scaling` parameter
175
+ "alibi_scaling": {
176
+ "type": "linear",
177
+ "train_seq_len": 8192
178
+ }
179
+ ```
180
+
181
+ ### Using fine-tuning + position interpolation:
182
+ Performing fine-tuning with position interpolation can help achieve greater extrapolation lengths. The scaling factor should be fixed to `finetuning_seq_len/train_seq_len`. To enable fixed linear scaling, update `config.json` as follows:
183
+ ```json
184
+ # update `n_positions` with the fine-tuning context length (e.g. 32768 tokens)
185
+ "n_positions": 32768,
186
+
187
+ # specify the scaling `factor` in `alibi_scaling` parameter
188
+ "alibi_scaling": {
189
+ "type": "linear",
190
+ "factor": 4.0
191
+ }
192
+ ```
193
+
194
  ## Uses and Limitations
195
 
196
  ### Intended Use
configuration_btlm.py CHANGED
@@ -84,6 +84,12 @@ class BTLMConfig(PretrainedConfig):
84
  mup_scale_qk_dot_by_d (`bool`, *optional*, defaults to `False`):
85
  Scale attention weights by dividing by hidden_size instead of sqrt(hidden_size). Need to set
86
  scale_attn_weights to `True` as well.
 
 
 
 
 
 
87
 
88
  Example:
89
 
@@ -134,6 +140,7 @@ class BTLMConfig(PretrainedConfig):
134
  mup_embeddings_scale=1.0,
135
  mup_output_alpha=1.0,
136
  mup_scale_qk_dot_by_d=False,
 
137
  **kwargs,
138
  ):
139
  self.vocab_size = vocab_size
@@ -162,4 +169,33 @@ class BTLMConfig(PretrainedConfig):
162
  self.mup_output_alpha = mup_output_alpha
163
  self.mup_scale_qk_dot_by_d = mup_scale_qk_dot_by_d
164
 
 
 
 
165
  super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  mup_scale_qk_dot_by_d (`bool`, *optional*, defaults to `False`):
85
  Scale attention weights by dividing by hidden_size instead of sqrt(hidden_size). Need to set
86
  scale_attn_weights to `True` as well.
87
+ alibi_scaling (`Dict`, *optional*):
88
+ Dictionary containing the scaling configuration for ALiBi embeddings. Currently only supports linear
89
+ scaling strategy. Can specify either the scaling `factor` (must be a float greater than 1) for fixed scaling
90
+ or `train_seq_len` for dynamic scaling on input samples with sequence length > `train_seq_len`. The expected
91
+ formats are `{"type": strategy name, "factor": scaling factor}` or
92
+ `{"type": strategy name, "train_seq_len": training sequence length}`.
93
 
94
  Example:
95
 
 
140
  mup_embeddings_scale=1.0,
141
  mup_output_alpha=1.0,
142
  mup_scale_qk_dot_by_d=False,
143
+ alibi_scaling=None,
144
  **kwargs,
145
  ):
146
  self.vocab_size = vocab_size
 
169
  self.mup_output_alpha = mup_output_alpha
170
  self.mup_scale_qk_dot_by_d = mup_scale_qk_dot_by_d
171
 
172
+ self.alibi_scaling = alibi_scaling
173
+ self._alibi_scaling_validation()
174
+
175
  super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
176
+
177
+ def _alibi_scaling_validation(self):
178
+ """
179
+ Validate the `alibi_scaling` configuration.
180
+ """
181
+ if self.alibi_scaling is None:
182
+ return
183
+
184
+ if not isinstance(self.alibi_scaling, dict) or len(self.alibi_scaling) != 2:
185
+ raise ValueError(
186
+ "`alibi_scaling` must be a dictionary with two fields, `type` and `factor` or `type` and `train_seq_len`, "
187
+ f"got {self.alibi_scaling}"
188
+ )
189
+ alibi_scaling_type = self.alibi_scaling.get("type", None)
190
+ alibi_scaling_factor = self.alibi_scaling.get("factor", None)
191
+ alibi_dynamic_scaling = self.alibi_scaling.get("train_seq_len", None)
192
+ if alibi_scaling_type is None or alibi_scaling_type != "linear":
193
+ raise ValueError(
194
+ f"`alibi_scaling`'s type field must be 'linear', got {alibi_scaling_type}"
195
+ )
196
+ if alibi_scaling_factor is not None:
197
+ if not isinstance(alibi_scaling_factor, float) or alibi_scaling_factor <= 1.0:
198
+ raise ValueError(f"`alibi_scaling`'s factor field must be a float > 1.0, got {alibi_scaling_factor}")
199
+ if alibi_dynamic_scaling is not None:
200
+ if not isinstance(alibi_dynamic_scaling, int) or alibi_dynamic_scaling <= 1:
201
+ raise ValueError(f"`alibi_scaling`'s `train_seq_len` field must be an integer > 1, got {alibi_dynamic_scaling}")
modeling_btlm.py CHANGED
@@ -63,10 +63,11 @@ class SwiGLUActivation(nn.Module):
63
 
64
 
65
  class AlibiPositionEmbeddingLayer(nn.Module):
66
- def __init__(self, num_heads):
67
  super(AlibiPositionEmbeddingLayer, self).__init__()
68
 
69
  self.num_heads = num_heads
 
70
  slopes = torch.tensor(AlibiPositionEmbeddingLayer._get_alibi_slopes(num_heads)).unsqueeze(-1)
71
  self.slopes = nn.parameter.Parameter(slopes, requires_grad=False)
72
 
@@ -84,7 +85,17 @@ class AlibiPositionEmbeddingLayer(nn.Module):
84
  )[None, :]
85
  relative_position = memory_position - context_position
86
  relative_position = torch.abs(relative_position).unsqueeze(0).expand(self.num_heads, -1, -1)
87
- alibi = (self.slopes * -1.0).unsqueeze(1) * relative_position
 
 
 
 
 
 
 
 
 
 
88
  return alibi
89
 
90
  @staticmethod
@@ -766,7 +777,7 @@ class BTLMModel(BTLMPreTrainedModel):
766
  self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
767
 
768
  self.relative_pe = (
769
- AlibiPositionEmbeddingLayer(config.num_attention_heads)
770
  if config.position_embedding_type == "alibi"
771
  else None
772
  )
 
63
 
64
 
65
  class AlibiPositionEmbeddingLayer(nn.Module):
66
+ def __init__(self, num_heads, alibi_scaling=None):
67
  super(AlibiPositionEmbeddingLayer, self).__init__()
68
 
69
  self.num_heads = num_heads
70
+ self.alibi_scaling = alibi_scaling
71
  slopes = torch.tensor(AlibiPositionEmbeddingLayer._get_alibi_slopes(num_heads)).unsqueeze(-1)
72
  self.slopes = nn.parameter.Parameter(slopes, requires_grad=False)
73
 
 
85
  )[None, :]
86
  relative_position = memory_position - context_position
87
  relative_position = torch.abs(relative_position).unsqueeze(0).expand(self.num_heads, -1, -1)
88
+
89
+ if self.alibi_scaling is None:
90
+ scale = 1.0
91
+ elif self.alibi_scaling.get("factor") is not None:
92
+ scale = self.alibi_scaling["factor"]
93
+ elif relative_position.shape[-1] > self.alibi_scaling["train_seq_len"]:
94
+ scale = relative_position.shape[-1] / self.alibi_scaling["train_seq_len"]
95
+ else:
96
+ scale = 1.0
97
+
98
+ alibi = (self.slopes / -scale).unsqueeze(1) * relative_position
99
  return alibi
100
 
101
  @staticmethod
 
777
  self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
778
 
779
  self.relative_pe = (
780
+ AlibiPositionEmbeddingLayer(config.num_attention_heads, config.alibi_scaling)
781
  if config.position_embedding_type == "alibi"
782
  else None
783
  )