jpohhhh's picture
Return embeddings as array instead of nested in JSON
a154b3b
raw
history blame
1.67 kB
from typing import Dict, List, Any
from optimum.onnxruntime import ORTModelForFeatureExtraction
from transformers import AutoTokenizer
import torch.nn.functional as F
import torch
# copied from the model card
def mean_pooling(model_output, attention_mask):
token_embeddings = model_output[0] #First element of model_output contains all token embeddings
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
class EndpointHandler():
def __init__(self, path=""):
# load the optimized model
self.model = ORTModelForFeatureExtraction.from_pretrained(path, file_name="model-quantized.onnx")
self.tokenizer = AutoTokenizer.from_pretrained(path)
def __call__(self, data: Any) -> List[List[float]]:
"""
Args:
data (:obj:):
includes the input data and the parameters for the inference.
Return:
A :obj:`list`:. The list contains the embeddings of the inference inputs
"""
inputs = data.get("inputs", data)
# tokenize the input
encoded_inputs = self.tokenizer(inputs, padding=True, truncation=True, return_tensors='pt')
# run the model
outputs = self.model(**encoded_inputs)
# Perform pooling
sentence_embeddings = mean_pooling(outputs, encoded_inputs['attention_mask'])
# Normalize embeddings
sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
# postprocess the prediction
return sentence_embeddings.tolist()