from collections import defaultdict |
def read_conjunctions(cfg): |
conj2sent = dict() |
file_path = cfg.conjunctions_file |
with open(file_path, 'r') as fin: |
sent = 1 |
currentSentText = '' |
for line in fin: |
if line == '\n': |
sent = 1 |
continue |
elif sent == 1: |
currentSentText = line.replace('\n', '') |
sent = 0 |
else: |
conj_sent = line.replace('\n', '') |
conj2sent[conj_sent] = currentSentText |
conj_sentences = list(conj2sent.keys()) |
return conj_sentences, conj2sent |
def print_predictions(outputs, file_path, vocab, sequence_label_domain=None): |
"""print_predictions prints prediction results |
Args: |
outputs (list): prediction outputs |
file_path (str): output file path |
vocab (Vocabulary): vocabulary |
sequence_label_domain (str, optional): sequence label domain. Defaults to None. |
""" |
with open(file_path, 'w') as fout: |
for sent_output in outputs: |
seq_len = sent_output['seq_len'] |
assert 'tokens' in sent_output |
tokens = [vocab.get_token_from_index(token, 'tokens') for token in sent_output['tokens'][:seq_len]] |
print("Token\t{}".format(' '.join(tokens)), file=fout) |
if 'text' in sent_output: |
print(f"Text\t{sent_output['text']}", file=fout) |
if 'sequence_labels' in sent_output and 'sequence_label_preds' in sent_output: |
sequence_labels = [ |
vocab.get_token_from_index(true_sequence_label, sequence_label_domain) |
for true_sequence_label in sent_output['sequence_labels'][:seq_len] |
] |
sequence_label_preds = [ |
vocab.get_token_from_index(pred_sequence_label, sequence_label_domain) |
for pred_sequence_label in sent_output['sequence_label_preds'][:seq_len] |
] |
print("Sequence-Label-True\t{}".format(' '.join(sequence_labels)), file=fout) |
print("Sequence-Label-Pred\t{}".format(' '.join(sequence_label_preds)), file=fout) |
if 'joint_label_matrix' in sent_output: |
for row in sent_output['joint_label_matrix'][:seq_len]: |
print("Joint-Label-True\t{}".format(' '.join( |
[vocab.get_token_from_index(item, 'ent_rel_id') for item in row[:seq_len]])), |
file=fout) |
if 'joint_label_preds' in sent_output: |
for row in sent_output['joint_label_preds'][:seq_len]: |
print("Joint-Label-Pred\t{}".format(' '.join( |
[vocab.get_token_from_index(item, 'ent_rel_id') for item in row[:seq_len]])), |
file=fout) |
if 'separate_positions' in sent_output: |
print("Separate-Position-True\t{}".format(' '.join(map(str, sent_output['separate_positions']))), |
file=fout) |
if 'all_separate_position_preds' in sent_output: |
print("Separate-Position-Pred\t{}".format(' '.join(map(str, |
sent_output['all_separate_position_preds']))), |
file=fout) |
if 'span2ent' in sent_output: |
for span, ent in sent_output['span2ent'].items(): |
ent = vocab.get_token_from_index(ent, 'span2ent') |
assert ent != 'None', "true relation can not be `None`." |
print("Ent-True\t{}\t{}\t{}".format(ent, span, ' '.join(tokens[span[0]:span[1]])), file=fout) |
if 'all_ent_preds' in sent_output: |
for span, ent in sent_output['all_ent_preds'].items(): |
print("Ent-Span-Pred\t{}".format(span), file=fout) |
print("Ent-Pred\t{}\t{}\t{}".format(ent, span, ' '.join(tokens[span[0]:span[1]])), file=fout) |
if 'span2rel' in sent_output: |
for (span1, span2), rel in sent_output['span2rel'].items(): |
rel = vocab.get_token_from_index(rel, 'span2rel') |
assert rel != 'None', "true relation can not be `None`." |
if rel[-1] == '<': |
span1, span2 = span2, span1 |
print("Rel-True\t{}\t{}\t{}\t{}\t{}".format(rel[:-2], span1, span2, |
' '.join(tokens[span1[0]:span1[1]]), |
' '.join(tokens[span2[0]:span2[1]])), |
file=fout) |
if 'all_rel_preds' in sent_output: |
for (span1, span2), rel in sent_output['all_rel_preds'].items(): |
if rel[-1] == '<': |
span1, span2 = span2, span1 |
print("Rel-Pred\t{}\t{}\t{}\t{}\t{}".format(rel[:-2], span1, span2, |
' '.join(tokens[span1[0]:span1[1]]), |
' '.join(tokens[span2[0]:span2[1]])), |
file=fout) |
print(file=fout) |
def print_extractions_allennlp_format(cfg, outputs, file_path, vocab): |
conj_sentences, conj2sent = read_conjunctions(cfg) |
ext_texts = [] |
with open(file_path, 'w') as fout: |
for sent_output in outputs: |
extractions = {} |
seq_len = sent_output['seq_len'] |
assert 'tokens' in sent_output |
tokens = [vocab.get_token_from_index(token, 'tokens') for token in sent_output['tokens'][:seq_len-6]] |
sentence = ' '.join(tokens) |
if sentence in conj_sentences: |
sentence = conj2sent[sentence] |
if 'all_rel_preds' in sent_output: |
for (span1, span2), rel in sent_output['all_rel_preds'].items(): |
if rel == '' or rel == ' ': |
continue |
if sent_output['all_ent_preds'][span1] == 'Relation': |
try: |
if span2 in extractions[span1][rel]: |
continue |
except: |
pass |
try: |
extractions[span1][rel].append(span2) |
except: |
extractions[span1] = defaultdict(list) |
extractions[span1][rel].append(span2) |
else: |
try: |
if span1 in extractions[span2][rel]: |
continue |
except: |
pass |
try: |
extractions[span2][rel].append(span1) |
except: |
extractions[span2] = defaultdict(list) |
extractions[span2][rel].append(span1) |
to_remove_rel_spans = set() |
expand_rel = {} |
to_add = {} |
for rel_span1, d1 in extractions.items(): |
for rel_span2, d2 in extractions.items(): |
if rel_span1 != rel_span2 and not (rel_span1 in to_remove_rel_spans or rel_span2 in to_remove_rel_spans): |
if d1["Subject"] == d2["Subject"] and d1["Object"] == d2["Object"]: |
if rel_span1 in to_remove_rel_spans: |
to_add[expand_rel[rel_span1] + rel_span2] = d1 |
to_remove_rel_spans.add(rel_span2) |
to_remove_rel_spans.add(expand_rel[rel_span1]) |
expand_rel[rel_span2] = expand_rel[rel_span1] + rel_span2 |
expand_rel[rel_span1] = expand_rel[rel_span1] + rel_span2 |
elif rel_span2 in to_remove_rel_spans: |
to_add[expand_rel[rel_span2] + rel_span1] = d1 |
to_remove_rel_spans.add(rel_span1) |
to_remove_rel_spans.add(expand_rel[rel_span2]) |
expand_rel[rel_span1] = expand_rel[rel_span2] + rel_span1 |
expand_rel[rel_span2] = expand_rel[rel_span2] + rel_span1 |
else: |
to_add[rel_span1 + rel_span2] = d1 |
expand_rel[rel_span1] = rel_span1 + rel_span2 |
expand_rel[rel_span2] = rel_span1 + rel_span2 |
to_remove_rel_spans.add(rel_span1) |
to_remove_rel_spans.add(rel_span2) |
for tm in to_remove_rel_spans: |
del extractions[tm] |
for k, v in to_add.items(): |
extractions[k] = v |
for rel_sp, d in extractions.items(): |
if len(d["Subject"]) > 1: |
sorted_d_subject = sorted(d["Subject"], key=lambda x: x[0][0]) |
sorted_d_subject = [x[0] for x in sorted_d_subject] |
subject_text = " ".join([" ".join(tokens[sub_span[0]:sub_span[1]]) for sub_span in sorted_d_subject]) |
elif len(d["Subject"]) == 1: |
subject_text = " ".join([" ".join(tokens[sub_span[0]:sub_span[1]]) for sub_span in d["Subject"][0]]) |
else: |
subject_text = "" |
if len(d["Object"]) > 1: |
sorted_d_object = sorted(d["Object"], key=lambda x: x[0][0]) |
sorted_d_object = [x[0] for x in sorted_d_object] |
object_text = " ".join([" ".join(tokens[sub_span[0]:sub_span[1]]) for sub_span in sorted_d_object]) |
elif len(d["Object"]) == 1: |
object_text = " ".join([" ".join(tokens[sub_span[0]:sub_span[1]]) for sub_span in d["Object"][0]]) |
else: |
object_text = "" |
rel_text = " ".join([" ".join(tokens[sub_span[0]:sub_span[1]]) for sub_span in rel_sp]).replace('[unused1]', 'is') |
ext = f'<arg1> {subject_text} </arg1> <rel> {rel_text} </rel> <arg2> {object_text} </arg2>' |
if ext not in ext_texts and (rel_text != '' and subject_text != ''): |
print("{}\t{}".format(sentence, ext), file=fout) |
ext_texts.append(ext) |
def print_predictions_for_joint_decoding(outputs, file_path, vocab): |
"""print_predictions prints prediction results |
Args: |
outputs (list): prediction outputs |
file_path (str): output file path |
vocab (Vocabulary): vocabulary |
sequence_label_domain (str, optional): sequence label domain. Defaults to None. |
""" |
with open(file_path, 'w') as fout: |
for sent_output in outputs: |
seq_len = sent_output['seq_len'] |
assert 'tokens' in sent_output |
tokens = [vocab.get_token_from_index(token, 'tokens') for token in sent_output['tokens'][:seq_len]] |
print("Token\t{}".format(' '.join(tokens)), file=fout) |
if 'joint_label_matrix' in sent_output: |
for row in sent_output['joint_label_matrix'][:seq_len]: |
print("Joint-Label-True\t{}".format(' '.join( |
[vocab.get_token_from_index(item, 'ent_rel_id') for item in row[:seq_len]])), |
file=fout) |
if 'joint_label_preds' in sent_output: |
for row in sent_output['joint_label_preds'][:seq_len]: |
print("Joint-Label-Pred\t{}".format(' '.join( |
[vocab.get_token_from_index(item, 'ent_rel_id') for item in row[:seq_len]])), |
file=fout) |
if 'separate_positions' in sent_output: |
print("Separate-Position-True\t{}".format(' '.join(map(str, sent_output['separate_positions']))), |
file=fout) |
if 'all_separate_position_preds' in sent_output: |
print("Separate-Position-Pred\t{}".format(' '.join(map(str, |
sent_output['all_separate_position_preds']))), |
file=fout) |
if 'all_ent_span_preds' in sent_output: |
for span in sent_output['all_ent_span_preds']: |
print("Ent-Span-Pred\t{}".format(span), file=fout) |
if 'span2ent' in sent_output: |
for span, ent in sent_output['span2ent'].items(): |
ent = vocab.get_token_from_index(ent, 'ent_rel_id') |
assert ent != 'None', "true relation can not be `None`." |
print("Ent-True\t{}\t{}\t{}".format(ent, span, ' '.join([' '.join(tokens[sub_span[0]:sub_span[1]]) for sub_span in span])), file=fout) |
if 'all_ent_preds' in sent_output: |
for span, ent in sent_output['all_ent_preds'].items(): |
print("Ent-Pred\t{}\t{}\t{}".format(ent, span, ' '.join( |
[' '.join(tokens[sub_span[0]:sub_span[1]]) for sub_span in span])), file=fout) |
if 'span2rel' in sent_output: |
for (span1, span2), rel in sent_output['span2rel'].items(): |
rel = vocab.get_token_from_index(rel, 'ent_rel_id') |
assert rel != 'None', "true relation can not be `None`." |
span1_text_list = [' '.join(tokens[sub_span[0]:sub_span[1]]) for sub_span in span1] |
span2_text_list = [' '.join(tokens[sub_span[0]:sub_span[1]]) for sub_span in span2] |
print("Rel-True\t{}\t{}\t{}\t{}\t{}".format(rel, span1, span2, ' '.join(span1_text_list), |
' '.join(span2_text_list)), |
file=fout) |
if 'all_rel_preds' in sent_output: |
for (span1, span2), rel in sent_output['all_rel_preds'].items(): |
span1_text_list = [' '.join(tokens[sub_span[0]:sub_span[1]]) for sub_span in span1] |
span2_text_list = [' '.join(tokens[sub_span[0]:sub_span[1]]) for sub_span in span2] |
print("Rel-Pred\t{}\t{}\t{}\t{}\t{}".format(rel, span1, span2, ' '.join(span1_text_list), |
' '.join(span2_text_list)), |
file=fout) |
print(file=fout) |
def print_predictions_for_entity_rel_decoding(outputs, file_path, vocab): |
"""print_predictions prints prediction results |
Args: |
outputs (list): prediction outputs |
file_path (str): output file path |
vocab (Vocabulary): vocabulary |
sequence_label_domain (str, optional): sequence label domain. Defaults to None. |
""" |
with open(file_path, 'w') as fout: |
for sent_output in outputs: |
seq_len = sent_output['seq_len'] |
assert 'tokens' in sent_output |
tokens = [vocab.get_token_from_index(token, 'tokens') for token in sent_output['tokens'][:seq_len]] |
print("Token\t{}".format(' '.join(tokens)), file=fout) |
if 'entity_label_preds' in sent_output: |
for row in sent_output['entity_label_preds'][:seq_len]: |
print("Ent-Label-Pred\t{}".format(' '.join( |
[vocab.get_token_from_index(item, 'ent_rel_id') for item in row[:seq_len]])), |
file=fout) |
if 'relation_label_matrix' in sent_output: |
for row in sent_output['relation_label_matrix'][:seq_len]: |
print("Rel-Label-True\t{}".format(' '.join( |
[vocab.get_token_from_index(item + 2, 'ent_rel_id') if item != 0 else "None" for item in row[:seq_len]])), |
file=fout) |
if 'relation_label_preds' in sent_output: |
for row in sent_output['relation_label_preds'][:seq_len]: |
print("Rel-Label-Pred\t{}".format(' '.join( |
[vocab.get_token_from_index(item + 2, 'ent_rel_id') if item != 0 else "None" for item in row[:seq_len]])), |
file=fout) |
if 'separate_positions' in sent_output: |
print("Separate-Position-True\t{}".format(' '.join(map(str, sent_output['separate_positions']))), |
file=fout) |
if 'all_separate_position_preds' in sent_output: |
print("Separate-Position-Pred\t{}".format(' '.join(map(str, |
sent_output['all_separate_position_preds']))), |
file=fout) |
if 'all_ent_span_preds' in sent_output: |
for span in sent_output['all_ent_span_preds']: |
print("Ent-Span-Pred\t{}".format(span), file=fout) |
if 'span2ent' in sent_output: |
for span, ent in sent_output['span2ent'].items(): |
ent = vocab.get_token_from_index(ent, 'ent_rel_id') |
assert ent != 'None', "true relation can not be `None`." |
print("Ent-True\t{}\t{}\t{}".format(ent, span, ' '.join( |
[' '.join(tokens[sub_span[0]:sub_span[1]]) for sub_span in span])), file=fout) |
if 'all_ent_preds' in sent_output: |
for span, ent in sent_output['all_ent_preds'].items(): |
print("Ent-Pred\t{}\t{}\t{}".format(ent, span, ' '.join( |
[' '.join(tokens[sub_span[0]:sub_span[1]]) for sub_span in span])), file=fout) |
if 'span2rel' in sent_output: |
for (span1, span2), rel in sent_output['span2rel'].items(): |
rel = vocab.get_token_from_index(rel, 'ent_rel_id') |
assert rel != 'None', "true relation can not be `None`." |
span1_text_list = [' '.join(tokens[sub_span[0]:sub_span[1]]) for sub_span in span1] |
span2_text_list = [' '.join(tokens[sub_span[0]:sub_span[1]]) for sub_span in span2] |
print("Rel-True\t{}\t{}\t{}\t{}\t{}".format(rel, span1, span2, ' '.join(span1_text_list), |
' '.join(span2_text_list)), |
file=fout) |
if 'all_rel_preds' in sent_output: |
for (span1, span2), rel in sent_output['all_rel_preds'].items(): |
span1_text_list = [' '.join(tokens[sub_span[0]:sub_span[1]]) for sub_span in span1] |
span2_text_list = [' '.join(tokens[sub_span[0]:sub_span[1]]) for sub_span in span2] |
print("Rel-Pred\t{}\t{}\t{}\t{}\t{}".format(rel, span1, span2, ' '.join(span1_text_list), |
' '.join(span2_text_list)), file=fout) |
print(file=fout) |
def print_predictions_for_relation_decoding(outputs, file_path, vocab): |
with open(file_path, 'w') as fout: |
for sent_output in outputs: |
seq_len = sent_output['seq_len'] |
assert 'tokens' in sent_output |
tokens = [vocab.get_token_from_index(token, 'tokens') for token in sent_output['tokens'][:seq_len]] |
print("Token\t{}".format(' '.join(tokens)), file=fout) |
if 'relation_label_matrix' in sent_output: |
for row in sent_output['relation_label_matrix'][:seq_len]: |
print("Relation-Label-True\t{}".format(' '.join( |
[vocab.get_token_from_index(item, 'ent_rel_id') for item in row[:seq_len]])), |
file=fout) |
if 'relation_label_preds' in sent_output: |
for row in sent_output['relation_label_preds'][:seq_len]: |
print("Relation-Label-Pred\t{}".format(' '.join( |
[vocab.get_token_from_index(item, 'ent_rel_id') for item in row[:seq_len]])), |
file=fout) |