Shaltiel commited on
Commit
2618264
1 Parent(s): b84d9be

Upload 4 files

Browse files
BertForJointParsing.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ import math
3
+ from operator import itemgetter
4
+ import torch
5
+ from torch import nn
6
+ from typing import Any, Dict, List, Literal, Optional, Tuple, Union
7
+ from transformers import BertPreTrainedModel, BertModel, BertTokenizerFast
8
+ from transformers.models.bert.modeling_bert import BertOnlyMLMHead
9
+ from transformers.utils import ModelOutput
10
+ try:
11
+ from .BertForSyntaxParsing import BertSyntaxParsingHead, SyntaxLabels, SyntaxLogitsOutput, parse_logits as syntax_parse_logits
12
+ from .BertForPrefixMarking import BertPrefixMarkingHead, parse_logits as prefix_parse_logits, encode_sentences_for_bert_for_prefix_marking
13
+ from .BertForMorphTagging import BertMorphTaggingHead, MorphLogitsOutput, MorphLabels, parse_logits as morph_parse_logits
14
+ except ImportError:
15
+ from BertForSyntaxParsing import BertSyntaxParsingHead, SyntaxLabels, SyntaxLogitsOutput, parse_logits as syntax_parse_logits
16
+ from BertForPrefixMarking import BertPrefixMarkingHead, parse_logits as prefix_parse_logits, encode_sentences_for_bert_for_prefix_marking
17
+ from BertForMorphTagging import BertMorphTaggingHead, MorphLogitsOutput, MorphLabels, parse_logits as morph_parse_logits
18
+
19
+ import warnings
20
+
21
+ @dataclass
22
+ class JointParsingOutput(ModelOutput):
23
+ loss: Optional[torch.FloatTensor] = None
24
+ # logits will contain the optional predictions for the given labels
25
+ logits: Optional[Union[SyntaxLogitsOutput, None]] = None
26
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
27
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
28
+ # if no labels are given, we will always include the syntax logits separately
29
+ syntax_logits: Optional[SyntaxLogitsOutput] = None
30
+ ner_logits: Optional[torch.FloatTensor] = None
31
+ prefix_logits: Optional[torch.FloatTensor] = None
32
+ lex_logits: Optional[torch.FloatTensor] = None
33
+ morph_logits: Optional[MorphLogitsOutput] = None
34
+
35
+ # wrapper class to wrap a torch.nn.Module so that you can store a module in multiple linked
36
+ # properties without registering the parameter multiple times
37
+ class ModuleRef:
38
+ def __init__(self, module: torch.nn.Module):
39
+ self.module = module
40
+
41
+ def forward(self, *args, **kwargs):
42
+ return self.module.forward(*args, **kwargs)
43
+
44
+ def __call__(self, *args, **kwargs):
45
+ return self.module(*args, **kwargs)
46
+
47
+ class BertForJointParsing(BertPreTrainedModel):
48
+ _tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"]
49
+
50
+ def __init__(self, config, do_syntax=None, do_ner=None, do_prefix=None, do_lex=None, do_morph=None, syntax_head_size=64):
51
+ super().__init__(config)
52
+
53
+ self.bert = BertModel(config, add_pooling_layer=False)
54
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
55
+ # create all the heads as None, and then populate them as defined
56
+ self.syntax, self.ner, self.prefix, self.lex, self.morph = (None,)*5
57
+
58
+ if do_syntax is not None:
59
+ config.do_syntax = do_syntax
60
+ config.syntax_head_size = syntax_head_size
61
+ if do_ner is not None: config.do_ner = do_ner
62
+ if do_prefix is not None: config.do_prefix = do_prefix
63
+ if do_lex is not None: config.do_lex = do_lex
64
+ if do_morph is not None: config.do_morph = do_morph
65
+
66
+ # add all the individual heads
67
+ if config.do_syntax:
68
+ self.syntax = BertSyntaxParsingHead(config)
69
+ if config.do_ner:
70
+ self.num_labels = config.num_labels
71
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels) # name it same as in BertForTokenClassification
72
+ self.ner = ModuleRef(self.classifier)
73
+ if config.do_prefix:
74
+ self.prefix = BertPrefixMarkingHead(config)
75
+ if config.do_lex:
76
+ self.cls = BertOnlyMLMHead(config) # name it the same as in BertForMaskedLM
77
+ self.lex = ModuleRef(self.cls)
78
+ if config.do_morph:
79
+ self.morph = BertMorphTaggingHead(config)
80
+
81
+ # Initialize weights and apply final processing
82
+ self.post_init()
83
+
84
+ def get_output_embeddings(self):
85
+ return self.cls.predictions.decoder if self.lex is not None else None
86
+
87
+ def set_output_embeddings(self, new_embeddings):
88
+ if self.lex is not None:
89
+ self.cls.predictions.decoder = new_embeddings
90
+
91
+ def forward(
92
+ self,
93
+ input_ids: Optional[torch.Tensor] = None,
94
+ attention_mask: Optional[torch.Tensor] = None,
95
+ token_type_ids: Optional[torch.Tensor] = None,
96
+ position_ids: Optional[torch.Tensor] = None,
97
+ prefix_class_id_options: Optional[torch.Tensor] = None,
98
+ labels: Optional[Union[SyntaxLabels, MorphLabels, torch.Tensor]] = None,
99
+ labels_type: Optional[Literal['syntax', 'ner', 'prefix', 'lex', 'morph']] = None,
100
+ head_mask: Optional[torch.Tensor] = None,
101
+ inputs_embeds: Optional[torch.Tensor] = None,
102
+ output_attentions: Optional[bool] = None,
103
+ output_hidden_states: Optional[bool] = None,
104
+ return_dict: Optional[bool] = None,
105
+ compute_syntax_mst: Optional[bool] = None
106
+ ):
107
+ if return_dict is False:
108
+ warnings.warn("Specified `return_dict=False` but the flag is ignored and treated as always True in this model.")
109
+
110
+ if labels is not None and labels_type is None:
111
+ raise ValueError("Cannot specify labels without labels_type")
112
+
113
+ if labels_type == 'seg' and prefix_class_id_options is None:
114
+ raise ValueError('Cannot calculate prefix logits without prefix_class_id_options')
115
+
116
+ if compute_syntax_mst is not None and self.syntax is None:
117
+ raise ValueError("Cannot compute syntax MST when the syntax head isn't loaded")
118
+
119
+
120
+ bert_outputs = self.bert(
121
+ input_ids,
122
+ attention_mask=attention_mask,
123
+ token_type_ids=token_type_ids,
124
+ position_ids=position_ids,
125
+ head_mask=head_mask,
126
+ inputs_embeds=inputs_embeds,
127
+ output_attentions=output_attentions,
128
+ output_hidden_states=output_hidden_states,
129
+ return_dict=True,
130
+ )
131
+
132
+ # calculate the extended attention mask for any child that might need it
133
+ extended_attention_mask = None
134
+ if attention_mask is not None:
135
+ extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_ids.size())
136
+
137
+ # extract the hidden states, and apply the dropout
138
+ hidden_states = self.dropout(bert_outputs[0])
139
+
140
+ logits = None
141
+ syntax_logits = None
142
+ ner_logits = None
143
+ prefix_logits = None
144
+ lex_logits = None
145
+ morph_logits = None
146
+
147
+ # Calculate the syntax
148
+ if self.syntax is not None and (labels is None or labels_type == 'syntax'):
149
+ # apply the syntax head
150
+ loss, syntax_logits = self.syntax(hidden_states, extended_attention_mask, labels, compute_syntax_mst)
151
+ logits = syntax_logits
152
+
153
+ # Calculate the NER
154
+ if self.ner is not None and (labels is None or labels_type == 'ner'):
155
+ ner_logits = self.ner(hidden_states)
156
+ logits = ner_logits
157
+ if labels is not None:
158
+ loss_fct = nn.CrossEntropyLoss()
159
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
160
+
161
+ # Calculate the segmentation
162
+ if self.prefix is not None and (labels is None or labels_type == 'prefix'):
163
+ loss, prefix_logits = self.prefix(hidden_states, prefix_class_id_options, labels)
164
+ logits = prefix_logits
165
+
166
+ # Calculate the lexeme
167
+ if self.lex is not None and (labels is None or labels_type == 'lex'):
168
+ lex_logits = self.lex(hidden_states)
169
+ logits = lex_logits
170
+ if labels is not None:
171
+ loss_fct = nn.CrossEntropyLoss() # -100 index = padding token
172
+ loss = loss_fct(lex_logits.view(-1, self.config.vocab_size), labels.view(-1))
173
+
174
+ if self.morph is not None and (labels is None or labels_type == 'morph'):
175
+ loss, morph_logits = self.morph(hidden_states, labels)
176
+ logits = morph_logits
177
+
178
+ # no labels => logits = None
179
+ if labels is None: logits = None
180
+
181
+ return JointParsingOutput(
182
+ loss,
183
+ logits,
184
+ hidden_states=bert_outputs.hidden_states,
185
+ attentions=bert_outputs.attentions,
186
+ # all the predicted logits section
187
+ syntax_logits=syntax_logits,
188
+ ner_logits=ner_logits,
189
+ prefix_logits=prefix_logits,
190
+ lex_logits=lex_logits,
191
+ morph_logits=morph_logits
192
+ )
193
+
194
+ def predict(self, sentences: Union[str, List[str]], tokenizer: BertTokenizerFast, padding='longest', truncation=True, compute_syntax_mst=True, per_token_ner=False):
195
+ is_single_sentence = isinstance(sentences, str)
196
+ if is_single_sentence:
197
+ sentences = [sentences]
198
+
199
+ # predict the logits for the sentence
200
+ if self.prefix is not None:
201
+ inputs = encode_sentences_for_bert_for_prefix_marking(tokenizer, sentences, padding)
202
+ else:
203
+ inputs = tokenizer(sentences, padding=padding, truncation=truncation, return_tensors='pt')
204
+
205
+ # Copy the tensors to the right device, and parse!
206
+ inputs = {k:v.to(self.device) for k,v in inputs.items()}
207
+ output = self.forward(**inputs, return_dict=True, compute_syntax_mst=compute_syntax_mst)
208
+
209
+ final_output = [dict(text=sentence, tokens=[dict(token=t) for t in combine_token_wordpieces(ids, tokenizer)]) for sentence, ids in zip(sentences, inputs['input_ids'])]
210
+ # Syntax logits: each sentence gets a dict(tree: List[dict(word,dep_head,dep_head_idx,dep_func)], root_idx: int)
211
+ if output.syntax_logits is not None:
212
+ for sent_idx,parsed in enumerate(syntax_parse_logits(inputs, sentences, tokenizer, output.syntax_logits)):
213
+ merge_token_list(final_output[sent_idx]['tokens'], parsed['tree'], 'syntax')
214
+ final_output[sent_idx]['root_idx'] = parsed['root_idx']
215
+
216
+ # Prefix logits: each sentence gets a list([prefix_segment, word_without_prefix]) - **WITH CLS & SEP**
217
+ if output.prefix_logits is not None:
218
+ for sent_idx,parsed in enumerate(prefix_parse_logits(inputs, sentences, tokenizer, output.prefix_logits)):
219
+ merge_token_list(final_output[sent_idx]['tokens'], map(tuple, parsed[1:-1]), 'seg')
220
+
221
+ # Lex logits each sentence gets a list(tuple(word, lexeme))
222
+ if output.lex_logits is not None:
223
+ for sent_idx, parsed in enumerate(lex_parse_logits(inputs, sentences, tokenizer, output.lex_logits)):
224
+ merge_token_list(final_output[sent_idx]['tokens'], map(itemgetter(1), parsed), 'lex')
225
+
226
+ # morph logits each sentences get a dict(text=str, tokens=list(dict(token, pos, feats, prefixes, suffix, suffix_feats?)))
227
+ if output.morph_logits is not None:
228
+ for sent_idx,parsed in enumerate(morph_parse_logits(inputs, sentences, tokenizer, output.morph_logits)):
229
+ merge_token_list(final_output[sent_idx]['tokens'], parsed['tokens'], 'morph')
230
+
231
+ # NER logits each sentence gets a list(tuple(word, ner))
232
+ if output.ner_logits is not None:
233
+ for sent_idx,parsed in enumerate(ner_parse_logits(inputs, sentences, tokenizer, output.ner_logits, self.config.id2label)):
234
+ if per_token_ner:
235
+ merge_token_list(final_output[sent_idx]['tokens'], map(itemgetter(1), parsed), 'ner')
236
+ final_output[sent_idx]['ner_entities'] = aggregate_ner_tokens(parsed)
237
+
238
+ if is_single_sentence:
239
+ final_output = final_output[0]
240
+ return final_output
241
+
242
+ def aggregate_ner_tokens(predictions):
243
+ entities = []
244
+ prev = None
245
+ for word,pred in predictions:
246
+ # O does nothing
247
+ if pred == 'O': prev = None
248
+ # B- || I-entity != prev (different entity or none)
249
+ elif pred.startswith('B-') or pred[2:] != prev:
250
+ prev = pred[2:]
251
+ entities.append(([word], prev))
252
+ else: entities[-1][0].append(word)
253
+
254
+ return [dict(phrase=' '.join(words), label=label) for words,label in entities]
255
+
256
+
257
+ def merge_token_list(src, update, key):
258
+ for token_src, token_update in zip(src, update):
259
+ token_src[key] = token_update
260
+
261
+ def combine_token_wordpieces(input_ids: torch.Tensor, tokenizer: BertTokenizerFast):
262
+ ret = []
263
+ for token in tokenizer.convert_ids_to_tokens(input_ids):
264
+ if token in [tokenizer.cls_token, tokenizer.sep_token, tokenizer.pad_token]: continue
265
+ if token.startswith('##'):
266
+ ret[-1] += token[2:]
267
+ else: ret.append(token)
268
+ return ret
269
+
270
+ def ner_parse_logits(inputs: Dict[str, torch.Tensor], sentences: List[str], tokenizer: BertTokenizerFast, logits: torch.Tensor, id2label: Dict[int, str]):
271
+ input_ids = inputs['input_ids']
272
+
273
+ predictions = torch.argmax(logits, dim=-1)
274
+ batch_ret = []
275
+ for batch_idx in range(len(sentences)):
276
+ ret = []
277
+ batch_ret.append(ret)
278
+ for tok_idx in range(input_ids.shape[1]):
279
+ token_id = input_ids[batch_idx, tok_idx]
280
+ # ignore cls, sep, pad
281
+ if token_id in [tokenizer.cls_token_id, tokenizer.sep_token_id, tokenizer.pad_token_id]: continue
282
+
283
+ token = tokenizer._convert_id_to_token(token_id)
284
+ # wordpieces should just be appended to the previous word
285
+ if token.startswith('##'):
286
+ ret[-1] = (ret[-1][0] + token[2:], ret[-1][1])
287
+ continue
288
+ ret.append((token, id2label[predictions[batch_idx, tok_idx].item()]))
289
+ return batch_ret
290
+
291
+ def lex_parse_logits(inputs: Dict[str, torch.Tensor], sentences: List[str], tokenizer: BertTokenizerFast, logits: torch.Tensor):
292
+ input_ids = inputs['input_ids']
293
+
294
+ predictions = torch.argmax(logits, dim=-1)
295
+ batch_ret = []
296
+ for batch_idx in range(len(sentences)):
297
+ ret = []
298
+ batch_ret.append(ret)
299
+ for tok_idx in range(input_ids.shape[1]):
300
+ token_id = input_ids[batch_idx, tok_idx]
301
+ # ignore cls, sep, pad
302
+ if token_id in [tokenizer.cls_token_id, tokenizer.sep_token_id, tokenizer.pad_token_id]: continue
303
+
304
+ token = tokenizer._convert_id_to_token(token_id)
305
+ # wordpieces should just be appended to the previous word
306
+ if token.startswith('##'):
307
+ ret[-1] = (ret[-1][0] + token[2:], ret[-1][1])
308
+ continue
309
+ ret.append((token, tokenizer._convert_id_to_token(predictions[batch_idx, tok_idx])))
310
+ return batch_ret
BertForMorphTagging.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ from operator import itemgetter
3
+ from transformers.utils import ModelOutput
4
+ import torch
5
+ from torch import nn
6
+ from typing import Dict, List, Tuple, Optional
7
+ from dataclasses import dataclass
8
+ from transformers import BertPreTrainedModel, BertModel, BertTokenizerFast
9
+
10
+ ALL_POS = ['DET', 'NOUN', 'VERB', 'CCONJ', 'ADP', 'PRON', 'PUNCT', 'ADJ', 'ADV', 'SCONJ', 'NUM', 'PROPN', 'AUX', 'X', 'INTJ', 'SYM']
11
+ ALL_PREFIX_POS = ['SCONJ', 'DET', 'ADV', 'CCONJ', 'ADP', 'NUM']
12
+ ALL_SUFFIX_POS = ['none', 'ADP_PRON', 'PRON']
13
+ ALL_FEATURES = [
14
+ ('Gender', ['none', 'Masc', 'Fem', 'Fem,Masc']),
15
+ ('Number', ['none', 'Sing', 'Plur', 'Plur,Sing', 'Dual', 'Dual,Plur']),
16
+ ('Person', ['none', '1', '2', '3', '1,2,3']),
17
+ ('Tense', ['none', 'Past', 'Fut', 'Pres', 'Imp'])
18
+ ]
19
+
20
+ @dataclass
21
+ class MorphLogitsOutput(ModelOutput):
22
+ prefix_logits: torch.FloatTensor = None
23
+ pos_logits: torch.FloatTensor = None
24
+ features_logits: List[torch.FloatTensor] = None
25
+ suffix_logits: torch.FloatTensor = None
26
+ suffix_features_logits: List[torch.FloatTensor] = None
27
+
28
+ def detach(self):
29
+ return MorphLogitsOutput(self.prefix_logits.detach(), self.pos_logits.detach(), [logits.deatch() for logits in self.features_logits], self.suffix_logits.detach(), [logits.deatch() for logits in self.suffix_features_logits])
30
+
31
+
32
+ @dataclass
33
+ class MorphTaggingOutput(ModelOutput):
34
+ loss: Optional[torch.FloatTensor] = None
35
+ logits: Optional[MorphLogitsOutput] = None
36
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
37
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
38
+
39
+ @dataclass
40
+ class MorphLabels(ModelOutput):
41
+ prefix_labels: Optional[torch.FloatTensor] = None
42
+ pos_labels: Optional[torch.FloatTensor] = None
43
+ features_labels: Optional[List[torch.FloatTensor]] = None
44
+ suffix_labels: Optional[torch.FloatTensor] = None
45
+ suffix_features_labels: Optional[List[torch.FloatTensor]] = None
46
+
47
+ def detach(self):
48
+ return MorphLabels(self.prefix_labels.detach(), self.pos_labels.detach(), [labels.detach() for labels in self.features_labels], self.suffix_labels.detach(), [labels.detach() for labels in self.suffix_features_labels])
49
+
50
+ def to(self, device):
51
+ return MorphLabels(self.prefix_labels.to(device), self.pos_labels.to(device), [feat.to(device) for feat in self.features_labels], self.suffix_labels.to(device), [feat.to(device) for feat in self.suffix_features_labels])
52
+
53
+ class BertMorphTaggingHead(nn.Module):
54
+ def __init__(self, config):
55
+ super().__init__()
56
+ self.config = config
57
+
58
+ self.num_prefix_classes = len(ALL_PREFIX_POS)
59
+ self.num_pos_classes = len(ALL_POS)
60
+ self.num_suffix_classes = len(ALL_SUFFIX_POS)
61
+ self.num_features_classes = list(map(len, map(itemgetter(1), ALL_FEATURES)))
62
+ # we need a classifier for prefix cls and POS cls
63
+ # the prefix will use BCEWithLogits for multiple labels cls
64
+ self.prefix_cls = nn.Linear(config.hidden_size, self.num_prefix_classes)
65
+ # and pos + feats will use good old cross entropy for single label
66
+ self.pos_cls = nn.Linear(config.hidden_size, self.num_pos_classes)
67
+ self.features_cls = nn.ModuleList([nn.Linear(config.hidden_size, len(features)) for _, features in ALL_FEATURES])
68
+ # and suffix + feats will also be cross entropy
69
+ self.suffix_cls = nn.Linear(config.hidden_size, self.num_suffix_classes)
70
+ self.suffix_features_cls = nn.ModuleList([nn.Linear(config.hidden_size, len(features)) for _, features in ALL_FEATURES])
71
+
72
+ def forward(
73
+ self,
74
+ hidden_states: torch.Tensor,
75
+ labels: Optional[MorphLabels] = None):
76
+ # run each of the classifiers on the transformed output
77
+ prefix_logits = self.prefix_cls(hidden_states)
78
+ pos_logits = self.pos_cls(hidden_states)
79
+ suffix_logits = self.suffix_cls(hidden_states)
80
+ features_logits = [cls(hidden_states) for cls in self.features_cls]
81
+ suffix_features_logits = [cls(hidden_states) for cls in self.suffix_features_cls]
82
+
83
+ loss = None
84
+ if labels is not None:
85
+ # step 1: prefix labels loss
86
+ loss_fct = nn.BCEWithLogitsLoss(weight=(labels.prefix_labels != -100).float())
87
+ loss = loss_fct(prefix_logits, labels.prefix_labels)
88
+ # step 2: pos labels loss
89
+ loss_fct = nn.CrossEntropyLoss()
90
+ loss += loss_fct(pos_logits.view(-1, self.num_pos_classes), labels.pos_labels.view(-1))
91
+ # step 2b: features
92
+ for feat_logits,feat_labels,num_features in zip(features_logits, labels.features_labels, self.num_features_classes):
93
+ loss += loss_fct(feat_logits.view(-1, num_features), feat_labels.view(-1))
94
+ # step 3: suffix logits loss
95
+ loss += loss_fct(suffix_logits.view(-1, self.num_suffix_classes), labels.suffix_labels.view(-1))
96
+ # step 3b: suffix features
97
+ for feat_logits,feat_labels,num_features in zip(suffix_features_logits, labels.suffix_features_labels, self.num_features_classes):
98
+ loss += loss_fct(feat_logits.view(-1, num_features), feat_labels.view(-1))
99
+
100
+ return loss, MorphLogitsOutput(prefix_logits, pos_logits, features_logits, suffix_logits, suffix_features_logits)
101
+
102
+ class BertForMorphTagging(BertPreTrainedModel):
103
+
104
+ def __init__(self, config):
105
+ super().__init__(config)
106
+
107
+ self.bert = BertModel(config, add_pooling_layer=False)
108
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
109
+ self.morph = BertMorphTaggingHead(config)
110
+
111
+ # Initialize weights and apply final processing
112
+ self.post_init()
113
+
114
+ def forward(
115
+ self,
116
+ input_ids: Optional[torch.Tensor] = None,
117
+ attention_mask: Optional[torch.Tensor] = None,
118
+ token_type_ids: Optional[torch.Tensor] = None,
119
+ position_ids: Optional[torch.Tensor] = None,
120
+ labels: Optional[MorphLabels] = None,
121
+ head_mask: Optional[torch.Tensor] = None,
122
+ inputs_embeds: Optional[torch.Tensor] = None,
123
+ output_attentions: Optional[bool] = None,
124
+ output_hidden_states: Optional[bool] = None,
125
+ return_dict: Optional[bool] = None,
126
+ ):
127
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
128
+
129
+ bert_outputs = self.bert(
130
+ input_ids,
131
+ attention_mask=attention_mask,
132
+ token_type_ids=token_type_ids,
133
+ position_ids=position_ids,
134
+ head_mask=head_mask,
135
+ inputs_embeds=inputs_embeds,
136
+ output_attentions=output_attentions,
137
+ output_hidden_states=output_hidden_states,
138
+ return_dict=return_dict,
139
+ )
140
+
141
+ hidden_states = bert_outputs[0]
142
+ hidden_states = self.dropout(hidden_states)
143
+
144
+ loss, logits = self.morph(hidden_states, labels)
145
+
146
+ if not return_dict:
147
+ return (loss,logits) + bert_outputs[2:]
148
+
149
+ return MorphTaggingOutput(
150
+ loss=loss,
151
+ logits=logits,
152
+ hidden_states=bert_outputs.hidden_states,
153
+ attentions=bert_outputs.attentions,
154
+ )
155
+
156
+ def predict(self, sentences: List[str], tokenizer: BertTokenizerFast, padding='longest'):
157
+ # tokenize the inputs and convert them to relevant device
158
+ inputs = tokenizer(sentences, padding=padding, truncation=True, return_tensors='pt')
159
+ inputs = {k:v.to(self.device) for k,v in inputs.items()}
160
+ # calculate the logits
161
+ logits = self.forward(**inputs, return_dict=True).logits
162
+ return parse_logits(inputs, sentences, tokenizer, logits)
163
+
164
+ def parse_logits(inputs: Dict[str, torch.Tensor], sentences: List[str], tokenizer: BertTokenizerFast, logits: MorphLogitsOutput):
165
+ prefix_logits, pos_logits, feats_logits, suffix_logits, suffix_feats_logits = \
166
+ logits.prefix_logits, logits.pos_logits, logits.features_logits, logits.suffix_logits, logits.suffix_features_logits
167
+
168
+ prefix_predictions = (prefix_logits > 0.5).int() # Threshold at 0.5 for multi-label classification
169
+ pos_predictions = pos_logits.argmax(axis=-1)
170
+ suffix_predictions = suffix_logits.argmax(axis=-1)
171
+ feats_predictions = [logits.argmax(axis=-1) for logits in feats_logits]
172
+ suffix_feats_predictions = [logits.argmax(axis=-1) for logits in suffix_feats_logits]
173
+
174
+ # create the return dictionary
175
+ # for each sentence, return a dict object with the following files { text, tokens }
176
+ # Where tokens is a list of dicts, where each dict is:
177
+ # { pos: str, feats: dict, prefixes: List[str], suffix: str | bool, suffix_feats: dict | None}
178
+ special_tokens = set([tokenizer.pad_token, tokenizer.cls_token, tokenizer.sep_token])
179
+ ret = []
180
+ for sent_idx,sentence in enumerate(sentences):
181
+ input_id_strs = tokenizer.convert_ids_to_tokens(inputs['input_ids'][sent_idx])
182
+ # iterate through each token in the sentence, ignoring special tokens
183
+ tokens = []
184
+ for token_idx,token_str in enumerate(input_id_strs):
185
+ if not token_str in special_tokens:
186
+ if token_str.startswith('##'):
187
+ tokens[-1]['token'] += token_str[2:]
188
+ continue
189
+ tokens.append(dict(
190
+ token=token_str,
191
+ pos=ALL_POS[pos_predictions[sent_idx, token_idx]],
192
+ feats=get_features_dict_from_predictions(feats_predictions, (sent_idx, token_idx)),
193
+ prefixes=[ALL_PREFIX_POS[idx] for idx,i in enumerate(prefix_predictions[sent_idx, token_idx]) if i > 0],
194
+ suffix=get_suffix_or_false(ALL_SUFFIX_POS[suffix_predictions[sent_idx, token_idx]]),
195
+ ))
196
+ if tokens[-1]['suffix']:
197
+ tokens[-1]['suffix_feats'] = get_features_dict_from_predictions(suffix_feats_predictions, (sent_idx, token_idx))
198
+ ret.append(dict(text=sentence, tokens=tokens))
199
+ return ret
200
+
201
+ def get_suffix_or_false(suffix):
202
+ return False if suffix == 'none' else suffix
203
+
204
+ def get_features_dict_from_predictions(predictions, idx):
205
+ ret = {}
206
+ for (feat_idx, (feat_name, feat_values)) in enumerate(ALL_FEATURES):
207
+ val = feat_values[predictions[feat_idx][idx]]
208
+ if val != 'none':
209
+ ret[feat_name] = val
210
+ return ret
211
+
212
+
BertForPrefixMarking.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.utils import ModelOutput
2
+ import torch
3
+ from torch import nn
4
+ from typing import Dict, List, Tuple, Optional
5
+ from dataclasses import dataclass
6
+ from transformers import BertPreTrainedModel, BertModel, BertTokenizerFast
7
+
8
+ # define the classes, and the possible prefixes for each class
9
+ POSSIBLE_PREFIX_CLASSES = [ ['לכש', 'כש', 'מש', 'בש', 'לש'], ['מ'], ['ש'], ['ה'], ['ו'], ['כ'], ['ל'], ['ב'] ]
10
+ # map each individual prefix to it's class number
11
+ PREFIXES_TO_CLASS = {w:i for i,l in enumerate(POSSIBLE_PREFIX_CLASSES) for w in l}
12
+ # keep a list of all the prefixes, sorted by length, so that we can decompose
13
+ # a given prefixes and figure out the classes
14
+ ALL_PREFIX_ITEMS = list(sorted(PREFIXES_TO_CLASS.keys(), key=len, reverse=True))
15
+ TOTAL_POSSIBLE_PREFIX_CLASSES = len(POSSIBLE_PREFIX_CLASSES)
16
+
17
+ def get_prefixes_from_str(s, greedy=False):
18
+ # keep trimming prefixes from the string
19
+ while len(s) > 0 and s[0] in PREFIXES_TO_CLASS:
20
+ # find the longest string to trim
21
+ next_pre = next((pre for pre in ALL_PREFIX_ITEMS if s.startswith(pre)), None)
22
+ if next_pre is None:
23
+ return
24
+ yield next_pre
25
+ # if the chosen prefix is more than one letter, there is always an option that the
26
+ # prefix is actually just the first letter of the prefix - so offer that up as a valid prefix
27
+ # as well. We will still jump to the length of the longer one, since if the next two/three
28
+ # letters are a prefix, they have to be the longest one
29
+ if not greedy and len(next_pre) > 1:
30
+ yield next_pre[0]
31
+ s = s[len(next_pre):]
32
+
33
+ def get_prefix_classes_from_str(s, greedy=False):
34
+ for pre in get_prefixes_from_str(s, greedy):
35
+ yield PREFIXES_TO_CLASS[pre]
36
+
37
+ @dataclass
38
+ class PrefixesClassifiersOutput(ModelOutput):
39
+ loss: Optional[torch.FloatTensor] = None
40
+ logits: Optional[torch.FloatTensor] = None
41
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
42
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
43
+
44
+ class BertPrefixMarkingHead(nn.Module):
45
+ def __init__(self, config) -> None:
46
+ super().__init__()
47
+ self.config = config
48
+
49
+ # an embedding table containing an embedding for each prefix class + 1 for NONE
50
+ # we will concatenate either the embedding/NONE for each class - and we want the concatenate
51
+ # size to be the hidden_size
52
+ prefix_class_embed = config.hidden_size // TOTAL_POSSIBLE_PREFIX_CLASSES
53
+ self.prefix_class_embeddings = nn.Embedding(TOTAL_POSSIBLE_PREFIX_CLASSES + 1, prefix_class_embed)
54
+
55
+ # one layer for transformation, apply an activation, then another N classifiers for each prefix class
56
+ self.transform = nn.Linear(config.hidden_size + prefix_class_embed * TOTAL_POSSIBLE_PREFIX_CLASSES, config.hidden_size)
57
+ self.activation = nn.Tanh()
58
+ self.classifiers = nn.ModuleList([nn.Linear(config.hidden_size, 2) for _ in range(TOTAL_POSSIBLE_PREFIX_CLASSES)])
59
+
60
+ def forward(
61
+ self,
62
+ hidden_states: torch.Tensor,
63
+ prefix_class_id_options: torch.Tensor,
64
+ labels: Optional[torch.Tensor] = None) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
65
+
66
+ # encode the prefix_class_id_options
67
+ # If input_ids is batch x seq_len
68
+ # Then sequence_output is batch x seq_len x hidden_dim
69
+ # So prefix_class_id_options is batch x seq_len x TOTAL_POSSIBLE_PREFIX_CLASSES
70
+ # Looking up the embeddings should give us batch x seq_len x TOTAL_POSSIBLE_PREFIX_CLASSES x hidden_dim / N
71
+ possible_class_embed = self.prefix_class_embeddings(prefix_class_id_options)
72
+ # then flatten the final dimension - now we have batch x seq_len x hidden_dim_2
73
+ possible_class_embed = possible_class_embed.reshape(possible_class_embed.shape[:-2] + (-1,))
74
+
75
+ # concatenate the new class embed into the sequence output before the transform
76
+ pre_transform_output = torch.cat((hidden_states, possible_class_embed), dim=-1) # batch x seq_len x (hidden_dim + hidden_dim_2)
77
+ pre_logits_output = self.activation(self.transform(pre_transform_output))# batch x seq_len x hidden_dim
78
+
79
+ # run each of the classifiers on the transformed output
80
+ logits = torch.cat([cls(pre_logits_output).unsqueeze(-2) for cls in self.classifiers], dim=-2)
81
+
82
+ loss = None
83
+ if labels is not None:
84
+ loss_fct = nn.CrossEntropyLoss()
85
+ loss = loss_fct(logits.view(-1, 2), labels.view(-1))
86
+
87
+ return (loss, logits)
88
+
89
+
90
+
91
+ class BertForPrefixMarking(BertPreTrainedModel):
92
+
93
+ def __init__(self, config):
94
+ super().__init__(config)
95
+
96
+ self.bert = BertModel(config, add_pooling_layer=False)
97
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
98
+ self.prefix = BertPrefixMarkingHead(config)
99
+
100
+ # Initialize weights and apply final processing
101
+ self.post_init()
102
+
103
+ def forward(
104
+ self,
105
+ input_ids: Optional[torch.Tensor] = None,
106
+ attention_mask: Optional[torch.Tensor] = None,
107
+ token_type_ids: Optional[torch.Tensor] = None,
108
+ prefix_class_id_options: Optional[torch.Tensor] = None,
109
+ position_ids: Optional[torch.Tensor] = None,
110
+ labels: Optional[torch.Tensor] = None,
111
+ head_mask: Optional[torch.Tensor] = None,
112
+ inputs_embeds: Optional[torch.Tensor] = None,
113
+ output_attentions: Optional[bool] = None,
114
+ output_hidden_states: Optional[bool] = None,
115
+ return_dict: Optional[bool] = None,
116
+ ):
117
+ r"""
118
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
119
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
120
+ """
121
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
122
+
123
+ bert_outputs = self.bert(
124
+ input_ids,
125
+ attention_mask=attention_mask,
126
+ token_type_ids=token_type_ids,
127
+ position_ids=position_ids,
128
+ head_mask=head_mask,
129
+ inputs_embeds=inputs_embeds,
130
+ output_attentions=output_attentions,
131
+ output_hidden_states=output_hidden_states,
132
+ return_dict=return_dict,
133
+ )
134
+
135
+ hidden_states = bert_outputs[0]
136
+ hidden_states = self.dropout(hidden_states)
137
+
138
+ loss, logits = self.prefix.forward(hidden_states, prefix_class_id_options, labels)
139
+ if not return_dict:
140
+ return (loss,logits,) + bert_outputs[2:]
141
+
142
+ return PrefixesClassifiersOutput(
143
+ loss=loss,
144
+ logits=logits,
145
+ hidden_states=bert_outputs.hidden_states,
146
+ attentions=bert_outputs.attentions,
147
+ )
148
+
149
+ def predict(self, sentences: List[str], tokenizer: BertTokenizerFast, padding='longest'):
150
+ # step 1: encode the sentences through using the tokenizer, and get the input tensors + prefix id tensors
151
+ inputs = encode_sentences_for_bert_for_prefix_marking(tokenizer, sentences, padding)
152
+ inputs = {k:v.to(self.device) for k,v in inputs.items()}
153
+
154
+ # run through bert
155
+ logits = self.forward(**inputs, return_dict=True).logits
156
+ return parse_logits(inputs, sentences, tokenizer, logits)
157
+
158
+ def parse_logits(inputs: Dict[str, torch.Tensor], sentences: List[str], tokenizer: BertTokenizerFast, logits: torch.FloatTensor):
159
+ # extract the predictions by argmaxing the final dimension (batch x sequence x prefixes x prediction)
160
+ logit_preds = torch.argmax(logits, axis=3)
161
+
162
+ ret = []
163
+
164
+ for sent_idx,sent_ids in enumerate(inputs['input_ids']):
165
+ tokens = tokenizer.convert_ids_to_tokens(sent_ids)
166
+ ret.append([])
167
+ for tok_idx,token in enumerate(tokens):
168
+ # If we've reached the pad token, then we are at the end
169
+ if token == tokenizer.pad_token: continue
170
+ if token.startswith('##'): continue
171
+
172
+ # combine the next tokens in? only if it's a breakup
173
+ next_tok_idx = tok_idx + 1
174
+ while next_tok_idx < len(tokens) and tokens[next_tok_idx].startswith('##'):
175
+ token += tokens[next_tok_idx][2:]
176
+ next_tok_idx += 1
177
+
178
+ prefix_len = get_predicted_prefix_len_from_logits(token, logit_preds[sent_idx, tok_idx])
179
+
180
+ if not prefix_len:
181
+ ret[-1].append([token])
182
+ else:
183
+ ret[-1].append([token[:prefix_len], token[prefix_len:]])
184
+ return ret
185
+
186
+ def encode_sentences_for_bert_for_prefix_marking(tokenizer: BertTokenizerFast, sentences: List[str], padding='longest', truncation=True):
187
+ inputs = tokenizer(sentences, padding=padding, truncation=truncation, return_tensors='pt')
188
+
189
+ # create our prefix_id_options array which will be like the input ids shape but with an addtional
190
+ # dimension containing for each prefix whether it can be for that word
191
+ prefix_id_options = torch.full(inputs['input_ids'].shape + (TOTAL_POSSIBLE_PREFIX_CLASSES,), TOTAL_POSSIBLE_PREFIX_CLASSES, dtype=torch.long)
192
+
193
+ # go through each token, and fill in the vector accordingly
194
+ for sent_idx, sent_ids in enumerate(inputs['input_ids']):
195
+ tokens = tokenizer.convert_ids_to_tokens(sent_ids)
196
+ for tok_idx, token in enumerate(tokens):
197
+ # if the first letter isn't a valid prefix letter, nothing to talk about
198
+ if len(token) < 2 or not token[0] in PREFIXES_TO_CLASS: continue
199
+
200
+ # combine the next tokens in? only if it's a breakup
201
+ next_tok_idx = tok_idx + 1
202
+ while next_tok_idx < len(tokens) and tokens[next_tok_idx].startswith('##'):
203
+ token += tokens[next_tok_idx][2:]
204
+ next_tok_idx += 1
205
+
206
+ # find all the possible prefixes - and mark them as 0 (and in the possible mark it as it's value for embed lookup)
207
+ for pre_class in get_prefix_classes_from_str(token):
208
+ prefix_id_options[sent_idx, tok_idx, pre_class] = pre_class
209
+
210
+ inputs['prefix_class_id_options'] = prefix_id_options
211
+ return inputs
212
+
213
+ def get_predicted_prefix_len_from_logits(token, token_logits):
214
+ # Go through each possible prefix, and check if the prefix is yes - and if
215
+ # so increase the counter of the matched length, otherwise break out. That will solve cases
216
+ # of predicting prefix combinations that don't exist on the word.
217
+ # For example, if we have the word ושכשהלכתי and the model predict ו & כש, then we will only
218
+ # take the vuv because in order to get the כש we need the ש as well.
219
+ # Two extra items:
220
+ # 1] Don't allow the same prefix multiple times
221
+ # 2] Always check that the word starts with that prefix - otherwise it's bad
222
+ # (except for the case of multi-letter prefix, where we force the next to be last)
223
+ cur_len, skip_next, last_check, seen_prefixes = 0, False, False, set()
224
+ for prefix in get_prefixes_from_str(token):
225
+ # Are we skipping this prefix? This will be the case where we matched כש, don't allow ש
226
+ if skip_next:
227
+ skip_next = False
228
+ continue
229
+ # check for duplicate prefixes, we don't allow two of the same prefix
230
+ # if it predicted two of the same, then we will break out
231
+ if prefix in seen_prefixes: break
232
+ seen_prefixes.add(prefix)
233
+
234
+ # check if we predicted this prefix
235
+ if token_logits[PREFIXES_TO_CLASS[prefix]].item():
236
+ cur_len += len(prefix)
237
+ if last_check: break
238
+ skip_next = len(prefix) > 1
239
+ # Otherwise, we predicted no. If we didn't, then this is the end of the prefix
240
+ # and time to break out. *Except* if it's a multi letter prefix, then we allow
241
+ # just the next letter - e.g., if כש doesn't match, then we allow כ, but then we know
242
+ # the word continues with a ש, and if it's not כש, then it's not כ-ש- (invalid)
243
+ elif len(prefix) > 1:
244
+ last_check = True
245
+ else:
246
+ break
247
+
248
+ return cur_len
BertForSyntaxParsing.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from transformers.utils import ModelOutput
3
+ import torch
4
+ from torch import nn
5
+ from typing import Dict, List, Tuple, Optional, Union
6
+ from dataclasses import dataclass
7
+ from transformers import BertPreTrainedModel, BertModel, BertTokenizerFast
8
+
9
+ ALL_FUNCTION_LABELS = ["nsubj", "punct", "mark", "case", "fixed", "obl", "det", "amod", "acl:relcl", "nmod", "cc", "conj", "root", "compound", "cop", "compound:affix", "advmod", "nummod", "appos", "nsubj:pass", "nmod:poss", "xcomp", "obj", "aux", "parataxis", "advcl", "ccomp", "csubj", "acl", "obl:tmod", "csubj:pass", "dep", "dislocated", "nmod:tmod", "nmod:npmod", "flat", "obl:npmod", "goeswith", "reparandum", "orphan", "list", "discourse", "iobj", "vocative", "expl", "flat:name"]
10
+
11
+ @dataclass
12
+ class SyntaxLogitsOutput(ModelOutput):
13
+ dependency_logits: torch.FloatTensor = None
14
+ function_logits: torch.FloatTensor = None
15
+ dependency_head_indices: torch.LongTensor = None
16
+
17
+ def detach(self):
18
+ return SyntaxTaggingOutput(self.dependency_logits.detach(), self.function_logits.detach(), self.dependency_head_indices.detach())
19
+
20
+ @dataclass
21
+ class SyntaxTaggingOutput(ModelOutput):
22
+ loss: Optional[torch.FloatTensor] = None
23
+ logits: Optional[SyntaxLogitsOutput] = None
24
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
25
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
26
+
27
+ @dataclass
28
+ class SyntaxLabels(ModelOutput):
29
+ dependency_labels: Optional[torch.LongTensor] = None
30
+ function_labels: Optional[torch.LongTensor] = None
31
+
32
+ def detach(self):
33
+ return SyntaxLabels(self.dependency_labels.detach(), self.function_labels.detach())
34
+
35
+ def to(self, device):
36
+ return SyntaxLabels(self.dependency_labels.to(device), self.function_labels.to(device))
37
+
38
+ class BertSyntaxParsingHead(nn.Module):
39
+ def __init__(self, config):
40
+ super().__init__()
41
+ self.config = config
42
+
43
+ # the attention query & key values
44
+ self.head_size = config.syntax_head_size# int(config.hidden_size / config.num_attention_heads * 2)
45
+ self.query = nn.Linear(config.hidden_size, self.head_size)
46
+ self.key = nn.Linear(config.hidden_size, self.head_size)
47
+ # the function classifier gets two encoding values and predicts the labels
48
+ self.num_function_classes = len(ALL_FUNCTION_LABELS)
49
+ self.cls = nn.Linear(config.hidden_size * 2, self.num_function_classes)
50
+
51
+ def forward(
52
+ self,
53
+ hidden_states: torch.Tensor,
54
+ extended_attention_mask: Optional[torch.Tensor],
55
+ labels: Optional[SyntaxLabels] = None,
56
+ compute_mst: bool = False) -> Tuple[torch.Tensor, SyntaxLogitsOutput]:
57
+
58
+ # Take the dot product between "query" and "key" to get the raw attention scores.
59
+ query_layer = self.query(hidden_states)
60
+ key_layer = self.key(hidden_states)
61
+ attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / math.sqrt(self.head_size)
62
+
63
+ # add in the attention mask
64
+ if extended_attention_mask is not None:
65
+ if extended_attention_mask.ndim == 4:
66
+ extended_attention_mask = extended_attention_mask.squeeze(1)
67
+ attention_scores += extended_attention_mask# batch x seq x seq
68
+
69
+ # At this point take the hidden_state of the word and of the dependency word, and predict the function
70
+ # If labels are provided, use the labels.
71
+ if self.training and labels is not None:
72
+ # Note that the labels can have -100, so just set those to zero with a max
73
+ dep_indices = labels.dependency_labels.clamp_min(0)
74
+ # Otherwise - check if he wants the MST or just the argmax
75
+ elif compute_mst:
76
+ dep_indices = compute_mst_tree(attention_scores)
77
+ else:
78
+ dep_indices = torch.argmax(attention_scores, dim=-1)
79
+
80
+ # After we retrieved the dependency indicies, create a tensor of teh batch indices, and and retrieve the vectors of the heads to calculate the function
81
+ batch_indices = torch.arange(dep_indices.size(0)).view(-1, 1).expand(-1, dep_indices.size(1)).to(dep_indices.device)
82
+ dep_vectors = hidden_states[batch_indices, dep_indices, :] # batch x seq x dim
83
+
84
+ # concatenate that with the last hidden states, and send to the classifier output
85
+ cls_inputs = torch.cat((hidden_states, dep_vectors), dim=-1)
86
+ function_logits = self.cls(cls_inputs)
87
+
88
+ loss = None
89
+ if labels is not None:
90
+ loss_fct = nn.CrossEntropyLoss()
91
+ # step 1: dependency scores loss - this is applied to the attention scores
92
+ loss = loss_fct(attention_scores.view(-1, hidden_states.size(-2)), labels.dependency_labels.view(-1))
93
+ # step 2: function loss
94
+ loss += loss_fct(function_logits.view(-1, self.num_function_classes), labels.function_labels.view(-1))
95
+
96
+ return (loss, SyntaxLogitsOutput(attention_scores, function_logits, dep_indices))
97
+
98
+
99
+ class BertForSyntaxParsing(BertPreTrainedModel):
100
+
101
+ def __init__(self, config):
102
+ super().__init__(config)
103
+
104
+ self.bert = BertModel(config, add_pooling_layer=False)
105
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
106
+ self.syntax = BertSyntaxParsingHead(config)
107
+
108
+ # Initialize weights and apply final processing
109
+ self.post_init()
110
+
111
+ def forward(
112
+ self,
113
+ input_ids: Optional[torch.Tensor] = None,
114
+ attention_mask: Optional[torch.Tensor] = None,
115
+ token_type_ids: Optional[torch.Tensor] = None,
116
+ position_ids: Optional[torch.Tensor] = None,
117
+ labels: Optional[SyntaxLabels] = None,
118
+ head_mask: Optional[torch.Tensor] = None,
119
+ inputs_embeds: Optional[torch.Tensor] = None,
120
+ output_attentions: Optional[bool] = None,
121
+ output_hidden_states: Optional[bool] = None,
122
+ return_dict: Optional[bool] = None,
123
+ compute_syntax_mst: Optional[bool] = None,
124
+ ):
125
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
126
+
127
+ bert_outputs = self.bert(
128
+ input_ids,
129
+ attention_mask=attention_mask,
130
+ token_type_ids=token_type_ids,
131
+ position_ids=position_ids,
132
+ head_mask=head_mask,
133
+ inputs_embeds=inputs_embeds,
134
+ output_attentions=output_attentions,
135
+ output_hidden_states=output_hidden_states,
136
+ return_dict=return_dict,
137
+ )
138
+
139
+ extended_attention_mask = None
140
+ if attention_mask is not None:
141
+ extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_ids.size())
142
+ # apply the syntax head
143
+ loss, logits = self.syntax(self.dropout(bert_outputs[0]), extended_attention_mask, labels, compute_syntax_mst)
144
+
145
+ if not return_dict:
146
+ return (loss,(logits.dependency_logits, logits.function_logits)) + bert_outputs[2:]
147
+
148
+ return SyntaxTaggingOutput(
149
+ loss=loss,
150
+ logits=logits,
151
+ hidden_states=bert_outputs.hidden_states,
152
+ attentions=bert_outputs.attentions,
153
+ )
154
+
155
+ def predict(self, sentences: Union[str, List[str]], tokenizer: BertTokenizerFast, compute_mst=True):
156
+ if isinstance(sentences, str):
157
+ sentences = [sentences]
158
+
159
+ # predict the logits for the sentence
160
+ inputs = tokenizer(sentences, padding='longest', truncation=True, return_tensors='pt')
161
+ inputs = {k:v.to(self.device) for k,v in inputs.items()}
162
+ logits = self.forward(**inputs, return_dict=True, compute_syntax_mst=compute_mst).logits
163
+ return parse_logits(inputs, sentences, tokenizer, logits)
164
+
165
+ def parse_logits(inputs: Dict[str, torch.Tensor], sentences: List[str], tokenizer: BertTokenizerFast, logits: SyntaxLogitsOutput):
166
+ outputs = []
167
+ for i in range(len(sentences)):
168
+ deps = logits.dependency_head_indices[i].tolist()
169
+ funcs = logits.function_logits.argmax(-1)[i].tolist()
170
+ toks = tokenizer.convert_ids_to_tokens(inputs['input_ids'][i])[1:-1] # ignore cls and sep
171
+
172
+ # first, go through the tokens and create a mapping between each dependency index and the index without wordpieces
173
+ # wordpieces. At the same time, append the wordpieces in
174
+ idx_mapping = {-1:-1} # default root
175
+ real_idx = -1
176
+ for i in range(len(toks)):
177
+ if not toks[i].startswith('##'):
178
+ real_idx += 1
179
+ idx_mapping[i] = real_idx
180
+
181
+ # build our tree, keeping tracking of the root idx
182
+ tree = []
183
+ root_idx = 0
184
+ for i in range(len(toks)):
185
+ if toks[i].startswith('##'):
186
+ tree[-1]['word'] += toks[i][2:]
187
+ continue
188
+
189
+ dep_idx = deps[i + 1] - 1 # increase 1 for cls, decrease 1 for cls
190
+ dep_head = 'root' if dep_idx == -1 else toks[dep_idx]
191
+ dep_func = ALL_FUNCTION_LABELS[funcs[i + 1]]
192
+
193
+ if dep_head == 'root': root_idx = len(tree)
194
+ tree.append(dict(word=toks[i], dep_head_idx=idx_mapping[dep_idx], dep_func=dep_func))
195
+ # append the head word
196
+ for d in tree:
197
+ d['dep_head'] = tree[d['dep_head_idx']]['word']
198
+
199
+ outputs.append(dict(tree=tree, root_idx=root_idx))
200
+ return outputs
201
+
202
+
203
+ def compute_mst_tree(attention_scores: torch.Tensor):
204
+ # attention scores should be 3 dimensions - batch x seq x seq (if it is 2 - just unsqueeze)
205
+ if attention_scores.ndim == 2: attention_scores = attention_scores.unsqueeze(0)
206
+ if attention_scores.ndim != 3 or attention_scores.shape[1] != attention_scores.shape[2]:
207
+ raise ValueError(f'Expected attention scores to be of shape batch x seq x seq, instead got {attention_scores.shape}')
208
+
209
+ batch_size, seq_len, _ = attention_scores.shape
210
+ # start by softmaxing so the scores are comparable
211
+ attention_scores = attention_scores.softmax(dim=-1)
212
+
213
+ # set the values for the CLS and sep to all by very low, so they never get chosen as a replacement arc
214
+ attention_scores[:, 0, :] = -10000
215
+ attention_scores[:, -1, :] = -10000
216
+ attention_scores[:, :, -1] = -10000 # can never predict sep
217
+
218
+ # find the root, and make him super high so we never have a conflict
219
+ root_cands = torch.argsort(attention_scores[:, :, 0], dim=-1)
220
+ batch_indices = torch.arange(batch_size, device=root_cands.device)
221
+ attention_scores[batch_indices.unsqueeze(1), root_cands, 0] = -10000
222
+ attention_scores[batch_indices, root_cands[:, -1], 0] = 10000
223
+
224
+ # we start by getting the argmax for each score, and then computing the cycles and contracting them
225
+ sorted_indices = torch.argsort(attention_scores, dim=-1, descending=True)
226
+ indices = sorted_indices[:, :, 0].clone() # take the argmax
227
+
228
+ # go through each batch item and make sure our tree works
229
+ for batch_idx in range(batch_size):
230
+ # We have one root - detect the cycles and contract them. A cycle can never contain the root so really
231
+ # for every cycle, we look at all the nodes, and find the highest arc out of the cycle for any values. Replace that and tada
232
+ has_cycle, cycle_nodes = detect_cycle(indices[batch_idx])
233
+ while has_cycle:
234
+ base_idx, head_idx = choose_contracting_arc(indices[batch_idx], sorted_indices[batch_idx], cycle_nodes, attention_scores[batch_idx])
235
+ indices[batch_idx, base_idx] = head_idx
236
+ # find the next cycle
237
+ has_cycle, cycle_nodes = detect_cycle(indices[batch_idx])
238
+
239
+ return indices
240
+
241
+ def detect_cycle(indices: torch.LongTensor):
242
+ # Simple cycle detection algorithm
243
+ # Returns a boolean indicating if a cycle is detected and the nodes involved in the cycle
244
+ visited = set()
245
+ for node in range(1, len(indices) - 1): # ignore the CLS/SEP tokens
246
+ if node in visited:
247
+ continue
248
+ current_path = set()
249
+ while node not in visited:
250
+ visited.add(node)
251
+ current_path.add(node)
252
+ node = indices[node].item()
253
+ if node == 0: break # roots never point to anything
254
+ if node in current_path:
255
+ return True, current_path # Cycle detected
256
+ return False, None
257
+
258
+ def choose_contracting_arc(indices: torch.LongTensor, sorted_indices: torch.LongTensor, cycle_nodes: set, scores: torch.FloatTensor):
259
+ # Chooses the highest-scoring, non-cycling arc from a graph. Iterates through 'cycle_nodes' to find
260
+ # the best arc based on 'scores', avoiding cycles and zero node connections.
261
+ # For each node, we only look at the next highest scoring non-cycling arc
262
+ best_base_idx, best_head_idx = -1, -1
263
+ score = float('-inf')
264
+
265
+ # convert the indices to a list once, to avoid multiple conversions (saves a few seconds)
266
+ currents = indices.tolist()
267
+ for base_node in cycle_nodes:
268
+ # we don't want to take anything that has a higher score than the current value - we can end up in an endless loop
269
+ # Since the indices are sorted, as soon as we find our current item, we can move on to the next.
270
+ current = currents[base_node]
271
+ found_current = False
272
+
273
+ for head_node in sorted_indices[base_node].tolist():
274
+ if head_node == current:
275
+ found_current = True
276
+ continue
277
+ if not found_current or head_node in cycle_nodes or head_node == 0:
278
+ continue
279
+
280
+ current_score = scores[base_node, head_node].item()
281
+ if current_score > score:
282
+ best_base_idx, best_head_idx, score = base_node, head_node, current_score
283
+ break
284
+
285
+ return best_base_idx, best_head_idx