File size: 6,664 Bytes
59f2128
 
 
5aa60f8
aad97b0
59f2128
a7b7255
7fc23a9
59f2128
 
7fc23a9
 
 
 
59f2128
a9b487c
 
 
 
 
 
 
 
 
 
 
 
 
7fc23a9
59f2128
 
 
7c4c108
db5b5ea
59f2128
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81baed5
59f2128
 
 
 
 
 
 
 
 
 
7fc23a9
59f2128
1a57b91
01e9e45
 
 
 
 
59f2128
db5b5ea
2e66948
 
59f2128
66b9dcc
846c662
66b9dcc
01e9e45
 
 
7fc23a9
59f2128
 
01e9e45
 
 
 
 
 
59f2128
01e9e45
 
 
 
 
 
59f2128
01e9e45
 
 
 
59f2128
01e9e45
 
59f2128
a8f4322
59f2128
a9b487c
 
 
 
 
a7b7255
59f2128
 
 
 
be242ca
 
59f2128
aad97b0
59f2128
 
 
 
 
 
 
 
 
be242ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
378ac54
be242ca
5aa60f8
59f2128
 
3f39241
 
 
 
 
be242ca
3f39241
 
 
 
 
 
 
be242ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3f39241
378ac54
59f2128
 
 
 
378ac54
59f2128
 
 
378ac54
59f2128
 
 
 
 
378ac54
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
import requests
import time
import os
import numpy as np
from tqdm import tqdm
from typing import Any, List, Optional, Dict
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator


class OptimumEncoder(BaseModel, Embeddings):
    _tokenizer: Any
    _model: Any
    _torch: Any

    def __init__(
        self, 
        name: str = "mixedbread-ai/mxbai-embed-large-v1", 
        device: Optional[str] = None, 
        cache_dir: Optional[str] = None,
        **kwargs: Any
    )-> None:
        super().__init__(**kwargs)
        self.name = name
        self.device = device
        self.cache_dir = cache_dir

        self._tokenizer, self._model = self._initialize_hf_model()
        

    def _initialize_hf_model(self):
        try:
            import onnxruntime as ort
            from optimum.onnxruntime import ORTModelForFeatureExtraction
        except ImportError:
            raise ImportError(
                "Please install optimum and onnxruntime to use OptimumEncoder. "
                "You can install it with: "
                "`pip install transformers optimum[onnxruntime-gpu]`"
            )

        try:
            import torch
        except ImportError:
            raise ImportError(
                "Please install Pytorch to use OptimumEncoder. "
                "You can install it with: "
                "`pip install semantic-router[local]`"
            )
        try:
            from transformers import AutoTokenizer
        except ImportError:
            raise ImportError(
                "Please install transformers to use OptimumEncoder. "
                "You can install it with: "
                "`pip install semantic-router[local]`"
            )

        self._torch = torch

        tokenizer = AutoTokenizer.from_pretrained(
            self.name
        )
        
        #provider_options = {
        #    "trt_engine_cache_enable": True,
        #    "trt_engine_cache_path": os.getenv('HF_HOME'),
        #    "trt_fp16_enable": True
        #}

        session_options = ort.SessionOptions()
        session_options.log_severity_level = 0

        ort_model = ORTModelForFeatureExtraction.from_pretrained(
            model_id=self.name, 
            file_name='model_fp16.onnx',
            subfolder='onnx',
            provider='CUDAExecutionProvider',
            use_io_binding=True,
            #provider_options=provider_options,
            session_options=session_options
        )

        # print("Building engine for a short sequence...")
        # short_text = ["short"]
        # short_encoded_input = tokenizer(
        #     short_text, padding=True, truncation=True, return_tensors="pt"
        # ).to(self.device)
        # short_output = ort_model(**short_encoded_input)
        
        # print("Building engine for a long sequence...")
        # long_text = ["a very long input just for demo purpose, this is very long" * 10]
        # long_encoded_input = tokenizer(
        #     long_text, padding=True, truncation=True, return_tensors="pt"
        # ).to(self.device)
        # long_output = ort_model(**long_encoded_input)

        # text = ["Replace me by any text you'd like."]
        # encoded_input = tokenizer(
        #     text, padding=True, truncation=True, return_tensors="pt"
        # ).to(self.device)
        
        # for i in range(3):
        #     output = ort_model(**encoded_input)

        return tokenizer, ort_model

    class Config:
        """Configuration for this pydantic object."""

        extra = Extra.allow

    def embed_documents(
        self,
        docs: List[str],
        batch_size: int = 32,
        normalize_embeddings: bool = True,
        pooling_strategy: str = "mean"
    ) -> List[List[float]]:
        all_embeddings = []
        for i in tqdm(range(0, len(docs), batch_size)):
            batch_docs = docs[i : i + batch_size]

            encoded_input = self._tokenizer(
                batch_docs, padding=True, truncation=True, return_tensors="pt"
            ).to(self.device)

            with self._torch.no_grad():
                model_output = self._model(**encoded_input)

            if pooling_strategy == "mean":
                embeddings = self._mean_pooling(
                    model_output, encoded_input["attention_mask"]
                )
            elif pooling_strategy == "max":
                embeddings = self._max_pooling(
                    model_output, encoded_input["attention_mask"]
                )
            else:
                raise ValueError(
                    "Invalid pooling_strategy. Please use 'mean' or 'max'."
                )

            if normalize_embeddings:
                embeddings = self._torch.nn.functional.normalize(embeddings, p=2, dim=1)
                
            all_embeddings.extend(embeddings.tolist())

        return all_embeddings

    def embed_query(
        self,
        docs: str,
        normalize_embeddings: bool = True,
        pooling_strategy: str = "mean"
    ) -> List[float]:
        encoded_input = self._tokenizer(
            docs, padding=True, truncation=True, return_tensors="pt"
        ).to(self.device)

        with self._torch.no_grad():
            model_output = self._model(**encoded_input)

        if pooling_strategy == "mean":
            embeddings = self._mean_pooling(
                model_output, encoded_input["attention_mask"]
            )
        elif pooling_strategy == "max":
            embeddings = self._max_pooling(
                model_output, encoded_input["attention_mask"]
            )
        else:
            raise ValueError(
                "Invalid pooling_strategy. Please use 'mean' or 'max'."
            )

        if normalize_embeddings:
            embeddings = self._torch.nn.functional.normalize(embeddings, p=2, dim=1)
        print(embeddings)        
        return embeddings.tolist()

    def _mean_pooling(self, model_output, attention_mask):
        token_embeddings = model_output[0]
        input_mask_expanded = (
            attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        )
        return self._torch.sum(
            token_embeddings * input_mask_expanded, 1
        ) / self._torch.clamp(input_mask_expanded.sum(1), min=1e-9)

    def _max_pooling(self, model_output, attention_mask):
        token_embeddings = model_output[0]
        input_mask_expanded = (
            attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        )
        token_embeddings[input_mask_expanded == 0] = -1e9
        return self._torch.max(token_embeddings, 1)[0]