Martino Mensio
updated to spacy v3:
d6504ae unverified
from itertools import groupby
import numpy as np
class EntityClassifier:
def __init__(self):
pass
def _get_grouped_by_length(self, entities):
sorted_by_len = sorted(entities, key=lambda entity: len(entity.get_span()), reverse=True)
entities_by_length = {}
for length, group in groupby(sorted_by_len, lambda entity: len(entity.get_span())):
entities = list(group)
entities_by_length[length] = entities
return entities_by_length
def _filter_max_length(self, entities):
entities_by_length = self._get_grouped_by_length(entities)
max_length = max(list(entities_by_length.keys()))
return entities_by_length[max_length]
def _select_max_prior(self, entities):
priors = [entity.get_prior() for entity in entities]
return entities[np.argmax(priors)]
def _get_casing_difference(self, word1, original):
difference = 0
for w1, w2 in zip(word1, original):
if w1 != w2:
difference += 1
return difference
def _filter_most_similar(self, entities):
similarities = np.array(
[self._get_casing_difference(entity.get_span().text, entity.get_original_alias()) for entity in entities])
min_indices = np.where(similarities == similarities.min())[0].tolist()
return [entities[i] for i in min_indices]
def __call__(self, entities):
filtered_by_length = self._filter_max_length(entities)
filtered_by_casing = self._filter_most_similar(filtered_by_length)
return self._select_max_prior(filtered_by_casing)