imvladikon commited on
Commit
2bbd253
·
1 Parent(s): 61e4faf

Create modeling_enc_t5.py

Browse files
Files changed (1) hide show
  1. modeling_enc_t5.py +228 -0
modeling_enc_t5.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from typing import Any, Dict, List, Optional
3
+
4
+ import torch
5
+ from torch import nn
6
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
7
+ from transformers import T5TokenizerFast
8
+ from transformers.modeling_outputs import SequenceClassifierOutput, TokenClassifierOutput
9
+ from transformers.models.t5.modeling_t5 import T5Config, T5PreTrainedModel, T5Stack
10
+ from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
11
+
12
+
13
+ class EncT5Tokenizer(T5TokenizerFast):
14
+ def __init__(
15
+ self,
16
+ vocab_file,
17
+ bos_token="<s>",
18
+ eos_token="</s>",
19
+ unk_token="<unk>",
20
+ pad_token="<pad>",
21
+ extra_ids=100,
22
+ additional_special_tokens=None,
23
+ sp_model_kwargs: Optional[Dict[str, Any]] = None,
24
+ **kwargs,
25
+ ) -> None:
26
+ sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
27
+
28
+ super().__init__(
29
+ vocab_file=vocab_file,
30
+ bos_token=bos_token,
31
+ eos_token=eos_token,
32
+ unk_token=unk_token,
33
+ pad_token=pad_token,
34
+ extra_ids=extra_ids,
35
+ additional_special_tokens=additional_special_tokens,
36
+ sp_model_kwargs=sp_model_kwargs,
37
+ **kwargs,
38
+ )
39
+
40
+ def get_special_tokens_mask(
41
+ self,
42
+ token_ids_0: List[int],
43
+ token_ids_1: Optional[List[int]] = None,
44
+ already_has_special_tokens: bool = False,
45
+ ) -> List[int]:
46
+ """
47
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
48
+ special tokens using the tokenizer `prepare_for_model` method.
49
+ Args:
50
+ token_ids_0 (`List[int]`):
51
+ List of IDs.
52
+ token_ids_1 (`List[int]`, *optional*):
53
+ Optional second list of IDs for sequence pairs.
54
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
55
+ Whether or not the token list is already formatted with special tokens for the model.
56
+ Returns:
57
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
58
+ """
59
+ if already_has_special_tokens:
60
+ return super().get_special_tokens_mask(
61
+ token_ids_0=token_ids_0,
62
+ token_ids_1=token_ids_1,
63
+ already_has_special_tokens=True,
64
+ )
65
+
66
+ # normal case: some special tokens
67
+ if token_ids_1 is None:
68
+ return [1] + ([0] * len(token_ids_0)) + [1]
69
+ return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
70
+
71
+ def create_token_type_ids_from_sequences(
72
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
73
+ ) -> List[int]:
74
+ """
75
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make
76
+ use of token type ids, therefore a list of zeros is returned.
77
+ Args:
78
+ token_ids_0 (`List[int]`):
79
+ List of IDs.
80
+ token_ids_1 (`List[int]`, *optional*):
81
+ Optional second list of IDs for sequence pairs.
82
+ Returns:
83
+ `List[int]`: List of zeros.
84
+ """
85
+ bos = [self.bos_token_id]
86
+ eos = [self.eos_token_id]
87
+
88
+ if token_ids_1 is None:
89
+ return len(bos + token_ids_0 + eos) * [0]
90
+ return len(bos + token_ids_0 + eos + token_ids_1 + eos) * [0]
91
+
92
+ def build_inputs_with_special_tokens(
93
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
94
+ ) -> List[int]:
95
+ """
96
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
97
+ adding special tokens. A sequence has the following format:
98
+ - single sequence: `<s> X </s>`
99
+ - pair of sequences: `<s> A </s> B </s>`
100
+ Args:
101
+ token_ids_0 (`List[int]`):
102
+ List of IDs to which the special tokens will be added.
103
+ token_ids_1 (`List[int]`, *optional*):
104
+ Optional second list of IDs for sequence pairs.
105
+ Returns:
106
+ `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
107
+ """
108
+ if token_ids_1 is None:
109
+ return [self.bos_token_id] + token_ids_0 + [self.eos_token_id]
110
+ else:
111
+ return (
112
+ [self.bos_token_id]
113
+ + token_ids_0
114
+ + [self.eos_token_id]
115
+ + token_ids_1
116
+ + [self.eos_token_id]
117
+ )
118
+
119
+
120
+ class EncT5ForTokenClassification(T5PreTrainedModel):
121
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
122
+
123
+ def __init__(self, config: T5Config, dropout=0.1):
124
+ super().__init__(config)
125
+ self.num_labels = config.num_labels
126
+ self.config = config
127
+
128
+ self.shared = nn.Embedding(config.vocab_size, config.d_model)
129
+
130
+ encoder_config = copy.deepcopy(config)
131
+ encoder_config.use_cache = False
132
+ encoder_config.is_encoder_decoder = False
133
+ self.encoder = T5Stack(encoder_config, self.shared)
134
+
135
+ self.dropout = nn.Dropout(dropout)
136
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
137
+
138
+ # Initialize weights and apply final processing
139
+ self.post_init()
140
+
141
+ # Model parallel
142
+ self.model_parallel = False
143
+ self.device_map = None
144
+
145
+ def parallelize(self, device_map=None):
146
+ self.device_map = (
147
+ get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
148
+ if device_map is None
149
+ else device_map
150
+ )
151
+ assert_device_map(self.device_map, len(self.encoder.block))
152
+ self.encoder.parallelize(self.device_map)
153
+ self.classifier = self.classifier.to(self.encoder.first_device)
154
+ self.model_parallel = True
155
+
156
+ def deparallelize(self):
157
+ self.encoder.deparallelize()
158
+ self.encoder = self.encoder.to("cpu")
159
+ self.model_parallel = False
160
+ self.device_map = None
161
+ torch.cuda.empty_cache()
162
+
163
+ def get_input_embeddings(self):
164
+ return self.shared
165
+
166
+ def set_input_embeddings(self, new_embeddings):
167
+ self.shared = new_embeddings
168
+ self.encoder.set_input_embeddings(new_embeddings)
169
+
170
+ def get_encoder(self):
171
+ return self.encoder
172
+
173
+ def _prune_heads(self, heads_to_prune):
174
+ """
175
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
176
+ class PreTrainedModel
177
+ """
178
+ for layer, heads in heads_to_prune.items():
179
+ self.encoder.layer[layer].attention.prune_heads(heads)
180
+
181
+ def forward(
182
+ self,
183
+ input_ids=None,
184
+ attention_mask=None,
185
+ head_mask=None,
186
+ inputs_embeds=None,
187
+ labels=None,
188
+ output_attentions=None,
189
+ output_hidden_states=None,
190
+ return_dict=None,
191
+ ):
192
+ return_dict = (
193
+ return_dict if return_dict is not None else self.config.use_return_dict
194
+ )
195
+
196
+ outputs = self.encoder(
197
+ input_ids=input_ids,
198
+ attention_mask=attention_mask,
199
+ inputs_embeds=inputs_embeds,
200
+ head_mask=head_mask,
201
+ output_attentions=output_attentions,
202
+ output_hidden_states=output_hidden_states,
203
+ return_dict=return_dict,
204
+ )
205
+
206
+ sequence_output = outputs[0]
207
+ sequence_output = self.dropout(sequence_output)
208
+ logits = self.classifier(sequence_output)
209
+
210
+ loss = None
211
+ if labels is not None:
212
+ loss_fct = CrossEntropyLoss()
213
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
214
+
215
+ if not return_dict:
216
+ output = (logits,) + outputs[2:]
217
+ return ((loss,) + output) if loss is not None else output
218
+
219
+ return TokenClassifierOutput(
220
+ loss=loss,
221
+ logits=logits,
222
+ hidden_states=outputs.hidden_states,
223
+ attentions=outputs.attentions,
224
+ )
225
+
226
+
227
+ EncT5Tokenizer.register_for_auto_class("AutoTokenizer")
228
+ EncT5ForTokenClassification.register_for_auto_class("AutoModelForTokenClassification")