yangwang825 commited on
Commit
f7c9580
·
verified ·
1 Parent(s): 4bb0200

Upload PureQwen2ForSequenceClassification

Browse files
Files changed (3) hide show
  1. config.json +5 -3
  2. model.safetensors +3 -0
  3. modeling_pure_qwen2.py +605 -0
config.json CHANGED
@@ -1,11 +1,13 @@
1
  {
 
2
  "alpha": 1,
3
  "architectures": [
4
- "Qwen2ForCausalLM"
5
  ],
6
  "attention_dropout": 0.0,
7
  "auto_map": {
8
- "AutoConfig": "configuration_pure_qwen2.PureQwen2Config"
 
9
  },
10
  "bos_token_id": 151643,
11
  "center": false,
@@ -31,7 +33,7 @@
31
  "sliding_window": null,
32
  "svd_rank": 5,
33
  "tie_word_embeddings": true,
34
- "torch_dtype": "bfloat16",
35
  "transformers_version": "4.46.2",
36
  "use_cache": true,
37
  "use_mrope": false,
 
1
  {
2
+ "_name_or_path": "Qwen/Qwen2.5-0.5B",
3
  "alpha": 1,
4
  "architectures": [
5
+ "PureQwen2ForSequenceClassification"
6
  ],
7
  "attention_dropout": 0.0,
8
  "auto_map": {
9
+ "AutoConfig": "configuration_pure_qwen2.PureQwen2Config",
10
+ "AutoModelForSequenceClassification": "modeling_pure_qwen2.PureQwen2ForSequenceClassification"
11
  },
12
  "bos_token_id": 151643,
13
  "center": false,
 
33
  "sliding_window": null,
34
  "svd_rank": 5,
35
  "tie_word_embeddings": true,
36
+ "torch_dtype": "float32",
37
  "transformers_version": "4.46.2",
38
  "use_cache": true,
39
  "use_mrope": false,
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:baab20dc9577db688c4ada7822066b4be311fb9d765f1f9c4cfee7c9df6f8486
3
+ size 1976170728
modeling_pure_qwen2.py ADDED
@@ -0,0 +1,605 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Union, Tuple, Optional, List
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch.autograd import Function
7
+ from transformers import PreTrainedModel
8
+ from transformers.models.qwen2.modeling_qwen2 import (
9
+ Qwen2DecoderLayer, Qwen2RMSNorm, Qwen2RotaryEmbedding
10
+ )
11
+ from transformers.utils import logging
12
+ from transformers.cache_utils import Cache, DynamicCache
13
+ from transformers.modeling_outputs import (
14
+ SequenceClassifierOutput,
15
+ BaseModelOutputWithPast
16
+ )
17
+ from transformers.utils import ModelOutput
18
+
19
+ from .configuration_pure_qwen2 import PureQwen2Config
20
+
21
+ logger = logging.get_logger(__name__)
22
+
23
+
24
+ class CovarianceFunction(Function):
25
+
26
+ @staticmethod
27
+ def forward(ctx, inputs):
28
+ x = inputs
29
+ b, c, h, w = x.data.shape
30
+ m = h * w
31
+ x = x.view(b, c, m)
32
+ I_hat = (-1.0 / m / m) * torch.ones(m, m, device=x.device) + (
33
+ 1.0 / m
34
+ ) * torch.eye(m, m, device=x.device)
35
+ I_hat = I_hat.view(1, m, m).repeat(b, 1, 1).type(x.dtype)
36
+ y = x @ I_hat @ x.transpose(-1, -2)
37
+ ctx.save_for_backward(inputs, I_hat)
38
+ return y
39
+
40
+ @staticmethod
41
+ def backward(ctx, grad_output):
42
+ inputs, I_hat = ctx.saved_tensors
43
+ x = inputs
44
+ b, c, h, w = x.data.shape
45
+ m = h * w
46
+ x = x.view(b, c, m)
47
+ grad_input = grad_output + grad_output.transpose(1, 2)
48
+ grad_input = grad_input @ x @ I_hat
49
+ grad_input = grad_input.reshape(b, c, h, w)
50
+ return grad_input
51
+
52
+
53
+ class Covariance(nn.Module):
54
+
55
+ def __init__(self):
56
+ super(Covariance, self).__init__()
57
+
58
+ def _covariance(self, x):
59
+ return CovarianceFunction.apply(x)
60
+
61
+ def forward(self, x):
62
+ # x should be [batch_size, seq_len, embed_dim]
63
+ if x.dim() == 2:
64
+ x = x.transpose(-1, -2)
65
+ C = self._covariance(x[None, :, :, None])
66
+ C = C.squeeze(dim=0)
67
+ return C
68
+
69
+
70
+ class PFSA(torch.nn.Module):
71
+ """
72
+ https://openreview.net/pdf?id=isodM5jTA7h
73
+ """
74
+ def __init__(self, input_dim, alpha=1):
75
+ super(PFSA, self).__init__()
76
+ self.input_dim = input_dim
77
+ self.alpha = alpha
78
+
79
+ def forward_one_sample(self, x):
80
+ x = x.transpose(1, 2)[..., None]
81
+ k = torch.mean(x, dim=[-1, -2], keepdim=True)
82
+ kd = torch.sqrt((k - k.mean(dim=1, keepdim=True)).pow(2).sum(dim=1, keepdim=True)) # [B, 1, 1, 1]
83
+ qd = torch.sqrt((x - x.mean(dim=1, keepdim=True)).pow(2).sum(dim=1, keepdim=True)) # [B, 1, T, 1]
84
+ C_qk = (((x - x.mean(dim=1, keepdim=True)) * (k - k.mean(dim=1, keepdim=True))).sum(dim=1, keepdim=True)) / (qd * kd)
85
+ A = (1 - torch.sigmoid(C_qk)) ** self.alpha
86
+ out = x * A
87
+ out = out.squeeze(dim=-1).transpose(1, 2)
88
+ return out
89
+
90
+ def forward(self, input_values, attention_mask=None):
91
+ """
92
+ x: [B, T, F]
93
+ """
94
+ out = []
95
+ b, t, f = input_values.shape
96
+ for x, mask in zip(input_values, attention_mask):
97
+ x = x.view(1, t, f)
98
+ # x_in = x[:, :sum(mask), :]
99
+ x_in = x[:, :int(mask.sum().item()), :]
100
+ x_out = self.forward_one_sample(x_in)
101
+ x_expanded = torch.zeros_like(x, device=x.device)
102
+ x_expanded[:, :x_out.shape[-2], :x_out.shape[-1]] = x_out
103
+ out.append(x_expanded)
104
+ out = torch.vstack(out)
105
+ out = out.view(b, t, f)
106
+ return out
107
+
108
+
109
+ class PURE(torch.nn.Module):
110
+
111
+ def __init__(
112
+ self,
113
+ in_dim,
114
+ svd_rank=16,
115
+ num_pc_to_remove=1,
116
+ center=False,
117
+ num_iters=2,
118
+ alpha=1,
119
+ disable_pcr=False,
120
+ disable_pfsa=False,
121
+ disable_covariance=True,
122
+ *args, **kwargs
123
+ ):
124
+ super().__init__()
125
+ self.in_dim = in_dim
126
+ self.svd_rank = svd_rank
127
+ self.num_pc_to_remove = num_pc_to_remove
128
+ self.center = center
129
+ self.num_iters = num_iters
130
+ self.do_pcr = not disable_pcr
131
+ self.do_pfsa = not disable_pfsa
132
+ self.do_covariance = not disable_covariance
133
+ self.attention = PFSA(in_dim, alpha=alpha)
134
+
135
+ def _compute_pc(self, X, attention_mask):
136
+ """
137
+ x: (B, T, F)
138
+ """
139
+ pcs = []
140
+ bs, seqlen, dim = X.shape
141
+ for x, mask in zip(X, attention_mask):
142
+ rank = int(mask.sum().item())
143
+ x = x[:rank, :]
144
+ if self.do_covariance:
145
+ x = Covariance()(x)
146
+ q = self.svd_rank
147
+ else:
148
+ q = min(self.svd_rank, rank)
149
+ _, _, V = torch.pca_lowrank(x, q=q, center=self.center, niter=self.num_iters)
150
+ # _, _, Vh = torch.linalg.svd(x_, full_matrices=False)
151
+ # V = Vh.mH
152
+ pc = V.transpose(0, 1)[:self.num_pc_to_remove, :] # pc: [K, F]
153
+ pcs.append(pc)
154
+ # pcs = torch.vstack(pcs)
155
+ # pcs = pcs.view(bs, self.num_pc_to_remove, dim)
156
+ return pcs
157
+
158
+ def _remove_pc(self, X, pcs):
159
+ """
160
+ [B, T, F], [B, ..., F]
161
+ """
162
+ b, t, f = X.shape
163
+ out = []
164
+ for i, (x, pc) in enumerate(zip(X, pcs)):
165
+ # v = []
166
+ # for j, t in enumerate(x):
167
+ # t_ = t
168
+ # for c_ in c:
169
+ # t_ = t_.view(f, 1) - c_.view(f, 1) @ c_.view(1, f) @ t.view(f, 1)
170
+ # v.append(t_.transpose(-1, -2))
171
+ # v = torch.vstack(v)
172
+ v = x - x @ pc.transpose(0, 1) @ pc
173
+ out.append(v[None, ...])
174
+ out = torch.vstack(out)
175
+ return out
176
+
177
+ def forward(self, input_values, attention_mask=None, *args, **kwargs):
178
+ """
179
+ PCR -> Attention
180
+ x: (B, T, F)
181
+ """
182
+ x = input_values
183
+ if self.do_pcr:
184
+ pc = self._compute_pc(x, attention_mask) # pc: [B, K, F]
185
+ xx = self._remove_pc(x, pc)
186
+ # xx = xt - xt @ pc.transpose(1, 2) @ pc # [B, T, F] * [B, F, K] * [B, K, F] = [B, T, F]
187
+ else:
188
+ xx = x
189
+ if self.do_pfsa:
190
+ xx = self.attention(xx, attention_mask)
191
+ return xx
192
+
193
+
194
+ class StatisticsPooling(torch.nn.Module):
195
+
196
+ def __init__(self, return_mean=True, return_std=True):
197
+ super().__init__()
198
+
199
+ # Small value for GaussNoise
200
+ self.eps = 1e-5
201
+ self.return_mean = return_mean
202
+ self.return_std = return_std
203
+ if not (self.return_mean or self.return_std):
204
+ raise ValueError(
205
+ "both of statistics are equal to False \n"
206
+ "consider enabling mean and/or std statistic pooling"
207
+ )
208
+
209
+ def forward(self, input_values, attention_mask=None):
210
+ """Calculates mean and std for a batch (input tensor).
211
+
212
+ Arguments
213
+ ---------
214
+ x : torch.Tensor
215
+ It represents a tensor for a mini-batch.
216
+ """
217
+ x = input_values
218
+ if attention_mask is None:
219
+ if self.return_mean:
220
+ mean = x.mean(dim=1)
221
+ if self.return_std:
222
+ std = x.std(dim=1)
223
+ else:
224
+ mean = []
225
+ std = []
226
+ for snt_id in range(x.shape[0]):
227
+ # Avoiding padded time steps
228
+ lengths = torch.sum(attention_mask, dim=1)
229
+ relative_lengths = lengths / torch.max(lengths)
230
+ actual_size = torch.round(relative_lengths[snt_id] * x.shape[1]).int()
231
+ # actual_size = int(torch.round(lengths[snt_id] * x.shape[1]))
232
+
233
+ # computing statistics
234
+ if self.return_mean:
235
+ mean.append(
236
+ torch.mean(x[snt_id, 0:actual_size, ...], dim=0)
237
+ )
238
+ if self.return_std:
239
+ std.append(torch.std(x[snt_id, 0:actual_size, ...], dim=0))
240
+ if self.return_mean:
241
+ mean = torch.stack(mean)
242
+ if self.return_std:
243
+ std = torch.stack(std)
244
+
245
+ if self.return_mean:
246
+ gnoise = self._get_gauss_noise(mean.size(), device=mean.device)
247
+ gnoise = gnoise
248
+ mean += gnoise
249
+ if self.return_std:
250
+ std = std + self.eps
251
+
252
+ # Append mean and std of the batch
253
+ if self.return_mean and self.return_std:
254
+ pooled_stats = torch.cat((mean, std), dim=1)
255
+ pooled_stats = pooled_stats.unsqueeze(1)
256
+ elif self.return_mean:
257
+ pooled_stats = mean.unsqueeze(1)
258
+ elif self.return_std:
259
+ pooled_stats = std.unsqueeze(1)
260
+
261
+ return pooled_stats
262
+
263
+ def _get_gauss_noise(self, shape_of_tensor, device="cpu"):
264
+ """Returns a tensor of epsilon Gaussian noise.
265
+
266
+ Arguments
267
+ ---------
268
+ shape_of_tensor : tensor
269
+ It represents the size of tensor for generating Gaussian noise.
270
+ """
271
+ gnoise = torch.randn(shape_of_tensor, device=device)
272
+ gnoise -= torch.min(gnoise)
273
+ gnoise /= torch.max(gnoise)
274
+ gnoise = self.eps * ((1 - 9) * gnoise + 9)
275
+
276
+ return gnoise
277
+
278
+
279
+ class PureQwen2PreTrainedModel(PreTrainedModel):
280
+ """
281
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
282
+ models.
283
+ """
284
+
285
+ config_class = PureQwen2Config
286
+ base_model_prefix = "model"
287
+ supports_gradient_checkpointing = True
288
+ _no_split_modules = ["Qwen2DecoderLayer"]
289
+ _skip_keys_device_placement = "past_key_values"
290
+ _supports_flash_attn_2 = True
291
+ _supports_sdpa = True
292
+ _supports_cache_class = True
293
+ _supports_quantized_cache = True
294
+ _supports_static_cache = True
295
+
296
+ def _init_weights(self, module):
297
+ std = self.config.initializer_range
298
+ if isinstance(module, nn.Linear):
299
+ module.weight.data.normal_(mean=0.0, std=std)
300
+ if module.bias is not None:
301
+ module.bias.data.zero_()
302
+ elif isinstance(module, nn.Embedding):
303
+ module.weight.data.normal_(mean=0.0, std=std)
304
+ if module.padding_idx is not None:
305
+ module.weight.data[module.padding_idx].zero_()
306
+
307
+
308
+ class PureQwen2Model(PureQwen2PreTrainedModel):
309
+ """
310
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`]
311
+
312
+ Args:
313
+ config: Qwen2Config
314
+ """
315
+
316
+ def __init__(self, config: PureQwen2Config):
317
+ super().__init__(config)
318
+ self.padding_idx = config.pad_token_id
319
+ self.vocab_size = config.vocab_size
320
+
321
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
322
+ self.layers = nn.ModuleList(
323
+ [Qwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
324
+ )
325
+ self._attn_implementation = config._attn_implementation
326
+ self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
327
+ self.rotary_emb = Qwen2RotaryEmbedding(config=config)
328
+
329
+ self.gradient_checkpointing = False
330
+ # Initialize weights and apply final processing
331
+ self.post_init()
332
+
333
+ def get_input_embeddings(self):
334
+ return self.embed_tokens
335
+
336
+ def set_input_embeddings(self, value):
337
+ self.embed_tokens = value
338
+
339
+ def forward(
340
+ self,
341
+ input_ids: torch.LongTensor = None,
342
+ attention_mask: Optional[torch.Tensor] = None,
343
+ position_ids: Optional[torch.LongTensor] = None,
344
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
345
+ inputs_embeds: Optional[torch.FloatTensor] = None,
346
+ use_cache: Optional[bool] = None,
347
+ output_attentions: Optional[bool] = None,
348
+ output_hidden_states: Optional[bool] = None,
349
+ return_dict: Optional[bool] = None,
350
+ cache_position: Optional[torch.LongTensor] = None,
351
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
352
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
353
+ output_hidden_states = (
354
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
355
+ )
356
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
357
+
358
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
359
+
360
+ if (input_ids is None) ^ (inputs_embeds is not None):
361
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
362
+
363
+ if self.gradient_checkpointing and self.training:
364
+ if use_cache:
365
+ logger.warning_once(
366
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
367
+ )
368
+ use_cache = False
369
+
370
+ # kept for BC (non `Cache` `past_key_values` inputs)
371
+ return_legacy_cache = False
372
+ if use_cache and not isinstance(past_key_values, Cache):
373
+ return_legacy_cache = True
374
+ if past_key_values is None:
375
+ past_key_values = DynamicCache()
376
+ else:
377
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
378
+ logger.warning_once(
379
+ "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
380
+ "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
381
+ "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
382
+ )
383
+
384
+ if inputs_embeds is None:
385
+ inputs_embeds = self.embed_tokens(input_ids)
386
+
387
+ if cache_position is None:
388
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
389
+ cache_position = torch.arange(
390
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
391
+ )
392
+ if position_ids is None:
393
+ position_ids = cache_position.unsqueeze(0)
394
+
395
+ causal_mask = self._update_causal_mask(
396
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
397
+ )
398
+
399
+ hidden_states = inputs_embeds
400
+
401
+ # create position embeddings to be shared across the decoder layers
402
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
403
+
404
+ # decoder layers
405
+ all_hidden_states = () if output_hidden_states else None
406
+ all_self_attns = () if output_attentions else None
407
+ next_decoder_cache = None
408
+
409
+ for decoder_layer in self.layers:
410
+ if output_hidden_states:
411
+ all_hidden_states += (hidden_states,)
412
+
413
+ if self.gradient_checkpointing and self.training:
414
+ layer_outputs = self._gradient_checkpointing_func(
415
+ decoder_layer.__call__,
416
+ hidden_states,
417
+ causal_mask,
418
+ position_ids,
419
+ past_key_values,
420
+ output_attentions,
421
+ use_cache,
422
+ cache_position,
423
+ position_embeddings,
424
+ )
425
+ else:
426
+ layer_outputs = decoder_layer(
427
+ hidden_states,
428
+ attention_mask=causal_mask,
429
+ position_ids=position_ids,
430
+ past_key_value=past_key_values,
431
+ output_attentions=output_attentions,
432
+ use_cache=use_cache,
433
+ cache_position=cache_position,
434
+ position_embeddings=position_embeddings,
435
+ )
436
+
437
+ hidden_states = layer_outputs[0]
438
+
439
+ if use_cache:
440
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
441
+
442
+ if output_attentions:
443
+ all_self_attns += (layer_outputs[1],)
444
+
445
+ hidden_states = self.norm(hidden_states)
446
+
447
+ # add hidden states from the last decoder layer
448
+ if output_hidden_states:
449
+ all_hidden_states += (hidden_states,)
450
+
451
+ next_cache = next_decoder_cache if use_cache else None
452
+ if return_legacy_cache:
453
+ next_cache = next_cache.to_legacy_cache()
454
+
455
+ if not return_dict:
456
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
457
+ return BaseModelOutputWithPast(
458
+ last_hidden_state=hidden_states,
459
+ past_key_values=next_cache,
460
+ hidden_states=all_hidden_states,
461
+ attentions=all_self_attns,
462
+ )
463
+
464
+
465
+ class PureQwen2ForSequenceClassification(PureQwen2PreTrainedModel):
466
+
467
+ def __init__(
468
+ self,
469
+ config,
470
+ label_smoothing=0.0,
471
+ ):
472
+ super().__init__(config)
473
+ self.label_smoothing = label_smoothing
474
+ self.num_labels = config.num_labels
475
+ self.config = config
476
+
477
+ self.model = PureQwen2Model(config)
478
+ self.pure = PURE(
479
+ in_dim=config.hidden_size,
480
+ svd_rank=config.svd_rank,
481
+ num_pc_to_remove=config.num_pc_to_remove,
482
+ center=config.center,
483
+ num_iters=config.num_iters,
484
+ alpha=config.alpha,
485
+ disable_pcr=config.disable_pcr,
486
+ disable_pfsa=config.disable_pfsa,
487
+ disable_covariance=config.disable_covariance
488
+ )
489
+ self.mean = StatisticsPooling(return_mean=True, return_std=False)
490
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
491
+
492
+ # Initialize weights and apply final processing
493
+ self.post_init()
494
+
495
+ def forward_pure_embeddings(
496
+ self,
497
+ input_ids: torch.LongTensor = None,
498
+ attention_mask: Optional[torch.Tensor] = None,
499
+ position_ids: Optional[torch.LongTensor] = None,
500
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
501
+ inputs_embeds: Optional[torch.FloatTensor] = None,
502
+ labels: Optional[torch.LongTensor] = None,
503
+ use_cache: Optional[bool] = None,
504
+ output_attentions: Optional[bool] = None,
505
+ output_hidden_states: Optional[bool] = None,
506
+ return_dict: Optional[bool] = None,
507
+ ) -> Union[Tuple, ModelOutput]:
508
+ r"""
509
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
510
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
511
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
512
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
513
+ """
514
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
515
+
516
+ transformer_outputs = self.model(
517
+ input_ids,
518
+ attention_mask=attention_mask,
519
+ position_ids=position_ids,
520
+ past_key_values=past_key_values,
521
+ inputs_embeds=inputs_embeds,
522
+ use_cache=use_cache,
523
+ output_attentions=output_attentions,
524
+ output_hidden_states=output_hidden_states,
525
+ return_dict=return_dict,
526
+ )
527
+ token_embeddings = transformer_outputs[0]
528
+
529
+ token_embeddings = self.pure(token_embeddings, attention_mask)
530
+
531
+ return ModelOutput(
532
+ last_hidden_state=token_embeddings,
533
+ )
534
+
535
+ def forward(
536
+ self,
537
+ input_ids: torch.LongTensor = None,
538
+ attention_mask: Optional[torch.Tensor] = None,
539
+ position_ids: Optional[torch.LongTensor] = None,
540
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
541
+ inputs_embeds: Optional[torch.FloatTensor] = None,
542
+ labels: Optional[torch.LongTensor] = None,
543
+ use_cache: Optional[bool] = None,
544
+ output_attentions: Optional[bool] = None,
545
+ output_hidden_states: Optional[bool] = None,
546
+ return_dict: Optional[bool] = None,
547
+ # position_ids: Optional[torch.LongTensor] = None,
548
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
549
+ r"""
550
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
551
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
552
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
553
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
554
+ """
555
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
556
+
557
+ outputs = self.model(
558
+ input_ids,
559
+ attention_mask=attention_mask,
560
+ position_ids=position_ids,
561
+ past_key_values=past_key_values,
562
+ inputs_embeds=inputs_embeds,
563
+ use_cache=use_cache,
564
+ output_attentions=output_attentions,
565
+ output_hidden_states=output_hidden_states,
566
+ return_dict=return_dict,
567
+ )
568
+ token_embeddings = outputs[0]
569
+
570
+ token_embeddings = self.pure(token_embeddings, attention_mask)
571
+ pooled_output = self.mean(token_embeddings).squeeze(1)
572
+ logits = self.score(pooled_output)
573
+
574
+ loss = None
575
+ if labels is not None:
576
+ if self.config.problem_type is None:
577
+ if self.num_labels == 1:
578
+ self.config.problem_type = "regression"
579
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
580
+ self.config.problem_type = "single_label_classification"
581
+ else:
582
+ self.config.problem_type = "multi_label_classification"
583
+
584
+ if self.config.problem_type == "regression":
585
+ loss_fct = nn.MSELoss()
586
+ if self.num_labels == 1:
587
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
588
+ else:
589
+ loss = loss_fct(logits, labels)
590
+ elif self.config.problem_type == "single_label_classification":
591
+ loss_fct = nn.CrossEntropyLoss(label_smoothing=self.label_smoothing)
592
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
593
+ elif self.config.problem_type == "multi_label_classification":
594
+ loss_fct = nn.BCEWithLogitsLoss()
595
+ loss = loss_fct(logits, labels)
596
+ if not return_dict:
597
+ output = (logits,) + outputs[2:]
598
+ return ((loss,) + output) if loss is not None else output
599
+
600
+ return SequenceClassifierOutput(
601
+ loss=loss,
602
+ logits=logits,
603
+ hidden_states=outputs.hidden_states,
604
+ attentions=outputs.attentions,
605
+ )