import json from typing import Dict, List, Any from FlagEmbedding import BGEM3FlagModel class EndpointHandler(): def __init__(self, path=""): self.model = BGEM3FlagModel(path, use_fp16=True) # Setting use_fp16 to True speeds up computation with a slight performance degradation # Preload all the elements you are going to need at inference. # pseudo: # self.model= load_model(path) def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ Args: data (:obj:): includes the input data and the parameters for the inference. Return: A :obj:`list`:. The object returned should be a list of vector """ inputs = data.pop("inputs", data) parameters = data.pop("parameters", None) # pass inputs with all kwargs in data if parameters is not None: embeddings = self.model.encode(inputs, **parameters)['dense_vecs'] else: embeddings = self.model.encode(inputs)['dense_vecs'] # postprocess the prediction list_of_lists = [arr.tolist() for arr in embeddings] return json.dumps(list_of_lists) # return self.model.encode(inputs, batch_size=12, max_length=8192)['dense_vecs']