emanuelaboros commited on
Commit
5101b13
·
verified ·
1 Parent(s): 1c2f458

Create modelling_nar.py

Browse files
Files changed (1) hide show
  1. modelling_nar.py +187 -0
modelling_nar.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.modeling_outputs import TokenClassifierOutput
2
+ import torch
3
+ import torch.nn as nn
4
+ from transformers import PreTrainedModel, AutoModel, AutoConfig, BertConfig
5
+ from torch.nn import CrossEntropyLoss
6
+ from typing import Optional, Tuple, Union
7
+ import logging, json, os
8
+
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ def get_info(label_map):
14
+ num_token_labels_dict = {task: len(labels) for task, labels in label_map.items()}
15
+ return num_token_labels_dict
16
+
17
+
18
+ class ModelForSequenceAndTokenClassification(PreTrainedModel):
19
+ def __init__(self, config, num_sequence_labels, num_token_labels, do_classif=False):
20
+ super().__init__(config)
21
+ self.num_token_labels = num_token_labels
22
+ self.num_sequence_labels = num_sequence_labels
23
+ self.config = config
24
+ self.do_classif = do_classif
25
+
26
+ self.bert = AutoModel.from_config(config)
27
+ classifier_dropout = (
28
+ config.classifier_dropout
29
+ if config.classifier_dropout is not None
30
+ else config.hidden_dropout_prob
31
+ )
32
+ self.dropout = nn.Dropout(classifier_dropout)
33
+
34
+ # For token classification
35
+ self.token_classifier = nn.Linear(config.hidden_size, self.num_token_labels)
36
+
37
+ if do_classif:
38
+ # For the entire sequence classification
39
+ self.sequence_classifier = nn.Linear(
40
+ config.hidden_size, self.num_sequence_labels
41
+ )
42
+
43
+ # Initialize weights and apply final processing
44
+ self.post_init()
45
+
46
+ """
47
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
48
+ models.
49
+ """
50
+
51
+ config_class = AutoConfig
52
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
53
+
54
+ def do_classif(self):
55
+ return self.do_classif
56
+
57
+ def _init_weights(self, module):
58
+ """Initialize the weights"""
59
+ if isinstance(module, nn.Linear):
60
+ # Slightly different from the TF version which uses truncated_normal for initialization
61
+ # cf https://github.com/pytorch/pytorch/pull/5617
62
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
63
+ if module.bias is not None:
64
+ module.bias.data.zero_()
65
+ elif isinstance(module, nn.Embedding):
66
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
67
+ if module.padding_idx is not None:
68
+ module.weight.data[module.padding_idx].zero_()
69
+ elif isinstance(module, nn.LayerNorm):
70
+ module.bias.data.zero_()
71
+ module.weight.data.fill_(1.0)
72
+
73
+ def forward(
74
+ self,
75
+ input_ids: Optional[torch.Tensor] = None,
76
+ attention_mask: Optional[torch.Tensor] = None,
77
+ token_type_ids: Optional[torch.Tensor] = None,
78
+ position_ids: Optional[torch.Tensor] = None,
79
+ head_mask: Optional[torch.Tensor] = None,
80
+ inputs_embeds: Optional[torch.Tensor] = None,
81
+ token_labels: Optional[torch.Tensor] = None,
82
+ sequence_labels: Optional[torch.Tensor] = None,
83
+ offset_mapping: Optional[torch.Tensor] = None,
84
+ output_attentions: Optional[bool] = None,
85
+ output_hidden_states: Optional[bool] = None,
86
+ return_dict: Optional[bool] = None,
87
+ ) -> Union[
88
+ Union[Tuple[torch.Tensor], SequenceClassifierOutput],
89
+ Union[Tuple[torch.Tensor], TokenClassifierOutput],
90
+ ]:
91
+ r"""
92
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
93
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
94
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
95
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
96
+ """
97
+ return_dict = (
98
+ return_dict if return_dict is not None else self.config.use_return_dict
99
+ )
100
+
101
+ outputs = self.bert(
102
+ input_ids,
103
+ attention_mask=attention_mask,
104
+ token_type_ids=token_type_ids,
105
+ position_ids=position_ids,
106
+ head_mask=head_mask,
107
+ inputs_embeds=inputs_embeds,
108
+ output_attentions=output_attentions,
109
+ output_hidden_states=output_hidden_states,
110
+ return_dict=return_dict,
111
+ )
112
+
113
+ # For token classification
114
+ token_output = outputs[0]
115
+
116
+ token_output = self.dropout(token_output)
117
+ token_logits = self.token_classifier(token_output)
118
+
119
+ if self.do_classif:
120
+ # For the entire sequence classification
121
+ pooled_output = outputs[1]
122
+
123
+ pooled_output = self.dropout(pooled_output)
124
+ sequence_logits = self.sequence_classifier(pooled_output)
125
+
126
+ # Computing the loss as the average of both losses
127
+ loss = None
128
+ if token_labels is not None:
129
+ loss_fct = CrossEntropyLoss()
130
+ # import pdb;pdb.set_trace()
131
+ loss_tokens = loss_fct(
132
+ token_logits.view(-1, self.num_token_labels), token_labels.view(-1)
133
+ )
134
+
135
+ if self.do_classif:
136
+ if self.config.problem_type == "regression":
137
+ loss_fct = MSELoss()
138
+ if self.num_sequence_labels == 1:
139
+ loss_sequence = loss_fct(
140
+ sequence_logits.squeeze(), sequence_labels.squeeze()
141
+ )
142
+ else:
143
+ loss_sequence = loss_fct(sequence_logits, sequence_labels)
144
+ if self.config.problem_type == "single_label_classification":
145
+ loss_fct = CrossEntropyLoss()
146
+ loss_sequence = loss_fct(
147
+ sequence_logits.view(-1, self.num_sequence_labels),
148
+ sequence_labels.view(-1),
149
+ )
150
+ elif self.config.problem_type == "multi_label_classification":
151
+ loss_fct = BCEWithLogitsLoss()
152
+ loss_sequence = loss_fct(sequence_logits, sequence_labels)
153
+
154
+ loss = loss_tokens + loss_sequence
155
+ else:
156
+ loss = loss_tokens
157
+
158
+ if not return_dict:
159
+ if self.do_classif:
160
+ output = (
161
+ sequence_logits,
162
+ token_logits,
163
+ ) + outputs[2:]
164
+ return ((loss,) + output) if loss is not None else output
165
+ else:
166
+ output = (token_logits,) + outputs[2:]
167
+ return ((loss,) + output) if loss is not None else output
168
+
169
+ if self.do_classif:
170
+ return SequenceClassifierOutput(
171
+ loss=loss,
172
+ logits=sequence_logits,
173
+ hidden_states=outputs.hidden_states,
174
+ attentions=outputs.attentions,
175
+ ), TokenClassifierOutput(
176
+ loss=loss,
177
+ logits=token_logits,
178
+ hidden_states=outputs.hidden_states,
179
+ attentions=outputs.attentions,
180
+ )
181
+ else:
182
+ return TokenClassifierOutput(
183
+ loss=loss,
184
+ logits=token_logits,
185
+ hidden_states=outputs.hidden_states,
186
+ attentions=outputs.attentions,
187
+ )