Markus28 commited on
Commit
599c64e
·
1 Parent(s): 767b681

feat: added head_mask

Browse files
Files changed (1) hide show
  1. modeling_bert.py +4 -0
modeling_bert.py CHANGED
@@ -379,12 +379,16 @@ class BertModel(BertPreTrainedModel):
379
  task_type_ids=None,
380
  attention_mask=None,
381
  masked_tokens_mask=None,
 
382
  ):
383
  """If masked_tokens_mask is not None (i.e. last_layer_subset == True in BertForPreTraining),
384
  we only want the output for the masked tokens. This means that we only compute the last
385
  layer output for these tokens.
386
  masked_tokens_mask: (batch, seqlen), dtype=torch.bool
387
  """
 
 
 
388
  hidden_states = self.embeddings(
389
  input_ids, position_ids=position_ids, token_type_ids=token_type_ids
390
  )
 
379
  task_type_ids=None,
380
  attention_mask=None,
381
  masked_tokens_mask=None,
382
+ head_mask=None,
383
  ):
384
  """If masked_tokens_mask is not None (i.e. last_layer_subset == True in BertForPreTraining),
385
  we only want the output for the masked tokens. This means that we only compute the last
386
  layer output for these tokens.
387
  masked_tokens_mask: (batch, seqlen), dtype=torch.bool
388
  """
389
+ if head_mask is not None:
390
+ raise NotImplementedError('Masking heads is not supported')
391
+
392
  hidden_states = self.embeddings(
393
  input_ids, position_ids=position_ids, token_type_ids=token_type_ids
394
  )