rita443 commited on
Commit
a5131f2
·
verified ·
1 Parent(s): e101cf4

Upload srl_pipeline.py

Browse files
Files changed (1) hide show
  1. srl_pipeline.py +242 -0
srl_pipeline.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Any, Dict, List, Tuple
3
+
4
+ import spacy
5
+ import torch
6
+ from transformers import Pipeline
7
+
8
+ from decoder import Decoder
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ class SrlPipeline(Pipeline):
14
+ """
15
+ A pipeline for Semantic Role Labeling (SRL) using transformers and spaCy models.
16
+
17
+ This pipeline tokenizes input sentences, finds verbs using POS tagging, and postprocesses
18
+ the model outputs using Viterbi decoding to provide human-readable results.
19
+
20
+ Attributes:
21
+ model ``str``: The name or identifier of the underlying transformer model.
22
+ tokenizer ``str``: The name or identifier of the tokenizer associated with the model.
23
+ framework ``str``: The framework used for the pipeline (e.g., PyTorch, TensorFlow).
24
+ task ``str``: The specific task of the pipeline.
25
+ verb_predictor: An instance of spaCy model used for predicting verbs in the input sentences.
26
+ Usage:
27
+ # Register the SrlPipeline in the pipeline registry
28
+ PIPELINE_REGISTRY.register_pipeline(
29
+ "srl",
30
+ pipeline_class=SrlPipeline,
31
+ model=SRLModel, # Assuming SRLModel is the model class used
32
+ default={"lang": "en"},
33
+ type="text",
34
+ )
35
+
36
+ # Load the model and tokenizer
37
+ model = AutoModel.from_pretrained("liaad/srl-en_roberta-large_hf", trust_remote_code=True)
38
+ tokenizer = AutoTokenizer.from_pretrained("liaad/srl-en_roberta-large_hf", trust_remote_code=True)
39
+
40
+ # Load the SRL pipeline
41
+ srl_pipeline = pipeline(
42
+ "srl",
43
+ model=model,
44
+ tokenizer=tokenizer,
45
+ framework="PyTorch", # Replace with actual framework used
46
+ task="semantic_role_labeling", # Replace with actual task name
47
+ lang="en" # Language specification
48
+ )
49
+
50
+ # Example text input
51
+ text = ["The cat jumps over the fence.", "She quickly eats the delicious cake."]
52
+
53
+ # Perform semantic role labeling
54
+ results = srl_pipeline(text)
55
+ """
56
+
57
+ def __init__(self, model: str, tokenizer: str, framework: str, task: str, **kwargs):
58
+ """
59
+ Initializes the Semantic Role Labeling pipeline.
60
+
61
+ Parameters:
62
+ - model ``str``: The model name or identifier.
63
+ - tokenizer ``str``: The tokenizer name or identifier.
64
+ - framework ``str``: The framework used.
65
+ - task ``str``: The specific task of the pipeline.
66
+ - **kwargs: Additional keyword arguments.
67
+ - lang ``str``, optional: Language specification ('en' for English or 'pt' for Portuguese, which is default).
68
+ """
69
+ super().__init__(model, tokenizer=tokenizer)
70
+ if "lang" in kwargs and kwargs["lang"] == "en":
71
+ logger.info("Loading English verb predictor model...")
72
+ self.verb_predictor = spacy.load("en_core_web_trf")
73
+ else:
74
+ logger.info("Loading Portuguese verb predictor model...")
75
+ self.verb_predictor = spacy.load("pt_core_news_lg")
76
+ logger.info("Got verb prediction model\n")
77
+
78
+ def _sanitize_parameters(
79
+ self, **kwargs: Dict[str, Any]
80
+ ) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
81
+ """
82
+ Sanitizes and organizes additional parameters.
83
+
84
+ Parameters:
85
+ - **kwargs: Additional keyword arguments.
86
+
87
+ Returns:
88
+ - ``Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]``: Three dictionaries of sanitized parameters for preprocess, _forward, and postprocess.
89
+ """
90
+ return {}, {}, {}
91
+
92
+ def preprocess(self, sentence: str) -> List[Dict[str, Any]]:
93
+ """
94
+ Preprocesses a sentence for semantic role labeling.
95
+
96
+ Parameters:
97
+ - sentence ``str``: The input sentence to be processed.
98
+
99
+ Returns:
100
+ - ``List[Dict[str, Any]]``: A list of dictionaries containing model inputs for each verb in the sentence.
101
+ """
102
+ # Extract sentence verbs
103
+ doc = self.verb_predictor(sentence)
104
+
105
+ verbs = {token.text for token in doc if token.pos_ == "VERB"}
106
+ # If the sentence only contains auxiliary verbs, consider those as the
107
+ # main verbs
108
+ if not verbs:
109
+ verbs = {token.text for token in doc if token.pos_ == "AUX"}
110
+
111
+ # Tokenize sentence
112
+ tokens = self.tokenizer.encode_plus(
113
+ sentence,
114
+ truncation=True,
115
+ return_token_type_ids=False,
116
+ return_offsets_mapping=True,
117
+ )
118
+ tokens_lst = tokens.tokens()
119
+ offsets = tokens["offset_mapping"]
120
+
121
+ input_ids = torch.tensor([tokens["input_ids"]], dtype=torch.long)
122
+ attention_mask = torch.tensor([tokens["attention_mask"]], dtype=torch.long)
123
+
124
+ model_input = {
125
+ "input_ids": input_ids,
126
+ "attention_mask": attention_mask,
127
+ "token_type_ids": [],
128
+ "tokens": tokens_lst,
129
+ "verb": "",
130
+ }
131
+
132
+ model_inputs = [
133
+ {**model_input} for _ in verbs
134
+ ] # Create a new dictionary for each verb
135
+
136
+ for i, verb in enumerate(verbs):
137
+ model_inputs[i]["verb"] = verb
138
+ token_type_ids = model_inputs[i]["token_type_ids"]
139
+ token_type_ids.append([])
140
+ curr_word_offsets: tuple[int, int] = None
141
+
142
+ for j in range(len(tokens_lst)):
143
+ curr_offsets = offsets[j]
144
+ curr_slice = sentence[curr_offsets[0] : curr_offsets[1]]
145
+ if not curr_slice:
146
+ token_type_ids[-1].append(0)
147
+ # Check if new token still belongs to same word
148
+ elif (
149
+ curr_word_offsets
150
+ and curr_offsets[0] >= curr_word_offsets[0]
151
+ and curr_offsets[1] <= curr_word_offsets[1]
152
+ ):
153
+ # Extend previous token type
154
+ token_type_ids[-1].append(token_type_ids[-1][-1])
155
+ else:
156
+ curr_word_offsets = self._find_word(sentence, start=curr_offsets[0])
157
+ curr_word = sentence[curr_word_offsets[0] : curr_word_offsets[1]]
158
+
159
+ token_type_ids[-1].append(
160
+ int(curr_word != "" and curr_word == verb)
161
+ )
162
+
163
+ model_inputs[i]["token_type_ids"] = torch.tensor(
164
+ token_type_ids, dtype=torch.long
165
+ )
166
+
167
+ return model_inputs
168
+
169
+ def _forward(self, model_inputs: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
170
+ """
171
+ Internal method to forward model inputs for prediction.
172
+
173
+ Parameters:
174
+ - model_inputs ``List[Dict[str, Any]]``: List of dictionaries containing model inputs.
175
+
176
+ Returns:
177
+ - ``List[Dict[str, Any]]``: List of dictionaries containing model outputs.
178
+ """
179
+ outputs = []
180
+ for model_input in model_inputs:
181
+ output = self.model(
182
+ input_ids=model_input["input_ids"],
183
+ attention_mask=model_input["attention_mask"],
184
+ token_type_ids=model_input["token_type_ids"],
185
+ )
186
+ output["verb"] = model_input["verb"]
187
+ output["tokens"] = model_input["tokens"]
188
+ outputs.append(output)
189
+ return outputs
190
+
191
+ def postprocess(self, model_outputs: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
192
+ """
193
+ Postprocesses model outputs to human-readable format.
194
+
195
+ Parameters:
196
+ - model_outputs ``List[Dict[str, Any]]``: List of dictionaries containing model outputs.
197
+
198
+ Returns:
199
+ - ``List[Dict[str, Any]]``: List of dictionaries containing processed results.
200
+ Each dictionary entry represents a verb with its associated labels and token-label pairs.
201
+ Example format: {verb: (labels, List[(token, label)])}
202
+ """
203
+ result = []
204
+ id2label = {int(k): str(v) for k, v in self.model.config.id2label.items()}
205
+ evaluator = Decoder(id2label)
206
+
207
+ for model_output in model_outputs:
208
+ class_probabilities = model_output["class_probabilities"]
209
+ attention_mask = model_output["attention_mask"]
210
+ output_dict = evaluator.make_output_human_readable(
211
+ class_probabilities, attention_mask
212
+ )
213
+ # Here we always fetch the first list because in a pipeline every
214
+ # sentence is processed one at a time
215
+ wordpiece_label_ids = output_dict["wordpiece_label_ids"][0]
216
+ labels = list(map(lambda idx: id2label[idx], wordpiece_label_ids))
217
+ result.append(
218
+ {
219
+ model_output["verb"]: (
220
+ labels,
221
+ list(zip(model_output["tokens"], labels)),
222
+ )
223
+ }
224
+ )
225
+ return result
226
+
227
+ def _find_word(self, s: str, start: int = 0) -> Tuple[int, int]:
228
+ """
229
+ Helper method to find the boundaries of a word in a string.
230
+ Assumes a non alphanumeric char represents the end of a word.
231
+
232
+ Parameters:
233
+ - s ``str``: The input string.
234
+ - start ``int``, optional: Starting index to start looking for the word. Defaults to 0.
235
+
236
+ Returns:
237
+ - ``Tuple[int, int]``: Start and end indices of the word.
238
+ """
239
+ for i, char in enumerate(s[start:], start):
240
+ if not char.isalpha():
241
+ return start, i
242
+ return start, len(s)