srl-en_roberta-large_hf / srl_pipeline.py
rita443's picture
Upload srl_pipeline.py
a5131f2 verified
raw
history blame
9.44 kB
import logging
from typing import Any, Dict, List, Tuple
import spacy
import torch
from transformers import Pipeline
from decoder import Decoder
logger = logging.getLogger(__name__)
class SrlPipeline(Pipeline):
"""
A pipeline for Semantic Role Labeling (SRL) using transformers and spaCy models.
This pipeline tokenizes input sentences, finds verbs using POS tagging, and postprocesses
the model outputs using Viterbi decoding to provide human-readable results.
Attributes:
model ``str``: The name or identifier of the underlying transformer model.
tokenizer ``str``: The name or identifier of the tokenizer associated with the model.
framework ``str``: The framework used for the pipeline (e.g., PyTorch, TensorFlow).
task ``str``: The specific task of the pipeline.
verb_predictor: An instance of spaCy model used for predicting verbs in the input sentences.
Usage:
# Register the SrlPipeline in the pipeline registry
PIPELINE_REGISTRY.register_pipeline(
"srl",
pipeline_class=SrlPipeline,
model=SRLModel, # Assuming SRLModel is the model class used
default={"lang": "en"},
type="text",
)
# Load the model and tokenizer
model = AutoModel.from_pretrained("liaad/srl-en_roberta-large_hf", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("liaad/srl-en_roberta-large_hf", trust_remote_code=True)
# Load the SRL pipeline
srl_pipeline = pipeline(
"srl",
model=model,
tokenizer=tokenizer,
framework="PyTorch", # Replace with actual framework used
task="semantic_role_labeling", # Replace with actual task name
lang="en" # Language specification
)
# Example text input
text = ["The cat jumps over the fence.", "She quickly eats the delicious cake."]
# Perform semantic role labeling
results = srl_pipeline(text)
"""
def __init__(self, model: str, tokenizer: str, framework: str, task: str, **kwargs):
"""
Initializes the Semantic Role Labeling pipeline.
Parameters:
- model ``str``: The model name or identifier.
- tokenizer ``str``: The tokenizer name or identifier.
- framework ``str``: The framework used.
- task ``str``: The specific task of the pipeline.
- **kwargs: Additional keyword arguments.
- lang ``str``, optional: Language specification ('en' for English or 'pt' for Portuguese, which is default).
"""
super().__init__(model, tokenizer=tokenizer)
if "lang" in kwargs and kwargs["lang"] == "en":
logger.info("Loading English verb predictor model...")
self.verb_predictor = spacy.load("en_core_web_trf")
else:
logger.info("Loading Portuguese verb predictor model...")
self.verb_predictor = spacy.load("pt_core_news_lg")
logger.info("Got verb prediction model\n")
def _sanitize_parameters(
self, **kwargs: Dict[str, Any]
) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
"""
Sanitizes and organizes additional parameters.
Parameters:
- **kwargs: Additional keyword arguments.
Returns:
- ``Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]``: Three dictionaries of sanitized parameters for preprocess, _forward, and postprocess.
"""
return {}, {}, {}
def preprocess(self, sentence: str) -> List[Dict[str, Any]]:
"""
Preprocesses a sentence for semantic role labeling.
Parameters:
- sentence ``str``: The input sentence to be processed.
Returns:
- ``List[Dict[str, Any]]``: A list of dictionaries containing model inputs for each verb in the sentence.
"""
# Extract sentence verbs
doc = self.verb_predictor(sentence)
verbs = {token.text for token in doc if token.pos_ == "VERB"}
# If the sentence only contains auxiliary verbs, consider those as the
# main verbs
if not verbs:
verbs = {token.text for token in doc if token.pos_ == "AUX"}
# Tokenize sentence
tokens = self.tokenizer.encode_plus(
sentence,
truncation=True,
return_token_type_ids=False,
return_offsets_mapping=True,
)
tokens_lst = tokens.tokens()
offsets = tokens["offset_mapping"]
input_ids = torch.tensor([tokens["input_ids"]], dtype=torch.long)
attention_mask = torch.tensor([tokens["attention_mask"]], dtype=torch.long)
model_input = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"token_type_ids": [],
"tokens": tokens_lst,
"verb": "",
}
model_inputs = [
{**model_input} for _ in verbs
] # Create a new dictionary for each verb
for i, verb in enumerate(verbs):
model_inputs[i]["verb"] = verb
token_type_ids = model_inputs[i]["token_type_ids"]
token_type_ids.append([])
curr_word_offsets: tuple[int, int] = None
for j in range(len(tokens_lst)):
curr_offsets = offsets[j]
curr_slice = sentence[curr_offsets[0] : curr_offsets[1]]
if not curr_slice:
token_type_ids[-1].append(0)
# Check if new token still belongs to same word
elif (
curr_word_offsets
and curr_offsets[0] >= curr_word_offsets[0]
and curr_offsets[1] <= curr_word_offsets[1]
):
# Extend previous token type
token_type_ids[-1].append(token_type_ids[-1][-1])
else:
curr_word_offsets = self._find_word(sentence, start=curr_offsets[0])
curr_word = sentence[curr_word_offsets[0] : curr_word_offsets[1]]
token_type_ids[-1].append(
int(curr_word != "" and curr_word == verb)
)
model_inputs[i]["token_type_ids"] = torch.tensor(
token_type_ids, dtype=torch.long
)
return model_inputs
def _forward(self, model_inputs: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Internal method to forward model inputs for prediction.
Parameters:
- model_inputs ``List[Dict[str, Any]]``: List of dictionaries containing model inputs.
Returns:
- ``List[Dict[str, Any]]``: List of dictionaries containing model outputs.
"""
outputs = []
for model_input in model_inputs:
output = self.model(
input_ids=model_input["input_ids"],
attention_mask=model_input["attention_mask"],
token_type_ids=model_input["token_type_ids"],
)
output["verb"] = model_input["verb"]
output["tokens"] = model_input["tokens"]
outputs.append(output)
return outputs
def postprocess(self, model_outputs: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Postprocesses model outputs to human-readable format.
Parameters:
- model_outputs ``List[Dict[str, Any]]``: List of dictionaries containing model outputs.
Returns:
- ``List[Dict[str, Any]]``: List of dictionaries containing processed results.
Each dictionary entry represents a verb with its associated labels and token-label pairs.
Example format: {verb: (labels, List[(token, label)])}
"""
result = []
id2label = {int(k): str(v) for k, v in self.model.config.id2label.items()}
evaluator = Decoder(id2label)
for model_output in model_outputs:
class_probabilities = model_output["class_probabilities"]
attention_mask = model_output["attention_mask"]
output_dict = evaluator.make_output_human_readable(
class_probabilities, attention_mask
)
# Here we always fetch the first list because in a pipeline every
# sentence is processed one at a time
wordpiece_label_ids = output_dict["wordpiece_label_ids"][0]
labels = list(map(lambda idx: id2label[idx], wordpiece_label_ids))
result.append(
{
model_output["verb"]: (
labels,
list(zip(model_output["tokens"], labels)),
)
}
)
return result
def _find_word(self, s: str, start: int = 0) -> Tuple[int, int]:
"""
Helper method to find the boundaries of a word in a string.
Assumes a non alphanumeric char represents the end of a word.
Parameters:
- s ``str``: The input string.
- start ``int``, optional: Starting index to start looking for the word. Defaults to 0.
Returns:
- ``Tuple[int, int]``: Start and end indices of the word.
"""
for i, char in enumerate(s[start:], start):
if not char.isalpha():
return start, i
return start, len(s)