emanuelaboros's picture
Create modelling_nar.py
5101b13 verified
raw
history blame
7.3 kB
from transformers.modeling_outputs import TokenClassifierOutput
import torch
import torch.nn as nn
from transformers import PreTrainedModel, AutoModel, AutoConfig, BertConfig
from torch.nn import CrossEntropyLoss
from typing import Optional, Tuple, Union
import logging, json, os
logger = logging.getLogger(__name__)
def get_info(label_map):
num_token_labels_dict = {task: len(labels) for task, labels in label_map.items()}
return num_token_labels_dict
class ModelForSequenceAndTokenClassification(PreTrainedModel):
def __init__(self, config, num_sequence_labels, num_token_labels, do_classif=False):
super().__init__(config)
self.num_token_labels = num_token_labels
self.num_sequence_labels = num_sequence_labels
self.config = config
self.do_classif = do_classif
self.bert = AutoModel.from_config(config)
classifier_dropout = (
config.classifier_dropout
if config.classifier_dropout is not None
else config.hidden_dropout_prob
)
self.dropout = nn.Dropout(classifier_dropout)
# For token classification
self.token_classifier = nn.Linear(config.hidden_size, self.num_token_labels)
if do_classif:
# For the entire sequence classification
self.sequence_classifier = nn.Linear(
config.hidden_size, self.num_sequence_labels
)
# Initialize weights and apply final processing
self.post_init()
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = AutoConfig
_keys_to_ignore_on_load_missing = [r"position_ids"]
def do_classif(self):
return self.do_classif
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
token_labels: Optional[torch.Tensor] = None,
sequence_labels: Optional[torch.Tensor] = None,
offset_mapping: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[
Union[Tuple[torch.Tensor], SequenceClassifierOutput],
Union[Tuple[torch.Tensor], TokenClassifierOutput],
]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# For token classification
token_output = outputs[0]
token_output = self.dropout(token_output)
token_logits = self.token_classifier(token_output)
if self.do_classif:
# For the entire sequence classification
pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output)
sequence_logits = self.sequence_classifier(pooled_output)
# Computing the loss as the average of both losses
loss = None
if token_labels is not None:
loss_fct = CrossEntropyLoss()
# import pdb;pdb.set_trace()
loss_tokens = loss_fct(
token_logits.view(-1, self.num_token_labels), token_labels.view(-1)
)
if self.do_classif:
if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_sequence_labels == 1:
loss_sequence = loss_fct(
sequence_logits.squeeze(), sequence_labels.squeeze()
)
else:
loss_sequence = loss_fct(sequence_logits, sequence_labels)
if self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss_sequence = loss_fct(
sequence_logits.view(-1, self.num_sequence_labels),
sequence_labels.view(-1),
)
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss_sequence = loss_fct(sequence_logits, sequence_labels)
loss = loss_tokens + loss_sequence
else:
loss = loss_tokens
if not return_dict:
if self.do_classif:
output = (
sequence_logits,
token_logits,
) + outputs[2:]
return ((loss,) + output) if loss is not None else output
else:
output = (token_logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
if self.do_classif:
return SequenceClassifierOutput(
loss=loss,
logits=sequence_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
), TokenClassifierOutput(
loss=loss,
logits=token_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
else:
return TokenClassifierOutput(
loss=loss,
logits=token_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)