Upload BertForJointParsing.py
Browse files- BertForJointParsing.py +20 -26
BertForJointParsing.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
from dataclasses import dataclass
|
2 |
-
import
|
3 |
from operator import itemgetter
|
4 |
import torch
|
5 |
from torch import nn
|
@@ -187,25 +187,6 @@ class BertForJointParsing(BertPreTrainedModel):
|
|
187 |
)
|
188 |
|
189 |
def predict(self, sentences: Union[str, List[str]], tokenizer: BertTokenizerFast, padding='longest', truncation=True, compute_syntax_mst=True, per_token_ner=False, output_style: Literal['json', 'ud', 'iahlt_ud'] = 'json'):
|
190 |
-
"""
|
191 |
-
Predicts various linguistic features using the DictaBERT model.
|
192 |
-
|
193 |
-
This function takes a sentence or a list of sentences in Hebrew and applies the BERT model to predict multiple linguistic attributes simultaneously. These include syntax, named entity recognition (NER), morphological analysis, lexical information, and text segmentation.
|
194 |
-
|
195 |
-
Parameters:
|
196 |
-
sentences (Union[str, List[str]]): A single sentence or a list of sentences in Hebrew.
|
197 |
-
tokenizer (BertTokenizerFast): The tokenizer used for preprocessing the input sentences.
|
198 |
-
padding (str, optional): The strategy for padding sentences. Defaults to 'longest'.
|
199 |
-
truncation (bool, optional): Flag to enable or disable truncation. Defaults to True.
|
200 |
-
compute_syntax_mst (bool, optional): If True, computes the maximum spanning tree for syntax prediction. Defaults to True.
|
201 |
-
per_token_ner (bool, optional): If True, performs NER for each token. Defaults to False.
|
202 |
-
output_style (Literal['json', 'ud', 'iahlt_ud'], optional): The format of the output. Choices are 'json', 'ud' (Universal Dependencies), or 'iahlt_ud' (UD in the style of IAHLT). Defaults to 'json'.
|
203 |
-
|
204 |
-
Returns:
|
205 |
-
Depending on the output_style chosen, returns the linguistic analysis in the specified format.
|
206 |
-
|
207 |
-
The function is integral for comprehensive linguistic analysis in applications involving Hebrew text, catering to a variety of NLP tasks.
|
208 |
-
"""
|
209 |
is_single_sentence = isinstance(sentences, str)
|
210 |
if is_single_sentence:
|
211 |
sentences = [sentences]
|
@@ -315,11 +296,10 @@ def ner_parse_logits(inputs: Dict[str, torch.Tensor], sentences: List[str], toke
|
|
315 |
def lex_parse_logits(inputs: Dict[str, torch.Tensor], sentences: List[str], tokenizer: BertTokenizerFast, logits: torch.Tensor):
|
316 |
input_ids = inputs['input_ids']
|
317 |
|
318 |
-
predictions = torch.
|
319 |
batch_ret = []
|
320 |
for batch_idx in range(len(sentences)):
|
321 |
-
|
322 |
-
batch_ret.append(ret)
|
323 |
for tok_idx in range(input_ids.shape[1]):
|
324 |
token_id = input_ids[batch_idx, tok_idx]
|
325 |
# ignore cls, sep, pad
|
@@ -328,9 +308,23 @@ def lex_parse_logits(inputs: Dict[str, torch.Tensor], sentences: List[str], toke
|
|
328 |
token = tokenizer._convert_id_to_token(token_id)
|
329 |
# wordpieces should just be appended to the previous word
|
330 |
if token.startswith('##'):
|
331 |
-
|
332 |
continue
|
333 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
334 |
return batch_ret
|
335 |
|
336 |
ud_prefixes_to_pos = {
|
@@ -437,7 +431,7 @@ def convert_output_to_ud(output_sentences, style: Literal['htb', 'iahlt']):
|
|
437 |
suf_feats = word['morph']['suffix_feats']
|
438 |
suf = ud_suffix_to_htb_str.get(f"Gender={suf_feats.get('Gender', 'Fem,Masc')}|Number={suf_feats.get('Number', 'Sing')}|Person={suf_feats.get('Person', '3')}", "_ืืื")
|
439 |
# for HTB, if the function is poss, then add a shel pointing to the next word
|
440 |
-
if func == 'nmod:poss':
|
441 |
intermediate_output.append(dict(word='_ืฉื_', lex='ืฉื', pos='ADP', dep=len(intermediate_output) + 2, func='case', feats='_', absolute_dep=True))
|
442 |
# add the main suffix in
|
443 |
intermediate_output.append(dict(word=suf, lex='ืืื', pos='PRON', dep=dep, func=func, feats='|'.join(f'{k}={v}' for k,v in word['morph']['suffix_feats'].items())))
|
|
|
1 |
from dataclasses import dataclass
|
2 |
+
import re
|
3 |
from operator import itemgetter
|
4 |
import torch
|
5 |
from torch import nn
|
|
|
187 |
)
|
188 |
|
189 |
def predict(self, sentences: Union[str, List[str]], tokenizer: BertTokenizerFast, padding='longest', truncation=True, compute_syntax_mst=True, per_token_ner=False, output_style: Literal['json', 'ud', 'iahlt_ud'] = 'json'):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
190 |
is_single_sentence = isinstance(sentences, str)
|
191 |
if is_single_sentence:
|
192 |
sentences = [sentences]
|
|
|
296 |
def lex_parse_logits(inputs: Dict[str, torch.Tensor], sentences: List[str], tokenizer: BertTokenizerFast, logits: torch.Tensor):
|
297 |
input_ids = inputs['input_ids']
|
298 |
|
299 |
+
predictions = torch.argsort(logits, dim=-1, descending=True)[..., :3]
|
300 |
batch_ret = []
|
301 |
for batch_idx in range(len(sentences)):
|
302 |
+
intermediate_ret = []
|
|
|
303 |
for tok_idx in range(input_ids.shape[1]):
|
304 |
token_id = input_ids[batch_idx, tok_idx]
|
305 |
# ignore cls, sep, pad
|
|
|
308 |
token = tokenizer._convert_id_to_token(token_id)
|
309 |
# wordpieces should just be appended to the previous word
|
310 |
if token.startswith('##'):
|
311 |
+
intermediate_ret[-1] = (intermediate_ret[-1][0] + token[2:], intermediate_ret[-1][1])
|
312 |
continue
|
313 |
+
intermediate_ret.append((token, tokenizer.convert_ids_to_tokens(predictions[batch_idx, tok_idx])))
|
314 |
+
|
315 |
+
# build the final output taking into account valid letters
|
316 |
+
ret = []
|
317 |
+
batch_ret.append(ret)
|
318 |
+
for (token, lexemes) in intermediate_ret:
|
319 |
+
# must overlap on at least 2 non ืืืื letters
|
320 |
+
possible_lets = set(c for c in token if c not in 'ืืืื')
|
321 |
+
final_lex = '[BLANK]'
|
322 |
+
for lex in lexemes:
|
323 |
+
if sum(c in possible_lets for c in lex) >= min([2, len(possible_lets), len([c for c in lex if c not in 'ืืืื'])]):
|
324 |
+
final_lex = lex
|
325 |
+
break
|
326 |
+
ret.append((token, final_lex))
|
327 |
+
|
328 |
return batch_ret
|
329 |
|
330 |
ud_prefixes_to_pos = {
|
|
|
431 |
suf_feats = word['morph']['suffix_feats']
|
432 |
suf = ud_suffix_to_htb_str.get(f"Gender={suf_feats.get('Gender', 'Fem,Masc')}|Number={suf_feats.get('Number', 'Sing')}|Person={suf_feats.get('Person', '3')}", "_ืืื")
|
433 |
# for HTB, if the function is poss, then add a shel pointing to the next word
|
434 |
+
if func == 'nmod:poss' and s_lex != 'ืฉื':
|
435 |
intermediate_output.append(dict(word='_ืฉื_', lex='ืฉื', pos='ADP', dep=len(intermediate_output) + 2, func='case', feats='_', absolute_dep=True))
|
436 |
# add the main suffix in
|
437 |
intermediate_output.append(dict(word=suf, lex='ืืื', pos='PRON', dep=dep, func=func, feats='|'.join(f'{k}={v}' for k,v in word['morph']['suffix_feats'].items())))
|