|
import requests |
|
import time |
|
import os |
|
import numpy as np |
|
from tqdm import tqdm |
|
from typing import Any, List, Optional, Dict |
|
from langchain_core.embeddings import Embeddings |
|
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator |
|
|
|
|
|
class OptimumEncoder(BaseModel, Embeddings): |
|
_tokenizer: Any |
|
_model: Any |
|
_torch: Any |
|
|
|
def __init__( |
|
self, |
|
name: str = "mixedbread-ai/mxbai-embed-large-v1", |
|
device: Optional[str] = None, |
|
cache_dir: Optional[str] = None, |
|
**kwargs: Any |
|
)-> None: |
|
super().__init__(**kwargs) |
|
self.name = name |
|
self.device = device |
|
self.cache_dir = cache_dir |
|
|
|
self._tokenizer, self._model = self._initialize_hf_model() |
|
|
|
|
|
def _initialize_hf_model(self): |
|
try: |
|
import onnxruntime as ort |
|
from optimum.onnxruntime import ORTModelForFeatureExtraction |
|
except ImportError: |
|
raise ImportError( |
|
"Please install optimum and onnxruntime to use OptimumEncoder. " |
|
"You can install it with: " |
|
"`pip install transformers optimum[onnxruntime-gpu]`" |
|
) |
|
|
|
try: |
|
import torch |
|
except ImportError: |
|
raise ImportError( |
|
"Please install Pytorch to use OptimumEncoder. " |
|
"You can install it with: " |
|
"`pip install semantic-router[local]`" |
|
) |
|
try: |
|
from transformers import AutoTokenizer |
|
except ImportError: |
|
raise ImportError( |
|
"Please install transformers to use OptimumEncoder. " |
|
"You can install it with: " |
|
"`pip install semantic-router[local]`" |
|
) |
|
|
|
self._torch = torch |
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
self.name |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
session_options = ort.SessionOptions() |
|
session_options.log_severity_level = 0 |
|
|
|
ort_model = ORTModelForFeatureExtraction.from_pretrained( |
|
model_id=self.name, |
|
file_name='model_fp16.onnx', |
|
subfolder='onnx', |
|
provider='CUDAExecutionProvider', |
|
use_io_binding=True, |
|
|
|
session_options=session_options |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return tokenizer, ort_model |
|
|
|
class Config: |
|
"""Configuration for this pydantic object.""" |
|
|
|
extra = Extra.allow |
|
|
|
def embed_documents( |
|
self, |
|
docs: List[str], |
|
batch_size: int = 32, |
|
normalize_embeddings: bool = True, |
|
pooling_strategy: str = "mean" |
|
) -> List[List[float]]: |
|
all_embeddings = [] |
|
for i in tqdm(range(0, len(docs), batch_size)): |
|
batch_docs = docs[i : i + batch_size] |
|
|
|
encoded_input = self._tokenizer( |
|
batch_docs, padding=True, truncation=True, return_tensors="pt" |
|
).to(self.device) |
|
|
|
with self._torch.no_grad(): |
|
model_output = self._model(**encoded_input) |
|
|
|
if pooling_strategy == "mean": |
|
embeddings = self._mean_pooling( |
|
model_output, encoded_input["attention_mask"] |
|
) |
|
elif pooling_strategy == "max": |
|
embeddings = self._max_pooling( |
|
model_output, encoded_input["attention_mask"] |
|
) |
|
else: |
|
raise ValueError( |
|
"Invalid pooling_strategy. Please use 'mean' or 'max'." |
|
) |
|
|
|
if normalize_embeddings: |
|
embeddings = self._torch.nn.functional.normalize(embeddings, p=2, dim=1) |
|
|
|
all_embeddings.extend(embeddings.tolist()) |
|
|
|
return all_embeddings |
|
|
|
def embed_query( |
|
self, |
|
docs: str, |
|
normalize_embeddings: bool = True, |
|
pooling_strategy: str = "mean" |
|
) -> List[float]: |
|
encoded_input = self._tokenizer( |
|
docs, padding=True, truncation=True, return_tensors="pt" |
|
).to(self.device) |
|
|
|
with self._torch.no_grad(): |
|
model_output = self._model(**encoded_input) |
|
|
|
if pooling_strategy == "mean": |
|
embeddings = self._mean_pooling( |
|
model_output, encoded_input["attention_mask"] |
|
) |
|
elif pooling_strategy == "max": |
|
embeddings = self._max_pooling( |
|
model_output, encoded_input["attention_mask"] |
|
) |
|
else: |
|
raise ValueError( |
|
"Invalid pooling_strategy. Please use 'mean' or 'max'." |
|
) |
|
|
|
if normalize_embeddings: |
|
embeddings = self._torch.nn.functional.normalize(embeddings, p=2, dim=1) |
|
print(embeddings) |
|
return embeddings.tolist() |
|
|
|
def _mean_pooling(self, model_output, attention_mask): |
|
token_embeddings = model_output[0] |
|
input_mask_expanded = ( |
|
attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() |
|
) |
|
return self._torch.sum( |
|
token_embeddings * input_mask_expanded, 1 |
|
) / self._torch.clamp(input_mask_expanded.sum(1), min=1e-9) |
|
|
|
def _max_pooling(self, model_output, attention_mask): |
|
token_embeddings = model_output[0] |
|
input_mask_expanded = ( |
|
attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() |
|
) |
|
token_embeddings[input_mask_expanded == 0] = -1e9 |
|
return self._torch.max(token_embeddings, 1)[0] |