|
from collections import defaultdict |
|
|
|
|
|
def read_conjunctions(cfg): |
|
conj2sent = dict() |
|
file_path = cfg.conjunctions_file |
|
|
|
with open(file_path, 'r') as fin: |
|
sent = 1 |
|
currentSentText = '' |
|
for line in fin: |
|
if line == '\n': |
|
sent = 1 |
|
continue |
|
elif sent == 1: |
|
currentSentText = line.replace('\n', '') |
|
sent = 0 |
|
else: |
|
conj_sent = line.replace('\n', '') |
|
conj2sent[conj_sent] = currentSentText |
|
conj_sentences = list(conj2sent.keys()) |
|
return conj_sentences, conj2sent |
|
|
|
|
|
def print_predictions(outputs, file_path, vocab, sequence_label_domain=None): |
|
"""print_predictions prints prediction results |
|
|
|
Args: |
|
outputs (list): prediction outputs |
|
file_path (str): output file path |
|
vocab (Vocabulary): vocabulary |
|
sequence_label_domain (str, optional): sequence label domain. Defaults to None. |
|
""" |
|
|
|
with open(file_path, 'w') as fout: |
|
for sent_output in outputs: |
|
seq_len = sent_output['seq_len'] |
|
assert 'tokens' in sent_output |
|
tokens = [vocab.get_token_from_index(token, 'tokens') for token in sent_output['tokens'][:seq_len]] |
|
print("Token\t{}".format(' '.join(tokens)), file=fout) |
|
|
|
if 'text' in sent_output: |
|
print(f"Text\t{sent_output['text']}", file=fout) |
|
|
|
if 'sequence_labels' in sent_output and 'sequence_label_preds' in sent_output: |
|
sequence_labels = [ |
|
vocab.get_token_from_index(true_sequence_label, sequence_label_domain) |
|
for true_sequence_label in sent_output['sequence_labels'][:seq_len] |
|
] |
|
sequence_label_preds = [ |
|
vocab.get_token_from_index(pred_sequence_label, sequence_label_domain) |
|
for pred_sequence_label in sent_output['sequence_label_preds'][:seq_len] |
|
] |
|
|
|
print("Sequence-Label-True\t{}".format(' '.join(sequence_labels)), file=fout) |
|
print("Sequence-Label-Pred\t{}".format(' '.join(sequence_label_preds)), file=fout) |
|
|
|
if 'joint_label_matrix' in sent_output: |
|
for row in sent_output['joint_label_matrix'][:seq_len]: |
|
print("Joint-Label-True\t{}".format(' '.join( |
|
[vocab.get_token_from_index(item, 'ent_rel_id') for item in row[:seq_len]])), |
|
file=fout) |
|
|
|
if 'joint_label_preds' in sent_output: |
|
for row in sent_output['joint_label_preds'][:seq_len]: |
|
print("Joint-Label-Pred\t{}".format(' '.join( |
|
[vocab.get_token_from_index(item, 'ent_rel_id') for item in row[:seq_len]])), |
|
file=fout) |
|
|
|
if 'separate_positions' in sent_output: |
|
print("Separate-Position-True\t{}".format(' '.join(map(str, sent_output['separate_positions']))), |
|
file=fout) |
|
|
|
if 'all_separate_position_preds' in sent_output: |
|
print("Separate-Position-Pred\t{}".format(' '.join(map(str, |
|
sent_output['all_separate_position_preds']))), |
|
file=fout) |
|
|
|
if 'span2ent' in sent_output: |
|
for span, ent in sent_output['span2ent'].items(): |
|
ent = vocab.get_token_from_index(ent, 'span2ent') |
|
assert ent != 'None', "true relation can not be `None`." |
|
|
|
print("Ent-True\t{}\t{}\t{}".format(ent, span, ' '.join(tokens[span[0]:span[1]])), file=fout) |
|
|
|
if 'all_ent_preds' in sent_output: |
|
for span, ent in sent_output['all_ent_preds'].items(): |
|
|
|
|
|
print("Ent-Span-Pred\t{}".format(span), file=fout) |
|
print("Ent-Pred\t{}\t{}\t{}".format(ent, span, ' '.join(tokens[span[0]:span[1]])), file=fout) |
|
|
|
if 'span2rel' in sent_output: |
|
for (span1, span2), rel in sent_output['span2rel'].items(): |
|
rel = vocab.get_token_from_index(rel, 'span2rel') |
|
assert rel != 'None', "true relation can not be `None`." |
|
|
|
if rel[-1] == '<': |
|
span1, span2 = span2, span1 |
|
print("Rel-True\t{}\t{}\t{}\t{}\t{}".format(rel[:-2], span1, span2, |
|
' '.join(tokens[span1[0]:span1[1]]), |
|
' '.join(tokens[span2[0]:span2[1]])), |
|
file=fout) |
|
|
|
if 'all_rel_preds' in sent_output: |
|
for (span1, span2), rel in sent_output['all_rel_preds'].items(): |
|
|
|
|
|
if rel[-1] == '<': |
|
span1, span2 = span2, span1 |
|
print("Rel-Pred\t{}\t{}\t{}\t{}\t{}".format(rel[:-2], span1, span2, |
|
' '.join(tokens[span1[0]:span1[1]]), |
|
' '.join(tokens[span2[0]:span2[1]])), |
|
file=fout) |
|
|
|
print(file=fout) |
|
|
|
|
|
def print_extractions_allennlp_format(cfg, outputs, file_path, vocab): |
|
conj_sentences, conj2sent = read_conjunctions(cfg) |
|
ext_texts = [] |
|
with open(file_path, 'w') as fout: |
|
for sent_output in outputs: |
|
extractions = {} |
|
seq_len = sent_output['seq_len'] |
|
assert 'tokens' in sent_output |
|
tokens = [vocab.get_token_from_index(token, 'tokens') for token in sent_output['tokens'][:seq_len-6]] |
|
sentence = ' '.join(tokens) |
|
if sentence in conj_sentences: |
|
sentence = conj2sent[sentence] |
|
|
|
if 'all_rel_preds' in sent_output: |
|
for (span1, span2), rel in sent_output['all_rel_preds'].items(): |
|
if rel == '' or rel == ' ': |
|
continue |
|
if sent_output['all_ent_preds'][span1] == 'Relation': |
|
try: |
|
if span2 in extractions[span1][rel]: |
|
continue |
|
except: |
|
pass |
|
try: |
|
extractions[span1][rel].append(span2) |
|
except: |
|
extractions[span1] = defaultdict(list) |
|
extractions[span1][rel].append(span2) |
|
else: |
|
try: |
|
if span1 in extractions[span2][rel]: |
|
continue |
|
except: |
|
pass |
|
try: |
|
extractions[span2][rel].append(span1) |
|
except: |
|
extractions[span2] = defaultdict(list) |
|
extractions[span2][rel].append(span1) |
|
to_remove_rel_spans = set() |
|
expand_rel = {} |
|
to_add = {} |
|
for rel_span1, d1 in extractions.items(): |
|
for rel_span2, d2 in extractions.items(): |
|
if rel_span1 != rel_span2 and not (rel_span1 in to_remove_rel_spans or rel_span2 in to_remove_rel_spans): |
|
if d1["Subject"] == d2["Subject"] and d1["Object"] == d2["Object"]: |
|
if rel_span1 in to_remove_rel_spans: |
|
to_add[expand_rel[rel_span1] + rel_span2] = d1 |
|
to_remove_rel_spans.add(rel_span2) |
|
to_remove_rel_spans.add(expand_rel[rel_span1]) |
|
expand_rel[rel_span2] = expand_rel[rel_span1] + rel_span2 |
|
expand_rel[rel_span1] = expand_rel[rel_span1] + rel_span2 |
|
elif rel_span2 in to_remove_rel_spans: |
|
to_add[expand_rel[rel_span2] + rel_span1] = d1 |
|
to_remove_rel_spans.add(rel_span1) |
|
to_remove_rel_spans.add(expand_rel[rel_span2]) |
|
expand_rel[rel_span1] = expand_rel[rel_span2] + rel_span1 |
|
expand_rel[rel_span2] = expand_rel[rel_span2] + rel_span1 |
|
else: |
|
to_add[rel_span1 + rel_span2] = d1 |
|
expand_rel[rel_span1] = rel_span1 + rel_span2 |
|
expand_rel[rel_span2] = rel_span1 + rel_span2 |
|
to_remove_rel_spans.add(rel_span1) |
|
to_remove_rel_spans.add(rel_span2) |
|
for tm in to_remove_rel_spans: |
|
del extractions[tm] |
|
for k, v in to_add.items(): |
|
extractions[k] = v |
|
for rel_sp, d in extractions.items(): |
|
if len(d["Subject"]) > 1: |
|
sorted_d_subject = sorted(d["Subject"], key=lambda x: x[0][0]) |
|
sorted_d_subject = [x[0] for x in sorted_d_subject] |
|
subject_text = " ".join([" ".join(tokens[sub_span[0]:sub_span[1]]) for sub_span in sorted_d_subject]) |
|
elif len(d["Subject"]) == 1: |
|
subject_text = " ".join([" ".join(tokens[sub_span[0]:sub_span[1]]) for sub_span in d["Subject"][0]]) |
|
else: |
|
subject_text = "" |
|
if len(d["Object"]) > 1: |
|
sorted_d_object = sorted(d["Object"], key=lambda x: x[0][0]) |
|
sorted_d_object = [x[0] for x in sorted_d_object] |
|
object_text = " ".join([" ".join(tokens[sub_span[0]:sub_span[1]]) for sub_span in sorted_d_object]) |
|
elif len(d["Object"]) == 1: |
|
object_text = " ".join([" ".join(tokens[sub_span[0]:sub_span[1]]) for sub_span in d["Object"][0]]) |
|
else: |
|
object_text = "" |
|
rel_text = " ".join([" ".join(tokens[sub_span[0]:sub_span[1]]) for sub_span in rel_sp]).replace('[unused1]', 'is') |
|
ext = f'<arg1> {subject_text} </arg1> <rel> {rel_text} </rel> <arg2> {object_text} </arg2>' |
|
if ext not in ext_texts and (rel_text != '' and subject_text != ''): |
|
print("{}\t{}".format(sentence, ext), file=fout) |
|
ext_texts.append(ext) |
|
|
|
|
|
def print_predictions_for_joint_decoding(outputs, file_path, vocab): |
|
"""print_predictions prints prediction results |
|
|
|
Args: |
|
outputs (list): prediction outputs |
|
file_path (str): output file path |
|
vocab (Vocabulary): vocabulary |
|
sequence_label_domain (str, optional): sequence label domain. Defaults to None. |
|
""" |
|
|
|
with open(file_path, 'w') as fout: |
|
for sent_output in outputs: |
|
seq_len = sent_output['seq_len'] |
|
assert 'tokens' in sent_output |
|
tokens = [vocab.get_token_from_index(token, 'tokens') for token in sent_output['tokens'][:seq_len]] |
|
print("Token\t{}".format(' '.join(tokens)), file=fout) |
|
|
|
if 'joint_label_matrix' in sent_output: |
|
for row in sent_output['joint_label_matrix'][:seq_len]: |
|
print("Joint-Label-True\t{}".format(' '.join( |
|
[vocab.get_token_from_index(item, 'ent_rel_id') for item in row[:seq_len]])), |
|
file=fout) |
|
|
|
if 'joint_label_preds' in sent_output: |
|
for row in sent_output['joint_label_preds'][:seq_len]: |
|
print("Joint-Label-Pred\t{}".format(' '.join( |
|
[vocab.get_token_from_index(item, 'ent_rel_id') for item in row[:seq_len]])), |
|
file=fout) |
|
|
|
if 'separate_positions' in sent_output: |
|
print("Separate-Position-True\t{}".format(' '.join(map(str, sent_output['separate_positions']))), |
|
file=fout) |
|
|
|
if 'all_separate_position_preds' in sent_output: |
|
print("Separate-Position-Pred\t{}".format(' '.join(map(str, |
|
sent_output['all_separate_position_preds']))), |
|
file=fout) |
|
|
|
if 'all_ent_span_preds' in sent_output: |
|
for span in sent_output['all_ent_span_preds']: |
|
print("Ent-Span-Pred\t{}".format(span), file=fout) |
|
|
|
if 'span2ent' in sent_output: |
|
for span, ent in sent_output['span2ent'].items(): |
|
ent = vocab.get_token_from_index(ent, 'ent_rel_id') |
|
assert ent != 'None', "true relation can not be `None`." |
|
|
|
print("Ent-True\t{}\t{}\t{}".format(ent, span, ' '.join([' '.join(tokens[sub_span[0]:sub_span[1]]) for sub_span in span])), file=fout) |
|
|
|
if 'all_ent_preds' in sent_output: |
|
for span, ent in sent_output['all_ent_preds'].items(): |
|
|
|
print("Ent-Pred\t{}\t{}\t{}".format(ent, span, ' '.join( |
|
[' '.join(tokens[sub_span[0]:sub_span[1]]) for sub_span in span])), file=fout) |
|
|
|
if 'span2rel' in sent_output: |
|
for (span1, span2), rel in sent_output['span2rel'].items(): |
|
rel = vocab.get_token_from_index(rel, 'ent_rel_id') |
|
assert rel != 'None', "true relation can not be `None`." |
|
span1_text_list = [' '.join(tokens[sub_span[0]:sub_span[1]]) for sub_span in span1] |
|
span2_text_list = [' '.join(tokens[sub_span[0]:sub_span[1]]) for sub_span in span2] |
|
print("Rel-True\t{}\t{}\t{}\t{}\t{}".format(rel, span1, span2, ' '.join(span1_text_list), |
|
' '.join(span2_text_list)), |
|
file=fout) |
|
|
|
if 'all_rel_preds' in sent_output: |
|
for (span1, span2), rel in sent_output['all_rel_preds'].items(): |
|
|
|
|
|
span1_text_list = [' '.join(tokens[sub_span[0]:sub_span[1]]) for sub_span in span1] |
|
span2_text_list = [' '.join(tokens[sub_span[0]:sub_span[1]]) for sub_span in span2] |
|
print("Rel-Pred\t{}\t{}\t{}\t{}\t{}".format(rel, span1, span2, ' '.join(span1_text_list), |
|
' '.join(span2_text_list)), |
|
file=fout) |
|
|
|
|
|
|
|
|
|
|
|
print(file=fout) |
|
|
|
|
|
def print_predictions_for_entity_rel_decoding(outputs, file_path, vocab): |
|
"""print_predictions prints prediction results |
|
|
|
Args: |
|
outputs (list): prediction outputs |
|
file_path (str): output file path |
|
vocab (Vocabulary): vocabulary |
|
sequence_label_domain (str, optional): sequence label domain. Defaults to None. |
|
""" |
|
|
|
with open(file_path, 'w') as fout: |
|
|
|
for sent_output in outputs: |
|
seq_len = sent_output['seq_len'] |
|
assert 'tokens' in sent_output |
|
tokens = [vocab.get_token_from_index(token, 'tokens') for token in sent_output['tokens'][:seq_len]] |
|
print("Token\t{}".format(' '.join(tokens)), file=fout) |
|
|
|
if 'entity_label_preds' in sent_output: |
|
for row in sent_output['entity_label_preds'][:seq_len]: |
|
print("Ent-Label-Pred\t{}".format(' '.join( |
|
[vocab.get_token_from_index(item, 'ent_rel_id') for item in row[:seq_len]])), |
|
file=fout) |
|
|
|
if 'relation_label_matrix' in sent_output: |
|
for row in sent_output['relation_label_matrix'][:seq_len]: |
|
print("Rel-Label-True\t{}".format(' '.join( |
|
[vocab.get_token_from_index(item + 2, 'ent_rel_id') if item != 0 else "None" for item in row[:seq_len]])), |
|
file=fout) |
|
|
|
if 'relation_label_preds' in sent_output: |
|
for row in sent_output['relation_label_preds'][:seq_len]: |
|
print("Rel-Label-Pred\t{}".format(' '.join( |
|
[vocab.get_token_from_index(item + 2, 'ent_rel_id') if item != 0 else "None" for item in row[:seq_len]])), |
|
file=fout) |
|
|
|
if 'separate_positions' in sent_output: |
|
print("Separate-Position-True\t{}".format(' '.join(map(str, sent_output['separate_positions']))), |
|
file=fout) |
|
|
|
if 'all_separate_position_preds' in sent_output: |
|
print("Separate-Position-Pred\t{}".format(' '.join(map(str, |
|
sent_output['all_separate_position_preds']))), |
|
file=fout) |
|
|
|
if 'all_ent_span_preds' in sent_output: |
|
for span in sent_output['all_ent_span_preds']: |
|
print("Ent-Span-Pred\t{}".format(span), file=fout) |
|
|
|
if 'span2ent' in sent_output: |
|
for span, ent in sent_output['span2ent'].items(): |
|
ent = vocab.get_token_from_index(ent, 'ent_rel_id') |
|
assert ent != 'None', "true relation can not be `None`." |
|
|
|
print("Ent-True\t{}\t{}\t{}".format(ent, span, ' '.join( |
|
[' '.join(tokens[sub_span[0]:sub_span[1]]) for sub_span in span])), file=fout) |
|
|
|
if 'all_ent_preds' in sent_output: |
|
for span, ent in sent_output['all_ent_preds'].items(): |
|
print("Ent-Pred\t{}\t{}\t{}".format(ent, span, ' '.join( |
|
[' '.join(tokens[sub_span[0]:sub_span[1]]) for sub_span in span])), file=fout) |
|
|
|
if 'span2rel' in sent_output: |
|
for (span1, span2), rel in sent_output['span2rel'].items(): |
|
rel = vocab.get_token_from_index(rel, 'ent_rel_id') |
|
assert rel != 'None', "true relation can not be `None`." |
|
|
|
span1_text_list = [' '.join(tokens[sub_span[0]:sub_span[1]]) for sub_span in span1] |
|
span2_text_list = [' '.join(tokens[sub_span[0]:sub_span[1]]) for sub_span in span2] |
|
print("Rel-True\t{}\t{}\t{}\t{}\t{}".format(rel, span1, span2, ' '.join(span1_text_list), |
|
' '.join(span2_text_list)), |
|
file=fout) |
|
if 'all_rel_preds' in sent_output: |
|
for (span1, span2), rel in sent_output['all_rel_preds'].items(): |
|
|
|
span1_text_list = [' '.join(tokens[sub_span[0]:sub_span[1]]) for sub_span in span1] |
|
span2_text_list = [' '.join(tokens[sub_span[0]:sub_span[1]]) for sub_span in span2] |
|
print("Rel-Pred\t{}\t{}\t{}\t{}\t{}".format(rel, span1, span2, ' '.join(span1_text_list), |
|
' '.join(span2_text_list)), file=fout) |
|
|
|
print(file=fout) |
|
|
|
def print_predictions_for_relation_decoding(outputs, file_path, vocab): |
|
with open(file_path, 'w') as fout: |
|
for sent_output in outputs: |
|
seq_len = sent_output['seq_len'] |
|
assert 'tokens' in sent_output |
|
tokens = [vocab.get_token_from_index(token, 'tokens') for token in sent_output['tokens'][:seq_len]] |
|
print("Token\t{}".format(' '.join(tokens)), file=fout) |
|
if 'relation_label_matrix' in sent_output: |
|
for row in sent_output['relation_label_matrix'][:seq_len]: |
|
print("Relation-Label-True\t{}".format(' '.join( |
|
[vocab.get_token_from_index(item, 'ent_rel_id') for item in row[:seq_len]])), |
|
file=fout) |
|
|
|
if 'relation_label_preds' in sent_output: |
|
for row in sent_output['relation_label_preds'][:seq_len]: |
|
print("Relation-Label-Pred\t{}".format(' '.join( |
|
[vocab.get_token_from_index(item, 'ent_rel_id') for item in row[:seq_len]])), |
|
file=fout) |
|
|