yangwang825
commited on
Upload model
Browse files- config.json +2 -1
- model.safetensors +2 -2
- modeling_pure_bert.py +176 -171
config.json
CHANGED
@@ -6,7 +6,8 @@
|
|
6 |
],
|
7 |
"attention_probs_dropout_prob": 0.1,
|
8 |
"auto_map": {
|
9 |
-
"AutoConfig": "configuration_pure_bert.PureBertConfig"
|
|
|
10 |
},
|
11 |
"center": false,
|
12 |
"classifier_dropout": null,
|
|
|
6 |
],
|
7 |
"attention_probs_dropout_prob": 0.1,
|
8 |
"auto_map": {
|
9 |
+
"AutoConfig": "configuration_pure_bert.PureBertConfig",
|
10 |
+
"AutoModel": "modeling_pure_bert.BertModel"
|
11 |
},
|
12 |
"center": false,
|
13 |
"classifier_dropout": null,
|
model.safetensors
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ae797db87034daa398c5027948c4c3e0a1ce3e8a0a4e1c2743cfad3122603c18
|
3 |
+
size 437951328
|
modeling_pure_bert.py
CHANGED
@@ -2,22 +2,25 @@ import torch
|
|
2 |
import torch.nn as nn
|
3 |
import numpy as np
|
4 |
from torch.autograd import Function
|
5 |
-
from transformers import
|
6 |
-
|
7 |
-
|
8 |
)
|
9 |
-
from typing import Union, Tuple, Optional
|
10 |
from transformers.modeling_outputs import (
|
11 |
SequenceClassifierOutput,
|
12 |
MultipleChoiceModelOutput,
|
13 |
-
QuestionAnsweringModelOutput
|
|
|
|
|
|
|
|
|
|
|
14 |
)
|
15 |
from transformers.utils import ModelOutput
|
16 |
|
17 |
from .configuration_pure_bert import PureBertConfig
|
18 |
|
19 |
-
PureBertModel = BertModel
|
20 |
-
|
21 |
|
22 |
class CovarianceFunction(Function):
|
23 |
|
@@ -302,23 +305,50 @@ class PureBertPreTrainedModel(PreTrainedModel):
|
|
302 |
module.weight.data.fill_(1.0)
|
303 |
|
304 |
|
305 |
-
class
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
306 |
|
307 |
def __init__(self, config, add_pooling_layer=True):
|
308 |
super().__init__(config)
|
309 |
-
self.num_labels = config.num_labels
|
310 |
self.config = config
|
311 |
|
312 |
-
self.
|
313 |
-
|
314 |
-
|
315 |
-
)
|
316 |
-
|
317 |
-
self.
|
|
|
318 |
|
319 |
# Initialize weights and apply final processing
|
320 |
self.post_init()
|
321 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
322 |
def forward(
|
323 |
self,
|
324 |
input_ids: Optional[torch.Tensor] = None,
|
@@ -327,183 +357,158 @@ class BertClsForSequenceClassification(PureBertPreTrainedModel):
|
|
327 |
position_ids: Optional[torch.Tensor] = None,
|
328 |
head_mask: Optional[torch.Tensor] = None,
|
329 |
inputs_embeds: Optional[torch.Tensor] = None,
|
330 |
-
|
|
|
|
|
|
|
331 |
output_attentions: Optional[bool] = None,
|
332 |
output_hidden_states: Optional[bool] = None,
|
333 |
return_dict: Optional[bool] = None,
|
334 |
-
) -> Union[Tuple[torch.Tensor],
|
335 |
r"""
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
340 |
"""
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
input_ids,
|
345 |
-
attention_mask=attention_mask,
|
346 |
-
token_type_ids=token_type_ids,
|
347 |
-
position_ids=position_ids,
|
348 |
-
head_mask=head_mask,
|
349 |
-
inputs_embeds=inputs_embeds,
|
350 |
-
output_attentions=output_attentions,
|
351 |
-
output_hidden_states=output_hidden_states,
|
352 |
-
return_dict=return_dict,
|
353 |
)
|
|
|
354 |
|
355 |
-
|
356 |
-
|
357 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
358 |
|
359 |
-
|
360 |
-
|
361 |
|
362 |
-
|
363 |
-
if
|
364 |
-
if self.config.problem_type is None:
|
365 |
-
if self.num_labels == 1:
|
366 |
-
self.config.problem_type = "regression"
|
367 |
-
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
368 |
-
self.config.problem_type = "single_label_classification"
|
369 |
-
else:
|
370 |
-
self.config.problem_type = "multi_label_classification"
|
371 |
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
loss_fct = nn.CrossEntropyLoss()
|
380 |
-
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
381 |
-
elif self.config.problem_type == "multi_label_classification":
|
382 |
-
loss_fct = nn.BCEWithLogitsLoss()
|
383 |
-
loss = loss_fct(logits, labels)
|
384 |
-
if not return_dict:
|
385 |
-
output = (logits,) + outputs[2:]
|
386 |
-
return ((loss,) + output) if loss is not None else output
|
387 |
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
|
|
|
393 |
)
|
394 |
|
|
|
|
|
395 |
|
396 |
-
|
397 |
-
|
398 |
-
|
399 |
-
|
400 |
-
|
401 |
-
self.alpha = alpha
|
402 |
-
self.label_smoothing = label_smoothing
|
403 |
-
self.config = config
|
404 |
-
|
405 |
-
self.bert = PureBertModel(config)
|
406 |
-
classifier_dropout = (
|
407 |
-
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
|
408 |
)
|
409 |
-
self.dropout = nn.Dropout(classifier_dropout)
|
410 |
-
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
411 |
-
|
412 |
-
# Initialize weights and apply final processing
|
413 |
-
self.post_init()
|
414 |
|
415 |
-
|
416 |
-
|
417 |
-
|
418 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
419 |
else:
|
420 |
-
|
421 |
-
|
422 |
-
|
423 |
-
|
424 |
-
|
425 |
-
|
426 |
-
|
427 |
-
|
428 |
-
|
429 |
-
|
430 |
-
|
431 |
-
|
432 |
-
|
433 |
-
|
434 |
-
|
435 |
-
|
436 |
-
|
437 |
-
|
438 |
-
|
439 |
-
|
440 |
-
|
441 |
-
|
442 |
-
|
443 |
-
|
444 |
-
|
445 |
-
|
446 |
-
|
447 |
-
|
448 |
-
|
449 |
-
|
450 |
-
|
451 |
-
|
452 |
-
attention_mask=
|
453 |
-
token_type_ids=token_type_ids,
|
454 |
-
position_ids=position_ids,
|
455 |
head_mask=head_mask,
|
456 |
-
|
|
|
|
|
|
|
457 |
output_attentions=output_attentions,
|
458 |
output_hidden_states=output_hidden_states,
|
459 |
return_dict=return_dict,
|
460 |
)
|
|
|
|
|
461 |
|
462 |
-
if self.training:
|
463 |
-
mixed_embeddings, targets_a, targets_b, lam = self.mixup_data(outputs.pooler_output, labels, self.alpha)
|
464 |
-
mixed_embeddings = self.dropout(mixed_embeddings)
|
465 |
-
logits = self.classifier(mixed_embeddings)
|
466 |
-
else:
|
467 |
-
pooler_output = self.dropout(outputs.pooler_output)
|
468 |
-
logits = self.classifier(pooler_output)
|
469 |
-
|
470 |
-
loss = None
|
471 |
-
if labels is not None:
|
472 |
-
if self.config.problem_type is None:
|
473 |
-
if self.num_labels == 1:
|
474 |
-
self.config.problem_type = "regression"
|
475 |
-
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
476 |
-
self.config.problem_type = "single_label_classification"
|
477 |
-
else:
|
478 |
-
self.config.problem_type = "multi_label_classification"
|
479 |
-
|
480 |
-
if self.config.problem_type == "regression":
|
481 |
-
loss_fct = nn.MSELoss()
|
482 |
-
if self.num_labels == 1:
|
483 |
-
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
484 |
-
else:
|
485 |
-
loss = loss_fct(logits, labels)
|
486 |
-
elif self.config.problem_type == "single_label_classification":
|
487 |
-
loss_fct = nn.CrossEntropyLoss(label_smoothing=self.label_smoothing)
|
488 |
-
logits = logits.view(-1, self.num_labels)
|
489 |
-
if self.training:
|
490 |
-
targets_a = targets_a.view(-1)
|
491 |
-
targets_b = targets_b.view(-1)
|
492 |
-
loss = lam * loss_fct(logits, targets_a) + (1 - lam) * loss_fct(logits, targets_b)
|
493 |
-
else:
|
494 |
-
loss = loss_fct(logits, labels.view(-1))
|
495 |
-
elif self.config.problem_type == "multi_label_classification":
|
496 |
-
loss_fct = nn.BCEWithLogitsLoss()
|
497 |
-
loss = loss_fct(logits, labels)
|
498 |
if not return_dict:
|
499 |
-
|
500 |
-
|
501 |
-
|
502 |
-
|
503 |
-
|
504 |
-
|
505 |
-
hidden_states=
|
506 |
-
attentions=
|
|
|
507 |
)
|
508 |
|
509 |
|
@@ -519,7 +524,7 @@ class PureBertForSequenceClassification(PureBertPreTrainedModel):
|
|
519 |
self.num_labels = config.num_labels
|
520 |
self.config = config
|
521 |
|
522 |
-
self.bert =
|
523 |
classifier_dropout = (
|
524 |
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
|
525 |
)
|
@@ -664,7 +669,7 @@ class PureBertForMultipleChoice(PureBertPreTrainedModel):
|
|
664 |
super().__init__(config)
|
665 |
self.label_smoothing = label_smoothing
|
666 |
|
667 |
-
self.bert =
|
668 |
self.pure = PURE(
|
669 |
in_dim=config.hidden_size,
|
670 |
svd_rank=config.svd_rank,
|
@@ -766,7 +771,7 @@ class PureBertForQuestionAnswering(PureBertPreTrainedModel):
|
|
766 |
self.num_labels = config.num_labels
|
767 |
self.label_smoothing = label_smoothing
|
768 |
|
769 |
-
self.bert =
|
770 |
self.pure = PURE(
|
771 |
in_dim=config.hidden_size,
|
772 |
svd_rank=config.svd_rank,
|
|
|
2 |
import torch.nn as nn
|
3 |
import numpy as np
|
4 |
from torch.autograd import Function
|
5 |
+
from transformers import PreTrainedModel
|
6 |
+
from transformers.models.bert.modeling_bert import (
|
7 |
+
BertEmbeddings, BertEncoder, BertPooler
|
8 |
)
|
9 |
+
from typing import Union, Tuple, Optional, List
|
10 |
from transformers.modeling_outputs import (
|
11 |
SequenceClassifierOutput,
|
12 |
MultipleChoiceModelOutput,
|
13 |
+
QuestionAnsweringModelOutput,
|
14 |
+
BaseModelOutputWithPoolingAndCrossAttentions
|
15 |
+
)
|
16 |
+
from transformers.modeling_attn_mask_utils import (
|
17 |
+
_prepare_4d_attention_mask_for_sdpa,
|
18 |
+
_prepare_4d_causal_attention_mask_for_sdpa,
|
19 |
)
|
20 |
from transformers.utils import ModelOutput
|
21 |
|
22 |
from .configuration_pure_bert import PureBertConfig
|
23 |
|
|
|
|
|
24 |
|
25 |
class CovarianceFunction(Function):
|
26 |
|
|
|
305 |
module.weight.data.fill_(1.0)
|
306 |
|
307 |
|
308 |
+
class BertModel(PureBertPreTrainedModel):
|
309 |
+
"""
|
310 |
+
|
311 |
+
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
|
312 |
+
cross-attention is added between the self-attention layers, following the architecture described in [Attention is
|
313 |
+
all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
|
314 |
+
Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
|
315 |
+
|
316 |
+
To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
|
317 |
+
to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
|
318 |
+
`add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
|
319 |
+
"""
|
320 |
+
|
321 |
+
_no_split_modules = ["BertEmbeddings", "BertLayer"]
|
322 |
|
323 |
def __init__(self, config, add_pooling_layer=True):
|
324 |
super().__init__(config)
|
|
|
325 |
self.config = config
|
326 |
|
327 |
+
self.embeddings = BertEmbeddings(config)
|
328 |
+
self.encoder = BertEncoder(config)
|
329 |
+
|
330 |
+
self.pooler = BertPooler(config) if add_pooling_layer else None
|
331 |
+
|
332 |
+
self.attn_implementation = config._attn_implementation
|
333 |
+
self.position_embedding_type = config.position_embedding_type
|
334 |
|
335 |
# Initialize weights and apply final processing
|
336 |
self.post_init()
|
337 |
|
338 |
+
def get_input_embeddings(self):
|
339 |
+
return self.embeddings.word_embeddings
|
340 |
+
|
341 |
+
def set_input_embeddings(self, value):
|
342 |
+
self.embeddings.word_embeddings = value
|
343 |
+
|
344 |
+
def _prune_heads(self, heads_to_prune):
|
345 |
+
"""
|
346 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
347 |
+
class PreTrainedModel
|
348 |
+
"""
|
349 |
+
for layer, heads in heads_to_prune.items():
|
350 |
+
self.encoder.layer[layer].attention.prune_heads(heads)
|
351 |
+
|
352 |
def forward(
|
353 |
self,
|
354 |
input_ids: Optional[torch.Tensor] = None,
|
|
|
357 |
position_ids: Optional[torch.Tensor] = None,
|
358 |
head_mask: Optional[torch.Tensor] = None,
|
359 |
inputs_embeds: Optional[torch.Tensor] = None,
|
360 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
361 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
362 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
363 |
+
use_cache: Optional[bool] = None,
|
364 |
output_attentions: Optional[bool] = None,
|
365 |
output_hidden_states: Optional[bool] = None,
|
366 |
return_dict: Optional[bool] = None,
|
367 |
+
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
|
368 |
r"""
|
369 |
+
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
370 |
+
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
371 |
+
the model is configured as a decoder.
|
372 |
+
encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, target_length)`, *optional*):
|
373 |
+
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
374 |
+
the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
|
375 |
+
|
376 |
+
- 1 for tokens that are **not masked**,
|
377 |
+
- 0 for tokens that are **masked**.
|
378 |
+
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
379 |
+
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
380 |
+
|
381 |
+
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
|
382 |
+
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
|
383 |
+
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
384 |
+
use_cache (`bool`, *optional*):
|
385 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
386 |
+
`past_key_values`).
|
387 |
"""
|
388 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
389 |
+
output_hidden_states = (
|
390 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
391 |
)
|
392 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
393 |
|
394 |
+
if self.config.is_decoder:
|
395 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
396 |
+
else:
|
397 |
+
use_cache = False
|
398 |
+
|
399 |
+
if input_ids is not None and inputs_embeds is not None:
|
400 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
401 |
+
elif input_ids is not None:
|
402 |
+
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
|
403 |
+
input_shape = input_ids.size()
|
404 |
+
elif inputs_embeds is not None:
|
405 |
+
input_shape = inputs_embeds.size()[:-1]
|
406 |
+
else:
|
407 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
408 |
|
409 |
+
batch_size, seq_length = input_shape
|
410 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
411 |
|
412 |
+
# past_key_values_length
|
413 |
+
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
414 |
|
415 |
+
if token_type_ids is None:
|
416 |
+
if hasattr(self.embeddings, "token_type_ids"):
|
417 |
+
buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
|
418 |
+
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
|
419 |
+
token_type_ids = buffered_token_type_ids_expanded
|
420 |
+
else:
|
421 |
+
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
422 |
|
423 |
+
embedding_output = self.embeddings(
|
424 |
+
input_ids=input_ids,
|
425 |
+
position_ids=position_ids,
|
426 |
+
token_type_ids=token_type_ids,
|
427 |
+
inputs_embeds=inputs_embeds,
|
428 |
+
past_key_values_length=past_key_values_length,
|
429 |
)
|
430 |
|
431 |
+
if attention_mask is None:
|
432 |
+
attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=device)
|
433 |
|
434 |
+
use_sdpa_attention_masks = (
|
435 |
+
self.attn_implementation == "sdpa"
|
436 |
+
and self.position_embedding_type == "absolute"
|
437 |
+
and head_mask is None
|
438 |
+
and not output_attentions
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
439 |
)
|
|
|
|
|
|
|
|
|
|
|
440 |
|
441 |
+
# Expand the attention mask
|
442 |
+
if use_sdpa_attention_masks and attention_mask.dim() == 2:
|
443 |
+
# Expand the attention mask for SDPA.
|
444 |
+
# [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
|
445 |
+
if self.config.is_decoder:
|
446 |
+
extended_attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
447 |
+
attention_mask,
|
448 |
+
input_shape,
|
449 |
+
embedding_output,
|
450 |
+
past_key_values_length,
|
451 |
+
)
|
452 |
+
else:
|
453 |
+
extended_attention_mask = _prepare_4d_attention_mask_for_sdpa(
|
454 |
+
attention_mask, embedding_output.dtype, tgt_len=seq_length
|
455 |
+
)
|
456 |
else:
|
457 |
+
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
458 |
+
# ourselves in which case we just need to make it broadcastable to all heads.
|
459 |
+
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
|
460 |
+
|
461 |
+
# If a 2D or 3D attention mask is provided for the cross-attention
|
462 |
+
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
463 |
+
if self.config.is_decoder and encoder_hidden_states is not None:
|
464 |
+
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
465 |
+
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
466 |
+
if encoder_attention_mask is None:
|
467 |
+
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
468 |
+
|
469 |
+
if use_sdpa_attention_masks and encoder_attention_mask.dim() == 2:
|
470 |
+
# Expand the attention mask for SDPA.
|
471 |
+
# [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
|
472 |
+
encoder_extended_attention_mask = _prepare_4d_attention_mask_for_sdpa(
|
473 |
+
encoder_attention_mask, embedding_output.dtype, tgt_len=seq_length
|
474 |
+
)
|
475 |
+
else:
|
476 |
+
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
477 |
+
else:
|
478 |
+
encoder_extended_attention_mask = None
|
479 |
+
|
480 |
+
# Prepare head mask if needed
|
481 |
+
# 1.0 in head_mask indicate we keep the head
|
482 |
+
# attention_probs has shape bsz x n_heads x N x N
|
483 |
+
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
484 |
+
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
485 |
+
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
486 |
+
|
487 |
+
encoder_outputs = self.encoder(
|
488 |
+
embedding_output,
|
489 |
+
attention_mask=extended_attention_mask,
|
|
|
|
|
490 |
head_mask=head_mask,
|
491 |
+
encoder_hidden_states=encoder_hidden_states,
|
492 |
+
encoder_attention_mask=encoder_extended_attention_mask,
|
493 |
+
past_key_values=past_key_values,
|
494 |
+
use_cache=use_cache,
|
495 |
output_attentions=output_attentions,
|
496 |
output_hidden_states=output_hidden_states,
|
497 |
return_dict=return_dict,
|
498 |
)
|
499 |
+
sequence_output = encoder_outputs[0]
|
500 |
+
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
501 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
502 |
if not return_dict:
|
503 |
+
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
504 |
+
|
505 |
+
return BaseModelOutputWithPoolingAndCrossAttentions(
|
506 |
+
last_hidden_state=sequence_output,
|
507 |
+
pooler_output=pooled_output,
|
508 |
+
past_key_values=encoder_outputs.past_key_values,
|
509 |
+
hidden_states=encoder_outputs.hidden_states,
|
510 |
+
attentions=encoder_outputs.attentions,
|
511 |
+
cross_attentions=encoder_outputs.cross_attentions,
|
512 |
)
|
513 |
|
514 |
|
|
|
524 |
self.num_labels = config.num_labels
|
525 |
self.config = config
|
526 |
|
527 |
+
self.bert = BertModel(config, add_pooling_layer=False)
|
528 |
classifier_dropout = (
|
529 |
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
|
530 |
)
|
|
|
669 |
super().__init__(config)
|
670 |
self.label_smoothing = label_smoothing
|
671 |
|
672 |
+
self.bert = BertModel(config)
|
673 |
self.pure = PURE(
|
674 |
in_dim=config.hidden_size,
|
675 |
svd_rank=config.svd_rank,
|
|
|
771 |
self.num_labels = config.num_labels
|
772 |
self.label_smoothing = label_smoothing
|
773 |
|
774 |
+
self.bert = BertModel(config, add_pooling_layer=False)
|
775 |
self.pure = PURE(
|
776 |
in_dim=config.hidden_size,
|
777 |
svd_rank=config.svd_rank,
|