andrewrreed's picture
andrewrreed HF staff
fix data dir
ac70c4f
raw
history blame
2.01 kB
import os
from typing import Dict, List, Any
from transformers import AutoTokenizer
from gector import GECToR, predict, load_verb_dict
class EndpointHandler:
def __init__(self, path=""):
self.model = GECToR.from_pretrained(path)
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.encode, self.decode = load_verb_dict(
os.path.join(path, "data/verb-form-vocab.txt")
)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Process the input data and return the predicted results.
Args:
data (Dict[str, Any]): The input data dictionary containing the following keys:
- "inputs" (List[str]): A list of input strings to be processed.
- "n_iterations" (int, optional): The number of iterations for prediction. Defaults to 5.
- "batch_size" (int, optional): The batch size for prediction. Defaults to 2.
- "keep_confidence" (float, optional): The confidence threshold for keeping predictions. Defaults to 0.0.
- "min_error_prob" (float, optional): The minimum error probability for keeping predictions. Defaults to 0.0.
Returns:
List[Dict[str, Any]]: A list of dictionaries containing the predicted results for each input string.
"""
srcs = data["inputs"]
# Extract optional parameters from data, with defaults
n_iterations = data.get("n_iterations", 5)
batch_size = data.get("batch_size", 2)
keep_confidence = data.get("keep_confidence", 0.0)
min_error_prob = data.get("min_error_prob", 0.0)
return predict(
model=self.model,
tokenizer=self.tokenizer,
srcs=srcs,
encode=self.encode,
decode=self.decode,
keep_confidence=keep_confidence,
min_error_prob=min_error_prob,
n_iteration=n_iterations,
batch_size=batch_size,
)