import requests import time import os import numpy as np from tqdm import tqdm from typing import Any, List, Optional, Dict from pydantic.v1 import PrivateAttr from fastembed.common.utils import normalize from semantic_router.encoders import BaseEncoder from semantic_router.utils.logger import logger class OptimumEncoder(BaseEncoder): name: str = "mixedbread-ai/mxbai-embed-large-v1" type: str = "huggingface" score_threshold: float = 0.5 tokenizer_kwargs: Dict = {} model_kwargs: Dict = {} device: Optional[str] = None _tokenizer: Any = PrivateAttr() _model: Any = PrivateAttr() _torch: Any = PrivateAttr() def __init__(self, **data): super().__init__(**data) self._tokenizer, self._model = self._initialize_hf_model() def _initialize_hf_model(self): try: import onnxruntime as ort from optimum.onnxruntime import ORTModelForFeatureExtraction except ImportError: raise ImportError( "Please install optimum and onnxruntime to use OptimumEncoder. " "You can install it with: " "`pip install transformers optimum[onnxruntime-gpu]`" ) try: import torch except ImportError: raise ImportError( "Please install Pytorch to use OptimumEncoder. " "You can install it with: " "`pip install semantic-router[local]`" ) try: from transformers import AutoTokenizer except ImportError: raise ImportError( "Please install transformers to use OptimumEncoder. " "You can install it with: " "`pip install semantic-router[local]`" ) self._torch = torch tokenizer = AutoTokenizer.from_pretrained( self.name, **self.tokenizer_kwargs, ) #provider_options = { # "trt_engine_cache_enable": True, # "trt_engine_cache_path": os.getenv('HF_HOME'), # "trt_fp16_enable": True #} session_options = ort.SessionOptions() session_options.log_severity_level = 0 ort_model = ORTModelForFeatureExtraction.from_pretrained( model_id=self.name, file_name='model_fp16.onnx', subfolder='onnx', provider='CUDAExecutionProvider', use_io_binding=True, #provider_options=provider_options, session_options=session_options, **self.model_kwargs ) # print("Building engine for a short sequence...") # short_text = ["short"] # short_encoded_input = tokenizer( # short_text, padding=True, truncation=True, return_tensors="pt" # ).to(self.device) # short_output = ort_model(**short_encoded_input) # print("Building engine for a long sequence...") # long_text = ["a very long input just for demo purpose, this is very long" * 10] # long_encoded_input = tokenizer( # long_text, padding=True, truncation=True, return_tensors="pt" # ).to(self.device) # long_output = ort_model(**long_encoded_input) # text = ["Replace me by any text you'd like."] # encoded_input = tokenizer( # text, padding=True, truncation=True, return_tensors="pt" # ).to(self.device) # for i in range(3): # output = ort_model(**encoded_input) return tokenizer, ort_model def __call__( self, docs: List[str], batch_size: int = 32, normalize_embeddings: bool = True, pooling_strategy: str = "mean", convert_to_numpy: bool = False ) -> List[List[float]] | List[np.ndarray]: all_embeddings = [] for i in tqdm(range(0, len(docs), batch_size)): batch_docs = docs[i : i + batch_size] encoded_input = self._tokenizer( batch_docs, padding=True, truncation=True, return_tensors="pt" ).to(self.device) with self._torch.no_grad(): model_output = self._model(**encoded_input) if pooling_strategy == "mean": embeddings = self._mean_pooling( model_output, encoded_input["attention_mask"] ) elif pooling_strategy == "max": embeddings = self._max_pooling( model_output, encoded_input["attention_mask"] ) else: raise ValueError( "Invalid pooling_strategy. Please use 'mean' or 'max'." ) if normalize_embeddings: embeddings = self._torch.nn.functional.normalize(embeddings, p=2, dim=1) if convert_to_numpy: embeddings = embeddings.detach().cpu().numpy() else: embeddings = embeddings.tolist() all_embeddings.extend(embeddings) return all_embeddings def _mean_pooling(self, model_output, attention_mask): token_embeddings = model_output[0] input_mask_expanded = ( attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() ) return self._torch.sum( token_embeddings * input_mask_expanded, 1 ) / self._torch.clamp(input_mask_expanded.sum(1), min=1e-9) def _max_pooling(self, model_output, attention_mask): token_embeddings = model_output[0] input_mask_expanded = ( attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() ) token_embeddings[input_mask_expanded == 0] = -1e9 return self._torch.max(token_embeddings, 1)[0]