|
import requests |
|
import time |
|
import os |
|
import numpy as np |
|
from typing import Any, List, Optional, Dict |
|
from pydantic.v1 import PrivateAttr |
|
from fastembed.common.utils import normalize |
|
from semantic_router.encoders import BaseEncoder |
|
from semantic_router.utils.logger import logger |
|
|
|
|
|
class OptimumEncoder(BaseEncoder): |
|
name: str = "mixedbread-ai/mxbai-embed-large-v1" |
|
type: str = "huggingface" |
|
score_threshold: float = 0.5 |
|
tokenizer_kwargs: Dict = {} |
|
model_kwargs: Dict = {} |
|
device: Optional[str] = None |
|
_tokenizer: Any = PrivateAttr() |
|
_model: Any = PrivateAttr() |
|
_torch: Any = PrivateAttr() |
|
|
|
def __init__(self, **data): |
|
super().__init__(**data) |
|
self._tokenizer, self._model = self._initialize_hf_model() |
|
|
|
def _initialize_hf_model(self): |
|
try: |
|
import onnxruntime as ort |
|
from optimum.onnxruntime import ORTModelForFeatureExtraction |
|
except ImportError: |
|
raise ImportError( |
|
"Please install optimum and onnxruntime to use OptimumEncoder. " |
|
"You can install it with: " |
|
"`pip install transformers optimum[onnxruntime-gpu]`" |
|
) |
|
|
|
try: |
|
import torch |
|
except ImportError: |
|
raise ImportError( |
|
"Please install Pytorch to use OptimumEncoder. " |
|
"You can install it with: " |
|
"`pip install semantic-router[local]`" |
|
) |
|
try: |
|
from transformers import AutoTokenizer |
|
except ImportError: |
|
raise ImportError( |
|
"Please install transformers to use OptimumEncoder. " |
|
"You can install it with: " |
|
"`pip install semantic-router[local]`" |
|
) |
|
|
|
self._torch = torch |
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
self.name, |
|
**self.tokenizer_kwargs, |
|
) |
|
|
|
provider_options = { |
|
"trt_engine_cache_enable": True, |
|
"trt_engine_cache_path": os.getenv('HF_HOME'), |
|
"trt_fp16_enable": True |
|
} |
|
|
|
session_options = ort.SessionOptions() |
|
session_options.log_severity_level = 0 |
|
|
|
ort_model = ORTModelForFeatureExtraction.from_pretrained( |
|
model_id=self.name, |
|
file_name='model_fp16.onnx', |
|
subfolder='onnx', |
|
provider='TensorrtExecutionProvider', |
|
provider_options=provider_options, |
|
session_options=session_options, |
|
**self.model_kwargs |
|
) |
|
|
|
print("Building engine for a short sequence...") |
|
short_text = ["short"] |
|
short_encoded_input = tokenizer( |
|
short_text, padding=True, truncation=True, return_tensors="pt" |
|
).to(self.device) |
|
short_output = ort_model(**short_encoded_input) |
|
|
|
print("Building engine for a long sequence...") |
|
long_text = ["a very long input just for demo purpose, this is very long" * 10] |
|
long_encoded_input = tokenizer( |
|
long_text, padding=True, truncation=True, return_tensors="pt" |
|
).to(self.device) |
|
long_output = ort_model(**long_encoded_input) |
|
|
|
text = ["Replace me by any text you'd like."] |
|
encoded_input = tokenizer( |
|
text, padding=True, truncation=True, return_tensors="pt" |
|
).to(self.device) |
|
|
|
for i in range(3): |
|
output = ort_model(**encoded_input) |
|
|
|
return tokenizer, ort_model |
|
|
|
def __call__( |
|
self, |
|
docs: List[str], |
|
batch_size: int = 32, |
|
normalize_embeddings: bool = True, |
|
pooling_strategy: str = "mean", |
|
convert_to_numpy: bool = False |
|
) -> List[List[float]] | List[np.ndarray]: |
|
all_embeddings = [] |
|
for i in range(0, len(docs), batch_size): |
|
batch_docs = docs[i : i + batch_size] |
|
|
|
encoded_input = self._tokenizer( |
|
batch_docs, padding=True, truncation=True, return_tensors="pt" |
|
).to(self.device) |
|
|
|
with self._torch.no_grad(): |
|
model_output = self._model(**encoded_input) |
|
|
|
if pooling_strategy == "mean": |
|
embeddings = self._mean_pooling( |
|
model_output, encoded_input["attention_mask"] |
|
) |
|
elif pooling_strategy == "max": |
|
embeddings = self._max_pooling( |
|
model_output, encoded_input["attention_mask"] |
|
) |
|
else: |
|
raise ValueError( |
|
"Invalid pooling_strategy. Please use 'mean' or 'max'." |
|
) |
|
|
|
if normalize_embeddings: |
|
embeddings = self._torch.nn.functional.normalize(embeddings, p=2, dim=1) |
|
|
|
if convert_to_numpy: |
|
embeddings.detach().cpu().numpy() |
|
else: |
|
embeddings.tolist() |
|
|
|
all_embeddings.extend(embeddings) |
|
|
|
return all_embeddings |
|
|
|
def _mean_pooling(self, model_output, attention_mask): |
|
token_embeddings = model_output[0] |
|
input_mask_expanded = ( |
|
attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() |
|
) |
|
return self._torch.sum( |
|
token_embeddings * input_mask_expanded, 1 |
|
) / self._torch.clamp(input_mask_expanded.sum(1), min=1e-9) |
|
|
|
def _max_pooling(self, model_output, attention_mask): |
|
token_embeddings = model_output[0] |
|
input_mask_expanded = ( |
|
attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() |
|
) |
|
token_embeddings[input_mask_expanded == 0] = -1e9 |
|
return self._torch.max(token_embeddings, 1)[0] |