Compact_Facts / inputs /dataset_readers /oie4_reader_for_table_decoding.py
khulnasoft's picture
Upload 108 files
4fb0bd1 verified
import json
from collections import defaultdict
import logging
logger = logging.getLogger(__name__)
def split_span(span):
sub_spans = [[span[0]]]
for i in range(1, len(span)):
if span[i - 1] == span[i] - 1:
sub_spans[-1].append(span[i])
else:
sub_spans.append([span[i]])
return sub_spans
class OIE4ReaderForJointDecoding():
"""Define text data reader and preprocess data for entity relation
joint decoding on ACE dataset.
"""
def __init__(self, file_path, is_test=False, max_len=dict()):
"""This function defines file path and some settings
Arguments:
file_path {str} -- file path
Keyword Arguments:
is_test {bool} -- indicate training or testing (default: {False})
max_len {dict} -- max length for some namespace (default: {dict()})
"""
self.file_path = file_path
self.is_test = is_test
self.max_len = dict(max_len)
self.seq_lens = defaultdict(list)
def __iter__(self):
"""Generator function
"""
with open(self.file_path, 'r') as fin:
for line in fin:
line = json.loads(line)
sentence = {}
state, results = self.get_tokens(line)
self.seq_lens['tokens'].append(len(results['tokens']))
if not state or ('tokens' in self.max_len and len(results['tokens']) > self.max_len['tokens']
and not self.is_test):
if not self.is_test:
continue
sentence.update(results)
state, results = self.get_wordpiece_tokens(line)
self.seq_lens['wordpiece_tokens'].append(len(results['wordpiece_tokens']))
if not state or ('wordpiece_tokens' in self.max_len
and len(results['wordpiece_tokens']) > self.max_len['wordpiece_tokens']):
if not self.is_test:
continue
sentence.update(results)
if len(sentence['tokens']) != len(sentence['wordpiece_tokens_index']):
logger.error(
"sentence id: {} wordpiece_tokens_index length is not equal to tokens.".format(line['sentId']))
continue
if len(sentence['wordpiece_tokens']) != len(sentence['wordpiece_segment_ids']):
logger.error(
"sentence id: {} wordpiece_tokens length is not equal to wordpiece_segment_ids.".
format(line['sentId']))
continue
state, results = self.get_entity_relation_label(line, len(sentence['tokens']))
for key, result in results.items():
self.seq_lens[key].append(len(result))
if key in self.max_len and len(result) > self.max_len[key]:
state = False
if not state:
continue
sentence.update(results)
yield sentence
def get_tokens(self, line):
"""This function splits text into tokens
Arguments:
line {dict} -- text
Returns:
bool -- execute state
dict -- results: tokens
"""
results = {}
if 'sentText' not in line:
logger.error("sentence id: {} doesn't contain 'sentText'.".format(
line['sentId']))
return False, results
results['text'] = line['sentText']
if 'tokens' in line:
results['tokens'] = line['tokens']
else:
results['tokens'] = line['sentText'].strip().split(' ')
return True, results
def get_wordpiece_tokens(self, line):
"""This function splits wordpiece text into wordpiece tokens
Arguments:
line {dict} -- text
Returns:
bool -- execute state
dict -- results: tokens
"""
results = {}
if 'wordpieceSentText' not in line or 'wordpieceTokensIndex' not in line or 'wordpieceSegmentIds' not in line:
logger.error(
"sentence id: {} doesn't contain 'wordpieceSentText' or 'wordpieceTokensIndex' or 'wordpieceSegmentIds'."
.format(line['sentId']))
return False, results
wordpiece_tokens = line['wordpieceSentText'].strip().split(' ')
results['wordpiece_tokens'] = wordpiece_tokens
results['wordpiece_tokens_index'] = [span[0] for span in line['wordpieceTokensIndex']]
results['wordpiece_segment_ids'] = list(line['wordpieceSegmentIds'])
return True, results
def get_entity_relation_label(self, line, sentence_length):
"""This function constructs mapping relation from span to entity label
and span pair to relation label, and joint entity relation label matrix.
Arguments:
line {dict} -- text
sentence_length {int} -- sentence length
Returns:
bool -- execute state
dict -- ent2rel: entity span mapping to entity label,
span2rel: two entity span mapping to relation label,
joint_label_matrix: joint entity relation label matrix
"""
results = {}
if 'entityMentions' not in line:
logger.error("sentence id: {} doesn't contain 'entityMentions'.".format(line['sentId']))
return False, results
entity_pos = [0] * sentence_length
idx2ent = {}
span2ent = {}
separate_positions = []
for entity in line['entityMentions']:
entity_sub_spans = []
# span = entity['span_ids']
st, ed = entity['span_ids'][0], entity['span_ids'][-1]
sub_spans = split_span(entity['span_ids'])
if len(sub_spans) == 1:
if st > 0:
separate_positions.append(st - 1)
if ed < sentence_length - 1:
separate_positions.append(ed)
entity_sub_spans.append((st, ed + 1))
else:
# noncontinuous spans
for sub in sub_spans:
if sub[0] > 0:
separate_positions.append(sub[0] - 1)
if sub[-1] < sentence_length - 1:
separate_positions.append(sub[-1])
entity_sub_spans.append((sub[0], sub[-1] + 1))
idx2ent[entity['emId']] = (tuple(entity_sub_spans), entity['text'])
# if len(sub) > 1:
# separate_positions.append(sub[-1])
# add start and end of an span for determining split positions
# (in our case, should add more separate positions due to non-continues spans)
# st, ed = entity['offset']
# if st != 0:
# separate_positions.append(st - 1)
# if ed != sentence_length:
# separate_positions.append(ed - 1)
# idx2ent[entity['emId']] = ((st, ed), entity['text'])
# if st >= ed + 1 or st < 0 or st > sentence_length or ed < 0 or ed > sentence_length:
# logger.error("sentence id: {} offset error: {}'.".format(line['sentId'], entity['span_ids']))
# return False, results
span2ent[tuple(entity_sub_spans)] = entity['label']
j = 0
for s_i in entity['span_ids']:
if entity_pos[s_i] != 0:
logger.error("sentence id: {} entity span overlap. {}".format(line['sentId'], entity['span_ids']))
return False, results
entity_pos[s_i] = 1
j += 1
separate_positions = list(set(separate_positions))
results['separate_positions'] = sorted(separate_positions)
results['span2ent'] = span2ent
if 'relationMentions' not in line:
logger.error("sentence id: {} doesn't contain 'relationMentions'.".format(line['sentId']))
return False, results
span2rel = {}
for relation in line['relationMentions']:
if relation['arg1']['emId'] not in idx2ent or relation['arg2']['emId'] not in idx2ent:
logger.error("sentence id: {} entity not exists .".format(line['sentId']))
continue
entity1_span, entity1_text = idx2ent[relation['arg1']['emId']]
entity2_span, entity2_text = idx2ent[relation['arg2']['emId']]
if entity1_text != relation['arg1']['text'] or entity2_text != relation['arg2']['text']:
logger.error("sentence id: {} entity text doesn't match realtiaon text.".format(line['sentId']))
return False, None
span2rel[(entity1_span, entity2_span)] = relation['label']
results['span2rel'] = span2rel
if 'jointLabelMatrix' not in line:
logger.error("sentence id: {} doesn't contain 'jointLabelMatrix'.".format(line['sentId']))
return False, results
results['joint_label_matrix'] = line['jointLabelMatrix']
return True, results
def get_seq_lens(self):
return self.seq_lens