KoichiYasuoka commited on
Commit
92bfe35
·
1 Parent(s): 04b3693

support inputs_embeds

Browse files
Files changed (1) hide show
  1. modeling_modernbert.py +68 -29
modeling_modernbert.py CHANGED
@@ -206,12 +206,17 @@ class ModernBertEmbeddings(nn.Module):
206
  def compiled_embeddings(self, input_ids: torch.LongTensor) -> torch.Tensor:
207
  return self.drop(self.norm(self.tok_embeddings(input_ids)))
208
 
209
- def forward(self, input_ids: torch.LongTensor, position_ids: Optional[torch.LongTensor] = None) -> torch.Tensor:
210
- hidden_states = (
211
- self.compiled_embeddings(input_ids)
212
- if self.config.reference_compile
213
- else self.drop(self.norm(self.tok_embeddings(input_ids)))
214
- )
 
 
 
 
 
215
  return hidden_states
216
 
217
 
@@ -777,9 +782,6 @@ MODERNBERT_INPUTS_DOCSTRING = r"""
777
  Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
778
  [`PreTrainedTokenizer.__call__`] for details.
779
 
780
- If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
781
- `past_key_values`).
782
-
783
  If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
784
  and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
785
  information on the default strategy.
@@ -795,6 +797,10 @@ MODERNBERT_INPUTS_DOCSTRING = r"""
795
  config.n_positions - 1]`.
796
 
797
  [What are position IDs?](../glossary#position-ids)
 
 
 
 
798
  indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
799
  Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
800
  cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
@@ -846,10 +852,11 @@ class ModernBertModel(ModernBertPreTrainedModel):
846
  )
847
  def forward(
848
  self,
849
- input_ids: torch.LongTensor = None,
850
  attention_mask: Optional[torch.Tensor] = None,
851
  sliding_window_mask: Optional[torch.Tensor] = None,
852
  position_ids: Optional[torch.LongTensor] = None,
 
853
  indices: Optional[torch.Tensor] = None,
854
  cu_seqlens: Optional[torch.Tensor] = None,
855
  max_seqlen: Optional[int] = None,
@@ -865,35 +872,49 @@ class ModernBertModel(ModernBertPreTrainedModel):
865
  )
866
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
867
 
 
 
 
868
  all_hidden_states = () if output_hidden_states else None
869
  all_self_attentions = () if output_attentions else None
870
 
871
  self._maybe_set_compile()
872
- self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
 
 
873
 
874
  if batch_size is None and seq_len is None:
875
- batch_size, seq_len = input_ids.shape[:2]
 
 
 
 
876
 
877
  if attention_mask is None:
878
- attention_mask = torch.ones((batch_size, seq_len), device=input_ids.device, dtype=torch.bool)
879
 
880
  repad = False
881
  if self.config._attn_implementation == "flash_attention_2":
882
  if indices is None and cu_seqlens is None and max_seqlen is None:
883
  repad = True
884
- with torch.no_grad():
885
- input_ids, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input(
886
- inputs=input_ids, attention_mask=attention_mask
 
 
 
 
 
887
  )
888
  else:
889
  if position_ids is None:
890
- position_ids = torch.arange(seq_len, device=input_ids.device).unsqueeze(0)
891
 
892
  attention_mask, sliding_window_mask = self._update_attention_mask(
893
  attention_mask, output_attentions=output_attentions
894
  )
895
 
896
- hidden_states = self.embeddings(input_ids)
897
 
898
  for encoder_layer in self.layers:
899
  if output_hidden_states:
@@ -1029,10 +1050,11 @@ class ModernBertForMaskedLM(ModernBertPreTrainedModel):
1029
  )
1030
  def forward(
1031
  self,
1032
- input_ids: Optional[torch.Tensor],
1033
  attention_mask: Optional[torch.Tensor] = None,
1034
  sliding_window_mask: Optional[torch.Tensor] = None,
1035
  position_ids: Optional[torch.Tensor] = None,
 
1036
  labels: Optional[torch.Tensor] = None,
1037
  indices: Optional[torch.Tensor] = None,
1038
  cu_seqlens: Optional[torch.Tensor] = None,
@@ -1049,19 +1071,32 @@ class ModernBertForMaskedLM(ModernBertPreTrainedModel):
1049
 
1050
  if self.config._attn_implementation == "flash_attention_2":
1051
  if indices is None and cu_seqlens is None and max_seqlen is None:
1052
- batch_size, seq_len = input_ids.shape[:2]
 
 
 
 
 
 
1053
  if attention_mask is None:
1054
- attention_mask = torch.ones((batch_size, seq_len), device=input_ids.device, dtype=torch.bool)
1055
- with torch.no_grad():
1056
- input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input(
1057
- inputs=input_ids, attention_mask=attention_mask, position_ids=position_ids, labels=labels
 
 
 
 
 
 
1058
  )
1059
 
1060
  outputs = self.model(
1061
- input_ids,
1062
  attention_mask=attention_mask,
1063
  sliding_window_mask=sliding_window_mask,
1064
  position_ids=position_ids,
 
1065
  indices=indices,
1066
  cu_seqlens=cu_seqlens,
1067
  max_seqlen=max_seqlen,
@@ -1134,10 +1169,11 @@ class ModernBertForSequenceClassification(ModernBertPreTrainedModel):
1134
  )
1135
  def forward(
1136
  self,
1137
- input_ids: Optional[torch.Tensor],
1138
  attention_mask: Optional[torch.Tensor] = None,
1139
  sliding_window_mask: Optional[torch.Tensor] = None,
1140
  position_ids: Optional[torch.Tensor] = None,
 
1141
  labels: Optional[torch.Tensor] = None,
1142
  indices: Optional[torch.Tensor] = None,
1143
  cu_seqlens: Optional[torch.Tensor] = None,
@@ -1159,10 +1195,11 @@ class ModernBertForSequenceClassification(ModernBertPreTrainedModel):
1159
  self._maybe_set_compile()
1160
 
1161
  outputs = self.model(
1162
- input_ids,
1163
  attention_mask=attention_mask,
1164
  sliding_window_mask=sliding_window_mask,
1165
  position_ids=position_ids,
 
1166
  indices=indices,
1167
  cu_seqlens=cu_seqlens,
1168
  max_seqlen=max_seqlen,
@@ -1245,10 +1282,11 @@ class ModernBertForTokenClassification(ModernBertPreTrainedModel):
1245
  )
1246
  def forward(
1247
  self,
1248
- input_ids: Optional[torch.Tensor],
1249
  attention_mask: Optional[torch.Tensor] = None,
1250
  sliding_window_mask: Optional[torch.Tensor] = None,
1251
  position_ids: Optional[torch.Tensor] = None,
 
1252
  labels: Optional[torch.Tensor] = None,
1253
  indices: Optional[torch.Tensor] = None,
1254
  cu_seqlens: Optional[torch.Tensor] = None,
@@ -1267,10 +1305,11 @@ class ModernBertForTokenClassification(ModernBertPreTrainedModel):
1267
  self._maybe_set_compile()
1268
 
1269
  outputs = self.model(
1270
- input_ids,
1271
  attention_mask=attention_mask,
1272
  sliding_window_mask=sliding_window_mask,
1273
  position_ids=position_ids,
 
1274
  indices=indices,
1275
  cu_seqlens=cu_seqlens,
1276
  max_seqlen=max_seqlen,
 
206
  def compiled_embeddings(self, input_ids: torch.LongTensor) -> torch.Tensor:
207
  return self.drop(self.norm(self.tok_embeddings(input_ids)))
208
 
209
+ def forward(
210
+ self, input_ids: torch.LongTensor = None, inputs_embeds: Optional[torch.Tensor] = None
211
+ ) -> torch.Tensor:
212
+ if inputs_embeds is not None:
213
+ hidden_states = self.drop(self.norm(inputs_embeds))
214
+ else:
215
+ hidden_states = (
216
+ self.compiled_embeddings(input_ids)
217
+ if self.config.reference_compile
218
+ else self.drop(self.norm(self.tok_embeddings(input_ids)))
219
+ )
220
  return hidden_states
221
 
222
 
 
782
  Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
783
  [`PreTrainedTokenizer.__call__`] for details.
784
 
 
 
 
785
  If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
786
  and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
787
  information on the default strategy.
 
797
  config.n_positions - 1]`.
