|
from typing import Optional |
|
import json |
|
from argparse import Namespace |
|
from pathlib import Path |
|
from transformers import Text2TextGenerationPipeline, AutoModelForSeq2SeqLM, AutoTokenizer |
|
|
|
def get_markers_for_model(is_t5_model: bool) -> Namespace: |
|
special_tokens_constants = Namespace() |
|
if is_t5_model: |
|
|
|
special_tokens_constants.separator_input_question_predicate = "<extra_id_1>" |
|
special_tokens_constants.separator_output_answers = "<extra_id_3>" |
|
special_tokens_constants.separator_output_questions = "<extra_id_5>" |
|
special_tokens_constants.separator_output_question_answer = "<extra_id_7>" |
|
special_tokens_constants.separator_output_pairs = "<extra_id_9>" |
|
special_tokens_constants.predicate_generic_marker = "<extra_id_10>" |
|
special_tokens_constants.predicate_verb_marker = "<extra_id_11>" |
|
special_tokens_constants.predicate_nominalization_marker = "<extra_id_12>" |
|
|
|
else: |
|
special_tokens_constants.separator_input_question_predicate = "<question_predicate_sep>" |
|
special_tokens_constants.separator_output_answers = "<answers_sep>" |
|
special_tokens_constants.separator_output_questions = "<question_sep>" |
|
special_tokens_constants.separator_output_question_answer = "<question_answer_sep>" |
|
special_tokens_constants.separator_output_pairs = "<qa_pairs_sep>" |
|
special_tokens_constants.predicate_generic_marker = "<predicate_marker>" |
|
special_tokens_constants.predicate_verb_marker = "<verbal_predicate_marker>" |
|
special_tokens_constants.predicate_nominalization_marker = "<nominalization_predicate_marker>" |
|
return special_tokens_constants |
|
|
|
def load_trained_model(name_or_path): |
|
import huggingface_hub as HFhub |
|
tokenizer = AutoTokenizer.from_pretrained(name_or_path) |
|
model = AutoModelForSeq2SeqLM.from_pretrained(name_or_path) |
|
|
|
kwargs_filename = None |
|
if name_or_path.startswith("kleinay/"): |
|
kwargs_filename = HFhub.hf_hub_download(repo_id=name_or_path, filename="preprocessing_kwargs.json") |
|
elif Path(name_or_path).is_dir() and (Path(name_or_path) / "experiment_kwargs.json").exists(): |
|
kwargs_filename = Path(name_or_path) / "experiment_kwargs.json" |
|
|
|
if kwargs_filename: |
|
preprocessing_kwargs = json.load(open(kwargs_filename)) |
|
|
|
model.config.preprocessing_kwargs = Namespace(**preprocessing_kwargs) |
|
model.config.update(preprocessing_kwargs) |
|
return model, tokenizer |
|
|
|
|
|
class QASRL_Pipeline(Text2TextGenerationPipeline): |
|
def __init__(self, model_repo: str, **kwargs): |
|
model, tokenizer = load_trained_model(model_repo) |
|
super().__init__(model, tokenizer, framework="pt") |
|
self.is_t5_model = "t5" in model.config.model_type |
|
self.special_tokens = get_markers_for_model(self.is_t5_model) |
|
self.data_args = model.config.preprocessing_kwargs |
|
|
|
if "predicate_marker_type" not in vars(self.data_args): |
|
self.data_args.predicate_marker_type = "generic" |
|
if "use_bilateral_predicate_marker" not in vars(self.data_args): |
|
self.data_args.use_bilateral_predicate_marker = True |
|
if "append_verb_form" not in vars(self.data_args): |
|
self.data_args.append_verb_form = True |
|
self._update_config(**kwargs) |
|
|
|
def _update_config(self, **kwargs): |
|
" Update self.model.config with initialization parameters and necessary defaults. " |
|
|
|
kwargs["max_length"] = kwargs.get("max_length", 80) |
|
|
|
for k,v in kwargs.items(): |
|
self.model.config.__dict__[k] = v |
|
|
|
def _sanitize_parameters(self, **kwargs): |
|
preprocess_kwargs, forward_kwargs, postprocess_kwargs = {}, {}, {} |
|
if "predicate_marker" in kwargs: |
|
preprocess_kwargs["predicate_marker"] = kwargs["predicate_marker"] |
|
if "predicate_type" in kwargs: |
|
preprocess_kwargs["predicate_type"] = kwargs["predicate_type"] |
|
if "verb_form" in kwargs: |
|
preprocess_kwargs["verb_form"] = kwargs["verb_form"] |
|
return preprocess_kwargs, forward_kwargs, postprocess_kwargs |
|
|
|
def preprocess(self, inputs, predicate_marker="<predicate>", predicate_type=None, verb_form=None): |
|
|
|
if isinstance(inputs, str): |
|
processed_inputs = self._preprocess_string(inputs, predicate_marker, predicate_type, verb_form) |
|
elif hasattr(inputs, "__iter__"): |
|
processed_inputs = [self._preprocess_string(s, predicate_marker, predicate_type, verb_form) for s in inputs] |
|
else: |
|
raise ValueError("inputs must be str or Iterable[str]") |
|
|
|
return super().preprocess(processed_inputs) |
|
|
|
def _preprocess_string(self, seq: str, predicate_marker: str, predicate_type: Optional[str], verb_form: Optional[str]) -> str: |
|
sent_tokens = seq.split(" ") |
|
assert predicate_marker in sent_tokens, f"Input sentence must include a predicate-marker token ('{predicate_marker}') before the target predicate word" |
|
predicate_idx = sent_tokens.index(predicate_marker) |
|
sent_tokens.remove(predicate_marker) |
|
sentence_before_predicate = " ".join([sent_tokens[i] for i in range(predicate_idx)]) |
|
predicate = sent_tokens[predicate_idx] |
|
sentence_after_predicate = " ".join([sent_tokens[i] for i in range(predicate_idx+1, len(sent_tokens))]) |
|
|
|
if self.data_args.predicate_marker_type == "generic": |
|
predicate_marker = self.special_tokens.predicate_generic_marker |
|
|
|
elif self.data_args.predicate_marker_type == "pred_type": |
|
assert predicate_type is not None, "For this model, you must provide the `predicate_type` either when initializing QASRL_Pipeline(...) or when applying __call__(...) on it" |
|
assert predicate_type in ("verbal", "nominal"), f"`predicate_type` must be either 'verbal' or 'nominal'; got '{predicate_type}'" |
|
predicate_marker = {"verbal": self.special_tokens.predicate_verb_marker , |
|
"nominal": self.special_tokens.predicate_nominalization_marker |
|
}[predicate_type] |
|
|
|
if self.data_args.use_bilateral_predicate_marker: |
|
seq = f"{sentence_before_predicate} {predicate_marker} {predicate} {predicate_marker} {sentence_after_predicate}" |
|
else: |
|
seq = f"{sentence_before_predicate} {predicate_marker} {predicate} {sentence_after_predicate}" |
|
|
|
|
|
if self.data_args.append_verb_form and verb_form is None: |
|
raise ValueError(f"For this model, you must provide the `verb_form` of the predicate when applying __call__(...)") |
|
elif self.data_args.append_verb_form: |
|
seq = f"{seq} {self.special_tokens.separator_input_question_predicate} {verb_form} " |
|
else: |
|
seq = f"{seq} " |
|
|
|
|
|
prefix = self._get_source_prefix(predicate_type) |
|
|
|
return prefix + seq |
|
|
|
def _get_source_prefix(self, predicate_type: Optional[str]): |
|
if not self.is_t5_model or self.data_args.source_prefix is None: |
|
return '' |
|
if not self.data_args.source_prefix.startswith("<"): |
|
return self.data_args.source_prefix |
|
if self.data_args.source_prefix == "<predicate-type>": |
|
if predicate_type is None: |
|
raise ValueError("source_prefix is '<predicate-type>' but input no `predicate_type`.") |
|
else: |
|
return f"Generate QAs for {predicate_type} QASRL: " |
|
|
|
def _forward(self, *args, **kwargs): |
|
outputs = super()._forward(*args, **kwargs) |
|
return outputs |
|
|
|
|
|
def postprocess(self, model_outputs): |
|
output_seq = self.tokenizer.decode( |
|
model_outputs["output_ids"].squeeze(), |
|
skip_special_tokens=False, |
|
clean_up_tokenization_spaces=False, |
|
) |
|
output_seq = output_seq.strip(self.tokenizer.pad_token).strip(self.tokenizer.eos_token).strip() |
|
qa_subseqs = output_seq.split(self.special_tokens.separator_output_pairs) |
|
qas = [self._postrocess_qa(qa_subseq) for qa_subseq in qa_subseqs] |
|
return {"generated_text": output_seq, |
|
"QAs": qas} |
|
|
|
def _postrocess_qa(self, seq: str) -> str: |
|
|
|
if self.special_tokens.separator_output_question_answer in seq: |
|
question, answer = seq.split(self.special_tokens.separator_output_question_answer)[:2] |
|
else: |
|
print("invalid format: no separator between question and answer found...") |
|
return None |
|
|
|
|
|
question = ' '.join(t for t in question.split(' ') if t != '_') |
|
answers = [a.strip() for a in answer.split(self.special_tokens.separator_output_answers)] |
|
return {"question": question, "answers": answers} |
|
|
|
|
|
if __name__ == "__main__": |
|
pipe = QASRL_Pipeline("kleinay/qanom-seq2seq-model-baseline") |
|
res1 = pipe("The student was interested in Luke 's <predicate> research about sea animals .", verb_form="research", predicate_type="nominal") |
|
res2 = pipe(["The doctor was interested in Luke 's <predicate> treatment .", |
|
"The Veterinary student was interested in Luke 's <predicate> treatment of sea animals ."], verb_form="treat", predicate_type="nominal", num_beams=10) |
|
res3 = pipe("A number of professions have <predicate> developed that specialize in the treatment of mental disorders .", verb_form="develop", predicate_type="verbal") |
|
print(res1) |
|
print(res2) |
|
print(res3) |
|
|