feat: added head_mask
Browse files- modeling_bert.py +4 -0
modeling_bert.py
CHANGED
@@ -379,12 +379,16 @@ class BertModel(BertPreTrainedModel):
|
|
379 |
task_type_ids=None,
|
380 |
attention_mask=None,
|
381 |
masked_tokens_mask=None,
|
|
|
382 |
):
|
383 |
"""If masked_tokens_mask is not None (i.e. last_layer_subset == True in BertForPreTraining),
|
384 |
we only want the output for the masked tokens. This means that we only compute the last
|
385 |
layer output for these tokens.
|
386 |
masked_tokens_mask: (batch, seqlen), dtype=torch.bool
|
387 |
"""
|
|
|
|
|
|
|
388 |
hidden_states = self.embeddings(
|
389 |
input_ids, position_ids=position_ids, token_type_ids=token_type_ids
|
390 |
)
|
|
|
379 |
task_type_ids=None,
|
380 |
attention_mask=None,
|
381 |
masked_tokens_mask=None,
|
382 |
+
head_mask=None,
|
383 |
):
|
384 |
"""If masked_tokens_mask is not None (i.e. last_layer_subset == True in BertForPreTraining),
|
385 |
we only want the output for the masked tokens. This means that we only compute the last
|
386 |
layer output for these tokens.
|
387 |
masked_tokens_mask: (batch, seqlen), dtype=torch.bool
|
388 |
"""
|
389 |
+
if head_mask is not None:
|
390 |
+
raise NotImplementedError('Masking heads is not supported')
|
391 |
+
|
392 |
hidden_states = self.embeddings(
|
393 |
input_ids, position_ids=position_ids, token_type_ids=token_type_ids
|
394 |
)
|