SmerkyG commited on
Commit
e6f93c3
·
verified ·
1 Parent(s): 07f9d2e

Update modeling_rwkv6qwen2.py

Browse files
Files changed (1) hide show
  1. modeling_rwkv6qwen2.py +181 -143
modeling_rwkv6qwen2.py CHANGED
@@ -29,7 +29,7 @@ from torch import nn
29
  import torch.nn.functional as F
30
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
31
 
32
- from transformers.cache_utils import Cache, StaticCache
33
  from transformers.generation import GenerationMixin
34
  from transformers.modeling_outputs import (
35
  BaseModelOutputWithPast,
@@ -209,7 +209,7 @@ try:
209
  from fla.ops.gla.fused_recurrent import fused_recurrent_gla
210
  except ImportError:
211
  print("Required module is not installed. Please install it using the following commands:")
212
- print("pip install -U git+https://github.com/sustcsonglin/flash-linear-attention")
213
  print("Additionally, ensure you have at least version 2.2.0 of Triton installed:")
214
  print("pip install triton>=2.2.0")
215
 
@@ -230,7 +230,6 @@ class RWKV6Attention(nn.Module):
230
  self.head_dim = getattr(config, 'head_dim', self.hidden_size // self.num_heads)
231
  self.num_key_value_heads = config.num_key_value_heads
232
  self.num_key_value_groups = self.num_heads // self.num_key_value_heads
233
- self.is_causal = True
234
  self.attention_dropout = config.attention_dropout
235
 
236
  if self.hidden_size % self.num_heads != 0:
@@ -284,7 +283,7 @@ class RWKV6Attention(nn.Module):
284
  hidden_states: torch.Tensor,
285
  attention_mask: Optional[torch.Tensor] = None,
286
  position_ids: Optional[torch.LongTensor] = None,
287
- past_key_value: Optional[RWKV6State] = None,
288
  output_attentions: bool = False,
289
  use_cache: bool = False,
290
  cache_position: Optional[torch.LongTensor] = None,
@@ -297,8 +296,8 @@ class RWKV6Attention(nn.Module):
297
 
298
  x = hidden_states
299
 
300
- if use_cache and past_key_value is not None and len(past_key_value) > self.layer_idx:
301
- input_kv_state, input_shift_state = past_key_value[self.layer_idx]
302
  xprev = torch.cat([input_shift_state, x[:, :-1]], dim=1)
303
  else:
304
  input_kv_state = None
@@ -334,9 +333,13 @@ class RWKV6Attention(nn.Module):
334
  dropout_rate = 0.0 if not self.training else self.attention_dropout
335
 
336
  decay_states_log = -decay_states.float().exp()
337
- #decay_states_log = decay_states_log.clamp(-5) # FIXME - is this necessary?
338
  key_states = (key_states * (1 - decay_states_log.exp())).to(key_states.dtype)
339
 
 
 
 
 
340
  query_states = query_states.to(value_states.dtype)
341
  key_states = key_states.to(value_states.dtype)
342
 
@@ -366,19 +369,19 @@ class RWKV6Attention(nn.Module):
366
  attn_weights = torch.empty(0, device=x.device)
367
 
368
  scale = query_states.shape[-1] ** -0.5
369
- output_final_state = not self.training and use_cache and past_key_value is not None
370
  #attn_output, output_kv_state = ChunkGLAFunction.apply(query_states, key_states, value_states, decay_states_log.float(), scale, input_kv_state, output_final_state)
371
  #attn_output, output_kv_state = chunk_gla(query_states, key_states, value_states, decay_states_log, scale, input_kv_state, output_final_state)
372
  attn_output, output_kv_state = fused_recurrent_gla(query_states, key_states, value_states, decay_states_log, None, scale, input_kv_state, output_final_state)
373
 
374
  if output_final_state:
375
- past_key_value.update(output_kv_state, output_shift_state, q_len, self.layer_idx)
376
 
377
  attn_output = attn_output.transpose(1, 2).contiguous()
378
  attn_output = attn_output.view(bsz, q_len, -1)
379
  attn_output = self.o_proj(attn_output * gate_states)
380
 
381
- return attn_output, attn_weights, past_key_value
382
 
383
  class RWKV6Qwen2DecoderLayer(Qwen2DecoderLayer):
384
  def __init__(self, config: RWKV6Qwen2Config, layer_idx: int):
@@ -391,6 +394,48 @@ class RWKV6Qwen2DecoderLayer(Qwen2DecoderLayer):
391
  self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
392
  self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
393
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
394
  RWKV6QWEN2_START_DOCSTRING = r"""
395
  This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
396
  library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
@@ -581,6 +626,7 @@ class RWKV6Qwen2Model(RWKV6Qwen2PreTrainedModel):
581
  #return_legacy_cache = False
582
  if use_cache and not isinstance(past_key_values, RWKV6State):
583
  #return_legacy_cache = True
 
584
  past_key_values = RWKV6State()
585
  # if past_key_values is None:
586
  # past_key_values = DynamicCache()
@@ -638,9 +684,9 @@ class RWKV6Qwen2Model(RWKV6Qwen2PreTrainedModel):
638
  else:
639
  layer_outputs = decoder_layer(
640
  hidden_states,
641
- attention_mask=causal_mask,
642
  position_ids=position_ids,
643
- past_key_value=past_key_values,
644
  output_attentions=output_attentions,
645
  use_cache=use_cache,
646
  cache_position=cache_position,
@@ -649,9 +695,6 @@ class RWKV6Qwen2Model(RWKV6Qwen2PreTrainedModel):
649
 
650
  hidden_states = layer_outputs[0]
651
 
652
- if use_cache:
653
- next_decoder_cache = layer_outputs[2 if output_attentions else 1]
654
-
655
  if output_attentions:
656
  all_self_attns += (layer_outputs[1],)
657
 
@@ -661,15 +704,14 @@ class RWKV6Qwen2Model(RWKV6Qwen2PreTrainedModel):
661
  if output_hidden_states:
662
  all_hidden_states += (hidden_states,)
663
 
664
- next_cache = next_decoder_cache if use_cache else None
665
  #if return_legacy_cache:
666
  # next_cache = next_cache.to_legacy_cache()
667
 
668
  if not return_dict:
669
- return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
670
  return BaseModelOutputWithPast(
671
  last_hidden_state=hidden_states,
672
- past_key_values=next_cache,
673
  hidden_states=all_hidden_states,
674
  attentions=all_self_attns,
675
  )
@@ -793,130 +835,126 @@ class RWKV6Qwen2ForCausalLM(RWKV6Qwen2PreTrainedModel, GenerationMixin):
793
  attentions=outputs.attentions,
794
  )
795
 
796
- def prepare_inputs_for_generation(
797
- self,
798
- input_ids: torch.LongTensor,
799
- past_key_values: Optional[Cache] = None,
800
- attention_mask: Optional[torch.LongTensor] = None,
801
- inputs_embeds: Optional[torch.FloatTensor] = None,
802
- cache_position: Optional[torch.LongTensor] = None,
803
- **kwargs,
804
- ):
805
- """
806
- Prepare the model inputs for generation. In includes operations like computing the 4D attention mask or
807
- slicing inputs given the existing cache.
808
-
809
- See the forward pass in the model documentation for expected arguments (different models might have different
810
- requirements for e.g. `past_key_values`). This function should work as is for most LLMs.
811
- """
812
-
813
- # 1. Handle BC:
814
- model_inputs = {}
815
- # - some models don't have `Cache` support (which implies they don't expect `cache_position` in `forward`)
816
- if self._supports_cache_class:
817
- model_inputs["cache_position"] = cache_position
818
- # - `cache_position` was not a mandatory input in `prepare_inputs_for_generation` for those models, and this
819
- # function may be called outside of `generate`. Handle most use cases by creating `cache_position` on the fly
820
- # (this alternative is not as robust as calling `generate` and letting it create `cache_position`)
821
- elif cache_position is None:
822
- past_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
823
- cache_position = torch.arange(past_length, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
824
-
825
- # 2. Generic cache-dependent input preparation
826
- # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
827
- # Exception 1: when passing input_embeds, input_ids may be missing entries
828
- # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
829
- # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case
830
- if past_key_values is not None:
831
- model_inputs["past_key_values"] = past_key_values
832
- if inputs_embeds is not None or cache_position[-1] >= input_ids.shape[1]: # Exception 1 or Exception 3
833
- input_ids = input_ids[:, -cache_position.shape[0] :]
834
- elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
835
- input_ids = input_ids[:, cache_position]
836
-
837
- # 3. Prepare base model inputs
838
- input_ids_key = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
839
- # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
840
- if not self.config.is_encoder_decoder:
841
- if inputs_embeds is not None and cache_position[0] == 0:
842
- model_inputs[input_ids_key] = None
843
- model_inputs["inputs_embeds"] = inputs_embeds
844
- else:
845
- # `clone` calls in this function ensure a consistent stride. See #32227
846
- model_inputs[input_ids_key] = input_ids.clone(memory_format=torch.contiguous_format)
847
- model_inputs["inputs_embeds"] = None
848
- else:
849
- model_inputs[input_ids_key] = input_ids.clone(memory_format=torch.contiguous_format)
850
-
851
- # 4. Create missing `position_ids` on the fly
852
- if (
853
- attention_mask is not None
854
- and kwargs.get("position_ids") is None
855
- and "position_ids" in set(inspect.signature(self.forward).parameters.keys())
856
- ):
857
- position_ids = attention_mask.long().cumsum(-1) - 1
858
- position_ids.masked_fill_(attention_mask == 0, 1)
859
- kwargs["position_ids"] = position_ids # placed in kwargs for further processing (see below)
860
-
861
- # 5. Slice model inputs if it's an input that should have the same length as `input_ids`
862
- for model_input_name in ["position_ids", "token_type_ids"]:
863
- model_input = kwargs.get(model_input_name)
864
- if model_input is not None:
865
- if past_key_values:
866
- model_input = model_input[:, -input_ids.shape[1] :]
867
- model_input = model_input.clone(memory_format=torch.contiguous_format)
868
- model_inputs[model_input_name] = model_input
869
-
870
- # 6. Create 4D attention mask is we are using a `StaticCache` (important for performant compiled forward pass)
871
- if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
872
- if model_inputs["inputs_embeds"] is not None:
873
- batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
874
- device = model_inputs["inputs_embeds"].device
875
- else:
876
- batch_size, sequence_length = model_inputs[input_ids_key].shape
877
- device = model_inputs[input_ids_key].device
878
-
879
- # Create the causal mask with fixed shape in advance, to reduce recompilations. If the function to create
880
- # the 4D causal mask exists, it should be present in the base model (XXXModel class).
881
- base_model = getattr(self, self.base_model_prefix, None)
882
- if base_model is None:
883
- causal_mask_creation_function = getattr(
884
- self, "_prepare_4d_causal_attention_mask_with_cache_position", None
885
- )
886
- else:
887
- causal_mask_creation_function = getattr(
888
- base_model, "_prepare_4d_causal_attention_mask_with_cache_position", None
889
- )
890
- if causal_mask_creation_function is None:
891
- logger.warning_once(
892
- f"{self.__class__.__name__} has no `_prepare_4d_causal_attention_mask_with_cache_position` method "
893
- "defined in its base modeling class. Compiled forward passes will be sub-optimal. If you're "
894
- "writing code, see Llama for an example implementation. If you're a user, please report this "
895
- "issue on GitHub."
896
- )
897
- else:
898
- attention_mask = causal_mask_creation_function(
899
- attention_mask,
900
- sequence_length=sequence_length,
901
- target_length=past_key_values.get_max_cache_shape(),
902
- dtype=self.dtype,
903
- device=device,
904
- cache_position=cache_position,
905
- batch_size=batch_size,
906
- config=self.config,
907
- past_key_values=past_key_values,
908
- )
909
- if attention_mask is not None:
910
- model_inputs["attention_mask"] = attention_mask
911
-
912
- # 7. Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
913
- for key, value in kwargs.items():
914
- if key not in model_inputs:
915
- model_inputs[key] = value
916
-
917
- # 8. Remove unexpected `generate` inputs (TODO @joao: fix trainer and examples)
918
- model_inputs.pop("labels", None)
919
- return model_inputs
920
 
921
  @add_start_docstrings(
922
  """
@@ -1215,4 +1253,4 @@ class RWKV6Qwen2ForQuestionAnswering(RWKV6Qwen2PreTrainedModel):
1215
  end_logits=end_logits,
1216
  hidden_states=outputs.hidden_states,
1217
  attentions=outputs.attentions,
1218
- )
 
29
  import torch.nn.functional as F
30
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
31
 
32
+ from transformers.cache_utils import Cache, StaticCache, DynamicCache
33
  from transformers.generation import GenerationMixin
34
  from transformers.modeling_outputs import (
35
  BaseModelOutputWithPast,
 
209
  from fla.ops.gla.fused_recurrent import fused_recurrent_gla
210
  except ImportError:
211
  print("Required module is not installed. Please install it using the following commands:")
212
+ print("pip install -U git+https://github.com/fla-org/flash-linear-attention")
213
  print("Additionally, ensure you have at least version 2.2.0 of Triton installed:")
214
  print("pip install triton>=2.2.0")
215
 
 
230
  self.head_dim = getattr(config, 'head_dim', self.hidden_size // self.num_heads)
231
  self.num_key_value_heads = config.num_key_value_heads
232
  self.num_key_value_groups = self.num_heads // self.num_key_value_heads
 
233
  self.attention_dropout = config.attention_dropout
234
 
235
  if self.hidden_size % self.num_heads != 0:
 
283
  hidden_states: torch.Tensor,
284
  attention_mask: Optional[torch.Tensor] = None,
285
  position_ids: Optional[torch.LongTensor] = None,
286
+ past_key_values: Optional[RWKV6State] = None,
287
  output_attentions: bool = False,
288
  use_cache: bool = False,
289
  cache_position: Optional[torch.LongTensor] = None,
 
296
 
297
  x = hidden_states
298
 
299
+ if use_cache and past_key_values is not None and len(past_key_values) > self.layer_idx:
300
+ input_kv_state, input_shift_state = past_key_values[self.layer_idx]
301
  xprev = torch.cat([input_shift_state, x[:, :-1]], dim=1)
302
  else:
303
  input_kv_state = None
 
333
  dropout_rate = 0.0 if not self.training else self.attention_dropout
334
 
335
  decay_states_log = -decay_states.float().exp()
336
+ decay_states_log = decay_states_log.clamp(-5) # FIXME - is this necessary?
337
  key_states = (key_states * (1 - decay_states_log.exp())).to(key_states.dtype)
338
 
339
+ if attention_mask is not None:
340
+ if q_len > 1:
341
+ decay_states_log = decay_states_log - 100 * F.pad(1 - attention_mask, [1, -1]).view(bsz, 1, q_len, 1)
342
+
343
  query_states = query_states.to(value_states.dtype)
344
  key_states = key_states.to(value_states.dtype)
345
 
 
369
  attn_weights = torch.empty(0, device=x.device)
370
 
371
  scale = query_states.shape[-1] ** -0.5
372
+ output_final_state = not self.training and use_cache and past_key_values is not None
373
  #attn_output, output_kv_state = ChunkGLAFunction.apply(query_states, key_states, value_states, decay_states_log.float(), scale, input_kv_state, output_final_state)
374
  #attn_output, output_kv_state = chunk_gla(query_states, key_states, value_states, decay_states_log, scale, input_kv_state, output_final_state)
375
  attn_output, output_kv_state = fused_recurrent_gla(query_states, key_states, value_states, decay_states_log, None, scale, input_kv_state, output_final_state)
376
 
377
  if output_final_state:
378
+ past_key_values.update(output_kv_state, output_shift_state, q_len, self.layer_idx)
379
 
380
  attn_output = attn_output.transpose(1, 2).contiguous()
381
  attn_output = attn_output.view(bsz, q_len, -1)
382
  attn_output = self.o_proj(attn_output * gate_states)
383
 
384
+ return attn_output, attn_weights
385
 
386
  class RWKV6Qwen2DecoderLayer(Qwen2DecoderLayer):
387
  def __init__(self, config: RWKV6Qwen2Config, layer_idx: int):
 
394
  self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
395
  self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
396
 
397
+ def forward(
398
+ self,
399
+ hidden_states: torch.Tensor,
400
+ attention_mask: Optional[torch.Tensor] = None,
401
+ position_ids: Optional[torch.LongTensor] = None,
402
+ past_key_values: Optional[Cache] = None,
403
+ output_attentions: Optional[bool] = False,
404
+ use_cache: Optional[bool] = False,
405
+ cache_position: Optional[torch.LongTensor] = None,
406
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
407
+ **kwargs,
408
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
409
+ residual = hidden_states
410
+
411
+ hidden_states = self.input_layernorm(hidden_states)
412
+
413
+ # Self Attention
414
+ hidden_states, self_attn_weights = self.self_attn(
415
+ hidden_states=hidden_states,
416
+ attention_mask=attention_mask,
417
+ position_ids=position_ids,
418
+ past_key_values=past_key_values,
419
+ output_attentions=output_attentions,
420
+ use_cache=use_cache,
421
+ cache_position=cache_position,
422
+ position_embeddings=position_embeddings,
423
+ **kwargs,
424
+ )
425
+ hidden_states = residual + hidden_states
426
+
427
+ # Fully Connected
428
+ residual = hidden_states
429
+ hidden_states = self.post_attention_layernorm(hidden_states)
430
+ hidden_states = self.mlp(hidden_states)
431
+ hidden_states = residual + hidden_states
432
+
433
+ outputs = (hidden_states,)
434
+ if output_attentions:
435
+ outputs += (self_attn_weights,)
436
+
437
+ return outputs
438
+
439
  RWKV6QWEN2_START_DOCSTRING = r"""
440
  This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
441
  library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
 
626
  #return_legacy_cache = False
627
  if use_cache and not isinstance(past_key_values, RWKV6State):
628
  #return_legacy_cache = True
629
+ print("creating past_key_values", past_key_values)
630
  past_key_values = RWKV6State()
631
  # if past_key_values is None:
632
  # past_key_values = DynamicCache()
 
684
  else:
685
  layer_outputs = decoder_layer(
686
  hidden_states,
687
+ attention_mask=attention_mask,
688
  position_ids=position_ids,
689
+ past_key_values=past_key_values,
690
  output_attentions=output_attentions,
691
  use_cache=use_cache,
692
  cache_position=cache_position,
 
695
 
696
  hidden_states = layer_outputs[0]
697
 
 
 
 
698
  if output_attentions:
699
  all_self_attns += (layer_outputs[1],)
700
 
 
704
  if output_hidden_states:
705
  all_hidden_states += (hidden_states,)
706
 
 
707
  #if return_legacy_cache:
708
  # next_cache = next_cache.to_legacy_cache()
709
 
710
  if not return_dict:
711
+ return tuple(v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns] if v is not None)
712
  return BaseModelOutputWithPast(
713
  last_hidden_state=hidden_states,
714
+ past_key_values=past_key_values,
715
  hidden_states=all_hidden_states,
716
  attentions=all_self_attns,
717
  )
 
835
  attentions=outputs.attentions,
836
  )
837
 
838
+ # def prepare_inputs_for_generation(
839
+ # self,
840
+ # input_ids: torch.LongTensor,
841
+ # past_key_values: Optional[Cache] = None,
842
+ # attention_mask: Optional[torch.LongTensor] = None,
843
+ # inputs_embeds: Optional[torch.FloatTensor] = None,
844
+ # cache_position: Optional[torch.LongTensor] = None,
845
+ # **kwargs,
846
+ # ):
847
+ # """
848
+ # Prepare the model inputs for generation. In includes operations like computing the 4D attention mask or
849
+ # slicing inputs given the existing cache.
850
+
851
+ # See the forward pass in the model documentation for expected arguments (different models might have different
852
+ # requirements for e.g. `past_key_values`). This function should work as is for most LLMs.
853
+ # """
854
+
855
+ # # 1. Handle BC:
856
+ # model_inputs = {}
857
+ # # - some models don't have `Cache` support (which implies they don't expect `cache_position` in `forward`)
858
+ # if self._supports_cache_class:
859
+ # model_inputs["cache_position"] = cache_position
860
+ # # - `cache_position` was not a mandatory input in `prepare_inputs_for_generation` for those models, and this
861
+ # # function may be called outside of `generate`. Handle most use cases by creating `cache_position` on the fly
862
+ # # (this alternative is not as robust as calling `generate` and letting it create `cache_position`)
863
+ # elif cache_position is None:
864
+ # past_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
865
+ # cache_position = torch.arange(past_length, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
866
+
867
+ # # 2. Generic cache-dependent input preparation
868
+ # # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
869
+ # # Exception 1: when passing input_embeds, input_ids may be missing entries
870
+ # # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
871
+ # # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case
872
+ # if past_key_values is not None:
873
+ # model_inputs["past_key_values"] = past_key_values
874
+ # if inputs_embeds is not None or cache_position[-1] >= input_ids.shape[1]: # Exception 1 or Exception 3
875
+ # input_ids = input_ids[:, -cache_position.shape[0] :]
876
+ # elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
877
+ # input_ids = input_ids[:, cache_position]
878
+
879
+ # # 3. Prepare base model inputs
880
+ # input_ids_key = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
881
+ # # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
882
+ # if not self.config.is_encoder_decoder:
883
+ # if inputs_embeds is not None and cache_position[0] == 0:
884
+ # model_inputs[input_ids_key] = None
885
+ # model_inputs["inputs_embeds"] = inputs_embeds
886
+ # else:
887
+ # # `clone` calls in this function ensure a consistent stride. See #32227
888
+ # model_inputs[input_ids_key] = input_ids.clone(memory_format=torch.contiguous_format)
889
+ # model_inputs["inputs_embeds"] = None
890
+ # else:
891
+ # model_inputs[input_ids_key] = input_ids.clone(memory_format=torch.contiguous_format)
892
+
893
+ # # 4. Create missing `position_ids` on the fly
894
+ # if (attention_mask is not None and kwargs.get("position_ids") is None and "position_ids" in set(inspect.signature(self.forward).parameters.keys())):
895
+ # position_ids = attention_mask.long().cumsum(-1) - 1
896
+ # position_ids.masked_fill_(attention_mask == 0, 1)
897
+ # kwargs["position_ids"] = position_ids # placed in kwargs for further processing (see below)
898
+
899
+ # # 5. Slice model inputs if it's an input that should have the same length as `input_ids`
900
+ # for model_input_name in ["position_ids", "token_type_ids"]:
901
+ # model_input = kwargs.get(model_input_name)
902
+ # if model_input is not None:
903
+ # if past_key_values:
904
+ # model_input = model_input[:, -input_ids.shape[1] :]
905
+ # model_input = model_input.clone(memory_format=torch.contiguous_format)
906
+ # model_inputs[model_input_name] = model_input
907
+
908
+ # # 6. Create 4D attention mask is we are using a `StaticCache` (important for performant compiled forward pass)
909
+ # if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
910
+ # if model_inputs["inputs_embeds"] is not None:
911
+ # batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
912
+ # device = model_inputs["inputs_embeds"].device
913
+ # else:
914
+ # batch_size, sequence_length = model_inputs[input_ids_key].shape
915
+ # device = model_inputs[input_ids_key].device
916
+
917
+ # # Create the causal mask with fixed shape in advance, to reduce recompilations. If the function to create
918
+ # # the 4D causal mask exists, it should be present in the base model (XXXModel class).
919
+ # base_model = getattr(self, self.base_model_prefix, None)
920
+ # if base_model is None:
921
+ # causal_mask_creation_function = getattr(
922
+ # self, "_prepare_4d_causal_attention_mask_with_cache_position", None
923
+ # )
924
+ # else:
925
+ # causal_mask_creation_function = getattr(
926
+ # base_model, "_prepare_4d_causal_attention_mask_with_cache_position", None
927
+ # )
928
+ # if causal_mask_creation_function is None:
929
+ # logger.warning_once(
930
+ # f"{self.__class__.__name__} has no `_prepare_4d_causal_attention_mask_with_cache_position` method "
931
+ # "defined in its base modeling class. Compiled forward passes will be sub-optimal. If you're "
932
+ # "writing code, see Llama for an example implementation. If you're a user, please report this "
933
+ # "issue on GitHub."
934
+ # )
935
+ # else:
936
+ # attention_mask = causal_mask_creation_function(
937
+ # attention_mask,
938
+ # sequence_length=sequence_length,
939
+ # target_length=past_key_values.get_max_cache_shape(),
940
+ # dtype=self.dtype,
941
+ # device=device,
942
+ # cache_position=cache_position,
943
+ # batch_size=batch_size,
944
+ # config=self.config,
945
+ # past_key_values=past_key_values,
946
+ # )
947
+ # if attention_mask is not None:
948
+ # model_inputs["attention_mask"] = attention_mask
949
+
950
+ # # 7. Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
951
+ # for key, value in kwargs.items():
952
+ # if key not in model_inputs:
953
+ # model_inputs[key] = value
954
+
955
+ # # 8. Remove unexpected `generate` inputs (TODO @joao: fix trainer and examples)
956
+ # model_inputs.pop("labels", None)
957
+ # return model_inputs
 
 
 
 
958
 
959
  @add_start_docstrings(
960
  """
 
1253
  end_logits=end_logits,
1254
  hidden_states=outputs.hidden_states,
1255
  attentions=outputs.attentions,
1256
+ )