mxbai-optimized / handler.py
C-Stuti's picture
Upload 3 files
901aaed verified
from transformers import Pipeline
import torch.nn.functional as F
import torch
from ts.torch_handler.base_handler import BaseHandler
import logging
import os
import transformers
from transformers import AutoTokenizer
logger = logging.getLogger(__name__)
logger.info("Transformers version %s", transformers.__version__)
from optimum.onnxruntime import ORTModelForFeatureExtraction
def mean_pooling(model_output, attention_mask):
token_embeddings = model_output[0]
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
class SentenceEmbeddingHandler(BaseHandler):
def __init__(self):
super(SentenceEmbeddingHandler, self).__init__()
self._context = None
self.initialized = False
class SentenceEmbeddingPipeline(Pipeline):
def initialize(self, context):
"""
Initialize function loads the model and the tokenizer
Args:
context (context): It is a JSON Object containing information
pertaining to the model artifacts parameters.
Raises:
RuntimeError: Raises the Runtime error when the model or
tokenizer is missing
"""
properties = context.system_properties
self.manifest = context.manifest
model_dir = properties.get("model_dir")
# use GPU if available
self.device = torch.device(
"cuda:" + str(properties.get("gpu_id"))
if torch.cuda.is_available() and properties.get("gpu_id") is not None
else "cpu"
)
logger.info(f'Using device {self.device}')
# load the model
model_file = self.manifest['model']['modelFile']
model_path = os.path.join(model_dir, model_file)
if os.path.isfile(model_path):
# self.model = AutoModel.from_pretrained(model_dir)
self.model = ORTModelForFeatureExtraction.from_pretrained(model_dir, file_name="model_optimized.onnx")
self.model.to(self.device)
logger.info(f'Successfully loaded model from {model_file}')
else:
raise RuntimeError('Missing the model file')
# load tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
if self.tokenizer is not None:
logger.info('Successfully loaded tokenizer')
else:
raise RuntimeError('Missing tokenizer')
self.initialized = True
def _sanitize_parameters(self, **kwargs):
# we don't have any hyperameters to sanitize
preprocess_kwargs = {}
return preprocess_kwargs, {}, {}
def preprocess_text(self, inputs):
encoded_inputs = self.tokenizer(inputs, padding=True, truncation=True, return_tensors='pt')
return encoded_inputs
def preprocess(self, requests):
"""
Tokenize the input text using the suitable tokenizer and convert
it to tensor
If token_ids is provided, the json must be of the form
{'input_ids': [[101, 102]], 'token_type_ids': [[0, 0]], 'attention_mask': [[1, 1]]}
Args:
requests: A list containing a dictionary, might be in the form
of [{'body': json_file}] or [{'data': json_file}] or [{'token_ids': json_file}]
Returns:
the tensor containing the batch of token vectors.
"""
# unpack the data
data = requests[0].get('body')
if data is None:
data = requests[0].get('data')
texts = data.get('input')
if texts is not None:
logger.info('Text provided')
return self.preprocess_text(texts)
encodings = data.get('encodings')
if encodings is not None:
logger.info('Encodings provided')
return transformers.BatchEncoding(data={k: torch.tensor(v) for k, v in encodings.items()})
raise Exception("unsupported payload")
def inference(self, model_inputs):
outputs = self.model(**model_inputs)
sentence_embeddings = mean_pooling(outputs, model_inputs['attention_mask'])
sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
return sentence_embeddings
def postprocess(self, outputs):
formatted_outputs = []
data=[outputs.tolist()]
for dat in data:
formatted_outputs.append({"status":"success","data":dat})
return formatted_outputs