Update modeling_Llamoe.py
Browse files- modeling_Llamoe.py +4 -2
modeling_Llamoe.py
CHANGED
@@ -575,9 +575,9 @@ class LlamoeSdpaAttention(LlamoeAttention):
|
|
575 |
|
576 |
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
|
577 |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None)
|
578 |
-
|
579 |
past_key_value = getattr(self, "past_key_value", past_key_value)
|
580 |
-
|
581 |
if past_key_value is not None:
|
582 |
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
583 |
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
@@ -585,6 +585,8 @@ class LlamoeSdpaAttention(LlamoeAttention):
|
|
585 |
|
586 |
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
587 |
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
|
|
|
|
588 |
|
589 |
causal_mask = attention_mask
|
590 |
if attention_mask is not None and cache_position is not None:
|
|
|
575 |
|
576 |
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
|
577 |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None)
|
578 |
+
|
579 |
past_key_value = getattr(self, "past_key_value", past_key_value)
|
580 |
+
|
581 |
if past_key_value is not None:
|
582 |
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
583 |
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
|
|
585 |
|
586 |
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
587 |
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
588 |
+
print("after_rb_key_states:",key_states)
|
589 |
+
print("after_rb_value_states:",value_states)
|
590 |
|
591 |
causal_mask = attention_mask
|
592 |
if attention_mask is not None and cache_position is not None:
|