kevin36524 commited on
Commit
39ac59c
·
verified ·
1 Parent(s): 257e5c9

Upload modeling_qwen2_anotated.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_qwen2_anotated.py +13 -8
modeling_qwen2_anotated.py CHANGED
@@ -542,12 +542,10 @@ class Qwen2SdpaAttention(Qwen2Attention):
542
  query_states = self.q_proj(hidden_states)
543
  key_states = self.k_proj(hidden_states)
544
  value_states = self.v_proj(hidden_states)
545
- #KEVINDEBUG query_states: torch.Size([1, 5|1, 896]) keystates:torch.Size([1, 5|1, 128]) value_states: torch.Size([1, 5|1, 128])
546
 
547
  query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
548
  key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
549
  value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
550
- #KEVINDEBUG query_states: torch.Size([1, 14, 5|1, 64]) key_states:torch.Size([1, 2, 5|1, 64]) value_states: torch.Size([1, 2, 5|1, 64])
551
 
552
  kv_seq_len = key_states.shape[-2]
553
  if past_key_value is not None:
@@ -555,40 +553,47 @@ class Qwen2SdpaAttention(Qwen2Attention):
555
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
556
 
557
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
558
- #KEVINDEBUG query_states: torch.Size([1, 14, 5|1, 64]) key_states:torch.Size([1, 2, 5|1, 64]) position_ids: torch.Size([1, 5|1])
559
  if past_key_value is not None:
560
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
561
  key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
562
- #past_key_value class info : $<class 'transformers.cache_utils.DynamicCache'>
563
- #KEVINDEBUG past_key_value is not None key_states: torch.Size([1, 2, 5...19, 64]) value_states: torch.Size([1, 2, 5..19, 64])
564
 
565
  key_states = repeat_kv(key_states, self.num_key_value_groups)
566
  value_states = repeat_kv(value_states, self.num_key_value_groups)
567
- #KEVINDEBUG key_states: torch.Size([1, 14, 5..19, 64]) value_states: torch.Size([1, 14, 5..19, 64])
568
 
569
  causal_mask = attention_mask
570
  if attention_mask is not None: # no matter the length, we just slice it
571
  causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
572
 
 
 
 
 
 
 
 
573
  # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
574
  # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
575
  # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
576
  is_causal = True if causal_mask is None and q_len > 1 else False
577
 
 
 
 
 
578
  attn_output = torch.nn.functional.scaled_dot_product_attention(
579
  query_states,
580
  key_states,
581
  value_states,
582
  attn_mask=causal_mask,
583
  dropout_p=self.attention_dropout if self.training else 0.0,
584
- is_causal=is_causal,
585
  )
586
 
587
  attn_output = attn_output.transpose(1, 2).contiguous()
588
  attn_output = attn_output.view(bsz, q_len, self.hidden_size)
589
 
590
  attn_output = self.o_proj(attn_output)
591
- #KEVINDEBUG attn_output is $torch.Size([1, 5|1, 896])
592
 
593
  return attn_output, None, past_key_value
594
 
 
542
  query_states = self.q_proj(hidden_states)
543
  key_states = self.k_proj(hidden_states)
544
  value_states = self.v_proj(hidden_states)
 
545
 
546
  query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
547
  key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
548
  value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
 
549
 
550
  kv_seq_len = key_states.shape[-2]
551
  if past_key_value is not None:
 
553
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
554
 
555
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
556
+
557
  if past_key_value is not None:
558
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
559
  key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
 
 
560
 
561
  key_states = repeat_kv(key_states, self.num_key_value_groups)
562
  value_states = repeat_kv(value_states, self.num_key_value_groups)
 
563
 
564
  causal_mask = attention_mask
565
  if attention_mask is not None: # no matter the length, we just slice it
566
  causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
567
 
568
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
569
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
570
+ if query_states.device.type == "cuda" and attention_mask is not None:
571
+ query_states = query_states.contiguous()
572
+ key_states = key_states.contiguous()
573
+ value_states = value_states.contiguous()
574
+
575
  # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
576
  # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
577
  # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
578
  is_causal = True if causal_mask is None and q_len > 1 else False
579
 
580
+ if is_causal:
581
+ L, S = query_states.size(-2), key_states.size(-2)
582
+ causal_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
583
+
584
  attn_output = torch.nn.functional.scaled_dot_product_attention(
585
  query_states,
586
  key_states,
587
  value_states,
588
  attn_mask=causal_mask,
589
  dropout_p=self.attention_dropout if self.training else 0.0,
590
+ is_causal=False,
591
  )
592
 
593
  attn_output = attn_output.transpose(1, 2).contiguous()
594
  attn_output = attn_output.view(bsz, q_len, self.hidden_size)
595
 
596
  attn_output = self.o_proj(attn_output)
 
597
 
598
  return attn_output, None, past_key_value
599