798
 
799
  [What are position IDs?](../glossary#position-ids)
800
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
801
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
802
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
803
+ model's internal embedding lookup matrix.
804
  indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
805
  Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
806
  cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
 
852
  )
853
  def forward(
854
  self,
855
+ input_ids: Optional[torch.LongTensor] = None,
856
  attention_mask: Optional[torch.Tensor] = None,
857
  sliding_window_mask: Optional[torch.Tensor] = None,
858
  position_ids: Optional[torch.LongTensor] = None,
859
+ inputs_embeds: Optional[torch.Tensor] = None,
860
  indices: Optional[torch.Tensor] = None,
861
  cu_seqlens: Optional[torch.Tensor] = None,
862
  max_seqlen: Optional[int] = None,
 
872
  )
873
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
874
 
875
+ if (input_ids is None) ^ (inputs_embeds is not None):
876
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
877
+
878
  all_hidden_states = () if output_hidden_states else None
879
  all_self_attentions = () if output_attentions else None
880
 
881
  self._maybe_set_compile()
882
+
883
+ if input_ids is not None:
884
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
885
 
886
  if batch_size is None and seq_len is None:
887
+ if inputs_embeds is not None:
888
+ batch_size, seq_len = inputs_embeds.shape[:2]
889
+ else:
890
+ batch_size, seq_len = input_ids.shape[:2]
891
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
892
 
893
  if attention_mask is None:
894
+ attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool)
895
 
