Update modeling_rwkv6qwen2.py
Browse files- modeling_rwkv6qwen2.py +181 -143
modeling_rwkv6qwen2.py
CHANGED
@@ -29,7 +29,7 @@ from torch import nn
|
|
29 |
import torch.nn.functional as F
|
30 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
31 |
|
32 |
-
from transformers.cache_utils import Cache, StaticCache
|
33 |
from transformers.generation import GenerationMixin
|
34 |
from transformers.modeling_outputs import (
|
35 |
BaseModelOutputWithPast,
|
@@ -209,7 +209,7 @@ try:
|
|
209 |
from fla.ops.gla.fused_recurrent import fused_recurrent_gla
|
210 |
except ImportError:
|
211 |
print("Required module is not installed. Please install it using the following commands:")
|
212 |
-
print("pip install -U git+https://github.com/
|
213 |
print("Additionally, ensure you have at least version 2.2.0 of Triton installed:")
|
214 |
print("pip install triton>=2.2.0")
|
215 |
|
@@ -230,7 +230,6 @@ class RWKV6Attention(nn.Module):
|
|
230 |
self.head_dim = getattr(config, 'head_dim', self.hidden_size // self.num_heads)
|
231 |
self.num_key_value_heads = config.num_key_value_heads
|
232 |
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
233 |
-
self.is_causal = True
|
234 |
self.attention_dropout = config.attention_dropout
|
235 |
|
236 |
if self.hidden_size % self.num_heads != 0:
|
@@ -284,7 +283,7 @@ class RWKV6Attention(nn.Module):
|
|
284 |
hidden_states: torch.Tensor,
|
285 |
attention_mask: Optional[torch.Tensor] = None,
|
286 |
position_ids: Optional[torch.LongTensor] = None,
|
287 |
-
|
288 |
output_attentions: bool = False,
|
289 |
use_cache: bool = False,
|
290 |
cache_position: Optional[torch.LongTensor] = None,
|
@@ -297,8 +296,8 @@ class RWKV6Attention(nn.Module):
|
|
297 |
|
298 |
x = hidden_states
|
299 |
|
300 |
-
if use_cache and
|
301 |
-
input_kv_state, input_shift_state =
|
302 |
xprev = torch.cat([input_shift_state, x[:, :-1]], dim=1)
|
303 |
else:
|
304 |
input_kv_state = None
|
@@ -334,9 +333,13 @@ class RWKV6Attention(nn.Module):
|
|
334 |
dropout_rate = 0.0 if not self.training else self.attention_dropout
|
335 |
|
336 |
decay_states_log = -decay_states.float().exp()
|
337 |
-
|
338 |
key_states = (key_states * (1 - decay_states_log.exp())).to(key_states.dtype)
|
339 |
|
|
|
|
|
|
|
|
|
340 |
query_states = query_states.to(value_states.dtype)
|
341 |
key_states = key_states.to(value_states.dtype)
|
342 |
|
@@ -366,19 +369,19 @@ class RWKV6Attention(nn.Module):
|
|
366 |
attn_weights = torch.empty(0, device=x.device)
|
367 |
|
368 |
scale = query_states.shape[-1] ** -0.5
|
369 |
-
output_final_state = not self.training and use_cache and
|
370 |
#attn_output, output_kv_state = ChunkGLAFunction.apply(query_states, key_states, value_states, decay_states_log.float(), scale, input_kv_state, output_final_state)
|
371 |
#attn_output, output_kv_state = chunk_gla(query_states, key_states, value_states, decay_states_log, scale, input_kv_state, output_final_state)
|
372 |
attn_output, output_kv_state = fused_recurrent_gla(query_states, key_states, value_states, decay_states_log, None, scale, input_kv_state, output_final_state)
|
373 |
|
374 |
if output_final_state:
|
375 |
-
|
376 |
|
377 |
attn_output = attn_output.transpose(1, 2).contiguous()
|
378 |
attn_output = attn_output.view(bsz, q_len, -1)
|
379 |
attn_output = self.o_proj(attn_output * gate_states)
|
380 |
|
381 |
-
return attn_output, attn_weights
|
382 |
|
383 |
class RWKV6Qwen2DecoderLayer(Qwen2DecoderLayer):
|
384 |
def __init__(self, config: RWKV6Qwen2Config, layer_idx: int):
|
@@ -391,6 +394,48 @@ class RWKV6Qwen2DecoderLayer(Qwen2DecoderLayer):
|
|
391 |
self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
392 |
self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
393 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
394 |
RWKV6QWEN2_START_DOCSTRING = r"""
|
395 |
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
396 |
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
@@ -581,6 +626,7 @@ class RWKV6Qwen2Model(RWKV6Qwen2PreTrainedModel):
|
|
581 |
#return_legacy_cache = False
|
582 |
if use_cache and not isinstance(past_key_values, RWKV6State):
|
583 |
#return_legacy_cache = True
|
|
|
584 |
past_key_values = RWKV6State()
|
585 |
# if past_key_values is None:
|
586 |
# past_key_values = DynamicCache()
|
@@ -638,9 +684,9 @@ class RWKV6Qwen2Model(RWKV6Qwen2PreTrainedModel):
|
|
638 |
else:
|
639 |
layer_outputs = decoder_layer(
|
640 |
hidden_states,
|
641 |
-
attention_mask=
|
642 |
position_ids=position_ids,
|
643 |
-
|
644 |
output_attentions=output_attentions,
|
645 |
use_cache=use_cache,
|
646 |
cache_position=cache_position,
|
@@ -649,9 +695,6 @@ class RWKV6Qwen2Model(RWKV6Qwen2PreTrainedModel):
|
|
649 |
|
650 |
hidden_states = layer_outputs[0]
|
651 |
|
652 |
-
if use_cache:
|
653 |
-
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
654 |
-
|
655 |
if output_attentions:
|
656 |
all_self_attns += (layer_outputs[1],)
|
657 |
|
@@ -661,15 +704,14 @@ class RWKV6Qwen2Model(RWKV6Qwen2PreTrainedModel):
|
|
661 |
if output_hidden_states:
|
662 |
all_hidden_states += (hidden_states,)
|
663 |
|
664 |
-
next_cache = next_decoder_cache if use_cache else None
|
665 |
#if return_legacy_cache:
|
666 |
# next_cache = next_cache.to_legacy_cache()
|
667 |
|
668 |
if not return_dict:
|
669 |
-
return tuple(v for v in [hidden_states,
|
670 |
return BaseModelOutputWithPast(
|
671 |
last_hidden_state=hidden_states,
|
672 |
-
past_key_values=
|
673 |
hidden_states=all_hidden_states,
|
674 |
attentions=all_self_attns,
|
675 |
)
|
@@ -793,130 +835,126 @@ class RWKV6Qwen2ForCausalLM(RWKV6Qwen2PreTrainedModel, GenerationMixin):
|
|
793 |
attentions=outputs.attentions,
|
794 |
)
|
795 |
|
796 |
-
def prepare_inputs_for_generation(
|
797 |
-
|
798 |
-
|
799 |
-
|
800 |
-
|
801 |
-
|
802 |
-
|
803 |
-
|
804 |
-
):
|
805 |
-
|
806 |
-
|
807 |
-
|
808 |
-
|
809 |
-
|
810 |
-
|
811 |
-
|
812 |
-
|
813 |
-
|
814 |
-
|
815 |
-
|
816 |
-
|
817 |
-
|
818 |
-
|
819 |
-
|
820 |
-
|
821 |
-
|
822 |
-
|
823 |
-
|
824 |
-
|
825 |
-
|
826 |
-
|
827 |
-
|
828 |
-
|
829 |
-
|
830 |
-
|
831 |
-
|
832 |
-
|
833 |
-
|
834 |
-
|
835 |
-
|
836 |
-
|
837 |
-
|
838 |
-
|
839 |
-
|
840 |
-
|
841 |
-
|
842 |
-
|
843 |
-
|
844 |
-
|
845 |
-
|
846 |
-
|
847 |
-
|
848 |
-
|
849 |
-
|
850 |
-
|
851 |
-
|
852 |
-
|
853 |
-
|
854 |
-
|
855 |
-
|
856 |
-
|
857 |
-
|
858 |
-
|
859 |
-
|
860 |
-
|
861 |
-
|
862 |
-
|
863 |
-
|
864 |
-
|
865 |
-
|
866 |
-
|
867 |
-
|
868 |
-
|
869 |
-
|
870 |
-
|
871 |
-
|
872 |
-
|
873 |
-
|
874 |
-
|
875 |
-
|
876 |
-
|
877 |
-
|
878 |
-
|
879 |
-
|
880 |
-
|
881 |
-
|
882 |
-
|
883 |
-
|
884 |
-
|
885 |
-
|
886 |
-
|
887 |
-
|
888 |
-
|
889 |
-
|
890 |
-
|
891 |
-
|
892 |
-
|
893 |
-
|
894 |
-
|
895 |
-
|
896 |
-
|
897 |
-
|
898 |
-
|
899 |
-
|
900 |
-
|
901 |
-
|
902 |
-
|
903 |
-
|
904 |
-
|
905 |
-
|
906 |
-
|
907 |
-
|
908 |
-
|
909 |
-
|
910 |
-
|
911 |
-
|
912 |
-
|
913 |
-
|
914 |
-
|
915 |
-
|
916 |
-
|
917 |
-
# 8. Remove unexpected `generate` inputs (TODO @joao: fix trainer and examples)
|
918 |
-
model_inputs.pop("labels", None)
|
919 |
-
return model_inputs
|
920 |
|
921 |
@add_start_docstrings(
|
922 |
"""
|
@@ -1215,4 +1253,4 @@ class RWKV6Qwen2ForQuestionAnswering(RWKV6Qwen2PreTrainedModel):
|
|
1215 |
end_logits=end_logits,
|
1216 |
hidden_states=outputs.hidden_states,
|
1217 |
attentions=outputs.attentions,
|
1218 |
-
)
|
|
|
29 |
import torch.nn.functional as F
|
30 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
31 |
|
32 |
+
from transformers.cache_utils import Cache, StaticCache, DynamicCache
|
33 |
from transformers.generation import GenerationMixin
|
34 |
from transformers.modeling_outputs import (
|
35 |
BaseModelOutputWithPast,
|
|
|
209 |
from fla.ops.gla.fused_recurrent import fused_recurrent_gla
|
210 |
except ImportError:
|
211 |
print("Required module is not installed. Please install it using the following commands:")
|
212 |
+
print("pip install -U git+https://github.com/fla-org/flash-linear-attention")
|
213 |
print("Additionally, ensure you have at least version 2.2.0 of Triton installed:")
|
214 |
print("pip install triton>=2.2.0")
|
215 |
|
|
|
230 |
self.head_dim = getattr(config, 'head_dim', self.hidden_size // self.num_heads)
|
231 |
self.num_key_value_heads = config.num_key_value_heads
|
232 |
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
|
|
233 |
self.attention_dropout = config.attention_dropout
|
234 |
|
235 |
if self.hidden_size % self.num_heads != 0:
|
|
|
283 |
hidden_states: torch.Tensor,
|
284 |
attention_mask: Optional[torch.Tensor] = None,
|
285 |
position_ids: Optional[torch.LongTensor] = None,
|
286 |
+
past_key_values: Optional[RWKV6State] = None,
|
287 |
output_attentions: bool = False,
|
288 |
use_cache: bool = False,
|
289 |
cache_position: Optional[torch.LongTensor] = None,
|
|
|
296 |
|
297 |
x = hidden_states
|
298 |
|
299 |
+
if use_cache and past_key_values is not None and len(past_key_values) > self.layer_idx:
|
300 |
+
input_kv_state, input_shift_state = past_key_values[self.layer_idx]
|
301 |
xprev = torch.cat([input_shift_state, x[:, :-1]], dim=1)
|
302 |
else:
|
303 |
input_kv_state = None
|
|
|
333 |
dropout_rate = 0.0 if not self.training else self.attention_dropout
|
334 |
|
335 |
decay_states_log = -decay_states.float().exp()
|
336 |
+
decay_states_log = decay_states_log.clamp(-5) # FIXME - is this necessary?
|
337 |
key_states = (key_states * (1 - decay_states_log.exp())).to(key_states.dtype)
|
338 |
|
339 |
+
if attention_mask is not None:
|
340 |
+
if q_len > 1:
|
341 |
+
decay_states_log = decay_states_log - 100 * F.pad(1 - attention_mask, [1, -1]).view(bsz, 1, q_len, 1)
|
342 |
+
|
343 |
query_states = query_states.to(value_states.dtype)
|
344 |
key_states = key_states.to(value_states.dtype)
|
345 |
|
|
|
369 |
attn_weights = torch.empty(0, device=x.device)
|
370 |
|
371 |
scale = query_states.shape[-1] ** -0.5
|
372 |
+
output_final_state = not self.training and use_cache and past_key_values is not None
|
373 |
#attn_output, output_kv_state = ChunkGLAFunction.apply(query_states, key_states, value_states, decay_states_log.float(), scale, input_kv_state, output_final_state)
|
374 |
#attn_output, output_kv_state = chunk_gla(query_states, key_states, value_states, decay_states_log, scale, input_kv_state, output_final_state)
|
375 |
attn_output, output_kv_state = fused_recurrent_gla(query_states, key_states, value_states, decay_states_log, None, scale, input_kv_state, output_final_state)
|
376 |
|
377 |
if output_final_state:
|
378 |
+
past_key_values.update(output_kv_state, output_shift_state, q_len, self.layer_idx)
|
379 |
|
380 |
attn_output = attn_output.transpose(1, 2).contiguous()
|
381 |
attn_output = attn_output.view(bsz, q_len, -1)
|
382 |
attn_output = self.o_proj(attn_output * gate_states)
|
383 |
|
384 |
+
return attn_output, attn_weights
|
385 |
|
386 |
class RWKV6Qwen2DecoderLayer(Qwen2DecoderLayer):
|
387 |
def __init__(self, config: RWKV6Qwen2Config, layer_idx: int):
|
|
|
394 |
self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
395 |
self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
396 |
|
397 |
+
def forward(
|
398 |
+
self,
|
399 |
+
hidden_states: torch.Tensor,
|
400 |
+
attention_mask: Optional[torch.Tensor] = None,
|
401 |
+
position_ids: Optional[torch.LongTensor] = None,
|
402 |
+
past_key_values: Optional[Cache] = None,
|
403 |
+
output_attentions: Optional[bool] = False,
|
404 |
+
use_cache: Optional[bool] = False,
|
405 |
+
cache_position: Optional[torch.LongTensor] = None,
|
406 |
+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
407 |
+
**kwargs,
|
408 |
+
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
409 |
+
residual = hidden_states
|
410 |
+
|
411 |
+
hidden_states = self.input_layernorm(hidden_states)
|
412 |
+
|
413 |
+
# Self Attention
|
414 |
+
hidden_states, self_attn_weights = self.self_attn(
|
415 |
+
hidden_states=hidden_states,
|
416 |
+
attention_mask=attention_mask,
|
417 |
+
position_ids=position_ids,
|
418 |
+
past_key_values=past_key_values,
|
419 |
+
output_attentions=output_attentions,
|
420 |
+
use_cache=use_cache,
|
421 |
+
cache_position=cache_position,
|
422 |
+
position_embeddings=position_embeddings,
|
423 |
+
**kwargs,
|
424 |
+
)
|
425 |
+
hidden_states = residual + hidden_states
|
426 |
+
|
427 |
+
# Fully Connected
|
428 |
+
residual = hidden_states
|
429 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
430 |
+
hidden_states = self.mlp(hidden_states)
|
431 |
+
hidden_states = residual + hidden_states
|
432 |
+
|
433 |
+
outputs = (hidden_states,)
|
434 |
+
if output_attentions:
|
435 |
+
outputs += (self_attn_weights,)
|
436 |
+
|
437 |
+
return outputs
|
438 |
+
|
439 |
RWKV6QWEN2_START_DOCSTRING = r"""
|
440 |
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
441 |
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
|
|
626 |
#return_legacy_cache = False
|
627 |
if use_cache and not isinstance(past_key_values, RWKV6State):
|
628 |
#return_legacy_cache = True
|
629 |
+
print("creating past_key_values", past_key_values)
|
630 |
past_key_values = RWKV6State()
|
631 |
# if past_key_values is None:
|
632 |
# past_key_values = DynamicCache()
|
|
|
684 |
else:
|
685 |
layer_outputs = decoder_layer(
|
686 |
hidden_states,
|
687 |
+
attention_mask=attention_mask,
|
688 |
position_ids=position_ids,
|
689 |
+
past_key_values=past_key_values,
|
690 |
output_attentions=output_attentions,
|
691 |
use_cache=use_cache,
|
692 |
cache_position=cache_position,
|
|
|
695 |
|
696 |
hidden_states = layer_outputs[0]
|
697 |
|
|
|
|
|
|
|
698 |
if output_attentions:
|
699 |
all_self_attns += (layer_outputs[1],)
|
700 |
|
|
|
704 |
if output_hidden_states:
|
705 |
all_hidden_states += (hidden_states,)
|
706 |
|
|
|
707 |
#if return_legacy_cache:
|
708 |
# next_cache = next_cache.to_legacy_cache()
|
709 |
|
710 |
if not return_dict:
|
711 |
+
return tuple(v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns] if v is not None)
|
712 |
return BaseModelOutputWithPast(
|
713 |
last_hidden_state=hidden_states,
|
714 |
+
past_key_values=past_key_values,
|
715 |
hidden_states=all_hidden_states,
|
716 |
attentions=all_self_attns,
|
717 |
)
|
|
|
835 |
attentions=outputs.attentions,
|
836 |
)
|
837 |
|
838 |
+
# def prepare_inputs_for_generation(
|
839 |
+
# self,
|
840 |
+
# input_ids: torch.LongTensor,
|
841 |
+
# past_key_values: Optional[Cache] = None,
|
842 |
+
# attention_mask: Optional[torch.LongTensor] = None,
|
843 |
+
# inputs_embeds: Optional[torch.FloatTensor] = None,
|
844 |
+
# cache_position: Optional[torch.LongTensor] = None,
|
845 |
+
# **kwargs,
|
846 |
+
# ):
|
847 |
+
# """
|
848 |
+
# Prepare the model inputs for generation. In includes operations like computing the 4D attention mask or
|
849 |
+
# slicing inputs given the existing cache.
|
850 |
+
|
851 |
+
# See the forward pass in the model documentation for expected arguments (different models might have different
|
852 |
+
# requirements for e.g. `past_key_values`). This function should work as is for most LLMs.
|
853 |
+
# """
|
854 |
+
|
855 |
+
# # 1. Handle BC:
|
856 |
+
# model_inputs = {}
|
857 |
+
# # - some models don't have `Cache` support (which implies they don't expect `cache_position` in `forward`)
|
858 |
+
# if self._supports_cache_class:
|
859 |
+
# model_inputs["cache_position"] = cache_position
|
860 |
+
# # - `cache_position` was not a mandatory input in `prepare_inputs_for_generation` for those models, and this
|
861 |
+
# # function may be called outside of `generate`. Handle most use cases by creating `cache_position` on the fly
|
862 |
+
# # (this alternative is not as robust as calling `generate` and letting it create `cache_position`)
|
863 |
+
# elif cache_position is None:
|
864 |
+
# past_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
865 |
+
# cache_position = torch.arange(past_length, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
|
866 |
+
|
867 |
+
# # 2. Generic cache-dependent input preparation
|
868 |
+
# # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
869 |
+
# # Exception 1: when passing input_embeds, input_ids may be missing entries
|
870 |
+
# # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
871 |
+
# # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case
|
872 |
+
# if past_key_values is not None:
|
873 |
+
# model_inputs["past_key_values"] = past_key_values
|
874 |
+
# if inputs_embeds is not None or cache_position[-1] >= input_ids.shape[1]: # Exception 1 or Exception 3
|
875 |
+
# input_ids = input_ids[:, -cache_position.shape[0] :]
|
876 |
+
# elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
|
877 |
+
# input_ids = input_ids[:, cache_position]
|
878 |
+
|
879 |
+
# # 3. Prepare base model inputs
|
880 |
+
# input_ids_key = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
|
881 |
+
# # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
882 |
+
# if not self.config.is_encoder_decoder:
|
883 |
+
# if inputs_embeds is not None and cache_position[0] == 0:
|
884 |
+
# model_inputs[input_ids_key] = None
|
885 |
+
# model_inputs["inputs_embeds"] = inputs_embeds
|
886 |
+
# else:
|
887 |
+
# # `clone` calls in this function ensure a consistent stride. See #32227
|
888 |
+
# model_inputs[input_ids_key] = input_ids.clone(memory_format=torch.contiguous_format)
|
889 |
+
# model_inputs["inputs_embeds"] = None
|
890 |
+
# else:
|
891 |
+
# model_inputs[input_ids_key] = input_ids.clone(memory_format=torch.contiguous_format)
|
892 |
+
|
893 |
+
# # 4. Create missing `position_ids` on the fly
|
894 |
+
# if (attention_mask is not None and kwargs.get("position_ids") is None and "position_ids" in set(inspect.signature(self.forward).parameters.keys())):
|
895 |
+
# position_ids = attention_mask.long().cumsum(-1) - 1
|
896 |
+
# position_ids.masked_fill_(attention_mask == 0, 1)
|
897 |
+
# kwargs["position_ids"] = position_ids # placed in kwargs for further processing (see below)
|
898 |
+
|
899 |
+
# # 5. Slice model inputs if it's an input that should have the same length as `input_ids`
|
900 |
+
# for model_input_name in ["position_ids", "token_type_ids"]:
|
901 |
+
# model_input = kwargs.get(model_input_name)
|
902 |
+
# if model_input is not None:
|
903 |
+
# if past_key_values:
|
904 |
+
# model_input = model_input[:, -input_ids.shape[1] :]
|
905 |
+
# model_input = model_input.clone(memory_format=torch.contiguous_format)
|
906 |
+
# model_inputs[model_input_name] = model_input
|
907 |
+
|
908 |
+
# # 6. Create 4D attention mask is we are using a `StaticCache` (important for performant compiled forward pass)
|
909 |
+
# if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
|
910 |
+
# if model_inputs["inputs_embeds"] is not None:
|
911 |
+
# batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
|
912 |
+
# device = model_inputs["inputs_embeds"].device
|
913 |
+
# else:
|
914 |
+
# batch_size, sequence_length = model_inputs[input_ids_key].shape
|
915 |
+
# device = model_inputs[input_ids_key].device
|
916 |
+
|
917 |
+
# # Create the causal mask with fixed shape in advance, to reduce recompilations. If the function to create
|
918 |
+
# # the 4D causal mask exists, it should be present in the base model (XXXModel class).
|
919 |
+
# base_model = getattr(self, self.base_model_prefix, None)
|
920 |
+
# if base_model is None:
|
921 |
+
# causal_mask_creation_function = getattr(
|
922 |
+
# self, "_prepare_4d_causal_attention_mask_with_cache_position", None
|
923 |
+
# )
|
924 |
+
# else:
|
925 |
+
# causal_mask_creation_function = getattr(
|
926 |
+
# base_model, "_prepare_4d_causal_attention_mask_with_cache_position", None
|
927 |
+
# )
|
928 |
+
# if causal_mask_creation_function is None:
|
929 |
+
# logger.warning_once(
|
930 |
+
# f"{self.__class__.__name__} has no `_prepare_4d_causal_attention_mask_with_cache_position` method "
|
931 |
+
# "defined in its base modeling class. Compiled forward passes will be sub-optimal. If you're "
|
932 |
+
# "writing code, see Llama for an example implementation. If you're a user, please report this "
|
933 |
+
# "issue on GitHub."
|
934 |
+
# )
|
935 |
+
# else:
|
936 |
+
# attention_mask = causal_mask_creation_function(
|
937 |
+
# attention_mask,
|
938 |
+
# sequence_length=sequence_length,
|
939 |
+
# target_length=past_key_values.get_max_cache_shape(),
|
940 |
+
# dtype=self.dtype,
|
941 |
+
# device=device,
|
942 |
+
# cache_position=cache_position,
|
943 |
+
# batch_size=batch_size,
|
944 |
+
# config=self.config,
|
945 |
+
# past_key_values=past_key_values,
|
946 |
+
# )
|
947 |
+
# if attention_mask is not None:
|
948 |
+
# model_inputs["attention_mask"] = attention_mask
|
949 |
+
|
950 |
+
# # 7. Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
|
951 |
+
# for key, value in kwargs.items():
|
952 |
+
# if key not in model_inputs:
|
953 |
+
# model_inputs[key] = value
|
954 |
+
|
955 |
+
# # 8. Remove unexpected `generate` inputs (TODO @joao: fix trainer and examples)
|
956 |
+
# model_inputs.pop("labels", None)
|
957 |
+
# return model_inputs
|
|
|
|
|
|
|
|
|
958 |
|
959 |
@add_start_docstrings(
|
960 |
"""
|
|
|
1253 |
end_logits=end_logits,
|
1254 |
hidden_states=outputs.hidden_states,
|
1255 |
attentions=outputs.attentions,
|
1256 |
+
)
|