|
from functools import lru_cache |
|
|
|
from typing import cast, Any, Callable, Dict, Iterable, List, Optional |
|
from typing import Sequence, Tuple, Union |
|
from collections import Counter |
|
from copy import deepcopy |
|
from itertools import islice |
|
import numpy as np |
|
|
|
import srsly |
|
from thinc.api import Config, Model, SequenceCategoricalCrossentropy, NumpyOps |
|
from thinc.types import Floats2d, Ints2d |
|
|
|
from spacy.pipeline._edit_tree_internals.edit_trees import EditTrees |
|
from spacy.pipeline._edit_tree_internals.schemas import validate_edit_tree |
|
from spacy.pipeline.lemmatizer import lemmatizer_score |
|
from spacy.pipeline.trainable_pipe import TrainablePipe |
|
from spacy.errors import Errors |
|
from spacy.language import Language |
|
from spacy.tokens import Doc, Token |
|
from spacy.training import Example, validate_examples, validate_get_examples |
|
from spacy.vocab import Vocab |
|
from spacy import util |
|
|
|
|
|
TOP_K_GUARDRAIL = 20 |
|
|
|
|
|
default_model_config = """ |
|
[model] |
|
@architectures = "spacy.Tagger.v2" |
|
|
|
[model.tok2vec] |
|
@architectures = "spacy.HashEmbedCNN.v2" |
|
pretrained_vectors = null |
|
width = 96 |
|
depth = 4 |
|
embed_size = 2000 |
|
window_size = 1 |
|
maxout_pieces = 3 |
|
subword_features = true |
|
""" |
|
DEFAULT_EDIT_TREE_LEMMATIZER_MODEL = Config().from_str(default_model_config)["model"] |
|
|
|
|
|
@Language.factory( |
|
"trainable_lemmatizer_v2", |
|
assigns=["token.lemma"], |
|
requires=[], |
|
default_config={ |
|
"model": DEFAULT_EDIT_TREE_LEMMATIZER_MODEL, |
|
"backoff": "orth", |
|
"min_tree_freq": 3, |
|
"overwrite": False, |
|
"top_k": 1, |
|
"overwrite_labels": True, |
|
"scorer": {"@scorers": "spacy.lemmatizer_scorer.v1"}, |
|
}, |
|
default_score_weights={"lemma_acc": 1.0}, |
|
) |
|
def make_edit_tree_lemmatizer( |
|
nlp: Language, |
|
name: str, |
|
model: Model, |
|
backoff: Optional[str], |
|
min_tree_freq: int, |
|
overwrite: bool, |
|
top_k: int, |
|
overwrite_labels: bool, |
|
scorer: Optional[Callable], |
|
): |
|
"""Construct an EditTreeLemmatizer component.""" |
|
return EditTreeLemmatizer( |
|
nlp.vocab, |
|
model, |
|
name, |
|
backoff=backoff, |
|
min_tree_freq=min_tree_freq, |
|
overwrite=overwrite, |
|
top_k=top_k, |
|
overwrite_labels=overwrite_labels, |
|
scorer=scorer, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
def debug(*args): |
|
pass |
|
|
|
|
|
class EditTreeLemmatizer(TrainablePipe): |
|
""" |
|
Lemmatizer that lemmatizes each word using a predicted edit tree. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
vocab: Vocab, |
|
model: Model, |
|
name: str = "trainable_lemmatizer", |
|
*, |
|
backoff: Optional[str] = "orth", |
|
min_tree_freq: int = 3, |
|
overwrite: bool = False, |
|
top_k: int = 1, |
|
overwrite_labels, |
|
scorer: Optional[Callable] = lemmatizer_score, |
|
): |
|
""" |
|
Construct an edit tree lemmatizer. |
|
|
|
backoff (Optional[str]): backoff to use when the predicted edit trees |
|
are not applicable. Must be an attribute of Token or None (leave the |
|
lemma unset). |
|
min_tree_freq (int): prune trees that are applied less than this |
|
frequency in the training data. |
|
overwrite (bool): overwrite existing lemma annotations. |
|
top_k (int): try to apply at most the k most probable edit trees. |
|
""" |
|
self.vocab = vocab |
|
self.model = model |
|
self.name = name |
|
self.backoff = backoff |
|
self.min_tree_freq = min_tree_freq |
|
self.overwrite = overwrite |
|
self.top_k = top_k |
|
self.overwrite_labels = overwrite_labels |
|
|
|
self.trees = EditTrees(self.vocab.strings) |
|
self.tree2label: Dict[int, int] = {} |
|
|
|
self.cfg: Dict[str, Any] = {"labels": []} |
|
self.scorer = scorer |
|
self.numpy_ops = NumpyOps() |
|
|
|
def get_loss( |
|
self, examples: Iterable[Example], scores: List[Floats2d] |
|
) -> Tuple[float, List[Floats2d]]: |
|
validate_examples(examples, "EditTreeLemmatizer.get_loss") |
|
loss_func = SequenceCategoricalCrossentropy(normalize=False, missing_value=-1) |
|
|
|
truths = [] |
|
for eg in examples: |
|
eg_truths = [] |
|
for (predicted, gold_lemma, gold_pos, gold_sent_start) in zip( |
|
eg.predicted, |
|
eg.get_aligned("LEMMA", as_string=True), |
|
eg.get_aligned("POS", as_string=True), |
|
eg.get_aligned_sent_starts(), |
|
): |
|
if gold_lemma is None: |
|
label = -1 |
|
else: |
|
form = self._get_true_cased_form( |
|
predicted.text, gold_sent_start, gold_pos |
|
) |
|
tree_id = self.trees.add(form, gold_lemma) |
|
|
|
label = self.tree2label.get(tree_id, 0) |
|
eg_truths.append(label) |
|
|
|
truths.append(eg_truths) |
|
|
|
d_scores, loss = loss_func(scores, truths) |
|
if self.model.ops.xp.isnan(loss): |
|
raise ValueError(Errors.E910.format(name=self.name)) |
|
|
|
return float(loss), d_scores |
|
|
|
def predict(self, docs: Iterable[Doc]) -> List[Ints2d]: |
|
if self.top_k == 1: |
|
scores2guesses = self._scores2guesses_top_k_equals_1 |
|
elif self.top_k <= TOP_K_GUARDRAIL: |
|
scores2guesses = self._scores2guesses_top_k_greater_1 |
|
else: |
|
scores2guesses = self._scores2guesses_top_k_guardrail |
|
|
|
|
|
|
|
|
|
|
|
|
|
n_docs = len(list(docs)) |
|
if not any(len(doc) for doc in docs): |
|
|
|
n_labels = len(self.cfg["labels"]) |
|
guesses: List[Ints2d] = [self.model.ops.alloc2i(0, n_labels) for _ in docs] |
|
assert len(guesses) == n_docs |
|
return guesses |
|
scores = self.model.predict(docs) |
|
assert len(scores) == n_docs |
|
guesses = scores2guesses(docs, scores) |
|
assert len(guesses) == n_docs |
|
return guesses |
|
|
|
def _scores2guesses_top_k_equals_1(self, docs, scores): |
|
guesses = [] |
|
for doc, doc_scores in zip(docs, scores): |
|
doc_guesses = doc_scores.argmax(axis=1) |
|
doc_guesses = self.numpy_ops.asarray(doc_guesses) |
|
|
|
doc_compat_guesses = [] |
|
for i, token in enumerate(doc): |
|
tree_id = self.cfg["labels"][doc_guesses[i]] |
|
form: str = self._get_true_cased_form_of_token(token) |
|
if self.trees.apply(tree_id, form) is not None: |
|
doc_compat_guesses.append(tree_id) |
|
else: |
|
doc_compat_guesses.append(-1) |
|
guesses.append(np.array(doc_compat_guesses)) |
|
|
|
return guesses |
|
|
|
def _scores2guesses_top_k_greater_1(self, docs, scores): |
|
guesses = [] |
|
top_k = min(self.top_k, len(self.labels)) |
|
for doc, doc_scores in zip(docs, scores): |
|
doc_scores = self.numpy_ops.asarray(doc_scores) |
|
doc_compat_guesses = [] |
|
for i, token in enumerate(doc): |
|
for _ in range(top_k): |
|
candidate = int(doc_scores[i].argmax()) |
|
candidate_tree_id = self.cfg["labels"][candidate] |
|
form: str = self._get_true_cased_form_of_token(token) |
|
if self.trees.apply(candidate_tree_id, form) is not None: |
|
doc_compat_guesses.append(candidate_tree_id) |
|
break |
|
doc_scores[i, candidate] = np.finfo(np.float32).min |
|
else: |
|
doc_compat_guesses.append(-1) |
|
guesses.append(np.array(doc_compat_guesses)) |
|
|
|
return guesses |
|
|
|
def _scores2guesses_top_k_guardrail(self, docs, scores): |
|
guesses = [] |
|
for doc, doc_scores in zip(docs, scores): |
|
doc_guesses = np.argsort(doc_scores)[..., : -self.top_k - 1 : -1] |
|
doc_guesses = self.numpy_ops.asarray(doc_guesses) |
|
|
|
doc_compat_guesses = [] |
|
for token, candidates in zip(doc, doc_guesses): |
|
tree_id = -1 |
|
for candidate in candidates: |
|
candidate_tree_id = self.cfg["labels"][candidate] |
|
|
|
form: str = self._get_true_cased_form_of_token(token) |
|
|
|
if self.trees.apply(candidate_tree_id, form) is not None: |
|
tree_id = candidate_tree_id |
|
break |
|
doc_compat_guesses.append(tree_id) |
|
|
|
guesses.append(np.array(doc_compat_guesses)) |
|
|
|
return guesses |
|
|
|
def set_annotations(self, docs: Iterable[Doc], batch_tree_ids): |
|
for i, doc in enumerate(docs): |
|
doc_tree_ids = batch_tree_ids[i] |
|
if hasattr(doc_tree_ids, "get"): |
|
doc_tree_ids = doc_tree_ids.get() |
|
for j, tree_id in enumerate(doc_tree_ids): |
|
if self.overwrite or doc[j].lemma == 0: |
|
|
|
|
|
|
|
if tree_id == -1: |
|
if self.backoff is not None: |
|
doc[j].lemma = getattr(doc[j], self.backoff) |
|
else: |
|
form = self._get_true_cased_form_of_token(doc[j]) |
|
lemma = self.trees.apply(tree_id, form) or form |
|
|
|
doc[j].lemma_ = lemma |
|
|
|
@property |
|
def labels(self) -> Tuple[int, ...]: |
|
"""Returns the labels currently added to the component.""" |
|
return tuple(self.cfg["labels"]) |
|
|
|
@property |
|
def hide_labels(self) -> bool: |
|
return True |
|
|
|
@property |
|
def label_data(self) -> Dict: |
|
trees = [] |
|
for tree_id in range(len(self.trees)): |
|
tree = self.trees[tree_id] |
|
if "orig" in tree: |
|
tree["orig"] = self.vocab.strings[tree["orig"]] |
|
if "subst" in tree: |
|
tree["subst"] = self.vocab.strings[tree["subst"]] |
|
trees.append(tree) |
|
return dict(trees=trees, labels=tuple(self.cfg["labels"])) |
|
|
|
def initialize( |
|
self, |
|
get_examples: Callable[[], Iterable[Example]], |
|
*, |
|
nlp: Optional[Language] = None, |
|
labels: Optional[Dict] = None, |
|
): |
|
validate_get_examples(get_examples, "EditTreeLemmatizer.initialize") |
|
|
|
if self.overwrite_labels: |
|
if labels is None: |
|
self._labels_from_data(get_examples) |
|
else: |
|
self._add_labels(labels) |
|
|
|
|
|
doc_sample = [] |
|
label_sample = [] |
|
for example in islice(get_examples(), 10): |
|
doc_sample.append(example.x) |
|
gold_labels: List[List[float]] = [] |
|
for token in example.reference: |
|
if token.lemma == 0: |
|
gold_label = None |
|
else: |
|
gold_label = self._pair2label(token.text, token.lemma_) |
|
|
|
gold_labels.append( |
|
[ |
|
1.0 if label == gold_label else 0.0 |
|
for label in self.cfg["labels"] |
|
] |
|
) |
|
|
|
gold_labels = cast(Floats2d, gold_labels) |
|
label_sample.append(self.model.ops.asarray(gold_labels, dtype="float32")) |
|
|
|
self._require_labels() |
|
assert len(doc_sample) > 0, Errors.E923.format(name=self.name) |
|
assert len(label_sample) > 0, Errors.E923.format(name=self.name) |
|
|
|
self.model.initialize(X=doc_sample, Y=label_sample) |
|
|
|
def from_bytes(self, bytes_data, *, exclude=tuple()): |
|
deserializers = { |
|
"cfg": lambda b: self.cfg.update(srsly.json_loads(b)), |
|
"model": lambda b: self.model.from_bytes(b), |
|
"vocab": lambda b: self.vocab.from_bytes(b, exclude=exclude), |
|
"trees": lambda b: self.trees.from_bytes(b), |
|
} |
|
|
|
util.from_bytes(bytes_data, deserializers, exclude) |
|
|
|
return self |
|
|
|
def to_bytes(self, *, exclude=tuple()): |
|
serializers = { |
|
"cfg": lambda: srsly.json_dumps(self.cfg), |
|
"model": lambda: self.model.to_bytes(), |
|
"vocab": lambda: self.vocab.to_bytes(exclude=exclude), |
|
"trees": lambda: self.trees.to_bytes(), |
|
} |
|
|
|
return util.to_bytes(serializers, exclude) |
|
|
|
def to_disk(self, path, exclude=tuple()): |
|
path = util.ensure_path(path) |
|
serializers = { |
|
"cfg": lambda p: srsly.write_json(p, self.cfg), |
|
"model": lambda p: self.model.to_disk(p), |
|
"vocab": lambda p: self.vocab.to_disk(p, exclude=exclude), |
|
"trees": lambda p: self.trees.to_disk(p), |
|
} |
|
util.to_disk(path, serializers, exclude) |
|
|
|
def from_disk(self, path, exclude=tuple()): |
|
def load_model(p): |
|
try: |
|
with open(p, "rb") as mfile: |
|
self.model.from_bytes(mfile.read()) |
|
except AttributeError: |
|
raise ValueError(Errors.E149) from None |
|
|
|
deserializers = { |
|
"cfg": lambda p: self.cfg.update(srsly.read_json(p)), |
|
"model": load_model, |
|
"vocab": lambda p: self.vocab.from_disk(p, exclude=exclude), |
|
"trees": lambda p: self.trees.from_disk(p), |
|
} |
|
|
|
util.from_disk(path, deserializers, exclude) |
|
return self |
|
|
|
def _add_labels(self, labels: Dict): |
|
if "labels" not in labels: |
|
raise ValueError(Errors.E857.format(name="labels")) |
|
if "trees" not in labels: |
|
raise ValueError(Errors.E857.format(name="trees")) |
|
|
|
self.cfg["labels"] = list(labels["labels"]) |
|
trees = [] |
|
for tree in labels["trees"]: |
|
errors = validate_edit_tree(tree) |
|
if errors: |
|
raise ValueError(Errors.E1026.format(errors="\n".join(errors))) |
|
|
|
tree = dict(tree) |
|
if "orig" in tree: |
|
tree["orig"] = self.vocab.strings[tree["orig"]] |
|
if "orig" in tree: |
|
tree["subst"] = self.vocab.strings[tree["subst"]] |
|
|
|
trees.append(tree) |
|
|
|
self.trees.from_json(trees) |
|
|
|
for label, tree in enumerate(self.labels): |
|
self.tree2label[tree] = label |
|
|
|
def _labels_from_data(self, get_examples: Callable[[], Iterable[Example]]): |
|
|
|
|
|
vocab = Vocab() |
|
trees = EditTrees(vocab.strings) |
|
tree_freqs: Counter = Counter() |
|
repr_pairs: Dict = {} |
|
for example in get_examples(): |
|
for token in example.reference: |
|
if token.lemma != 0: |
|
form = self._get_true_cased_form_of_token(token) |
|
|
|
tree_id = trees.add(form, token.lemma_) |
|
tree_freqs[tree_id] += 1 |
|
repr_pairs[tree_id] = (form, token.lemma_) |
|
|
|
|
|
|
|
for tree_id, freq in tree_freqs.items(): |
|
if freq >= self.min_tree_freq: |
|
form, lemma = repr_pairs[tree_id] |
|
self._pair2label(form, lemma, add_label=True) |
|
|
|
@lru_cache() |
|
def _get_true_cased_form(self, token: str, is_sent_start: bool, pos: str) -> str: |
|
if is_sent_start and pos != "PROPN": |
|
return token.lower() |
|
else: |
|
return token |
|
|
|
def _get_true_cased_form_of_token(self, token: Token) -> str: |
|
return self._get_true_cased_form(token.text, token.is_sent_start, token.pos_) |
|
|
|
def _pair2label(self, form, lemma, add_label=False): |
|
""" |
|
Look up the edit tree identifier for a form/label pair. If the edit |
|
tree is unknown and "add_label" is set, the edit tree will be added to |
|
the labels. |
|
""" |
|
tree_id = self.trees.add(form, lemma) |
|
if tree_id not in self.tree2label: |
|
if not add_label: |
|
return None |
|
|
|
self.tree2label[tree_id] = len(self.cfg["labels"]) |
|
self.cfg["labels"].append(tree_id) |
|
return self.tree2label[tree_id] |
|
|