Markus28 commited on
Commit
3cb3930
·
1 Parent(s): 9072f7f

feat: added separate BertForMaskedLM class

Browse files
Files changed (1) hide show
  1. modeling_bert.py +80 -0
modeling_bert.py CHANGED
@@ -689,4 +689,84 @@ class BertForPreTraining(BertPreTrainedModel):
689
  loss=total_loss,
690
  prediction_logits=prediction_scores,
691
  seq_relationship_logits=seq_relationship_score,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
692
  )
 
689
  loss=total_loss,
690
  prediction_logits=prediction_scores,
691
  seq_relationship_logits=seq_relationship_score,
692
+ )
693
+
694
+
695
+ class BertForMaskedLM(BertPreTrainedModel):
696
+ def __init__(self, config: JinaBertConfig):
697
+ super().__init__(config)
698
+ # If dense_seq_output, we only need to pass the hidden states for the masked out tokens
699
+ # (around 15%) to the classifier heads.
700
+ self.dense_seq_output = getattr(config, "dense_seq_output", False)
701
+ # If last_layer_subset, we only need the compute the last layer for a subset of tokens
702
+ # (e.g., the tokens we need to compute the masked LM loss and the next-sentence prediction).
703
+ self.last_layer_subset = getattr(config, "last_layer_subset", False)
704
+ if self.last_layer_subset:
705
+ assert self.dense_seq_output, "last_layer_subset requires dense_seq_output"
706
+ use_xentropy = getattr(config, "use_xentropy", False)
707
+ if use_xentropy and CrossEntropyLoss is None:
708
+ raise ImportError("xentropy_cuda is not installed")
709
+ loss_cls = (
710
+ nn.CrossEntropyLoss
711
+ if not use_xentropy
712
+ else partial(CrossEntropyLoss, inplace_backward=True)
713
+ )
714
+
715
+ self.bert = BertModel(config)
716
+ self.cls = BertPreTrainingHeads(config)
717
+ self.mlm_loss = loss_cls(ignore_index=0)
718
+
719
+ # Initialize weights and apply final processing
720
+ self.apply(partial(_init_weights, initializer_range=config.initializer_range))
721
+ self.tie_weights()
722
+
723
+ def tie_weights(self):
724
+ self.cls.predictions.decoder.weight = self.bert.embeddings.word_embeddings.weight
725
+
726
+ def get_input_embeddings(self):
727
+ return self.bert.embeddings.word_embeddings
728
+
729
+ def forward(
730
+ self,
731
+ input_ids,
732
+ position_ids=None,
733
+ token_type_ids=None,
734
+ attention_mask=None,
735
+ labels=None
736
+ ):
737
+ masked_tokens_mask = labels > 0 if (self.last_layer_subset and labels is not None) else None
738
+ outputs = self.bert(
739
+ input_ids,
740
+ position_ids=position_ids,
741
+ token_type_ids=token_type_ids,
742
+ attention_mask=attention_mask.bool() if attention_mask is not None else None,
743
+ masked_tokens_mask=masked_tokens_mask,
744
+ )
745
+ sequence_output, pooled_output = outputs.last_hidden_state, outputs.pooler_output
746
+ if self.dense_seq_output and labels is not None:
747
+ masked_token_idx = torch.nonzero(labels.flatten() > 0, as_tuple=False).flatten()
748
+ if not self.last_layer_subset:
749
+ sequence_output = index_first_axis(
750
+ rearrange(sequence_output, "b s d -> (b s) d"), masked_token_idx
751
+ )
752
+ prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
753
+
754
+ if (
755
+ self.dense_seq_output and labels is not None
756
+ ): # prediction_scores are already flattened
757
+ masked_lm_loss = self.mlm_loss(
758
+ prediction_scores, labels.flatten()[masked_token_idx]
759
+ ).float()
760
+
761
+ assert labels is not None
762
+
763
+ masked_lm_loss = self.mlm_loss(
764
+ rearrange(prediction_scores, "... v -> (...) v"),
765
+ rearrange(labels, "... -> (...)"),
766
+ ).float()
767
+
768
+ return BertForPreTrainingOutput(
769
+ loss=masked_lm_loss,
770
+ prediction_logits=prediction_scores,
771
+ seq_relationship_logits=seq_relationship_score,
772
  )