LeroyDyer commited on
Commit
18010a8
·
verified ·
1 Parent(s): f99fffe

Upload 3 files

Browse files
Files changed (2) hide show
  1. configuration_mistral.py +147 -6
  2. modeling_mistral.py +396 -366
configuration_mistral.py CHANGED
@@ -12,7 +12,7 @@
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
- """ Mistral model configuration"""
16
 
17
  from ...configuration_utils import PretrainedConfig
18
  from ...utils import logging
@@ -20,11 +20,6 @@ from ...utils import logging
20
 
21
  logger = logging.get_logger(__name__)
22
 
23
- MISTRAL_PRETRAINED_CONFIG_ARCHIVE_MAP = {
24
- "mistralai/Mistral-7B-v0.1": "https://huggingface.co/mistralai/Mistral-7B-v0.1/resolve/main/config.json",
25
- "mistralai/Mistral-7B-Instruct-v0.1": "https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/resolve/main/config.json",
26
- }
27
-
28
 
29
  class MistralConfig(PretrainedConfig):
30
  r"""
@@ -163,6 +158,152 @@ class MistralConfig(PretrainedConfig):
163
  self.use_complex_talk_head = use_complex_talk_head
164
  self.use_weighted_talk_head = use_weighted_talk_head
165
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
  super().__init__(
167
  pad_token_id=pad_token_id,
168
  bos_token_id=bos_token_id,
 
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
+ """Mistral model configuration"""
16
 
17
  from ...configuration_utils import PretrainedConfig
18
  from ...utils import logging
 
20
 
21
  logger = logging.get_logger(__name__)
22
 
 
 
 
 
 
23
 
24
  class MistralConfig(PretrainedConfig):
25
  r"""
 
158
  self.use_complex_talk_head = use_complex_talk_head
159
  self.use_weighted_talk_head = use_weighted_talk_head
160
 
