|
from transformers import Pipeline |
|
import torch.nn.functional as F |
|
import torch |
|
from ts.torch_handler.base_handler import BaseHandler |
|
import logging |
|
import os |
|
import transformers |
|
from transformers import AutoTokenizer |
|
logger = logging.getLogger(__name__) |
|
logger.info("Transformers version %s", transformers.__version__) |
|
from optimum.onnxruntime import ORTModelForFeatureExtraction |
|
|
|
def mean_pooling(model_output, attention_mask): |
|
token_embeddings = model_output[0] |
|
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() |
|
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) |
|
|
|
class SentenceEmbeddingHandler(BaseHandler): |
|
def __init__(self): |
|
super(SentenceEmbeddingHandler, self).__init__() |
|
self._context = None |
|
self.initialized = False |
|
class SentenceEmbeddingPipeline(Pipeline): |
|
def initialize(self, context): |
|
""" |
|
Initialize function loads the model and the tokenizer |
|
|
|
Args: |
|
context (context): It is a JSON Object containing information |
|
pertaining to the model artifacts parameters. |
|
|
|
Raises: |
|
RuntimeError: Raises the Runtime error when the model or |
|
tokenizer is missing |
|
""" |
|
|
|
properties = context.system_properties |
|
self.manifest = context.manifest |
|
model_dir = properties.get("model_dir") |
|
|
|
|
|
self.device = torch.device( |
|
"cuda:" + str(properties.get("gpu_id")) |
|
if torch.cuda.is_available() and properties.get("gpu_id") is not None |
|
else "cpu" |
|
) |
|
logger.info(f'Using device {self.device}') |
|
|
|
|
|
model_file = self.manifest['model']['modelFile'] |
|
model_path = os.path.join(model_dir, model_file) |
|
|
|
if os.path.isfile(model_path): |
|
|
|
self.model = ORTModelForFeatureExtraction.from_pretrained(model_dir, file_name="model_optimized.onnx") |
|
self.model.to(self.device) |
|
|
|
logger.info(f'Successfully loaded model from {model_file}') |
|
else: |
|
raise RuntimeError('Missing the model file') |
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_dir) |
|
if self.tokenizer is not None: |
|
logger.info('Successfully loaded tokenizer') |
|
else: |
|
raise RuntimeError('Missing tokenizer') |
|
|
|
self.initialized = True |
|
def _sanitize_parameters(self, **kwargs): |
|
|
|
preprocess_kwargs = {} |
|
return preprocess_kwargs, {}, {} |
|
|
|
def preprocess_text(self, inputs): |
|
encoded_inputs = self.tokenizer(inputs, padding=True, truncation=True, return_tensors='pt') |
|
return encoded_inputs |
|
|
|
def preprocess(self, requests): |
|
""" |
|
Tokenize the input text using the suitable tokenizer and convert |
|
it to tensor |
|
|
|
If token_ids is provided, the json must be of the form |
|
{'input_ids': [[101, 102]], 'token_type_ids': [[0, 0]], 'attention_mask': [[1, 1]]} |
|
|
|
Args: |
|
requests: A list containing a dictionary, might be in the form |
|
of [{'body': json_file}] or [{'data': json_file}] or [{'token_ids': json_file}] |
|
Returns: |
|
the tensor containing the batch of token vectors. |
|
""" |
|
|
|
|
|
data = requests[0].get('body') |
|
if data is None: |
|
data = requests[0].get('data') |
|
|
|
texts = data.get('input') |
|
if texts is not None: |
|
logger.info('Text provided') |
|
return self.preprocess_text(texts) |
|
|
|
encodings = data.get('encodings') |
|
if encodings is not None: |
|
logger.info('Encodings provided') |
|
return transformers.BatchEncoding(data={k: torch.tensor(v) for k, v in encodings.items()}) |
|
|
|
raise Exception("unsupported payload") |
|
def inference(self, model_inputs): |
|
outputs = self.model(**model_inputs) |
|
sentence_embeddings = mean_pooling(outputs, model_inputs['attention_mask']) |
|
sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1) |
|
return sentence_embeddings |
|
|
|
def postprocess(self, outputs): |
|
formatted_outputs = [] |
|
data=[outputs.tolist()] |
|
for dat in data: |
|
formatted_outputs.append({"status":"success","data":dat}) |
|
return formatted_outputs |