Upload modeling_qwen2_anotated.py with huggingface_hub
Browse files- modeling_qwen2_anotated.py +13 -8
modeling_qwen2_anotated.py
CHANGED
@@ -542,12 +542,10 @@ class Qwen2SdpaAttention(Qwen2Attention):
|
|
542 |
query_states = self.q_proj(hidden_states)
|
543 |
key_states = self.k_proj(hidden_states)
|
544 |
value_states = self.v_proj(hidden_states)
|
545 |
-
#KEVINDEBUG query_states: torch.Size([1, 5|1, 896]) keystates:torch.Size([1, 5|1, 128]) value_states: torch.Size([1, 5|1, 128])
|
546 |
|
547 |
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
548 |
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
549 |
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
550 |
-
#KEVINDEBUG query_states: torch.Size([1, 14, 5|1, 64]) key_states:torch.Size([1, 2, 5|1, 64]) value_states: torch.Size([1, 2, 5|1, 64])
|
551 |
|
552 |
kv_seq_len = key_states.shape[-2]
|
553 |
if past_key_value is not None:
|
@@ -555,40 +553,47 @@ class Qwen2SdpaAttention(Qwen2Attention):
|
|
555 |
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
556 |
|
557 |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
558 |
-
|
559 |
if past_key_value is not None:
|
560 |
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
561 |
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
562 |
-
#past_key_value class info : $<class 'transformers.cache_utils.DynamicCache'>
|
563 |
-
#KEVINDEBUG past_key_value is not None key_states: torch.Size([1, 2, 5...19, 64]) value_states: torch.Size([1, 2, 5..19, 64])
|
564 |
|
565 |
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
566 |
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
567 |
-
#KEVINDEBUG key_states: torch.Size([1, 14, 5..19, 64]) value_states: torch.Size([1, 14, 5..19, 64])
|
568 |
|
569 |
causal_mask = attention_mask
|
570 |
if attention_mask is not None: # no matter the length, we just slice it
|
571 |
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
572 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
573 |
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
574 |
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
575 |
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
|
576 |
is_causal = True if causal_mask is None and q_len > 1 else False
|
577 |
|
|
|
|
|
|
|
|
|
578 |
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
579 |
query_states,
|
580 |
key_states,
|
581 |
value_states,
|
582 |
attn_mask=causal_mask,
|
583 |
dropout_p=self.attention_dropout if self.training else 0.0,
|
584 |
-
is_causal=
|
585 |
)
|
586 |
|
587 |
attn_output = attn_output.transpose(1, 2).contiguous()
|
588 |
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
|
589 |
|
590 |
attn_output = self.o_proj(attn_output)
|
591 |
-
#KEVINDEBUG attn_output is $torch.Size([1, 5|1, 896])
|
592 |
|
593 |
return attn_output, None, past_key_value
|
594 |
|
|
|
542 |
query_states = self.q_proj(hidden_states)
|
543 |
key_states = self.k_proj(hidden_states)
|
544 |
value_states = self.v_proj(hidden_states)
|
|
|
545 |
|
546 |
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
547 |
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
548 |
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
|
549 |
|
550 |
kv_seq_len = key_states.shape[-2]
|
551 |
if past_key_value is not None:
|
|
|
553 |
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
554 |
|
555 |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
556 |
+
|
557 |
if past_key_value is not None:
|
558 |
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
559 |
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
|
|
|
|
560 |
|
561 |
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
562 |
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
|
|
563 |
|
564 |
causal_mask = attention_mask
|
565 |
if attention_mask is not None: # no matter the length, we just slice it
|
566 |
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
567 |
|
568 |
+
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
569 |
+
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
570 |
+
if query_states.device.type == "cuda" and attention_mask is not None:
|
571 |
+
query_states = query_states.contiguous()
|
572 |
+
key_states = key_states.contiguous()
|
573 |
+
value_states = value_states.contiguous()
|
574 |
+
|
575 |
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
576 |
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
577 |
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
|
578 |
is_causal = True if causal_mask is None and q_len > 1 else False
|
579 |
|
580 |
+
if is_causal:
|
581 |
+
L, S = query_states.size(-2), key_states.size(-2)
|
582 |
+
causal_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
|
583 |
+
|
584 |
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
585 |
query_states,
|
586 |
key_states,
|
587 |
value_states,
|
588 |
attn_mask=causal_mask,
|
589 |
dropout_p=self.attention_dropout if self.training else 0.0,
|
590 |
+
is_causal=False,
|
591 |
)
|
592 |
|
593 |
attn_output = attn_output.transpose(1, 2).contiguous()
|
594 |
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
|
595 |
|
596 |
attn_output = self.o_proj(attn_output)
|
|
|
597 |
|
598 |
return attn_output, None, past_key_value
|
599 |
|