161
+ super().__init__(
162
+ pad_token_id=pad_token_id,
163
+ bos_token_id=bos_token_id,
164
+ eos_token_id=eos_token_id,
165
+ tie_word_embeddings=tie_word_embeddings,
166
+ **kwargs,
167
+ )
168
+ class MistralStarConfig(PretrainedConfig):
169
+ r"""
170
+ This is the configuration class to store the configuration of a [`MistralModel`]. It is used to instantiate an
171
+ Mistral model according to the specified arguments, defining the model architecture. Instantiating a configuration
172
+ with the defaults will yield a similar configuration to that of the Mistral-7B-v0.1 or Mistral-7B-Instruct-v0.1.
173
+
174
+ [mistralai/Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1)
175
+ [mistralai/Mistral-7B-Instruct-v0.1](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1)
176
+
177
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
178
+ documentation from [`PretrainedConfig`] for more information.
179
+
180
+
181
+ Args:
182
+ vocab_size (`int`, *optional*, defaults to 32000):
183
+ Vocabulary size of the Mistral model. Defines the number of different tokens that can be represented by the
184
+ `inputs_ids` passed when calling [`MistralModel`]
185
+ hidden_size (`int`, *optional*, defaults to 4096):
186
+ Dimension of the hidden representations.
187
+ intermediate_size (`int`, *optional*, defaults to 14336):
188
+ Dimension of the MLP representations.
189
+ num_hidden_layers (`int`, *optional*, defaults to 32):
190
+ Number of hidden layers in the Transformer encoder.
191
+ num_attention_heads (`int`, *optional*, defaults to 32):
192
+ Number of attention heads for each attention layer in the Transformer encoder.
193
+ num_key_value_heads (`int`, *optional*, defaults to 8):
194
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
195
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
196
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
197
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
198
+ by meanpooling all the original heads within that group. For more details checkout [this
199
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`.
200
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
201
+ The non-linear activation function (function or string) in the decoder.
202
+ max_position_embeddings (`int`, *optional*, defaults to `4096*32`):
203
+ The maximum sequence length that this model might ever be used with. Mistral's sliding window attention
204
+ allows sequence of up to 4096*32 tokens.
205
+ initializer_range (`float`, *optional*, defaults to 0.02):
206
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
207
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
208
+ The epsilon used by the rms normalization layers.
209
+ use_cache (`bool`, *optional*, defaults to `True`):
210
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
211
+ relevant if `config.is_decoder=True`.
212
+ pad_token_id (`int`, *optional*):
213
+ The id of the padding token.
214
+ bos_token_id (`int`, *optional*, defaults to 1):
215
+ The id of the "beginning-of-sequence" token.
216
+ eos_token_id (`int`, *optional*, defaults to 2):
217
+ The id of the "end-of-sequence" token.
218
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
219
+ Whether the model's input and output word embeddings should be tied.
220
+ rope_theta (`float`, *optional*, defaults to 10000.0):
221
+ The base period of the RoPE embeddings.
222
+ sliding_window (`int`, *optional*, defaults to 4096):
223
+ Sliding window attention window size. If not specified, will default to `4096`.
224
+ attention_dropout (`float`, *optional*, defaults to 0.0):
225
+ The dropout ratio for the attention probabilities.
226
+
227
+ ```python
228
+ >>> from transformers import MistralModel, MistralConfig
229
+
230
+ >>> # Initializing a Mistral 7B style configuration
231
+ >>> configuration = MistralConfig()
232
+
233
+ >>> # Initializing a model from the Mistral 7B style configuration
234
+ >>> model = MistralModel(configuration)
235
+
236
+ >>> # Accessing the model configuration
237
+ >>> configuration = model.config
238
+ ```"""
239
+
240
+ model_type = "mistralstar"
241
+ keys_to_ignore_at_inference = ["past_key_values"]
242
+
243
+ def __init__(
244
+ self,
245
+ vocab_size=32000,
246
+ hidden_size=4096,
247
+ intermediate_size=14336,
248
+ num_hidden_layers=32,
249
+ num_attention_heads=32,
250
+ num_key_value_heads=8,
251
+ hidden_act="silu",
252
+ max_position_embeddings=4096 * 32,
253
+ initializer_range=0.02,
254
+ rms_norm_eps=1e-6,
255
+ use_cache=True,
256
+ pad_token_id=None,
257
+ bos_token_id=1,
258
+ eos_token_id=2,
259
+ tie_word_embeddings=False,
260
+ rope_theta=10000.0,
261
+ sliding_window=4096,
262
+ attention_dropout=0.0,
263
+ max_thoughts=16,
264
+ thought_length = 10,
265
+ merged_talk_heads=True,
266
+ merged_lm_and_talk_heads=False,
267
+ merged_lm_and_think_heads=True,
268
+ use_concat_talk_head=True,
269
+ use_shallow_think=True,
270
+ use_shallow_talk=False,
271
+ use_complex_think_head=False,
272
+ use_complex_talk_head=True,
273
+ use_weighted_talk_head=True,
274
+ **kwargs,
275
+ ):
276
+ self.vocab_size = vocab_size
277
+ self.max_position_embeddings = max_position_embeddings
278
+ self.hidden_size = hidden_size
279
+ self.intermediate_size = intermediate_size
280
+ self.num_hidden_layers = num_hidden_layers
281
+ self.num_attention_heads = num_attention_heads
282
+ self.sliding_window = sliding_window
283
+
284
+ # for backward compatibility
285
+ if num_key_value_heads is None:
286
+ num_key_value_heads = num_attention_heads
287
+
288
+ self.num_key_value_heads = num_key_value_heads
289
+ self.hidden_act = hidden_act
290
+ self.initializer_range = initializer_range
291
+ self.rms_norm_eps = rms_norm_eps
292
+ self.use_cache = use_cache
293
+ self.rope_theta = rope_theta
294
+ self.attention_dropout = attention_dropout
295
+ self.max_thoughts = max_thoughts
296
+ self.thought_length = thought_length
297
+ self.merged_talk_heads = merged_talk_heads
298
+ self.merged_lm_and_talk_heads = merged_lm_and_talk_heads
299
+ self.merged_lm_and_think_heads = merged_lm_and_think_heads
300
+ self.use_concat_talk_head = use_concat_talk_head
301
+ self.use_shallow_think = use_shallow_think
302
+ self.use_shallow_talk = use_shallow_talk
303
+ self.use_complex_think_head = use_complex_think_head
304
+ self.use_complex_talk_head = use_complex_talk_head
305
+ self.use_weighted_talk_head = use_weighted_talk_head
306
+
307
  super().__init__(
308
  pad_token_id=pad_token_id,
309
  bos_token_id=bos_token_id,
modeling_mistral.py CHANGED
@@ -662,6 +662,267 @@ class MistralPreTrainedModel(PreTrainedModel):
662
  module.weight.data[module.padding_idx].zero_()
663
 
664
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
665
 
666
  ############################## LM Heads #################################
667
 
@@ -2016,390 +2277,129 @@ class MixtralSparseMoeBlock(nn.Module):
2016
  """ """
2017
  batch_size, sequence_length, hidden_dim = hidden_states.shape
2018
  if self.training and self.jitter_noise > 0:
2019
- hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise)
2020
- hidden_states = hidden_states.view(-1, hidden_dim)
2021
- # router_logits: (batch * sequence_length, n_experts)
2022
- router_logits = self.gate(hidden_states)
2023
-
2024
- routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
2025
- routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
2026
- routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
2027
- # we cast back to the input dtype
2028
- routing_weights = routing_weights.to(hidden_states.dtype)
2029
-
2030
- final_hidden_states = torch.zeros(
2031
- (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
2032
- )
2033
-
2034
- # One hot encode the selected experts to create an expert mask
2035
- # this will be used to easily index which expert is going to be sollicitated
2036
- expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
2037
-
2038
- # Loop over all available experts in the model and perform the computation on each expert
2039
- for expert_idx in range(self.num_experts):
2040
- expert_layer = self.experts[expert_idx]
2041
- idx, top_x = torch.where(expert_mask[expert_idx])
2042
-
2043
- # Index the correct hidden states and compute the expert hidden state for
2044
- # the current expert. We need to make sure to multiply the output hidden
2045
- # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
2046
- current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
2047
- current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
2048
-
2049
- # However `index_add_` only support torch tensors for indexing so we'll use
2050
- # the `top_x` tensor here.
2051
- final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
2052
- final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
2053
- return final_hidden_states, router_logits
2054
- class MixtralDecoderLayer(nn.Module):
2055
- def __init__(self, config: MixtralConfig, layer_idx: int):
2056
- super().__init__()
2057
- self.hidden_size = config.hidden_size
2058
-
2059
- self.self_attn = MISTRAL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
2060
- self.mlp = MistralMLP(config)
2061
- self.block_sparse_moe = MixtralSparseMoeBlock(config)
2062
- self.input_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
2063
- self.post_attention_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
2064
-
2065
- def forward(
2066
- self,
2067
- hidden_states: torch.Tensor,
2068
- attention_mask: Optional[torch.Tensor] = None,
2069
- position_ids: Optional[torch.LongTensor] = None,
2070
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
2071
- output_attentions: Optional[bool] = False,
2072
- output_router_logits: Optional[bool] = False,
2073
- use_cache: Optional[bool] = False,
2074
- cache_position: Optional[torch.LongTensor] = None,
2075
- **kwargs,
2076
- ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
2077
- """
2078
- Args:
2079
- hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
2080
- attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
2081
- `(batch, sequence_length)` where padding elements are indicated by 0.
2082
- past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
2083
- output_attentions (`bool`, *optional*):
2084
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
2085
- returned tensors for more detail.
2086
- output_router_logits (`bool`, *optional*):
2087
- Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
2088
- should not be returned during inference.
2089
- use_cache (`bool`, *optional*):
2090
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
2091
- (see `past_key_values`).
2092
- cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
2093
- Indices depicting the position of the input sequence tokens in the sequence.
2094
- kwargs (`dict`, *optional*):
2095
- Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
2096
- into the model
2097
- """
2098
-
2099
- residual = hidden_states
2100
-
2101
- hidden_states = self.input_layernorm(hidden_states)
2102
-
2103
- # Self Attention
2104
- hidden_states, self_attn_weights, present_key_value = self.self_attn(
2105
- hidden_states=hidden_states,
2106
- attention_mask=attention_mask,
2107
- position_ids=position_ids,
2108
- past_key_value=past_key_value,
2109
- output_attentions=output_attentions,
2110
- use_cache=use_cache,
2111
- cache_position=cache_position,
2112
- )
2113
- hidden_states = residual + hidden_states
2114
-
2115
- # Fully Connected
2116
- residual = hidden_states
2117
- hidden_states = self.post_attention_layernorm(hidden_states)
2118
- hidden_states, router_logits = self.block_sparse_moe(hidden_states)
2119
- hidden_states = residual + hidden_states
2120
-
2121
- # Fully Connected
2122
- residual = hidden_states
2123
- hidden_states = self.post_attention_layernorm(hidden_states)
2124
- hidden_states = self.mlp(hidden_states)
2125
- hidden_states = residual + hidden_states
2126
-
2127
- outputs = (hidden_states,)
2128
-
2129
- if output_attentions:
2130
- outputs += (self_attn_weights,)
2131
-
2132
- if use_cache:
2133
- outputs += (present_key_value,)
2134
-
2135
- if output_router_logits:
2136
- outputs += (router_logits,)
2137
-
2138
- return outputs
2139
-
2140
- ################################ closed COMPONENTS ################################
2141
-
2142
-
2143
- @add_start_docstrings(
2144
- "The bare Mistral Model outputting raw hidden-states without any specific head on top.",
2145
- MISTRAL_START_DOCSTRING,
2146
- )
2147
- class MistralModel(MistralPreTrainedModel):
2148
- """
2149
- Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MistralDecoderLayer`]
2150
-
2151
- Args:
2152
- config: MistralConfig
2153
- """
2154
-
2155
- def __init__(self, config: MistralConfig):
2156
- super().__init__(config)
2157
- self.padding_idx = config.pad_token_id
2158
- self.vocab_size = config.vocab_size
2159
-
2160
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
2161
- self.layers = nn.ModuleList(
2162
- [MistralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
2163
- )
2164
- self._attn_implementation = config._attn_implementation
2165
- self.norm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
2166
-
2167
- self.gradient_checkpointing = False
2168
- # Initialize weights and apply final processing
2169
- self.post_init()
2170
-
2171
- def get_input_embeddings(self):
2172
- return self.embed_tokens
2173
-
2174
- def set_input_embeddings(self, value):
2175
- self.embed_tokens = value
2176
-
2177
- @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
2178
- def forward(
2179
- self,
2180
- input_ids: torch.LongTensor = None,
2181
- attention_mask: Optional[torch.Tensor] = None,
2182
- position_ids: Optional[torch.LongTensor] = None,
2183
- past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
2184
- inputs_embeds: Optional[torch.FloatTensor] = None,
2185
- use_cache: Optional[bool] = None,
2186
- output_attentions: Optional[bool] = None,
2187
- output_hidden_states: Optional[bool] = None,
2188
- return_dict: Optional[bool] = None,
2189
- cache_position: Optional[torch.LongTensor] = None,
2190
- ) -> Union[Tuple, BaseModelOutputWithPast]:
2191
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
2192
- output_hidden_states = (
2193
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
2194
- )
2195
- use_cache = use_cache if use_cache is not None else self.config.use_cache
2196
-
2197
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
2198
-
2199
- # retrieve input_ids and inputs_embeds
2200
- if (input_ids is None) ^ (inputs_embeds is not None):
2201
- raise ValueError(
2202
- "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
2203
- )
2204
-
2205
- if self.gradient_checkpointing and self.training and use_cache:
2206
- logger.warning_once(
2207
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
2208
- )
2209
- use_cache = False
2210
-
2211
- if inputs_embeds is None:
2212
- inputs_embeds = self.embed_tokens(input_ids)
2213
-
2214
- return_legacy_cache = False
2215
- if use_cache and not isinstance(past_key_values, Cache):
2216
- past_key_values = DynamicCache.from_legacy_cache(past_key_values)
2217
- return_legacy_cache = True
2218
- logger.warning_once(
2219
- "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. "
2220
- "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)"
2221
- )
2222
-
2223
- if cache_position is None:
2224
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
2225
- cache_position = torch.arange(
2226
- past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
2227
- )
2228
-
2229
- if position_ids is None:
2230
- position_ids = cache_position.unsqueeze(0)
2231
-
2232
- causal_mask = self._update_causal_mask(
2233
- attention_mask, inputs_embeds, cache_position, past_key_values, use_cache, output_attentions
2234
- )
2235
 
2236
- hidden_states = inputs_embeds
 
 
 
 
2237
 
2238
- # decoder layers
2239
- all_hidden_states = () if output_hidden_states else None
2240
- all_self_attns = () if output_attentions else None
2241
- next_decoder_cache = None
2242
 
2243
- for decoder_layer in self.layers:
2244
- if output_hidden_states:
2245
- all_hidden_states += (hidden_states,)
2246
 
2247
- if self.gradient_checkpointing and self.training:
2248
- layer_outputs = self._gradient_checkpointing_func(
2249
- decoder_layer.__call__,
2250
- hidden_states,
2251
- causal_mask,
2252
- position_ids,
2253
- past_key_values,
2254
- output_attentions,
2255
- use_cache,
2256
- cache_position,
2257
- )
2258
- else:
2259
- layer_outputs = decoder_layer(
2260
- hidden_states,
2261
- attention_mask=causal_mask,
2262
- position_ids=position_ids,
2263
- past_key_value=past_key_values,
2264
- output_attentions=output_attentions,
2265
- use_cache=use_cache,
2266
- cache_position=cache_position,
2267
- )
2268
 
2269
- hidden_states = layer_outputs[0]
 
 
 
 
2270
 
2271
- if use_cache:
2272
- next_decoder_cache = layer_outputs[2 if output_attentions else 1]
 
 
 
 
 
 
 
2273
 
2274
- if output_attentions:
2275
- all_self_attns += (layer_outputs[1],)
 
 
 
2276
 
2277
- hidden_states = self.norm(hidden_states)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2278
 
2279
- # add hidden states from the last decoder layer
2280
- if output_hidden_states:
2281
- all_hidden_states += (hidden_states,)
2282
 
2283
- next_cache = next_decoder_cache if use_cache else None
2284
- if return_legacy_cache:
2285
- next_cache = next_cache.to_legacy_cache()
2286
 
2287
- if not return_dict:
2288
- return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
2289
- return BaseModelOutputWithPast(
2290
- last_hidden_state=hidden_states,
2291
- past_key_values=next_cache,
2292
- hidden_states=all_hidden_states,
2293
- attentions=all_self_attns,
 
 
2294
  )
 
2295
 
2296
- def _update_causal_mask(
2297
- self,
2298
- attention_mask: torch.Tensor,
2299
- input_tensor: torch.Tensor,
2300
- cache_position: torch.Tensor,
2301
- past_key_values: Cache,
2302
- use_cache: bool,
2303
- output_attentions: bool,
2304
- ):
2305
-
2306
- # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
2307
- # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
2308
- # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
2309
 
2310
- if self._attn_implementation == "flash_attention_2":
2311
- if attention_mask is not None and use_cache:
2312
- is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0]
2313
- if is_padding_right:
2314
- raise ValueError(
2315
- "You are attempting to perform batched generation with padding_side='right'"
2316
- " this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to "
2317
- " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
2318
- )
2319
- if attention_mask is not None and 0.0 in attention_mask:
2320
- return attention_mask
2321
- return None
2322
 
2323
- # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
2324
- # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
2325
- # to infer the attention mask.
2326
 
2327
- # cache_position must be valid here no matter which cache we use
2328
- past_seen_tokens = cache_position[0] if past_key_values is not None else 0
2329
- using_static_cache = isinstance(past_key_values, StaticCache)
2330
- using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)
2331
 
2332
- if (
2333
- self.config._attn_implementation == "sdpa"
2334
- and not (using_static_cache or using_sliding_window_cache)
2335
- and not output_attentions
2336
- ):
2337
- if AttentionMaskConverter._ignore_causal_mask_sdpa(
2338
- attention_mask,
2339
- inputs_embeds=input_tensor,
2340
- past_key_values_length=past_seen_tokens,
2341
- sliding_window=self.config.sliding_window,
2342
- is_training=self.training,
2343
- ):
2344
- return None
2345
 
2346
- dtype, device = input_tensor.dtype, input_tensor.device
2347
- min_dtype = torch.finfo(dtype).min
2348
- sequence_length = input_tensor.shape[1]
2349
- # SlidingWindowCache
2350
- if using_sliding_window_cache:
2351
- target_length = max(sequence_length, self.config.sliding_window)
2352
- # StaticCache
2353
- elif using_static_cache:
2354
- target_length = past_key_values.get_max_length()
2355
- # DynamicCache or no cache
2356
- else:
2357
- target_length = (
2358
- attention_mask.shape[-1]
2359
- if isinstance(attention_mask, torch.Tensor)
2360
- else past_seen_tokens + sequence_length + 1
2361
- )
2362
 
2363
- if attention_mask is not None and attention_mask.dim() == 4:
2364
- # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
2365
- if attention_mask.max() != 0:
2366
- raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`")
2367
- causal_mask = attention_mask
2368
- else:
2369
- causal_mask = torch.full(
2370
- (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
2371
- )
2372
- exclude_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
2373
- if self.config.sliding_window is not None:
2374
- if not using_sliding_window_cache or sequence_length > self.config.sliding_window:
2375
- exclude_mask.bitwise_or_(
2376
- torch.arange(target_length, device=device)
2377
- <= (cache_position.reshape(-1, 1) - self.config.sliding_window)
2378
- )
2379
- causal_mask *= exclude_mask
2380
- causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
2381
- if attention_mask is not None:
2382
- causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
2383
- if attention_mask.dim() == 2:
2384
- mask_length = attention_mask.shape[-1]
2385
- padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
2386
- padding_mask = padding_mask == 0
2387
- causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
2388
- padding_mask, min_dtype
2389
- )
2390
 
2391
- if (
2392
- self.config._attn_implementation == "sdpa"
2393
- and attention_mask is not None
2394
- and attention_mask.device.type == "cuda"
2395
- and not output_attentions
2396
- ):
2397
- # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
2398
- # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
2399
- # Details: https://github.com/pytorch/pytorch/issues/110213
2400
- causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
2401
 
2402
- return causal_mask
2403
 
2404
  ############# Causal LM #################
2405
  class MistralForCausalLM(MistralPreTrainedModel):
@@ -3421,10 +3421,40 @@ class MistralForCausalLM(MistralPreTrainedModel):
3421
  else:
3422
  cur_talk_loss = talk_loss_list[talk_idx]
3423
  log_dict[f"rel_loss_{i}"] += (nonzero_mean(loss_list[i]) - cur_talk_loss) / self.n_tokens_print
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3424
 
3425
-
3426
- self.n_ahead_talk = n_ahead_talk_to_restore
3427
- self.n_passes = n_passes_to_restore
3428
  return CausalLMOutputWithPast(
3429
  loss=loss if loss is not None else None,
3430
  logits=(rm_logits if self.n_ahead > 1 else logits) if not self.output_logits_at_the_end else logits,
 
662
  module.weight.data[module.padding_idx].zero_()
663
 
664
 
665
+ @add_start_docstrings(
666
+ "The bare Mistral Model outputting raw hidden-states without any specific head on top.",
667
+ MISTRAL_START_DOCSTRING,
668
+ )
669
+ class MistralModel(MistralPreTrainedModel):
670
+ """
671
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MistralDecoderLayer`]
672
+
673
+ Args:
674
+ config: MistralConfig
675
+ """
676
+
677
+ def __init__(self, config: MistralConfig):
678
+ super().__init__(config)
679
+ self.padding_idx = config.pad_token_id
680
+ self.vocab_size = config.vocab_size
681
+
682
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
683
+ self.layers = nn.ModuleList(
684
+ [MistralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
685
+ )
686
+ self._attn_implementation = config._attn_implementation
687
+ self.norm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
688
+
689
+ self.gradient_checkpointing = False
690
+ # Initialize weights and apply final processing
691
+ self.post_init()
692
+
693
+ def get_input_embeddings(self):
694
+ return self.embed_tokens
695
+
696
+ def set_input_embeddings(self, value):
697
+ self.embed_tokens = value
698
+
699
+ @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
700
+ def forward(
701
+ self,
702
+ input_ids: torch.LongTensor = None,
703
+ attention_mask: Optional[torch.Tensor] = None,
704
+ position_ids: Optional[torch.LongTensor] = None,
705
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
706
+ inputs_embeds: Optional[torch.FloatTensor] = None,
707
+ use_cache: Optional[bool] = None,
708
+ output_attentions: Optional[bool] = None,
709
+ output_hidden_states: Optional[bool] = None,
710
+ return_dict: Optional[bool] = None,
711
+ cache_position: Optional[torch.LongTensor] = None,
712
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
713
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
714
+ output_hidden_states = (
715
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
716
+ )
717
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
718
+
719
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
720
+
721
+ # retrieve input_ids and inputs_embeds
722
+ if (input_ids is None) ^ (inputs_embeds is not None):
723
+ raise ValueError(
724
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
725
+ )
726
+
727
+ if self.gradient_checkpointing and self.training and use_cache:
728
+ logger.warning_once(
729
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
730
+ )
731
+ use_cache = False
732
+
733
+ if inputs_embeds is None:
734
+ inputs_embeds = self.embed_tokens(input_ids)
735
+
736
+ return_legacy_cache = False
737
+ if use_cache and not isinstance(past_key_values, Cache):
738
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
739
+ return_legacy_cache = True
740
+ logger.warning_once(
741
+ "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. "
742
+ "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)"
743
+ )
744
+
745
+ if cache_position is None:
746
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
747
+ cache_position = torch.arange(
748
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
749
+ )
750
+
751
+ if position_ids is None:
752
+ position_ids = cache_position.unsqueeze(0)
753
+
754
+ causal_mask = self._update_causal_mask(
755
+ attention_mask, inputs_embeds, cache_position, past_key_values, use_cache, output_attentions
756
+ )
757
+
758
+ hidden_states = inputs_embeds
759
+
760
+ # decoder layers
761
+ all_hidden_states = () if output_hidden_states else None
762
+ all_self_attns = () if output_attentions else None
763
+ next_decoder_cache = None
764
+
765
+ for decoder_layer in self.layers:
766
+ if output_hidden_states:
767
+ all_hidden_states += (hidden_states,)
768
+
769
+ if self.gradient_checkpointing and self.training:
770
+ layer_outputs = self._gradient_checkpointing_func(
771
+ decoder_layer.__call__,
772
+ hidden_states,
773
+ causal_mask,
774
+ position_ids,
775
+ past_key_values,
776
+ output_attentions,
777
+ use_cache,
778
+ cache_position,
779
+ )
780
+ else:
781
+ layer_outputs = decoder_layer(
782
+ hidden_states,
783
+ attention_mask=causal_mask,
784
+ position_ids=position_ids,
785
+ past_key_value=past_key_values,
786
+ output_attentions=output_attentions,
787
+ use_cache=use_cache,
788
+ cache_position=cache_position,
789
+ )
790
+
791
+ hidden_states = layer_outputs[0]
792
+
793
+ if use_cache:
794
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
795
+
796
+ if output_attentions:
797
+ all_self_attns += (layer_outputs[1],)
798
+
799
+ hidden_states = self.norm(hidden_states)
800
+
801
+ # add hidden states from the last decoder layer
802
+ if output_hidden_states:
803
+ all_hidden_states += (hidden_states,)
804
+
805
+ next_cache = next_decoder_cache if use_cache else None
806
+ if return_legacy_cache:
807
+ next_cache = next_cache.to_legacy_cache()
808
+
809
+ if not return_dict:
810
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
811
+ return BaseModelOutputWithPast(
812
+ last_hidden_state=hidden_states,
813
+ past_key_values=next_cache,
814
+ hidden_states=all_hidden_states,
815
+ attentions=all_self_attns,
816
+ )
817
+
818
+ def _update_causal_mask(
819
+ self,
820
+ attention_mask: torch.Tensor,
821
+ input_tensor: torch.Tensor,
822
+ cache_position: torch.Tensor,
823
+ past_key_values: Cache,
824
+ use_cache: bool,
825
+ output_attentions: bool,
826
+ ):
827
+ # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
828
+ # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
829
+ # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
830
+ # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
831
+
832
+ if self._attn_implementation == "flash_attention_2":
833
+ if attention_mask is not None and use_cache:
834
+ is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0]
835
+ if is_padding_right:
836
+ raise ValueError(
837
+ "You are attempting to perform batched generation with padding_side='right'"
838
+ " this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to "
839
+ " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
840
+ )
841
+ if attention_mask is not None and 0.0 in attention_mask:
842
+ return attention_mask
843
+ return None
844
+
845
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
846
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
847
+ # to infer the attention mask.
848
+
849
+ # cache_position must be valid here no matter which cache we use
850
+ past_seen_tokens = cache_position[0] if past_key_values is not None else 0
851
+ using_static_cache = isinstance(past_key_values, StaticCache)
852
+ using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)
853
+
854
+ if (
855
+ self.config._attn_implementation == "sdpa"
856
+ and not (using_static_cache or using_sliding_window_cache)
857
+ and not output_attentions
858
+ ):
859
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
860
+ attention_mask,
861
+ inputs_embeds=input_tensor,
862
+ past_key_values_length=past_seen_tokens,
863
+ sliding_window=self.config.sliding_window,
864
+ is_training=self.training,
865
+ ):
866
+ return None
867
+
868
+ dtype, device = input_tensor.dtype, input_tensor.device
869
+ min_dtype = torch.finfo(dtype).min
870
+ sequence_length = input_tensor.shape[1]
871
+ # SlidingWindowCache
872
+ if using_sliding_window_cache:
873
+ target_length = max(sequence_length, self.config.sliding_window)
874
+ # StaticCache
875
+ elif using_static_cache:
876
+ target_length = past_key_values.get_max_length()
877
+ # DynamicCache or no cache
878
+ else:
879
+ target_length = (
880
+ attention_mask.shape[-1]
881
+ if isinstance(attention_mask, torch.Tensor)
882
+ else past_seen_tokens + sequence_length + 1
883
+ )
884
+
885
+ if attention_mask is not None and attention_mask.dim() == 4:
886
+ # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
887
+ if attention_mask.max() != 0:
888
+ raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`")
889
+ causal_mask = attention_mask
890
+ else:
891
+ causal_mask = torch.full(
892
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
893
+ )
894
+ exclude_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
895
+ if self.config.sliding_window is not None:
896
+ if not using_sliding_window_cache or sequence_length > self.config.sliding_window:
897
+ exclude_mask.bitwise_or_(
898
+ torch.arange(target_length, device=device)
899
+ <= (cache_position.reshape(-1, 1) - self.config.sliding_window)
900
+ )
901
+ causal_mask *= exclude_mask
902
+ causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
903
+ if attention_mask is not None:
904
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
905
+ if attention_mask.dim() == 2:
906
+ mask_length = attention_mask.shape[-1]
907
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
908
+ padding_mask = padding_mask == 0
909
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
910
+ padding_mask, min_dtype
911
+ )
912
+
913
+ if (
914
+ self.config._attn_implementation == "sdpa"
915
+ and attention_mask is not None
916
+ and attention_mask.device.type == "cuda"
917
+ and not output_attentions
918
+ ):
919
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
920
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
921
+ # Details: https://github.com/pytorch/pytorch/issues/110213
922
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
923
+
924
+ return causal_mask
925
+
926
 
927
  ############################## LM Heads #################################
928
 
 
2277
  """ """
2278
  batch_size, sequence_length, hidden_dim = hidden_states.shape
2279
  if self.training and self.jitter_noise > 0:
2280
+ hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise)
2281
+ hidden_states = hidden_states.view(-1, hidden_dim)
2282
+ # router_logits: (batch * sequence_length, n_experts)
2283
+ router_logits = self.gate(hidden_states)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2284
 
2285
+ routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
2286
+ routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
2287
+ routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
2288
+ # we cast back to the input dtype
2289
+ routing_weights = routing_weights.to(hidden_states.dtype)
2290
 
2291
+ final_hidden_states = torch.zeros(
2292
+ (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
2293
+ )
 
2294
 
2295
+ # One hot encode the selected experts to create an expert mask
2296
+ # this will be used to easily index which expert is going to be sollicitated
2297
+ expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
2298
 
2299
+ # Loop over all available experts in the model and perform the computation on each expert
2300
+ for expert_idx in range(self.num_experts):
2301
+ expert_layer = self.experts[expert_idx]
2302
+ idx, top_x = torch.where(expert_mask[expert_idx])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2303
 
2304
+ # Index the correct hidden states and compute the expert hidden state for
2305
+ # the current expert. We need to make sure to multiply the output hidden
2306
+ # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
2307
+ current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
2308
+ current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
2309
 
2310
+ # However `index_add_` only support torch tensors for indexing so we'll use
2311
+ # the `top_x` tensor here.
2312
+ final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
2313
+ final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
2314
+ return final_hidden_states, router_logits
2315
+ class MixtralDecoderLayer(nn.Module):
2316
+ def __init__(self, config: MixtralConfig, layer_idx: int):
2317
+ super().__init__()
2318
+ self.hidden_size = config.hidden_size
2319
 
2320
+ self.self_attn = MISTRAL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
2321
+ self.mlp = MistralMLP(config)
2322
+ self.block_sparse_moe = MixtralSparseMoeBlock(config)
2323
+ self.input_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
2324
+ self.post_attention_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
2325
 
2326
+ def forward(
2327
+ self,
2328
+ hidden_states: torch.Tensor,
2329
+ attention_mask: Optional[torch.Tensor] = None,
2330
+ position_ids: Optional[torch.LongTensor] = None,
2331
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
2332
+ output_attentions: Optional[bool] = False,
2333
+ output_router_logits: Optional[bool] = False,
2334
+ use_cache: Optional[bool] = False,
2335
+ cache_position: Optional[torch.LongTensor] = None,
2336
+ **kwargs,
2337
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
2338
+ """
2339
+ Args:
2340
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
2341
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
2342
+ `(batch, sequence_length)` where padding elements are indicated by 0.
2343
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
2344
+ output_attentions (`bool`, *optional*):
2345
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
2346
+ returned tensors for more detail.
2347
+ output_router_logits (`bool`, *optional*):
2348
+ Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
2349
+ should not be returned during inference.
2350
+ use_cache (`bool`, *optional*):
2351
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
2352
+ (see `past_key_values`).
2353
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
2354
+ Indices depicting the position of the input sequence tokens in the sequence.
2355
+ kwargs (`dict`, *optional*):
2356
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
2357
+ into the model
2358
+ """
2359
 
2360
+ residual = hidden_states
 
 
2361
 
2362
+ hidden_states = self.input_layernorm(hidden_states)
 
 
2363
 
2364
+ # Self Attention
2365
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
2366
+ hidden_states=hidden_states,
2367
+ attention_mask=attention_mask,
2368
+ position_ids=position_ids,
2369
+ past_key_value=past_key_value,
2370
+ output_attentions=output_attentions,
2371
+ use_cache=use_cache,
2372
+ cache_position=cache_position,
2373
  )
2374
+ hidden_states = residual + hidden_states
2375
 
2376
+ # Fully Connected
2377
+ residual = hidden_states
2378
+ hidden_states = self.post_attention_layernorm(hidden_states)
2379
+ hidden_states, router_logits = self.block_sparse_moe(hidden_states)
2380
+ hidden_states = residual + hidden_states
 
 
 
 
 
 
 
 
2381
 
2382
+ # Fully Connected
2383
+ residual = hidden_states
2384
+ hidden_states = self.post_attention_layernorm(hidden_states)
2385
+ hidden_states = self.mlp(hidden_states)
2386
+ hidden_states = residual + hidden_states
 
 
 
 
 
 
 
2387
 
2388
+ outputs = (hidden_states,)
 
 
2389
 
2390
+ if output_attentions:
2391
+ outputs += (self_attn_weights,)
 
 
2392
 
2393
+ if use_cache:
2394
+ outputs += (present_key_value,)
 
 
 
 
 
 
 
 
 
 
 
2395
 
2396
+ if output_router_logits:
2397
+ outputs += (router_logits,)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2398
 
2399
+ return outputs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2400
 
2401
+ ################################ closed COMPONENTS ################################
 
 
 
 
 
 
 
 
 
2402
 
 
2403
 
2404
  ############# Causal LM #################
2405
  class MistralForCausalLM(MistralPreTrainedModel):
 
3421
  else:
3422
  cur_talk_loss = talk_loss_list[talk_idx]
3423
  log_dict[f"rel_loss_{i}"] += (nonzero_mean(loss_list[i]) - cur_talk_loss) / self.n_tokens_print
3424
+ if self.training:
3425
+ self.training_steps += 1
3426
+ try:
3427
+ # if self.training_steps % (self.gradient_accumulation_steps * 256) == 0:
3428
+ if self.wandb_enabled:
3429
+ if self.training_steps % (self.n_tokens_print) == 0 or not self.training:# and "0" in str(loss.device):
3430
+ if not self.training:
3431
+ new_log_dict = {}
3432
+ for key in list(log_dict.keys()):
3433
+ new_log_dict["eval_" + key] = log_dict[key]
3434
+ log_dict = new_log_dict
3435
+ log_dict["training_steps"] = self.training_steps
3436
+ log_dict["batch_size"] = batch_size
3437
+ log_dict["example_steps"] = self.training_steps * batch_size * self.gradient_accumulation_steps
3438
+ if self.n_ahead > 1:
3439
+ log_dict["compute_steps"] = self.training_steps * batch_size * (self.n_ahead + self.n_ahead_talk - 1) * self.gradient_accumulation_steps
3440
+ else: # There's no overhead for talk tokens if there's no thinking
3441
+ log_dict["compute_steps"] = self.training_steps * batch_size * self.gradient_accumulation_steps
3442
+ # remove all nans
3443
+ for key in list(log_dict.keys()):
3444
+ if log_dict[key] != log_dict[key]:
3445
+ del log_dict[key]
3446
+ if self.training:
3447
+ wandb.log(log_dict)
3448
+ if self.training:
3449
+ self.log_dict = defaultdict(int)
3450
+ else:
3451
+ self.eval_log_dict = defaultdict(int)
3452
+ except Exception as e:
3453
+ pass
3454
 
3455
+ if not self.training:
3456
+ self.n_ahead_talk = n_ahead_talk_to_restore
3457
+ self.n_passes = n_passes_to_restore
3458
  return CausalLMOutputWithPast(
3459
  loss=loss if loss is not None else None,
3460
  logits=(rm_logits if self.n_ahead > 1 else logits) if not self.output_logits_at_the_end else logits,