damerajee commited on
Commit
c952105
·
verified ·
1 Parent(s): 7c59a8a

Update modeling_Llamoe.py

Browse files
Files changed (1) hide show
  1. modeling_Llamoe.py +3 -9
modeling_Llamoe.py CHANGED
@@ -562,13 +562,10 @@ class LlamoeSdpaAttention(LlamoeAttention):
562
  bsz, q_len, _ = hidden_states.size()
563
 
564
 
565
- print("hidden_states:",hidden_states.shape)
566
  query_states = self.q_proj(hidden_states)
567
  key_states = self.k_proj(hidden_states)
568
  value_states = self.v_proj(hidden_states)
569
- print("query_states:",query_states.shape)
570
- print("key_states:",key_states.shape)
571
- print("value_states:",value_states.shape)
572
  query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
573
  key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
574
  value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
@@ -585,15 +582,12 @@ class LlamoeSdpaAttention(LlamoeAttention):
585
 
586
  key_states = repeat_kv(key_states, self.num_key_value_groups)
587
  value_states = repeat_kv(value_states, self.num_key_value_groups)
588
- print("after_rb_key_states:",key_states)
589
- print("after_rb_value_states:",value_states)
590
 
591
  causal_mask = attention_mask
592
- print("causal_mask:",causal_mask)
593
  if attention_mask is not None and cache_position is not None:
594
  causal_mask = causal_mask[:, :, cache_position, : key_states.shape[-2]]
595
 
596
- print("after_causal_masks:",causal_mask)
597
  # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
598
  # Reference: https://github.com/pytorch/pytorch/issues/112577.
599
  if query_states.device.type == "cuda" and causal_mask is not None:
@@ -605,7 +599,7 @@ class LlamoeSdpaAttention(LlamoeAttention):
605
  query_states,
606
  key_states,
607
  value_states,
608
- attn_mask=causal_mask,
609
  dropout_p=self.attention_dropout if self.training else 0.0,
610
  )
611
 
 
562
  bsz, q_len, _ = hidden_states.size()
563
 
564
 
 
565
  query_states = self.q_proj(hidden_states)
566
  key_states = self.k_proj(hidden_states)
567
  value_states = self.v_proj(hidden_states)
568
+
 
 
569
  query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
570
  key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
571
  value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
 
582
 
583
  key_states = repeat_kv(key_states, self.num_key_value_groups)
584
  value_states = repeat_kv(value_states, self.num_key_value_groups)
585
+
 
586
 
587
  causal_mask = attention_mask
 
588
  if attention_mask is not None and cache_position is not None:
589
  causal_mask = causal_mask[:, :, cache_position, : key_states.shape[-2]]
590
 
 
591
  # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
592
  # Reference: https://github.com/pytorch/pytorch/issues/112577.
593
  if query_states.device.type == "cuda" and causal_mask is not None:
 
599
  query_states,
600
  key_states,
601
  value_states,
602
+ attn_mask=None,
603
  dropout_p=self.attention_dropout if self.training else 0.0,
604
  )
605