yangwang825 commited on
Commit
c7e1ee9
·
verified ·
1 Parent(s): a1ef643

Upload model

Browse files
Files changed (3) hide show
  1. config.json +2 -1
  2. model.safetensors +2 -2
  3. modeling_pure_bert.py +176 -171
config.json CHANGED
@@ -6,7 +6,8 @@
6
  ],
7
  "attention_probs_dropout_prob": 0.1,
8
  "auto_map": {
9
- "AutoConfig": "configuration_pure_bert.PureBertConfig"
 
10
  },
11
  "center": false,
12
  "classifier_dropout": null,
 
6
  ],
7
  "attention_probs_dropout_prob": 0.1,
8
  "auto_map": {
9
+ "AutoConfig": "configuration_pure_bert.PureBertConfig",
10
+ "AutoModel": "modeling_pure_bert.BertModel"
11
  },
12
  "center": false,
13
  "classifier_dropout": null,
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:f5af4e46d93e437b595b0320570f60c9d593955977b465c9e1fca83129f2308d
3
- size 435593012
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ae797db87034daa398c5027948c4c3e0a1ce3e8a0a4e1c2743cfad3122603c18
3
+ size 437951328
modeling_pure_bert.py CHANGED
@@ -2,22 +2,25 @@ import torch
2
  import torch.nn as nn
3
  import numpy as np
4
  from torch.autograd import Function
5
- from transformers import (
6
- BertModel,
7
- PreTrainedModel,
8
  )
9
- from typing import Union, Tuple, Optional
10
  from transformers.modeling_outputs import (
11
  SequenceClassifierOutput,
12
  MultipleChoiceModelOutput,
13
- QuestionAnsweringModelOutput
 
 
 
 
 
14
  )
15
  from transformers.utils import ModelOutput
16
 
17
  from .configuration_pure_bert import PureBertConfig
18
 
19
- PureBertModel = BertModel
20
-
21
 
22
  class CovarianceFunction(Function):
23
 
@@ -302,23 +305,50 @@ class PureBertPreTrainedModel(PreTrainedModel):
302
  module.weight.data.fill_(1.0)
303
 
304
 
305
- class BertClsForSequenceClassification(PureBertPreTrainedModel):
 
 
 
 
 
 
 
 
 
 
 
 
 
306
 
307
  def __init__(self, config, add_pooling_layer=True):
308
  super().__init__(config)
309
- self.num_labels = config.num_labels
310
  self.config = config
311
 
312
- self.bert = PureBertModel(config, add_pooling_layer=add_pooling_layer)
313
- classifier_dropout = (
314
- config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
315
- )
316
- self.dropout = nn.Dropout(classifier_dropout)
317
- self.classifier = nn.Linear(config.hidden_size, config.num_labels)
 
318
 
319
  # Initialize weights and apply final processing
320
  self.post_init()
