damerajee commited on
Commit
ee7740f
·
verified ·
1 Parent(s): e34c61b

Update modeling_Llamoe.py

Browse files
Files changed (1) hide show
  1. modeling_Llamoe.py +4 -2
modeling_Llamoe.py CHANGED
@@ -575,9 +575,9 @@ class LlamoeSdpaAttention(LlamoeAttention):
575
 
576
  cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
577
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None)
578
-
579
  past_key_value = getattr(self, "past_key_value", past_key_value)
580
-
581
  if past_key_value is not None:
582
  # sin and cos are specific to RoPE models; cache_position needed for the static cache
583
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
@@ -585,6 +585,8 @@ class LlamoeSdpaAttention(LlamoeAttention):
585
 
586
  key_states = repeat_kv(key_states, self.num_key_value_groups)
587
  value_states = repeat_kv(value_states, self.num_key_value_groups)
 
 
588
 
589
  causal_mask = attention_mask
590
  if attention_mask is not None and cache_position is not None:
 
575
 
576
  cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
577
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None)
578
+
579
  past_key_value = getattr(self, "past_key_value", past_key_value)
580
+
581
  if past_key_value is not None:
582
  # sin and cos are specific to RoPE models; cache_position needed for the static cache
583
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
 
585
 
586
  key_states = repeat_kv(key_states, self.num_key_value_groups)
587
  value_states = repeat_kv(value_states, self.num_key_value_groups)
588
+ print("after_rb_key_states:",key_states)
589
+ print("after_rb_value_states:",value_states)
590
 
591
  causal_mask = attention_mask
592
  if attention_mask is not None and cache_position is not None: