File size: 2,010 Bytes
1572885 67a58db ac70c4f 67a58db |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 |
import os
from typing import Dict, List, Any
from transformers import AutoTokenizer
from gector import GECToR, predict, load_verb_dict
class EndpointHandler:
def __init__(self, path=""):
self.model = GECToR.from_pretrained(path)
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.encode, self.decode = load_verb_dict(
os.path.join(path, "data/verb-form-vocab.txt")
)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Process the input data and return the predicted results.
Args:
data (Dict[str, Any]): The input data dictionary containing the following keys:
- "inputs" (List[str]): A list of input strings to be processed.
- "n_iterations" (int, optional): The number of iterations for prediction. Defaults to 5.
- "batch_size" (int, optional): The batch size for prediction. Defaults to 2.
- "keep_confidence" (float, optional): The confidence threshold for keeping predictions. Defaults to 0.0.
- "min_error_prob" (float, optional): The minimum error probability for keeping predictions. Defaults to 0.0.
Returns:
List[Dict[str, Any]]: A list of dictionaries containing the predicted results for each input string.
"""
srcs = data["inputs"]
# Extract optional parameters from data, with defaults
n_iterations = data.get("n_iterations", 5)
batch_size = data.get("batch_size", 2)
keep_confidence = data.get("keep_confidence", 0.0)
min_error_prob = data.get("min_error_prob", 0.0)
return predict(
model=self.model,
tokenizer=self.tokenizer,
srcs=srcs,
encode=self.encode,
decode=self.decode,
keep_confidence=keep_confidence,
min_error_prob=min_error_prob,
n_iteration=n_iterations,
batch_size=batch_size,
)
|