lmalek awacke1 commited on
Commit
b5b5466
·
0 Parent(s):

Duplicate from DataScienceEngineering/4-Seq2SeqQAT5

Browse files

Co-authored-by: Aaron C Wacker <[email protected]>

Files changed (5) hide show
  1. .gitattributes +34 -0
  2. README.md +14 -0
  3. app.py +51 -0
  4. qasrl_model_pipeline.py +183 -0
  5. requirements.txt +2 -0
.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: 🙋AutoQA - 💬NLP Seq2Seq Gradio
3
+ emoji: 4-QAI❓
4
+ colorFrom: green
5
+ colorTo: gray
6
+ sdk: gradio
7
+ sdk_version: 3.0.5
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ duplicated_from: DataScienceEngineering/4-Seq2SeqQAT5
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces#reference
app.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from qasrl_model_pipeline import QASRL_Pipeline
3
+
4
+ models = ["kleinay/qanom-seq2seq-model-baseline",
5
+ "kleinay/qanom-seq2seq-model-joint"]
6
+ pipelines = {model: QASRL_Pipeline(model) for model in models}
7
+
8
+
9
+ description = f"""Using Seq2Seq T5 model which takes a sequence of items and outputs another sequence this model generates Questions and Answers (QA) with focus on Semantic Role Labeling (SRL)"""
10
+ title="Seq2Seq T5 Questions and Answers (QA) with Semantic Role Labeling (SRL)"
11
+ examples = [[models[0], "In March and April the patient <p> had two falls. One was related to asthma, heart palpitations. The second was due to syncope and post covid vaccination dizziness during exercise. The patient is now getting an EKG. Former EKG had shown that there was a bundle branch block. Patient had some uncontrolled immune system reactions like anaphylaxis and shortness of breath.", True, "fall"],
12
+ [models[1], "In March and April the patient had two falls. One was related to asthma, heart palpitations. The second was due to syncope and post covid vaccination dizziness during exercise. The patient is now getting an EKG. Former EKG had shown that there was a bundle branch block. Patient had some uncontrolled immune system reactions <p> like anaphylaxis and shortness of breath.", True, "reactions"],
13
+ [models[0], "In March and April the patient had two falls. One was related <p> to asthma, heart palpitations. The second was due to syncope and post covid vaccination dizziness during exercise. The patient is now getting an EKG. Former EKG had shown that there was a bundle branch block. Patient had some uncontrolled immune system reactions like anaphylaxis and shortness of breath.", True, "relate"],
14
+ [models[1], "In March and April the patient <p> had two falls. One was related to asthma, heart palpitations. The second was due to syncope and post covid vaccination dizziness during exercise. The patient is now getting an EKG. Former EKG had shown that there was a bundle branch block. Patient had some uncontrolled immune system reactions like anaphylaxis and shortness of breath.", False, "fall"]]
15
+
16
+ input_sent_box_label = "Insert sentence here. Mark the predicate by adding the token '<p>' before it."
17
+ verb_form_inp_placeholder = "e.g. 'decide' for the nominalization 'decision', 'teach' for 'teacher', etc."
18
+ links = """<p style='text-align: center'>
19
+ <a href='https://www.qasrl.org' target='_blank'>QASRL Website</a> | <a href='https://huggingface.co/kleinay/qanom-seq2seq-model-baseline' target='_blank'>Model Repo at Huggingface Hub</a>
20
+ </p>"""
21
+ def call(model_name, sentence, is_nominal, verb_form):
22
+ predicate_marker="<p>"
23
+ if predicate_marker not in sentence:
24
+ raise ValueError("You must highlight one word of the sentence as a predicate using preceding '<p>'.")
25
+
26
+ if not verb_form:
27
+ if is_nominal:
28
+ raise ValueError("You should provide the verbal form of the nominalization")
29
+
30
+ toks = sentence.split(" ")
31
+ pred_idx = toks.index(predicate_marker)
32
+ predicate = toks(pred_idx+1)
33
+ verb_form=predicate
34
+ pipeline = pipelines[model_name]
35
+ pipe_out = pipeline([sentence],
36
+ predicate_marker=predicate_marker,
37
+ predicate_type="nominal" if is_nominal else "verbal",
38
+ verb_form=verb_form)[0]
39
+ return pipe_out["QAs"], pipe_out["generated_text"]
40
+ iface = gr.Interface(fn=call,
41
+ inputs=[gr.inputs.Radio(choices=models, default=models[0], label="Model"),
42
+ gr.inputs.Textbox(placeholder=input_sent_box_label, label="Sentence", lines=4),
43
+ gr.inputs.Checkbox(default=True, label="Is Nominalization?"),
44
+ gr.inputs.Textbox(placeholder=verb_form_inp_placeholder, label="Verbal form (for nominalizations)", default='')],
45
+ outputs=[gr.outputs.JSON(label="Model Output - QASRL"), gr.outputs.Textbox(label="Raw output sequence")],
46
+ title=title,
47
+ description=description,
48
+ article=links,
49
+ examples=examples )
50
+
51
+ iface.launch()
qasrl_model_pipeline.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+ import json
3
+ from argparse import Namespace
4
+ from pathlib import Path
5
+ from transformers import Text2TextGenerationPipeline, AutoModelForSeq2SeqLM, AutoTokenizer
6
+
7
+ def get_markers_for_model(is_t5_model: bool) -> Namespace:
8
+ special_tokens_constants = Namespace()
9
+ if is_t5_model:
10
+ # T5 model have 100 special tokens by default
11
+ special_tokens_constants.separator_input_question_predicate = "<extra_id_1>"
12
+ special_tokens_constants.separator_output_answers = "<extra_id_3>"
13
+ special_tokens_constants.separator_output_questions = "<extra_id_5>" # if using only questions
14
+ special_tokens_constants.separator_output_question_answer = "<extra_id_7>"
15
+ special_tokens_constants.separator_output_pairs = "<extra_id_9>"
16
+ special_tokens_constants.predicate_generic_marker = "<extra_id_10>"
17
+ special_tokens_constants.predicate_verb_marker = "<extra_id_11>"
18
+ special_tokens_constants.predicate_nominalization_marker = "<extra_id_12>"
19
+
20
+ else:
21
+ special_tokens_constants.separator_input_question_predicate = "<question_predicate_sep>"
22
+ special_tokens_constants.separator_output_answers = "<answers_sep>"
23
+ special_tokens_constants.separator_output_questions = "<question_sep>" # if using only questions
24
+ special_tokens_constants.separator_output_question_answer = "<question_answer_sep>"
25
+ special_tokens_constants.separator_output_pairs = "<qa_pairs_sep>"
26
+ special_tokens_constants.predicate_generic_marker = "<predicate_marker>"
27
+ special_tokens_constants.predicate_verb_marker = "<verbal_predicate_marker>"
28
+ special_tokens_constants.predicate_nominalization_marker = "<nominalization_predicate_marker>"
29
+ return special_tokens_constants
30
+
31
+ def load_trained_model(name_or_path):
32
+ import huggingface_hub as HFhub
33
+ tokenizer = AutoTokenizer.from_pretrained(name_or_path)
34
+ model = AutoModelForSeq2SeqLM.from_pretrained(name_or_path)
35
+ # load preprocessing_kwargs from the model repo on HF hub, or from the local model directory
36
+ kwargs_filename = None
37
+ if name_or_path.startswith("kleinay/"): # and 'preprocessing_kwargs.json' in HFhub.list_repo_files(name_or_path): # the supported version of HFhub doesn't support list_repo_files
38
+ kwargs_filename = HFhub.hf_hub_download(repo_id=name_or_path, filename="preprocessing_kwargs.json")
39
+ elif Path(name_or_path).is_dir() and (Path(name_or_path) / "experiment_kwargs.json").exists():
40
+ kwargs_filename = Path(name_or_path) / "experiment_kwargs.json"
41
+
42
+ if kwargs_filename:
43
+ preprocessing_kwargs = json.load(open(kwargs_filename))
44
+ # integrate into model.config (for decoding args, e.g. "num_beams"), and save also as standalone object for preprocessing
45
+ model.config.preprocessing_kwargs = Namespace(**preprocessing_kwargs)
46
+ model.config.update(preprocessing_kwargs)
47
+ return model, tokenizer
48
+
49
+
50
+ class QASRL_Pipeline(Text2TextGenerationPipeline):
51
+ def __init__(self, model_repo: str, **kwargs):
52
+ model, tokenizer = load_trained_model(model_repo)
53
+ super().__init__(model, tokenizer, framework="pt")
54
+ self.is_t5_model = "t5" in model.config.model_type
55
+ self.special_tokens = get_markers_for_model(self.is_t5_model)
56
+ self.data_args = model.config.preprocessing_kwargs
57
+ # backward compatibility - default keyword values implemeted in `run_summarization`, thus not saved in `preprocessing_kwargs`
58
+ if "predicate_marker_type" not in vars(self.data_args):
59
+ self.data_args.predicate_marker_type = "generic"
60
+ if "use_bilateral_predicate_marker" not in vars(self.data_args):
61
+ self.data_args.use_bilateral_predicate_marker = True
62
+ if "append_verb_form" not in vars(self.data_args):
63
+ self.data_args.append_verb_form = True
64
+ self._update_config(**kwargs)
65
+
66
+ def _update_config(self, **kwargs):
67
+ " Update self.model.config with initialization parameters and necessary defaults. "
68
+ # set default values that will always override model.config, but can overriden by __init__ kwargs
69
+ kwargs["max_length"] = kwargs.get("max_length", 80)
70
+ # override model.config with kwargs
71
+ for k,v in kwargs.items():
72
+ self.model.config.__dict__[k] = v
73
+
74
+ def _sanitize_parameters(self, **kwargs):
75
+ preprocess_kwargs, forward_kwargs, postprocess_kwargs = {}, {}, {}
76
+ if "predicate_marker" in kwargs:
77
+ preprocess_kwargs["predicate_marker"] = kwargs["predicate_marker"]
78
+ if "predicate_type" in kwargs:
79
+ preprocess_kwargs["predicate_type"] = kwargs["predicate_type"]
80
+ if "verb_form" in kwargs:
81
+ preprocess_kwargs["verb_form"] = kwargs["verb_form"]
82
+ return preprocess_kwargs, forward_kwargs, postprocess_kwargs
83
+
84
+ def preprocess(self, inputs, predicate_marker="<predicate>", predicate_type=None, verb_form=None):
85
+ # Here, inputs is string or list of strings; apply string postprocessing
86
+ if isinstance(inputs, str):
87
+ processed_inputs = self._preprocess_string(inputs, predicate_marker, predicate_type, verb_form)
88
+ elif hasattr(inputs, "__iter__"):
89
+ processed_inputs = [self._preprocess_string(s, predicate_marker, predicate_type, verb_form) for s in inputs]
90
+ else:
91
+ raise ValueError("inputs must be str or Iterable[str]")
92
+ # Now pass to super.preprocess for tokenization
93
+ return super().preprocess(processed_inputs)
94
+
95
+ def _preprocess_string(self, seq: str, predicate_marker: str, predicate_type: Optional[str], verb_form: Optional[str]) -> str:
96
+ sent_tokens = seq.split(" ")
97
+ assert predicate_marker in sent_tokens, f"Input sentence must include a predicate-marker token ('{predicate_marker}') before the target predicate word"
98
+ predicate_idx = sent_tokens.index(predicate_marker)
99
+ sent_tokens.remove(predicate_marker)
100
+ sentence_before_predicate = " ".join([sent_tokens[i] for i in range(predicate_idx)])
101
+ predicate = sent_tokens[predicate_idx]
102
+ sentence_after_predicate = " ".join([sent_tokens[i] for i in range(predicate_idx+1, len(sent_tokens))])
103
+
104
+ if self.data_args.predicate_marker_type == "generic":
105
+ predicate_marker = self.special_tokens.predicate_generic_marker
106
+ # In case we want special marker for each predicate type: """
107
+ elif self.data_args.predicate_marker_type == "pred_type":
108
+ assert predicate_type is not None, "For this model, you must provide the `predicate_type` either when initializing QASRL_Pipeline(...) or when applying __call__(...) on it"
109
+ assert predicate_type in ("verbal", "nominal"), f"`predicate_type` must be either 'verbal' or 'nominal'; got '{predicate_type}'"
110
+ predicate_marker = {"verbal": self.special_tokens.predicate_verb_marker ,
111
+ "nominal": self.special_tokens.predicate_nominalization_marker
112
+ }[predicate_type]
113
+
114
+ if self.data_args.use_bilateral_predicate_marker:
115
+ seq = f"{sentence_before_predicate} {predicate_marker} {predicate} {predicate_marker} {sentence_after_predicate}"
116
+ else:
117
+ seq = f"{sentence_before_predicate} {predicate_marker} {predicate} {sentence_after_predicate}"
118
+
119
+ # embed also verb_form
120
+ if self.data_args.append_verb_form and verb_form is None:
121
+ raise ValueError(f"For this model, you must provide the `verb_form` of the predicate when applying __call__(...)")
122
+ elif self.data_args.append_verb_form:
123
+ seq = f"{seq} {self.special_tokens.separator_input_question_predicate} {verb_form} "
124
+ else:
125
+ seq = f"{seq} "
126
+
127
+ # append source prefix (for t5 models)
128
+ prefix = self._get_source_prefix(predicate_type)
129
+
130
+ return prefix + seq
131
+
132
+ def _get_source_prefix(self, predicate_type: Optional[str]):
133
+ if not self.is_t5_model or self.data_args.source_prefix is None:
134
+ return ''
135
+ if not self.data_args.source_prefix.startswith("<"): # Regular prefix - not dependent on input row x
136
+ return self.data_args.source_prefix
137
+ if self.data_args.source_prefix == "<predicate-type>":
138
+ if predicate_type is None:
139
+ raise ValueError("source_prefix is '<predicate-type>' but input no `predicate_type`.")
140
+ else:
141
+ return f"Generate QAs for {predicate_type} QASRL: "
142
+
143
+ def _forward(self, *args, **kwargs):
144
+ outputs = super()._forward(*args, **kwargs)
145
+ return outputs
146
+
147
+
148
+ def postprocess(self, model_outputs):
149
+ output_seq = self.tokenizer.decode(
150
+ model_outputs["output_ids"].squeeze(),
151
+ skip_special_tokens=False,
152
+ clean_up_tokenization_spaces=False,
153
+ )
154
+ output_seq = output_seq.strip(self.tokenizer.pad_token).strip(self.tokenizer.eos_token).strip()
155
+ qa_subseqs = output_seq.split(self.special_tokens.separator_output_pairs)
156
+ qas = [self._postrocess_qa(qa_subseq) for qa_subseq in qa_subseqs]
157
+ return {"generated_text": output_seq,
158
+ "QAs": qas}
159
+
160
+ def _postrocess_qa(self, seq: str) -> str:
161
+ # split question and answers
162
+ if self.special_tokens.separator_output_question_answer in seq:
163
+ question, answer = seq.split(self.special_tokens.separator_output_question_answer)[:2]
164
+ else:
165
+ print("invalid format: no separator between question and answer found...")
166
+ return None
167
+ # question, answer = seq, '' # Or: backoff to only question
168
+ # skip "_" slots in questions
169
+ question = ' '.join(t for t in question.split(' ') if t != '_')
170
+ answers = [a.strip() for a in answer.split(self.special_tokens.separator_output_answers)]
171
+ return {"question": question, "answers": answers}
172
+
173
+
174
+ if __name__ == "__main__":
175
+ pipe = QASRL_Pipeline("kleinay/qanom-seq2seq-model-baseline")
176
+ res1 = pipe("The student was interested in Luke 's <predicate> research about sea animals .", verb_form="research", predicate_type="nominal")
177
+ res2 = pipe(["The doctor was interested in Luke 's <predicate> treatment .",
178
+ "The Veterinary student was interested in Luke 's <predicate> treatment of sea animals ."], verb_form="treat", predicate_type="nominal", num_beams=10)
179
+ res3 = pipe("A number of professions have <predicate> developed that specialize in the treatment of mental disorders .", verb_form="develop", predicate_type="verbal")
180
+ print(res1)
181
+ print(res2)
182
+ print(res3)
183
+
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ transformers==4.15.0
2
+ torch