KoichiYasuoka
commited on
Commit
·
c85eb0e
1
Parent(s):
2fe2dbf
support inputs_embeds
Browse files- modeling_modernbert.py +68 -26
modeling_modernbert.py
CHANGED
@@ -206,12 +206,17 @@ class ModernBertEmbeddings(nn.Module):
|
|
206 |
def compiled_embeddings(self, input_ids: torch.LongTensor) -> torch.Tensor:
|
207 |
return self.drop(self.norm(self.tok_embeddings(input_ids)))
|
208 |
|
209 |
-
def forward(
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
|
|
|
|
|
|
|
|
|
|
215 |
return hidden_states
|
216 |
|
217 |
|
@@ -792,6 +797,10 @@ MODERNBERT_INPUTS_DOCSTRING = r"""
|
|
792 |
config.n_positions - 1]`.
|
793 |
|
794 |
[What are position IDs?](../glossary#position-ids)
|
|
|
|
|
|
|
|
|
795 |
indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
|
796 |
Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
|
797 |
cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
|
@@ -843,10 +852,11 @@ class ModernBertModel(ModernBertPreTrainedModel):
|
|
843 |
)
|
844 |
def forward(
|
845 |
self,
|
846 |
-
input_ids: torch.LongTensor = None,
|
847 |
attention_mask: Optional[torch.Tensor] = None,
|
848 |
sliding_window_mask: Optional[torch.Tensor] = None,
|
849 |
position_ids: Optional[torch.LongTensor] = None,
|
|
|
850 |
indices: Optional[torch.Tensor] = None,
|
851 |
cu_seqlens: Optional[torch.Tensor] = None,
|
852 |
max_seqlen: Optional[int] = None,
|
@@ -862,35 +872,49 @@ class ModernBertModel(ModernBertPreTrainedModel):
|
|
862 |
)
|
863 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
864 |
|
|
|
|
|
|
|
865 |
all_hidden_states = () if output_hidden_states else None
|
866 |
all_self_attentions = () if output_attentions else None
|
867 |
|
868 |
self._maybe_set_compile()
|
869 |
-
|
|
|
|
|
870 |
|
871 |
if batch_size is None and seq_len is None:
|
872 |
-
|
|
|
|
|
|
|
|
|
873 |
|
874 |
if attention_mask is None:
|
875 |
-
attention_mask = torch.ones((batch_size, seq_len), device=
|
876 |
|
877 |
repad = False
|
878 |
if self.config._attn_implementation == "flash_attention_2":
|
879 |
if indices is None and cu_seqlens is None and max_seqlen is None:
|
880 |
repad = True
|
881 |
-
|
882 |
-
|
883 |
-
|
|
|
|
|
|
|
|
|
|
|
884 |
)
|
885 |
else:
|
886 |
if position_ids is None:
|
887 |
-
position_ids = torch.arange(seq_len, device=
|
888 |
|
889 |
attention_mask, sliding_window_mask = self._update_attention_mask(
|
890 |
attention_mask, output_attentions=output_attentions
|
891 |
)
|
892 |
|
893 |
-
hidden_states = self.embeddings(input_ids)
|
894 |
|
895 |
for encoder_layer in self.layers:
|
896 |
if output_hidden_states:
|
@@ -1026,10 +1050,11 @@ class ModernBertForMaskedLM(ModernBertPreTrainedModel):
|
|
1026 |
)
|
1027 |
def forward(
|
1028 |
self,
|
1029 |
-
input_ids: Optional[torch.
|
1030 |
attention_mask: Optional[torch.Tensor] = None,
|
1031 |
sliding_window_mask: Optional[torch.Tensor] = None,
|
1032 |
position_ids: Optional[torch.Tensor] = None,
|
|
|
1033 |
labels: Optional[torch.Tensor] = None,
|
1034 |
indices: Optional[torch.Tensor] = None,
|
1035 |
cu_seqlens: Optional[torch.Tensor] = None,
|
@@ -1046,19 +1071,32 @@ class ModernBertForMaskedLM(ModernBertPreTrainedModel):
|
|
1046 |
|
1047 |
if self.config._attn_implementation == "flash_attention_2":
|
1048 |
if indices is None and cu_seqlens is None and max_seqlen is None:
|
1049 |
-
batch_size
|
|
|
|
|
|
|
|
|
|
|
|
|
1050 |
if attention_mask is None:
|
1051 |
-
attention_mask = torch.ones((batch_size, seq_len), device=
|
1052 |
-
|
1053 |
-
|
1054 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
1055 |
)
|
1056 |
|
1057 |
outputs = self.model(
|
1058 |
-
input_ids,
|
1059 |
attention_mask=attention_mask,
|
1060 |
sliding_window_mask=sliding_window_mask,
|
1061 |
position_ids=position_ids,
|
|
|
1062 |
indices=indices,
|
1063 |
cu_seqlens=cu_seqlens,
|
1064 |
max_seqlen=max_seqlen,
|
@@ -1131,10 +1169,11 @@ class ModernBertForSequenceClassification(ModernBertPreTrainedModel):
|
|
1131 |
)
|
1132 |
def forward(
|
1133 |
self,
|
1134 |
-
input_ids: Optional[torch.
|
1135 |
attention_mask: Optional[torch.Tensor] = None,
|
1136 |
sliding_window_mask: Optional[torch.Tensor] = None,
|
1137 |
position_ids: Optional[torch.Tensor] = None,
|
|
|
1138 |
labels: Optional[torch.Tensor] = None,
|
1139 |
indices: Optional[torch.Tensor] = None,
|
1140 |
cu_seqlens: Optional[torch.Tensor] = None,
|
@@ -1156,10 +1195,11 @@ class ModernBertForSequenceClassification(ModernBertPreTrainedModel):
|
|
1156 |
self._maybe_set_compile()
|
1157 |
|
1158 |
outputs = self.model(
|
1159 |
-
input_ids,
|
1160 |
attention_mask=attention_mask,
|
1161 |
sliding_window_mask=sliding_window_mask,
|
1162 |
position_ids=position_ids,
|
|
|
1163 |
indices=indices,
|
1164 |
cu_seqlens=cu_seqlens,
|
1165 |
max_seqlen=max_seqlen,
|
@@ -1242,10 +1282,11 @@ class ModernBertForTokenClassification(ModernBertPreTrainedModel):
|
|
1242 |
)
|
1243 |
def forward(
|
1244 |
self,
|
1245 |
-
input_ids: Optional[torch.
|
1246 |
attention_mask: Optional[torch.Tensor] = None,
|
1247 |
sliding_window_mask: Optional[torch.Tensor] = None,
|
1248 |
position_ids: Optional[torch.Tensor] = None,
|
|
|
1249 |
labels: Optional[torch.Tensor] = None,
|
1250 |
indices: Optional[torch.Tensor] = None,
|
1251 |
cu_seqlens: Optional[torch.Tensor] = None,
|
@@ -1264,10 +1305,11 @@ class ModernBertForTokenClassification(ModernBertPreTrainedModel):
|
|
1264 |
self._maybe_set_compile()
|
1265 |
|
1266 |
outputs = self.model(
|
1267 |
-
input_ids,
|
1268 |
attention_mask=attention_mask,
|
1269 |
sliding_window_mask=sliding_window_mask,
|
1270 |
position_ids=position_ids,
|
|
|
1271 |
indices=indices,
|
1272 |
cu_seqlens=cu_seqlens,
|
1273 |
max_seqlen=max_seqlen,
|
|
|
206 |
def compiled_embeddings(self, input_ids: torch.LongTensor) -> torch.Tensor:
|
207 |
return self.drop(self.norm(self.tok_embeddings(input_ids)))
|
208 |
|
209 |
+
def forward(
|
210 |
+
self, input_ids: torch.LongTensor = None, inputs_embeds: Optional[torch.Tensor] = None
|
211 |
+
) -> torch.Tensor:
|
212 |
+
if inputs_embeds is not None:
|
213 |
+
hidden_states = self.drop(self.norm(inputs_embeds))
|
214 |
+
else:
|
215 |
+
hidden_states = (
|
216 |
+
self.compiled_embeddings(input_ids)
|
217 |
+
if self.config.reference_compile
|
218 |
+
else self.drop(self.norm(self.tok_embeddings(input_ids)))
|
219 |
+
)
|
220 |
return hidden_states
|
221 |
|
222 |
|
|
|
797 |
config.n_positions - 1]`.
|
798 |
|
799 |
[What are position IDs?](../glossary#position-ids)
|
800 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
801 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
802 |
+
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
803 |
+
model's internal embedding lookup matrix.
|
804 |
indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
|
805 |
Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
|
806 |
cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
|
|
|
852 |
)
|
853 |
def forward(
|
854 |
self,
|
855 |
+
input_ids: Optional[torch.LongTensor] = None,
|
856 |
attention_mask: Optional[torch.Tensor] = None,
|
857 |
sliding_window_mask: Optional[torch.Tensor] = None,
|
858 |
position_ids: Optional[torch.LongTensor] = None,
|
859 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
860 |
indices: Optional[torch.Tensor] = None,
|
861 |
cu_seqlens: Optional[torch.Tensor] = None,
|
862 |
max_seqlen: Optional[int] = None,
|
|
|
872 |
)
|
873 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
874 |
|
875 |
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
876 |
+
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
877 |
+
|
878 |
all_hidden_states = () if output_hidden_states else None
|
879 |
all_self_attentions = () if output_attentions else None
|
880 |
|
881 |
self._maybe_set_compile()
|
882 |
+
|
883 |
+
if input_ids is not None:
|
884 |
+
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
|
885 |
|
886 |
if batch_size is None and seq_len is None:
|
887 |
+
if inputs_embeds is not None:
|
888 |
+
batch_size, seq_len = inputs_embeds.shape[:2]
|
889 |
+
else:
|
890 |
+
batch_size, seq_len = input_ids.shape[:2]
|
891 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
892 |
|
893 |
if attention_mask is None:
|
894 |
+
attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool)
|
895 |
|
896 |
repad = False
|
897 |
if self.config._attn_implementation == "flash_attention_2":
|
898 |
if indices is None and cu_seqlens is None and max_seqlen is None:
|
899 |
repad = True
|
900 |
+
if inputs_embeds is None:
|
901 |
+
with torch.no_grad():
|
902 |
+
input_ids, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input(
|
903 |
+
inputs=input_ids, attention_mask=attention_mask
|
904 |
+
)
|
905 |
+
else:
|
906 |
+
inputs_embeds, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input(
|
907 |
+
inputs=inputs_embeds, attention_mask=attention_mask
|
908 |
)
|
909 |
else:
|
910 |
if position_ids is None:
|
911 |
+
position_ids = torch.arange(seq_len, device=device).unsqueeze(0)
|
912 |
|
913 |
attention_mask, sliding_window_mask = self._update_attention_mask(
|
914 |
attention_mask, output_attentions=output_attentions
|
915 |
)
|
916 |
|
917 |
+
hidden_states = self.embeddings(input_ids=input_ids, inputs_embeds=inputs_embeds)
|
918 |
|
919 |
for encoder_layer in self.layers:
|
920 |
if output_hidden_states:
|
|
|
1050 |
)
|
1051 |
def forward(
|
1052 |
self,
|
1053 |
+
input_ids: Optional[torch.LongTensor] = None,
|
1054 |
attention_mask: Optional[torch.Tensor] = None,
|
1055 |
sliding_window_mask: Optional[torch.Tensor] = None,
|
1056 |
position_ids: Optional[torch.Tensor] = None,
|
1057 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
1058 |
labels: Optional[torch.Tensor] = None,
|
1059 |
indices: Optional[torch.Tensor] = None,
|
1060 |
cu_seqlens: Optional[torch.Tensor] = None,
|
|
|
1071 |
|
1072 |
if self.config._attn_implementation == "flash_attention_2":
|
1073 |
if indices is None and cu_seqlens is None and max_seqlen is None:
|
1074 |
+
if batch_size is None and seq_len is None:
|
1075 |
+
if inputs_embeds is not None:
|
1076 |
+
batch_size, seq_len = inputs_embeds.shape[:2]
|
1077 |
+
else:
|
1078 |
+
batch_size, seq_len = input_ids.shape[:2]
|
1079 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
1080 |
+
|
1081 |
if attention_mask is None:
|
1082 |
+
attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool)
|
1083 |
+
|
1084 |
+
if inputs_embeds is None:
|
1085 |
+
with torch.no_grad():
|
1086 |
+
input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input(
|
1087 |
+
inputs=input_ids, attention_mask=attention_mask, position_ids=position_ids, labels=labels
|
1088 |
+
)
|
1089 |
+
else:
|
1090 |
+
inputs_embeds, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input(
|
1091 |
+
inputs=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, labels=labels
|
1092 |
)
|
1093 |
|
1094 |
outputs = self.model(
|
1095 |
+
input_ids=input_ids,
|
1096 |
attention_mask=attention_mask,
|
1097 |
sliding_window_mask=sliding_window_mask,
|
1098 |
position_ids=position_ids,
|
1099 |
+
inputs_embeds=inputs_embeds,
|
1100 |
indices=indices,
|
1101 |
cu_seqlens=cu_seqlens,
|
1102 |
max_seqlen=max_seqlen,
|
|
|
1169 |
)
|
1170 |
def forward(
|
1171 |
self,
|
1172 |
+
input_ids: Optional[torch.LongTensor] = None,
|
1173 |
attention_mask: Optional[torch.Tensor] = None,
|
1174 |
sliding_window_mask: Optional[torch.Tensor] = None,
|
1175 |
position_ids: Optional[torch.Tensor] = None,
|
1176 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
1177 |
labels: Optional[torch.Tensor] = None,
|
1178 |
indices: Optional[torch.Tensor] = None,
|
1179 |
cu_seqlens: Optional[torch.Tensor] = None,
|
|
|
1195 |
self._maybe_set_compile()
|
1196 |
|
1197 |
outputs = self.model(
|
1198 |
+
input_ids=input_ids,
|
1199 |
attention_mask=attention_mask,
|
1200 |
sliding_window_mask=sliding_window_mask,
|
1201 |
position_ids=position_ids,
|
1202 |
+
inputs_embeds=inputs_embeds,
|
1203 |
indices=indices,
|
1204 |
cu_seqlens=cu_seqlens,
|
1205 |
max_seqlen=max_seqlen,
|
|
|
1282 |
)
|
1283 |
def forward(
|
1284 |
self,
|
1285 |
+
input_ids: Optional[torch.LongTensor] = None,
|
1286 |
attention_mask: Optional[torch.Tensor] = None,
|
1287 |
sliding_window_mask: Optional[torch.Tensor] = None,
|
1288 |
position_ids: Optional[torch.Tensor] = None,
|
1289 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
1290 |
labels: Optional[torch.Tensor] = None,
|
1291 |
indices: Optional[torch.Tensor] = None,
|
1292 |
cu_seqlens: Optional[torch.Tensor] = None,
|
|
|
1305 |
self._maybe_set_compile()
|
1306 |
|
1307 |
outputs = self.model(
|
1308 |
+
input_ids=input_ids,
|
1309 |
attention_mask=attention_mask,
|
1310 |
sliding_window_mask=sliding_window_mask,
|
1311 |
position_ids=position_ids,
|
1312 |
+
inputs_embeds=inputs_embeds,
|
1313 |
indices=indices,
|
1314 |
cu_seqlens=cu_seqlens,
|
1315 |
max_seqlen=max_seqlen,
|