|
from typing import Dict, List, Any |
|
from comet import load_from_checkpoint |
|
|
|
|
|
class EndpointHandler(): |
|
def __init__(self, path=""): |
|
self.model = load_from_checkpoint("model.ckpt") |
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Any]: |
|
""" |
|
data args: |
|
inputs (:obj: `dict[str, str]`) |
|
Return: |
|
A :obj: `dict`: will be serialized and returned |
|
""" |
|
|
|
inputs = data.pop("inputs") |
|
|
|
batch_size = inputs.pop("batch_size") |
|
workers = inputs.pop("workers") |
|
|
|
data = inputs.pop("data") |
|
|
|
model_output = self.model.predict(data, batch_size=batch_size, num_workers=workers, gpus=0) |
|
scores = model_output["scores"] |
|
|
|
return scores |
|
|