321
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
322
  def forward(
323
  self,
324
  input_ids: Optional[torch.Tensor] = None,
@@ -327,183 +357,158 @@ class BertClsForSequenceClassification(PureBertPreTrainedModel):
327
  position_ids: Optional[torch.Tensor] = None,
328
  head_mask: Optional[torch.Tensor] = None,
329
  inputs_embeds: Optional[torch.Tensor] = None,
330
- labels: Optional[torch.Tensor] = None,
 
 
 
331
  output_attentions: Optional[bool] = None,
332
  output_hidden_states: Optional[bool] = None,
333
  return_dict: Optional[bool] = None,
334
- ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
335
  r"""
336
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
337
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
338
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
339
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
340
  """
341
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
342
-
343
- outputs = self.bert(
344
- input_ids,
345
- attention_mask=attention_mask,
346
- token_type_ids=token_type_ids,
347
- position_ids=position_ids,
348
- head_mask=head_mask,
349
- inputs_embeds=inputs_embeds,
350
- output_attentions=output_attentions,
351
- output_hidden_states=output_hidden_states,
352
- return_dict=return_dict,
353
  )
 
354
 
355
- pooled_output = outputs.pooler_output
356
- if pooled_output is None:
357
- pooled_output = outputs.last_hidden_state[:, 0, :]
 
 
 
 
 
 
 
 
 
 
 
358
 
359
- pooled_output = self.dropout(pooled_output)
360
- logits = self.classifier(pooled_output)
361
 
362
- loss = None
363
- if labels is not None:
364
- if self.config.problem_type is None:
365
- if self.num_labels == 1:
366
- self.config.problem_type = "regression"
367
- elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
368
- self.config.problem_type = "single_label_classification"
369
- else:
370
- self.config.problem_type = "multi_label_classification"
371
 
372
- if self.config.problem_type == "regression":
373
- loss_fct = nn.MSELoss()
374
- if self.num_labels == 1:
375
- loss = loss_fct(logits.squeeze(), labels.squeeze())
376
- else:
377
- loss = loss_fct(logits, labels)
378
- elif self.config.problem_type == "single_label_classification":
379
- loss_fct = nn.CrossEntropyLoss()
380
- loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
381
- elif self.config.problem_type == "multi_label_classification":
382
- loss_fct = nn.BCEWithLogitsLoss()
383
- loss = loss_fct(logits, labels)
384
- if not return_dict:
385
- output = (logits,) + outputs[2:]
386
- return ((loss,) + output) if loss is not None else output
387
 
388
- return SequenceClassifierOutput(
389
- loss=loss,
390
- logits=logits,
391
- hidden_states=outputs.hidden_states,
392
- attentions=outputs.attentions,
 
393
  )
394
 
 
 
395
 
396
- class BertMixupForSequenceClassification(PureBertPreTrainedModel):
397
-
398
- def __init__(self, config, alpha=1.0, label_smoothing=0.0):
399
- super().__init__(config)
400
- self.num_labels = config.num_labels
401
- self.alpha = alpha
402
- self.label_smoothing = label_smoothing
403
- self.config = config
404
-
405
- self.bert = PureBertModel(config)
406
- classifier_dropout = (
407
- config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
408
  )
409
- self.dropout = nn.Dropout(classifier_dropout)
410
- self.classifier = nn.Linear(config.hidden_size, config.num_labels)
411
-
412
- # Initialize weights and apply final processing
413
- self.post_init()
414
 
415
- def mixup_data(self, embeddings, labels, alpha=1.0):
416
- """Compute the mixup data. Returns mixed inputs, pairs of targets, and lambda"""
417
- if alpha > 0:
418
- lam = np.random.beta(alpha, alpha)
 
 
 
 
 
 
 
 
 
 
 
419
  else:
420
- lam = 1
421
-
422
- batch_size = embeddings.size()[0]
423
- index = torch.randperm(batch_size).to(embeddings.device)
424
-
425
- mixed_x = lam * embeddings + (1 - lam) * embeddings[index, :]
426
- y_a, y_b = labels, labels[index]
427
- return mixed_x, y_a, y_b, lam
428
-
429
- def forward(
430
- self,
431
- input_ids: Optional[torch.Tensor] = None,
432
- attention_mask: Optional[torch.Tensor] = None,
433
- token_type_ids: Optional[torch.Tensor] = None,
434
- position_ids: Optional[torch.Tensor] = None,
435
- head_mask: Optional[torch.Tensor] = None,
436
- inputs_embeds: Optional[torch.Tensor] = None,
437
- labels: Optional[torch.Tensor] = None,
438
- output_attentions: Optional[bool] = None,
439
- output_hidden_states: Optional[bool] = None,
440
- return_dict: Optional[bool] = None,
441
- ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
442
- r"""
443
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
444
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
445
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
446
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
447
- """
448
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
449
-
450
- outputs = self.bert(
451
- input_ids,
452
- attention_mask=attention_mask,
453
- token_type_ids=token_type_ids,
454
- position_ids=position_ids,
455
  head_mask=head_mask,
456
- inputs_embeds=inputs_embeds,
 
 
 
457
  output_attentions=output_attentions,
458
  output_hidden_states=output_hidden_states,
459
  return_dict=return_dict,
460
  )
 
 
461
 
462
- if self.training:
463
- mixed_embeddings, targets_a, targets_b, lam = self.mixup_data(outputs.pooler_output, labels, self.alpha)
464
- mixed_embeddings = self.dropout(mixed_embeddings)
465
- logits = self.classifier(mixed_embeddings)
466
- else:
467
- pooler_output = self.dropout(outputs.pooler_output)
468
- logits = self.classifier(pooler_output)
469
-
470
- loss = None
471
- if labels is not None:
472
- if self.config.problem_type is None:
473
- if self.num_labels == 1:
474
- self.config.problem_type = "regression"
475
- elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
476
- self.config.problem_type = "single_label_classification"
477
- else:
478
- self.config.problem_type = "multi_label_classification"
479
-
480
- if self.config.problem_type == "regression":
481
- loss_fct = nn.MSELoss()
482
- if self.num_labels == 1:
483
- loss = loss_fct(logits.squeeze(), labels.squeeze())
484
- else:
485
- loss = loss_fct(logits, labels)
486
- elif self.config.problem_type == "single_label_classification":
487
- loss_fct = nn.CrossEntropyLoss(label_smoothing=self.label_smoothing)
488
- logits = logits.view(-1, self.num_labels)
489
- if self.training:
490
- targets_a = targets_a.view(-1)
491
- targets_b = targets_b.view(-1)
492
- loss = lam * loss_fct(logits, targets_a) + (1 - lam) * loss_fct(logits, targets_b)
493
- else:
494
- loss = loss_fct(logits, labels.view(-1))
495
- elif self.config.problem_type == "multi_label_classification":
496
- loss_fct = nn.BCEWithLogitsLoss()
497
- loss = loss_fct(logits, labels)
498
  if not return_dict:
499
- output = (logits,) + outputs[2:]
500
- return ((loss,) + output) if loss is not None else output
501
-
502
- return SequenceClassifierOutput(
503
- loss=loss,
504
- logits=logits,
505
- hidden_states=outputs.hidden_states,
506
- attentions=outputs.attentions,
 
507
  )
508
 
509
 
@@ -519,7 +524,7 @@ class PureBertForSequenceClassification(PureBertPreTrainedModel):
519
  self.num_labels = config.num_labels
520
  self.config = config
521
 
522
- self.bert = PureBertModel(config, add_pooling_layer=False)
523
  classifier_dropout = (
524
  config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
525
  )
@@ -664,7 +669,7 @@ class PureBertForMultipleChoice(PureBertPreTrainedModel):
664
  super().__init__(config)
665
  self.label_smoothing = label_smoothing
666
 
667
- self.bert = PureBertModel(config)
668
  self.pure = PURE(
669
  in_dim=config.hidden_size,
670
  svd_rank=config.svd_rank,
@@ -766,7 +771,7 @@ class PureBertForQuestionAnswering(PureBertPreTrainedModel):
766
  self.num_labels = config.num_labels
767
  self.label_smoothing = label_smoothing
768
 
769
- self.bert = PureBertModel(config, add_pooling_layer=False)
770
  self.pure = PURE(
771
  in_dim=config.hidden_size,
772
  svd_rank=config.svd_rank,
 
2
  import torch.nn as nn
3
  import numpy as np
4
  from torch.autograd import Function
5
+ from transformers import PreTrainedModel
6
+ from transformers.models.bert.modeling_bert import (
7
+ BertEmbeddings, BertEncoder, BertPooler
8
  )
9
+ from typing import Union, Tuple, Optional, List
10
  from transformers.modeling_outputs import (
11
  SequenceClassifierOutput,
12
  MultipleChoiceModelOutput,
13
+ QuestionAnsweringModelOutput,
14
+ BaseModelOutputWithPoolingAndCrossAttentions
15
+ )
16
+ from transformers.modeling_attn_mask_utils import (
17
+ _prepare_4d_attention_mask_for_sdpa,
18
+ _prepare_4d_causal_attention_mask_for_sdpa,
19
  )
20
  from transformers.utils import ModelOutput
21
 
22
  from .configuration_pure_bert import PureBertConfig
23
 
 
 
24
 
25
  class CovarianceFunction(Function):
26
 
 
305
  module.weight.data.fill_(1.0)
306
 
307
 
308
+ class BertModel(PureBertPreTrainedModel):
309
+ """
310
+
311
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
312
+ cross-attention is added between the self-attention layers, following the architecture described in [Attention is
313
+ all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
314
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
315
+
316
+ To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
317
+ to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
318
+ `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
319
+ """
320
+
321
+ _no_split_modules = ["BertEmbeddings", "BertLayer"]
322
 
323
  def __init__(self, config, add_pooling_layer=True):
324
  super().__init__(config)
 
325
  self.config = config
326
 
327
+ self.embeddings = BertEmbeddings(config)
328
+ self.encoder = BertEncoder(config)
329
+
330
+ self.pooler = BertPooler(config) if add_pooling_layer else None
331
+
332
+ self.attn_implementation = config._attn_implementation
333
+ self.position_embedding_type = config.position_embedding_type
334
 
335
  # Initialize weights and apply final processing
336
  self.post_init()
337
 
338
+ def get_input_embeddings(self):
339
+ return self.embeddings.word_embeddings
340
+
341
+ def set_input_embeddings(self, value):
342
+ self.embeddings.word_embeddings = value
343
+
344
+ def _prune_heads(self, heads_to_prune):
345
+ """
346
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
347
+ class PreTrainedModel
348
+ """
349
+ for layer, heads in heads_to_prune.items():
350
+ self.encoder.layer[layer].attention.prune_heads(heads)
351
+
352
  def forward(
353
  self,
354
  input_ids: Optional[torch.Tensor] = None,
 
357
  position_ids: Optional[torch.Tensor] = None,
358
  head_mask: Optional[torch.Tensor] = None,
359
  inputs_embeds: Optional[torch.Tensor] = None,
360
+ encoder_hidden_states: Optional[torch.Tensor] = None,
361
+ encoder_attention_mask: Optional[torch.Tensor] = None,
362
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
363
+ use_cache: Optional[bool] = None,
364
  output_attentions: Optional[bool] = None,
365
  output_hidden_states: Optional[bool] = None,
366
  return_dict: Optional[bool] = None,
367
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
368
  r"""
369
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
370
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
371
+ the model is configured as a decoder.
372
+ encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, target_length)`, *optional*):
373
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
374
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
375
+
376
+ - 1 for tokens that are **not masked**,
377
+ - 0 for tokens that are **masked**.
378
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
379
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
380
+
381
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
382
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
383
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
384
+ use_cache (`bool`, *optional*):
385
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
386
+ `past_key_values`).
387
  """
388
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
389
+ output_hidden_states = (
390
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
 
 
 
 
 
 
 
 
391
  )
392
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
393
 
394
+ if self.config.is_decoder:
395
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
396
+ else:
397
+ use_cache = False
398
+
399
+ if input_ids is not None and inputs_embeds is not None:
400
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
401
+ elif input_ids is not None:
402
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
403
+ input_shape = input_ids.size()
404
+ elif inputs_embeds is not None:
405
+ input_shape = inputs_embeds.size()[:-1]
406
+ else:
407
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
408
 
409
+ batch_size, seq_length = input_shape
410
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
411
 
412
+ # past_key_values_length
413
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
 
 
 
 
 
 
 
414
 
415
+ if token_type_ids is None:
416
+ if hasattr(self.embeddings, "token_type_ids"):
417
+ buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
418
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
419
+ token_type_ids = buffered_token_type_ids_expanded
420
+ else:
421
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
 
 
 
 
 
 
 
 
422
 
423
+ embedding_output = self.embeddings(
424
+ input_ids=input_ids,
425
+ position_ids=position_ids,
426
+ token_type_ids=token_type_ids,
427
+ inputs_embeds=inputs_embeds,
428
+ past_key_values_length=past_key_values_length,
429
  )
430
 
431
+ if attention_mask is None:
432
+ attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=device)
433
 
434
+ use_sdpa_attention_masks = (
435
+ self.attn_implementation == "sdpa"
436
+ and self.position_embedding_type == "absolute"
437
+ and head_mask is None
438
+ and not output_attentions
 
 
 
 
 
 
 
439
  )
 
 
 
 
 
440
 
441
+ # Expand the attention mask
442
+ if use_sdpa_attention_masks and attention_mask.dim() == 2:
443
+ # Expand the attention mask for SDPA.
444
+ # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
445
+ if self.config.is_decoder:
446
+ extended_attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
447
+ attention_mask,
448
+ input_shape,
449
+ embedding_output,
450
+ past_key_values_length,
451
+ )
452
+ else:
453
+ extended_attention_mask = _prepare_4d_attention_mask_for_sdpa(
454
+ attention_mask, embedding_output.dtype, tgt_len=seq_length
455
+ )
456
  else:
457
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
458
+ # ourselves in which case we just need to make it broadcastable to all heads.
459
+ extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
460
+
461
+ # If a 2D or 3D attention mask is provided for the cross-attention
462
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
463
+ if self.config.is_decoder and encoder_hidden_states is not None:
464
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
465
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
466
+ if encoder_attention_mask is None:
467
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
468
+
469
+ if use_sdpa_attention_masks and encoder_attention_mask.dim() == 2:
470
+ # Expand the attention mask for SDPA.
471
+ # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
472
+ encoder_extended_attention_mask = _prepare_4d_attention_mask_for_sdpa(
473
+ encoder_attention_mask, embedding_output.dtype, tgt_len=seq_length
474
+ )
475
+ else:
476
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
477
+ else:
478
+ encoder_extended_attention_mask = None
479
+
480
+ # Prepare head mask if needed
481
+ # 1.0 in head_mask indicate we keep the head
482
+ # attention_probs has shape bsz x n_heads x N x N
483
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
484
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
485
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
486
+
487
+ encoder_outputs = self.encoder(
488
+ embedding_output,
489
+ attention_mask=extended_attention_mask,
 
 
490
  head_mask=head_mask,
491
+ encoder_hidden_states=encoder_hidden_states,
492
+ encoder_attention_mask=encoder_extended_attention_mask,
493
+ past_key_values=past_key_values,
494
+ use_cache=use_cache,
495
  output_attentions=output_attentions,
496
  output_hidden_states=output_hidden_states,
497
  return_dict=return_dict,
498
  )
499
+ sequence_output = encoder_outputs[0]
500
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
501
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
502
  if not return_dict:
503
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
504
+
505
+ return BaseModelOutputWithPoolingAndCrossAttentions(
506
+ last_hidden_state=sequence_output,
507
+ pooler_output=pooled_output,
508
+ past_key_values=encoder_outputs.past_key_values,
509
+ hidden_states=encoder_outputs.hidden_states,
510
+ attentions=encoder_outputs.attentions,
511
+ cross_attentions=encoder_outputs.cross_attentions,
512
  )
513
 
514
 
 
524
  self.num_labels = config.num_labels
525
  self.config = config
526
 
527
+ self.bert = BertModel(config, add_pooling_layer=False)
528
  classifier_dropout = (
529
  config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
530
  )
 
669
  super().__init__(config)
670
  self.label_smoothing = label_smoothing
671
 
672
+ self.bert = BertModel(config)
673
  self.pure = PURE(
674
  in_dim=config.hidden_size,
675
  svd_rank=config.svd_rank,
 
771
  self.num_labels = config.num_labels
772
  self.label_smoothing = label_smoothing
773
 
774
+ self.bert = BertModel(config, add_pooling_layer=False)
775
  self.pure = PURE(
776
  in_dim=config.hidden_size,
777
  svd_rank=config.svd_rank,