File size: 6,168 Bytes
59f2128
 
 
5aa60f8
59f2128
 
 
 
 
 
 
 
66b9dcc
59f2128
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7c4c108
db5b5ea
59f2128
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a9ba889
 
59f2128
 
db5b5ea
2e66948
 
59f2128
66b9dcc
 
 
2e66948
49ec730
2e66948
59f2128
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3af7e0a
dfac9fe
3af7e0a
59f2128
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3af7e0a
 
 
 
59f2128
a1ddbe5
 
 
 
3af7e0a
59f2128
5aa60f8
59f2128
 
3af7e0a
59f2128
 
 
 
3af7e0a
59f2128
 
3af7e0a
 
 
 
59f2128
3af7e0a
59f2128
 
 
 
 
3af7e0a
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
import requests
import time
import os
import numpy as np
from typing import Any, List, Optional, Dict
from pydantic.v1 import PrivateAttr

from semantic_router.encoders import BaseEncoder
from semantic_router.utils.logger import logger


class OptimumEncoder(BaseEncoder):
    name: str = "mixedbread-ai/mxbai-embed-large-v1"
    type: str = "huggingface"
    score_threshold: float = 0.5
    tokenizer_kwargs: Dict = {}
    model_kwargs: Dict = {}
    device: Optional[str] = None
    _tokenizer: Any = PrivateAttr()
    _model: Any = PrivateAttr()
    _torch: Any = PrivateAttr()

    def __init__(self, **data):
        super().__init__(**data)
        self._tokenizer, self._model = self._initialize_hf_model()

    def _initialize_hf_model(self):
        try:
            import onnxruntime as ort
            from optimum.onnxruntime import ORTModelForFeatureExtraction
        except ImportError:
            raise ImportError(
                "Please install optimum and onnxruntime to use OptimumEncoder. "
                "You can install it with: "
                "`pip install transformers optimum[onnxruntime-gpu]`"
            )

        try:
            import torch
        except ImportError:
            raise ImportError(
                "Please install Pytorch to use OptimumEncoder. "
                "You can install it with: "
                "`pip install semantic-router[local]`"
            )
        try:
            from transformers import AutoTokenizer
        except ImportError:
            raise ImportError(
                "Please install transformers to use OptimumEncoder. "
                "You can install it with: "
                "`pip install semantic-router[local]`"
            )

        self._torch = torch

        tokenizer = AutoTokenizer.from_pretrained(
            self.name,
            **self.tokenizer_kwargs,
        )

        provider_options = {
            "trt_engine_cache_enable": True,
            "trt_engine_cache_path": os.getenv('HF_HOME'),
            "trt_fp16_enable": True
        }

        session_options = ort.SessionOptions()
        session_options.log_severity_level = 0

        ort_model = ORTModelForFeatureExtraction.from_pretrained(
            model_id=self.name, 
            model_type='model_fp16',
            subfolder='onnx',
            provider=['TensorrtExecutionProvider'],
            provider_options=provider_options,
            session_options=session_options,
            **self.model_kwargs
        )

        print("Building engine for a short sequence...")
        short_text = ["short"]
        short_encoded_input = tokenizer(
            short_text, padding=True, truncation=True, return_tensors="pt"
        ).to("cuda")
        short_output = ort_model(**short_encoded_input)
        
        print("Building engine for a long sequence...")
        long_text = ["a very long input just for demo purpose, this is very long" * 10]
        long_encoded_input = tokenizer(
            long_text, padding=True, truncation=True, return_tensors="pt"
        ).to(self.device)
        long_output = ort_model(**long_encoded_input)

        text = ["Replace me by any text you'd like."]
        encoded_input = tokenizer(
            text, padding=True, truncation=True, return_tensors="pt"
        ).to(self.device)
        
        for i in range(3):
            output = ort_model.generate(**encoded_input)

        return tokenizer, model

    def __call__(
        self,
        docs: List[str],
        batch_size: int = 32,
        normalize_embeddings: bool = True,
        pooling_strategy: str = "mean",
        matryoshka_dim: int = 1024,
        convert_to_numpy: bool = False
    ) -> List[List[float]] | List[np.ndarray]:
        all_embeddings = []
        for i in range(0, len(docs), batch_size):
            batch_docs = docs[i : i + batch_size]

            encoded_input = self._tokenizer(
                batch_docs, padding=True, truncation=True, return_tensors="pt"
            ).to(self.device)

            with self._torch.no_grad():
                model_output = self._model(**encoded_input)

            if pooling_strategy == "mean":
                embeddings = self._mean_pooling(
                    model_output, encoded_input["attention_mask"]
                )
            elif pooling_strategy == "max":
                embeddings = self._max_pooling(
                    model_output, encoded_input["attention_mask"]
                )
            else:
                raise ValueError(
                    "Invalid pooling_strategy. Please use 'mean' or 'max'."
                )

            if normalize_embeddings:
                if convert_to_numpy:
                    embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
                else:
                    embeddings = self._torch.nn.functional.normalize(embeddings, p=2, dim=1).detach().cpu().tolist()

            original_dimensions = embeddings.size(1)

            if original_dimensions > matryoshka_dim:
                embeddings = embeddings[:, :matryoshka_dim]
                
            all_embeddings.extend(embeddings)

        return all_embeddings

    def _mean_pooling(self, model_output, attention_mask, convert_to_numpy):
        token_embeddings = model_output[0]
        input_mask_expanded = (
            attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        )
        sum = self._torch.sum(
            token_embeddings * input_mask_expanded, 1
        ) / self._torch.clamp(input_mask_expanded.sum(1), min=1e-9)
        if convert_to_numpy:
            return sum.detach().cpu().numpy()
        else:
            return max

    def _max_pooling(self, model_output, attention_mask, convert_to_numpy):
        token_embeddings = model_output[0]
        input_mask_expanded = (
            attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        )
        token_embeddings[input_mask_expanded == 0] = -1e9
        max = self._torch.max(token_embeddings, 1)[0]
        if convert_to_numpy:
            return max.detach().cpu().numpy()
        else:
            return max