stefan-it commited on
Commit
7b7eb08
·
1 Parent(s): 83fd560

modeling: sync xLSTMForSequenceClassification with Patrick's codebase from https://github.com/HallerPatrick/helibrunna/blob/a1b377271867d5f23201ccacb55e017749aba487/model/modeling_xlstm.py

Browse files
Files changed (1) hide show
  1. modeling_xlstm.py +83 -1
modeling_xlstm.py CHANGED
@@ -2,8 +2,9 @@ from typing import Optional, Sequence, Tuple, Union
2
 
3
  import torch
4
  from torch import nn
 
5
  from transformers import PreTrainedModel
6
- from transformers.modeling_outputs import BaseModelOutput, CausalLMOutputWithPast
7
  from xlstm.components.init import small_init_init_
8
  from xlstm.utils import WeightDecayOptimGroupMixin
9
  from xlstm.xlstm_block_stack import xLSTMBlockStack as _xLSTMBlockStack
@@ -212,3 +213,84 @@ class xLSTMForCausalLM(xLSTMPreTrainedModel, WeightDecayOptimGroupMixin):
212
  "input_ids": input_ids.to(self.device),
213
  }
214
  return model_inputs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  import torch
4
  from torch import nn
5
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
6
  from transformers import PreTrainedModel
7
+ from transformers.modeling_outputs import BaseModelOutput, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
8
  from xlstm.components.init import small_init_init_
9
  from xlstm.utils import WeightDecayOptimGroupMixin
10
  from xlstm.xlstm_block_stack import xLSTMBlockStack as _xLSTMBlockStack
 
213
  "input_ids": input_ids.to(self.device),
214
  }
215
  return model_inputs
216
+
217
+
218
+ class xLSTMForSequenceClassification(xLSTMPreTrainedModel):
219
+
220
+ def __init__(self, config: xLSTMConfig, **kwargs):
221
+ super().__init__(config)
222
+ self.num_labels = config.num_labels
223
+ self.config = config
224
+ self.model = xLSTMModel(config)
225
+ self.classifier = nn.Linear(config.embedding_dim, config.num_labels, bias=False)
226
+
227
+ self.init_weights()
228
+
229
+ def forward(
230
+ self,
231
+ input_ids: torch.Tensor,
232
+ labels: Optional[torch.LongTensor] = None,
233
+ output_hidden_states: Optional[bool] = None,
234
+ return_dict: Optional[bool] = None,
235
+ ):
236
+ output = self.model(
237
+ input_ids,
238
+ output_hidden_states=output_hidden_states,
239
+ )
240
+
241
+ hidden_state = output[0]
242
+
243
+ logits = self.classifier(hidden_state)
244
+ batch_size = input_ids.shape[0]
245
+
246
+ if self.config.pad_token_id is None and batch_size != 1:
247
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
248
+ if self.config.pad_token_id is None:
249
+ sequence_lengths = -1
250
+ else:
251
+ if input_ids is not None:
252
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
253
+ sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
254
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
255
+ sequence_lengths = sequence_lengths.to(logits.device)
256
+ else:
257
+ sequence_lengths = -1
258
+
259
+
260
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
261
+
262
+ loss = None
263
+
264
+ if labels is not None:
265
+ labels = labels.to(logits.device)
266
+ if self.config.problem_type is None:
267
+ if self.num_labels == 1:
268
+ self.config.problem_type = "regression"
269
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
270
+ self.config.problem_type = "single_label_classification"
271
+ else:
272
+ self.config.problem_type = "multi_label_classification"
273
+
274
+ if self.config.problem_type == "regression":
275
+ loss_fct = MSELoss()
276
+ if self.num_labels == 1:
277
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
278
+ else:
279
+ loss = loss_fct(pooled_logits, labels)
280
+ elif self.config.problem_type == "single_label_classification":
281
+ loss_fct = CrossEntropyLoss()
282
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
283
+ elif self.config.problem_type == "multi_label_classification":
284
+ loss_fct = BCEWithLogitsLoss()
285
+ loss = loss_fct(pooled_logits, labels)
286
+
287
+ if not return_dict:
288
+ output = (pooled_logits,) + output[1:]
289
+ return ((loss,) + output) if loss is not None else output
290
+
291
+
292
+ return SequenceClassifierOutputWithPast(
293
+ loss=loss,
294
+ logits=pooled_logits,
295
+ hidden_states=output.hidden_states,
296
+ )