File size: 4,550 Bytes
901aaed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
from transformers import Pipeline
import torch.nn.functional as F
import torch
from ts.torch_handler.base_handler import BaseHandler
import logging
import os
import transformers
from transformers import AutoTokenizer
logger = logging.getLogger(__name__)
logger.info("Transformers version %s", transformers.__version__)
from optimum.onnxruntime import ORTModelForFeatureExtraction

def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0]
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
 
class SentenceEmbeddingHandler(BaseHandler):
    def __init__(self):
        super(SentenceEmbeddingHandler, self).__init__()
        self._context = None
        self.initialized = False
class SentenceEmbeddingPipeline(Pipeline):
    def initialize(self, context):
        """
        Initialize function loads the model and the tokenizer

        Args:
            context (context): It is a JSON Object containing information
            pertaining to the model artifacts parameters.

        Raises:
            RuntimeError: Raises the Runtime error when the model or
            tokenizer is missing
        """

        properties = context.system_properties
        self.manifest = context.manifest
        model_dir = properties.get("model_dir")

        # use GPU if available
        self.device = torch.device(
            "cuda:" + str(properties.get("gpu_id"))
            if torch.cuda.is_available() and properties.get("gpu_id") is not None
            else "cpu"
        )
        logger.info(f'Using device {self.device}')

        # load the model
        model_file = self.manifest['model']['modelFile']
        model_path = os.path.join(model_dir, model_file)

        if os.path.isfile(model_path):
            # self.model = AutoModel.from_pretrained(model_dir)
            self.model = ORTModelForFeatureExtraction.from_pretrained(model_dir, file_name="model_optimized.onnx")
            self.model.to(self.device)
    
            logger.info(f'Successfully loaded model from {model_file}')
        else:
            raise RuntimeError('Missing the model file')

        # load tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
        if self.tokenizer is not None:
            logger.info('Successfully loaded tokenizer')
        else:
            raise RuntimeError('Missing tokenizer')

        self.initialized = True
    def _sanitize_parameters(self, **kwargs):
        # we don't have any hyperameters to sanitize
        preprocess_kwargs = {}
        return preprocess_kwargs, {}, {}
 
    def preprocess_text(self, inputs):
        encoded_inputs = self.tokenizer(inputs, padding=True, truncation=True, return_tensors='pt')
        return encoded_inputs
    
    def preprocess(self, requests):
        """
        Tokenize the input text using the suitable tokenizer and convert 
        it to tensor

        If token_ids is provided, the json must be of the form
        {'input_ids': [[101, 102]], 'token_type_ids': [[0, 0]], 'attention_mask': [[1, 1]]}

        Args:
            requests: A list containing a dictionary, might be in the form
            of [{'body': json_file}] or [{'data': json_file}] or [{'token_ids': json_file}]
        Returns:
            the tensor containing the batch of token vectors.
        """

        # unpack the data
        data = requests[0].get('body')
        if data is None:
            data = requests[0].get('data')
        
        texts = data.get('input')
        if texts is not None:
            logger.info('Text provided')
            return self.preprocess_text(texts)

        encodings = data.get('encodings')
        if encodings is not None:
            logger.info('Encodings provided')
            return transformers.BatchEncoding(data={k: torch.tensor(v) for k, v in encodings.items()})

        raise Exception("unsupported payload")
    def inference(self, model_inputs):
        outputs = self.model(**model_inputs)
        sentence_embeddings = mean_pooling(outputs, model_inputs['attention_mask'])
        sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
        return  sentence_embeddings
 
    def postprocess(self, outputs):
        formatted_outputs = []
        data=[outputs.tolist()]
        for dat in data:
            formatted_outputs.append({"status":"success","data":dat})
        return formatted_outputs