File size: 744 Bytes
7ffa8af
e423171
7ffa8af
 
 
 
e423171
7ffa8af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e423171
7ffa8af
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from typing import Dict, List, Any
from comet import load_from_checkpoint


class EndpointHandler():
    def __init__(self, path=""):
        self.model = load_from_checkpoint("model.ckpt")

    def __call__(self, data: Dict[str, Any]) -> List[Any]:
        """
        data args:
            inputs (:obj: `dict[str, str]`)
        Return:
            A :obj: `dict`: will be serialized and returned
        """

        inputs = data.pop("inputs")

        batch_size = inputs.pop("batch_size")
        workers = inputs.pop("workers")

        data = inputs.pop("data")
        
        model_output = self.model.predict(data, batch_size=batch_size, num_workers=workers, gpus=0)
        scores = model_output["scores"]

        return scores