896
  repad = False
897
  if self.config._attn_implementation == "flash_attention_2":
898
  if indices is None and cu_seqlens is None and max_seqlen is None:
899
  repad = True
900
+ if inputs_embeds is None:
901
+ with torch.no_grad():
902
+ input_ids, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input(
903
+ inputs=input_ids, attention_mask=attention_mask
904
+ )
905
+ else:
906
+ inputs_embeds, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input(
907
+ inputs=inputs_embeds, attention_mask=attention_mask
908
  )
909
  else:
910
  if position_ids is None:
911
+ position_ids = torch.arange(seq_len, device=device).unsqueeze(0)
912
 
913
  attention_mask, sliding_window_mask = self._update_attention_mask(
914
  attention_mask, output_attentions=output_attentions
915
  )
916
 
917
+ hidden_states = self.embeddings(input_ids=input_ids, inputs_embeds=inputs_embeds)
918
 
919
  for encoder_layer in self.layers:
920
  if output_hidden_states:
 
1050
  )
1051
  def forward(
1052
  self,
1053
+ input_ids: Optional[torch.LongTensor] = None,
1054
  attention_mask: Optional[torch.Tensor] = None,
1055
  sliding_window_mask: Optional[torch.Tensor] = None,
1056
  position_ids: Optional[torch.Tensor] = None,
1057
+ inputs_embeds: Optional[torch.Tensor] = None,
1058
  labels: Optional[torch.Tensor] = None,
1059
  indices: Optional[torch.Tensor] = None,
1060
  cu_seqlens: Optional[torch.Tensor] = None,
 
1071
 
1072
  if self.config._attn_implementation == "flash_attention_2":
1073
  if indices is None and cu_seqlens is None and max_seqlen is None:
1074
+ if batch_size is None and seq_len is None:
1075
+ if inputs_embeds is not None:
1076
+ batch_size, seq_len = inputs_embeds.shape[:2]
1077
+ else:
1078
+ batch_size, seq_len = input_ids.shape[:2]
1079
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
1080
+
1081
  if attention_mask is None:
1082
+ attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool)
1083
+
1084
+ if inputs_embeds is None:
1085
+ with torch.no_grad():
1086
+ input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input(
1087
+ inputs=input_ids, attention_mask=attention_mask, position_ids=position_ids, labels=labels
1088
+ )
1089
+ else:
1090
+ inputs_embeds, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input(
1091
+ inputs=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, labels=labels
1092
  )
1093
 
1094
  outputs = self.model(
1095
+ input_ids=input_ids,
1096
  attention_mask=attention_mask,
1097
  sliding_window_mask=sliding_window_mask,
1098
  position_ids=position_ids,
1099
+ inputs_embeds=inputs_embeds,
1100
  indices=indices,
1101
  cu_seqlens=cu_seqlens,
1102
  max_seqlen=max_seqlen,
 
1169
  )
1170
  def forward(
1171
  self,
1172
+ input_ids: Optional[torch.LongTensor] = None,
1173
  attention_mask: Optional[torch.Tensor] = None,
1174
  sliding_window_mask: Optional[torch.Tensor] = None,
1175
  position_ids: Optional[torch.Tensor] = None,
1176
+ inputs_embeds: Optional[torch.Tensor] = None,
1177
  labels: Optional[torch.Tensor] = None,
1178
  indices: Optional[torch.Tensor] = None,
1179
  cu_seqlens: Optional[torch.Tensor] = None,
 
1195
  self._maybe_set_compile()
1196
 
1197
  outputs = self.model(
1198
+ input_ids=input_ids,
1199
  attention_mask=attention_mask,
1200
  sliding_window_mask=sliding_window_mask,
1201
  position_ids=position_ids,
1202
+ inputs_embeds=inputs_embeds,
1203
  indices=indices,
1204
  cu_seqlens=cu_seqlens,
1205
  max_seqlen=max_seqlen,
 
1282
  )
1283
  def forward(
1284
  self,
1285
+ input_ids: Optional[torch.LongTensor] = None,
1286
  attention_mask: Optional[torch.Tensor] = None,
1287
  sliding_window_mask: Optional[torch.Tensor] = None,
1288
  position_ids: Optional[torch.Tensor] = None,
1289
+ inputs_embeds: Optional[torch.Tensor] = None,
1290
  labels: Optional[torch.Tensor] = None,
1291
  indices: Optional[torch.Tensor] = None,
1292
  cu_seqlens: Optional[torch.Tensor] = None,
 
1305
  self._maybe_set_compile()
1306
 
1307
  outputs = self.model(
1308
+ input_ids=input_ids,
1309
  attention_mask=attention_mask,
1310
  sliding_window_mask=sliding_window_mask,
1311
  position_ids=position_ids,
1312
+ inputs_embeds=inputs_embeds,
1313
  indices=indices,
1314
  cu_seqlens=cu_seqlens,
1315
  max_seqlen=max_seqlen,