File size: 2,010 Bytes
1572885
67a58db
 
 
 
 
 
 
 
 
ac70c4f
 
 
67a58db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import os
from typing import Dict, List, Any
from transformers import AutoTokenizer
from gector import GECToR, predict, load_verb_dict


class EndpointHandler:
    def __init__(self, path=""):
        self.model = GECToR.from_pretrained(path)
        self.tokenizer = AutoTokenizer.from_pretrained(path)
        self.encode, self.decode = load_verb_dict(
            os.path.join(path, "data/verb-form-vocab.txt")
        )

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
        Process the input data and return the predicted results.

        Args:
            data (Dict[str, Any]): The input data dictionary containing the following keys:
                - "inputs" (List[str]): A list of input strings to be processed.
                - "n_iterations" (int, optional): The number of iterations for prediction. Defaults to 5.
                - "batch_size" (int, optional): The batch size for prediction. Defaults to 2.
                - "keep_confidence" (float, optional): The confidence threshold for keeping predictions. Defaults to 0.0.
                - "min_error_prob" (float, optional): The minimum error probability for keeping predictions. Defaults to 0.0.

        Returns:
            List[Dict[str, Any]]: A list of dictionaries containing the predicted results for each input string.
        """
        srcs = data["inputs"]

        # Extract optional parameters from data, with defaults
        n_iterations = data.get("n_iterations", 5)
        batch_size = data.get("batch_size", 2)
        keep_confidence = data.get("keep_confidence", 0.0)
        min_error_prob = data.get("min_error_prob", 0.0)

        return predict(
            model=self.model,
            tokenizer=self.tokenizer,
            srcs=srcs,
            encode=self.encode,
            decode=self.decode,
            keep_confidence=keep_confidence,
            min_error_prob=min_error_prob,
            n_iteration=n_iterations,
            batch_size=batch_size,
        )