|
import csv |
|
import json |
|
import torch |
|
from transformers import BertTokenizer |
|
|
|
|
|
class CNerTokenizer(BertTokenizer): |
|
def __init__(self, vocab_file, do_lower_case=True): |
|
super().__init__(vocab_file=str(vocab_file), do_lower_case=do_lower_case) |
|
self.vocab_file = str(vocab_file) |
|
self.do_lower_case = do_lower_case |
|
|
|
def tokenize(self, text): |
|
_tokens = [] |
|
for c in text: |
|
if self.do_lower_case: |
|
c = c.lower() |
|
if c in self.vocab: |
|
_tokens.append(c) |
|
else: |
|
_tokens.append('[UNK]') |
|
return _tokens |
|
|
|
|
|
class DataProcessor(object): |
|
"""Base class for data converters for sequence classification data sets.""" |
|
|
|
def get_train_examples(self, data_dir): |
|
"""Gets a collection of `InputExample`s for the train set.""" |
|
raise NotImplementedError() |
|
|
|
def get_dev_examples(self, data_dir): |
|
"""Gets a collection of `InputExample`s for the dev set.""" |
|
raise NotImplementedError() |
|
|
|
def get_labels(self): |
|
"""Gets the list of labels for this data set.""" |
|
raise NotImplementedError() |
|
|
|
@classmethod |
|
def _read_tsv(cls, input_file, quotechar=None): |
|
"""Reads a tab separated value file.""" |
|
with open(input_file, "r", encoding="utf-8-sig") as f: |
|
reader = csv.reader(f, delimiter="\t", quotechar=quotechar) |
|
lines = [] |
|
for line in reader: |
|
lines.append(line) |
|
return lines |
|
|
|
@classmethod |
|
def _read_text(self, input_file): |
|
lines = [] |
|
with open(input_file, 'r') as f: |
|
words = [] |
|
labels = [] |
|
for line in f: |
|
if line.startswith("-DOCSTART-") or line == "" or line == "\n": |
|
if words: |
|
lines.append({"words": words, "labels": labels}) |
|
words = [] |
|
labels = [] |
|
else: |
|
splits = line.split(" ") |
|
words.append(splits[0]) |
|
if len(splits) > 1: |
|
labels.append(splits[-1].replace("\n", "")) |
|
else: |
|
|
|
labels.append("O") |
|
if words: |
|
lines.append({"words": words, "labels": labels}) |
|
return lines |
|
|
|
@classmethod |
|
def _read_json(self, input_file): |
|
lines = [] |
|
with open(input_file, 'r', encoding='utf8') as f: |
|
for line in f: |
|
line = json.loads(line.strip()) |
|
text = line['text'] |
|
label_entities = line.get('label', None) |
|
words = list(text) |
|
labels = ['O'] * len(words) |
|
if label_entities is not None: |
|
for key, value in label_entities.items(): |
|
for sub_name, sub_index in value.items(): |
|
for start_index, end_index in sub_index: |
|
assert ''.join(words[start_index:end_index+1]) == sub_name |
|
if start_index == end_index: |
|
labels[start_index] = 'S-'+key |
|
else: |
|
if end_index - start_index == 1: |
|
labels[start_index] = 'B-' + key |
|
labels[end_index] = 'E-' + key |
|
else: |
|
labels[start_index] = 'B-' + key |
|
labels[start_index + 1:end_index] = ['I-' + key] * (len(sub_name) - 2) |
|
labels[end_index] = 'E-' + key |
|
lines.append({"words": words, "labels": labels}) |
|
return lines |
|
|
|
|
|
def get_entity_bios(seq, id2label, middle_prefix='I-'): |
|
"""Gets entities from sequence. |
|
note: BIOS |
|
Args: |
|
seq (list): sequence of labels. |
|
Returns: |
|
list: list of (chunk_type, chunk_start, chunk_end). |
|
Example: |
|
# >>> seq = ['B-PER', 'I-PER', 'O', 'S-LOC'] |
|
# >>> get_entity_bios(seq) |
|
[['PER', 0,1], ['LOC', 3, 3]] |
|
""" |
|
chunks = [] |
|
chunk = [-1, -1, -1] |
|
for indx, tag in enumerate(seq): |
|
if not isinstance(tag, str): |
|
tag = id2label[tag] |
|
if tag.startswith("S-"): |
|
if chunk[2] != -1: |
|
chunks.append(chunk) |
|
chunk = [-1, -1, -1] |
|
chunk[1] = indx |
|
chunk[2] = indx |
|
chunk[0] = tag.split('-')[1] |
|
chunks.append(chunk) |
|
chunk = (-1, -1, -1) |
|
if tag.startswith("B-"): |
|
if chunk[2] != -1: |
|
chunks.append(chunk) |
|
chunk = [-1, -1, -1] |
|
chunk[1] = indx |
|
chunk[0] = tag.split('-')[1] |
|
elif tag.startswith(middle_prefix) and chunk[1] != -1: |
|
_type = tag.split('-')[1] |
|
if _type == chunk[0]: |
|
chunk[2] = indx |
|
if indx == len(seq) - 1: |
|
chunks.append(chunk) |
|
else: |
|
if chunk[2] != -1: |
|
chunks.append(chunk) |
|
chunk = [-1, -1, -1] |
|
return chunks |
|
|
|
|
|
def get_entity_bio(seq, id2label, middle_prefix='I-'): |
|
"""Gets entities from sequence. |
|
note: BIO |
|
Args: |
|
seq (list): sequence of labels. |
|
Returns: |
|
list: list of (chunk_type, chunk_start, chunk_end). |
|
Example: |
|
seq = ['B-PER', 'I-PER', 'O', 'B-LOC'] |
|
get_entity_bio(seq) |
|
#output |
|
[['PER', 0,1], ['LOC', 3, 3]] |
|
""" |
|
chunks = [] |
|
chunk = [-1, -1, -1] |
|
for indx, tag in enumerate(seq): |
|
if not isinstance(tag, str): |
|
tag = id2label[tag] |
|
if tag.startswith("B-"): |
|
if chunk[2] != -1: |
|
chunks.append(chunk) |
|
chunk = [-1, -1, -1] |
|
chunk[1] = indx |
|
chunk[0] = tag.split('-')[1] |
|
chunk[2] = indx |
|
if indx == len(seq) - 1: |
|
chunks.append(chunk) |
|
elif tag.startswith(middle_prefix) and chunk[1] != -1: |
|
_type = tag.split('-')[1] |
|
if _type == chunk[0]: |
|
chunk[2] = indx |
|
|
|
if indx == len(seq) - 1: |
|
chunks.append(chunk) |
|
else: |
|
if chunk[2] != -1: |
|
chunks.append(chunk) |
|
chunk = [-1, -1, -1] |
|
return chunks |
|
|
|
|
|
def get_entity_bioes(seq, id2label, middle_prefix='I-'): |
|
"""Gets entities from sequence. |
|
note: BIOS |
|
Args: |
|
seq (list): sequence of labels. |
|
Returns: |
|
list: list of (chunk_type, chunk_start, chunk_end). |
|
Example: |
|
# >>> seq = ['B-PER', 'I-PER', 'O', 'S-LOC'] |
|
# >>> get_entity_bios(seq) |
|
[['PER', 0,1], ['LOC', 3, 3]] |
|
""" |
|
chunks = [] |
|
chunk = [-1, -1, -1] |
|
for indx, tag in enumerate(seq): |
|
if not isinstance(tag, str): |
|
tag = id2label[tag] |
|
if tag.startswith("S-"): |
|
if chunk[2] != -1: |
|
chunks.append(chunk) |
|
chunk = [-1, -1, -1] |
|
chunk[1] = indx |
|
chunk[2] = indx |
|
chunk[0] = tag.split('-')[1] |
|
chunks.append(chunk) |
|
chunk = (-1, -1, -1) |
|
if tag.startswith("B-"): |
|
if chunk[2] != -1: |
|
chunks.append(chunk) |
|
chunk = [-1, -1, -1] |
|
chunk[1] = indx |
|
chunk[0] = tag.split('-')[1] |
|
elif (tag.startswith(middle_prefix) or tag.startswith("E-")) and chunk[1] != -1: |
|
_type = tag.split('-')[1] |
|
if _type == chunk[0]: |
|
chunk[2] = indx |
|
if indx == len(seq) - 1: |
|
chunks.append(chunk) |
|
else: |
|
if chunk[2] != -1: |
|
chunks.append(chunk) |
|
chunk = [-1, -1, -1] |
|
return chunks |
|
|
|
|
|
def get_entities(seq, id2label, markup='bio', middle_prefix='I-'): |
|
''' |
|
:param seq: |
|
:param id2label: |
|
:param markup: |
|
:return: |
|
''' |
|
assert markup in ['bio', 'bios', 'bioes'] |
|
if markup == 'bio': |
|
return get_entity_bio(seq, id2label, middle_prefix) |
|
elif markup == 'bios': |
|
return get_entity_bios(seq, id2label, middle_prefix) |
|
else: |
|
return get_entity_bioes(seq, id2label, middle_prefix) |
|
|
|
|
|
def bert_extract_item(start_logits, end_logits): |
|
S = [] |
|
start_pred = torch.argmax(start_logits, -1).cpu().numpy()[0][1:-1] |
|
end_pred = torch.argmax(end_logits, -1).cpu().numpy()[0][1:-1] |
|
for i, s_l in enumerate(start_pred): |
|
if s_l == 0: |
|
continue |
|
for j, e_l in enumerate(end_pred[i:]): |
|
if s_l == e_l: |
|
S.append((s_l, i, i + j)) |
|
break |
|
return S |
|
|