Update modeling_phi3.py
Browse files- modeling_phi3.py +18 -28
modeling_phi3.py
CHANGED
@@ -25,6 +25,7 @@ import torch.nn.functional as F
|
|
25 |
import torch.utils.checkpoint
|
26 |
from torch import nn
|
27 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
|
|
28 |
from transformers.activations import ACT2FN
|
29 |
from transformers.cache_utils import Cache, DynamicCache
|
30 |
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
@@ -43,9 +44,9 @@ from transformers.utils import (
|
|
43 |
logging,
|
44 |
replace_return_docstrings,
|
45 |
)
|
46 |
-
|
47 |
from .configuration_phi3 import Phi3Config
|
48 |
|
|
|
49 |
logger = logging.get_logger(__name__)
|
50 |
|
51 |
# Transformers scans dependencies in the modeling file, causing issues on conditional loading. The regex only ignores try/catch blocks, but not if statements
|
@@ -86,7 +87,7 @@ PHI3_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
|
86 |
|
87 |
# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Phi3
|
88 |
class Phi3RMSNorm(nn.Module):
|
89 |
-
def __init__(self, hidden_size, eps=1e-
|
90 |
"""
|
91 |
Phi3RMSNorm is equivalent to T5LayerNorm
|
92 |
"""
|
@@ -120,7 +121,7 @@ def _get_unpad_data(attention_mask):
|
|
120 |
|
121 |
# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Phi3
|
122 |
class Phi3RotaryEmbedding(nn.Module):
|
123 |
-
def __init__(self, dim, max_position_embeddings=
|
124 |
super().__init__()
|
125 |
|
126 |
self.dim = dim
|
@@ -228,7 +229,6 @@ def rotate_half(x):
|
|
228 |
return torch.cat((-x2, x1), dim=-1)
|
229 |
|
230 |
|
231 |
-
# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
|
232 |
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
233 |
"""Applies Rotary Position Embedding to the query and key tensors.
|
234 |
|
@@ -608,7 +608,7 @@ class Phi3FlashAttention2(Phi3Attention):
|
|
608 |
|
609 |
return attn_output, attn_weights, past_key_value
|
610 |
|
611 |
-
# Copied from transformers.models.
|
612 |
def _flash_attention_forward(
|
613 |
self,
|
614 |
query_states,
|
@@ -650,14 +650,9 @@ class Phi3FlashAttention2(Phi3Attention):
|
|
650 |
# Contains at least one padding token in the sequence
|
651 |
if attention_mask is not None:
|
652 |
batch_size = query_states.shape[0]
|
653 |
-
(
|
654 |
-
query_states,
|
655 |
-
|
656 |
-
value_states,
|
657 |
-
indices_q,
|
658 |
-
cu_seq_lens,
|
659 |
-
max_seq_lens,
|
660 |
-
) = self._upad_input(query_states, key_states, value_states, attention_mask, query_length)
|
661 |
|
662 |
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
663 |
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
@@ -687,10 +682,7 @@ class Phi3FlashAttention2(Phi3Attention):
|
|
687 |
dropout_p=dropout,
|
688 |
softmax_scale=softmax_scale,
|
689 |
causal=causal,
|
690 |
-
window_size=(
|
691 |
-
self.config.sliding_window,
|
692 |
-
self.config.sliding_window,
|
693 |
-
),
|
694 |
)
|
695 |
|
696 |
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
|
@@ -712,15 +704,12 @@ class Phi3FlashAttention2(Phi3Attention):
|
|
712 |
dropout,
|
713 |
softmax_scale=softmax_scale,
|
714 |
causal=causal,
|
715 |
-
window_size=(
|
716 |
-
self.config.sliding_window,
|
717 |
-
self.config.sliding_window,
|
718 |
-
),
|
719 |
)
|
720 |
|
721 |
return attn_output
|
722 |
|
723 |
-
# Copied from transformers.models.
|
724 |
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
|
725 |
batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
|
726 |
|
@@ -737,8 +726,7 @@ class Phi3FlashAttention2(Phi3Attention):
|
|
737 |
|
738 |
if query_length == kv_seq_len:
|
739 |
query_layer = index_first_axis(
|
740 |
-
query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim),
|
741 |
-
indices_k,
|
742 |
)
|
743 |
cu_seqlens_q = cu_seqlens_k
|
744 |
max_seqlen_in_batch_q = max_seqlen_in_batch_k
|
@@ -1233,7 +1221,7 @@ class Phi3Model(Phi3PreTrainedModel):
|
|
1233 |
class Phi3ForCausalLM(Phi3PreTrainedModel):
|
1234 |
_tied_weights_keys = ["lm_head.weight"]
|
1235 |
|
1236 |
-
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__ with Llama->Phi3
|
1237 |
def __init__(self, config):
|
1238 |
super().__init__(config)
|
1239 |
self.model = Phi3Model(config)
|
@@ -1439,7 +1427,7 @@ class Phi3ForCausalLM(Phi3PreTrainedModel):
|
|
1439 |
""",
|
1440 |
PHI3_START_DOCSTRING,
|
1441 |
)
|
1442 |
-
# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Phi3
|
1443 |
class Phi3ForSequenceClassification(Phi3PreTrainedModel):
|
1444 |
def __init__(self, config):
|
1445 |
super().__init__(config)
|
@@ -1555,7 +1543,7 @@ class Phi3ForSequenceClassification(Phi3PreTrainedModel):
|
|
1555 |
""",
|
1556 |
PHI3_START_DOCSTRING,
|
1557 |
)
|
1558 |
-
# Copied from transformers.models.mpt.modeling_mpt.MptForTokenClassification with Mpt->Phi3,self.transformer->self.model,transformer_outputs->model_outputs
|
1559 |
class Phi3ForTokenClassification(Phi3PreTrainedModel):
|
1560 |
def __init__(self, config: Phi3Config):
|
1561 |
super().__init__(config)
|
@@ -1622,7 +1610,9 @@ class Phi3ForTokenClassification(Phi3PreTrainedModel):
|
|
1622 |
labels = labels.to(logits.device)
|
1623 |
batch_size, seq_length = labels.shape
|
1624 |
loss_fct = CrossEntropyLoss()
|
1625 |
-
loss = loss_fct(
|
|
|
|
|
1626 |
|
1627 |
if not return_dict:
|
1628 |
output = (logits,) + model_outputs[2:]
|
|
|
25 |
import torch.utils.checkpoint
|
26 |
from torch import nn
|
27 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
28 |
+
|
29 |
from transformers.activations import ACT2FN
|
30 |
from transformers.cache_utils import Cache, DynamicCache
|
31 |
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
|
|
44 |
logging,
|
45 |
replace_return_docstrings,
|
46 |
)
|
|
|
47 |
from .configuration_phi3 import Phi3Config
|
48 |
|
49 |
+
|
50 |
logger = logging.get_logger(__name__)
|
51 |
|
52 |
# Transformers scans dependencies in the modeling file, causing issues on conditional loading. The regex only ignores try/catch blocks, but not if statements
|
|
|
87 |
|
88 |
# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Phi3
|
89 |
class Phi3RMSNorm(nn.Module):
|
90 |
+
def __init__(self, hidden_size, eps=1e-6):
|
91 |
"""
|
92 |
Phi3RMSNorm is equivalent to T5LayerNorm
|
93 |
"""
|
|
|
121 |
|
122 |
# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Phi3
|
123 |
class Phi3RotaryEmbedding(nn.Module):
|
124 |
+
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
125 |
super().__init__()
|
126 |
|
127 |
self.dim = dim
|
|
|
229 |
return torch.cat((-x2, x1), dim=-1)
|
230 |
|
231 |
|
|
|
232 |
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
233 |
"""Applies Rotary Position Embedding to the query and key tensors.
|
234 |
|
|
|
608 |
|
609 |
return attn_output, attn_weights, past_key_value
|
610 |
|
611 |
+
# Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2._flash_attention_forward
|
612 |
def _flash_attention_forward(
|
613 |
self,
|
614 |
query_states,
|
|
|
650 |
# Contains at least one padding token in the sequence
|
651 |
if attention_mask is not None:
|
652 |
batch_size = query_states.shape[0]
|
653 |
+
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
|
654 |
+
query_states, key_states, value_states, attention_mask, query_length
|
655 |
+
)
|
|
|
|
|
|
|
|
|
|
|
656 |
|
657 |
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
658 |
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
|
|
682 |
dropout_p=dropout,
|
683 |
softmax_scale=softmax_scale,
|
684 |
causal=causal,
|
685 |
+
window_size=(self.config.sliding_window, self.config.sliding_window),
|
|
|
|
|
|
|
686 |
)
|
687 |
|
688 |
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
|
|
|
704 |
dropout,
|
705 |
softmax_scale=softmax_scale,
|
706 |
causal=causal,
|
707 |
+
window_size=(self.config.sliding_window, self.config.sliding_window),
|
|
|
|
|
|
|
708 |
)
|
709 |
|
710 |
return attn_output
|
711 |
|
712 |
+
# Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2._upad_input
|
713 |
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
|
714 |
batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
|
715 |
|
|
|
726 |
|
727 |
if query_length == kv_seq_len:
|
728 |
query_layer = index_first_axis(
|
729 |
+
query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
|
|
|
730 |
)
|
731 |
cu_seqlens_q = cu_seqlens_k
|
732 |
max_seqlen_in_batch_q = max_seqlen_in_batch_k
|
|
|
1221 |
class Phi3ForCausalLM(Phi3PreTrainedModel):
|
1222 |
_tied_weights_keys = ["lm_head.weight"]
|
1223 |
|
1224 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__ with Llama->Phi3
|
1225 |
def __init__(self, config):
|
1226 |
super().__init__(config)
|
1227 |
self.model = Phi3Model(config)
|
|
|
1427 |
""",
|
1428 |
PHI3_START_DOCSTRING,
|
1429 |
)
|
1430 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Phi3, LLAMA->PHI3, self.transformer->self.model, transformer_outputs->model_outputs
|
1431 |
class Phi3ForSequenceClassification(Phi3PreTrainedModel):
|
1432 |
def __init__(self, config):
|
1433 |
super().__init__(config)
|
|
|
1543 |
""",
|
1544 |
PHI3_START_DOCSTRING,
|
1545 |
)
|
1546 |
+
# Copied from transformers.models.mpt.modeling_mpt.MptForTokenClassification with Mpt->Phi3,MPT->PHI3,self.transformer->self.model,transformer_outputs->model_outputs
|
1547 |
class Phi3ForTokenClassification(Phi3PreTrainedModel):
|
1548 |
def __init__(self, config: Phi3Config):
|
1549 |
super().__init__(config)
|
|
|
1610 |
labels = labels.to(logits.device)
|
1611 |
batch_size, seq_length = labels.shape
|
1612 |
loss_fct = CrossEntropyLoss()
|
1613 |
+
loss = loss_fct(
|
1614 |
+
logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length)
|
1615 |
+
)
|
1616 |
|
1617 |
if not return_dict:
|
1618 |
output = (logits,) + model_outputs[2:]
|