File size: 2,661 Bytes
4fb0bd1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
import re
from collections import defaultdict
def parse_entity_label(entity_label):
"""This function parses entity label string
Arguments:
entity_label {str} -- entity label string
Returns:
tuple -- (chunk_tag, chunk_type)
"""
res = re.match(r'^([^-]*)-(.*)$', entity_label)
return res.groups() if res else (entity_label, '')
def start_of_chunk(pre_chunk_tag, pre_chunk_type, cur_chunk_tag, cur_chunk_type):
"""This function judges whether the start of chunk
Arguments:
pre_chunk_tag {str} -- previous chunk tag
pre_chunk_type {str} -- previous chunk type
cur_chunk_tag {str} -- current chunk tag
cur_chunk_type {str} -- current chunk type
Returns:
bool -- the chunk is starting or not
"""
# `O` must not be starting chunk
if cur_chunk_tag == 'O':
return False
# type of two consecutive chunks are different, current chunk must be starting chunk
if cur_chunk_type != pre_chunk_type:
return True
# `B` and `U` must be starting chunk
if cur_chunk_tag == 'B' or cur_chunk_tag == 'U':
return True
# before `I` and `E` must be `B` or `I`
if cur_chunk_tag == 'I' or cur_chunk_tag == 'E':
if pre_chunk_tag == 'E' or pre_chunk_tag == 'O' or pre_chunk_tag == 'U':
return True
return False
def get_entity_span(entity_labels):
"""This function gets entity span
Arguments:
entity_labels {list} -- entity labels
Returns:
dict -- entity span index dict
"""
pre_chunk_tag = 'O'
pre_chunk_type = ''
chunk_list = []
span2ent = defaultdict(str)
for idx, entity_label in enumerate(entity_labels):
cur_chunk_tag, cur_chunk_type = parse_entity_label(entity_label)
is_start = start_of_chunk(pre_chunk_tag, pre_chunk_type, cur_chunk_tag, cur_chunk_type)
if is_start:
if chunk_list:
span2ent[(chunk_list[1], chunk_list[-1] + 1)] = chunk_list[0]
chunk_list = [cur_chunk_type, idx]
elif chunk_list and cur_chunk_type == chunk_list[0]:
chunk_list.append(idx)
pre_chunk_tag = cur_chunk_tag
pre_chunk_type = cur_chunk_type
if chunk_list:
span2ent[(chunk_list[1], chunk_list[-1] + 1)] = chunk_list[0]
return span2ent
if __name__ == '__main__':
span2ent = get_entity_span(['B-1', 'I-1', 'U-1', 'E-1', 'I-1', 'O', 'B-1', 'O', 'O', 'U-3', 'E-2', 'O', 'I-1'])
assert span2ent == {(0, 2): '1', (2, 3): '1', (3, 4): '1', (4, 5): '1', (6, 7): '1', (9, 10): '3', (10, 11): '2', (12, 13): '1'}
|