bge-m3-h / handler.py
zkwang's picture
Update handler.py
201c562 verified
raw
history blame
1.29 kB
import json
from typing import Dict, List, Any
from FlagEmbedding import BGEM3FlagModel
class EndpointHandler():
def __init__(self, path=""):
self.model = BGEM3FlagModel(path, use_fp16=True) # Setting use_fp16 to True speeds up computation with a slight performance degradation
# Preload all the elements you are going to need at inference.
# pseudo:
# self.model= load_model(path)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Args:
data (:obj:):
includes the input data and the parameters for the inference.
Return:
A :obj:`list`:. The object returned should be a list of vector
"""
inputs = data.pop("inputs", data)
parameters = data.pop("parameters", None)
# pass inputs with all kwargs in data
if parameters is not None:
embeddings = self.model.encode(inputs, **parameters)['dense_vecs']
else:
embeddings = self.model.encode(inputs)['dense_vecs']
# postprocess the prediction
list_of_lists = [arr.tolist() for arr in embeddings]
return json.dumps(list_of_lists)
# return self.model.encode(inputs, batch_size=12, max_length=8192)['dense_vecs']