Multipurpose-AI-Agent-Development / optimum_encoder.py
devve1's picture
Update optimum_encoder.py
be242ca verified
raw
history blame contribute delete
No virus
6.66 kB
import requests
import time
import os
import numpy as np
from tqdm import tqdm
from typing import Any, List, Optional, Dict
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
class OptimumEncoder(BaseModel, Embeddings):
_tokenizer: Any
_model: Any
_torch: Any
def __init__(
self,
name: str = "mixedbread-ai/mxbai-embed-large-v1",
device: Optional[str] = None,
cache_dir: Optional[str] = None,
**kwargs: Any
)-> None:
super().__init__(**kwargs)
self.name = name
self.device = device
self.cache_dir = cache_dir
self._tokenizer, self._model = self._initialize_hf_model()
def _initialize_hf_model(self):
try:
import onnxruntime as ort
from optimum.onnxruntime import ORTModelForFeatureExtraction
except ImportError:
raise ImportError(
"Please install optimum and onnxruntime to use OptimumEncoder. "
"You can install it with: "
"`pip install transformers optimum[onnxruntime-gpu]`"
)
try:
import torch
except ImportError:
raise ImportError(
"Please install Pytorch to use OptimumEncoder. "
"You can install it with: "
"`pip install semantic-router[local]`"
)
try:
from transformers import AutoTokenizer
except ImportError:
raise ImportError(
"Please install transformers to use OptimumEncoder. "
"You can install it with: "
"`pip install semantic-router[local]`"
)
self._torch = torch
tokenizer = AutoTokenizer.from_pretrained(
self.name
)
#provider_options = {
# "trt_engine_cache_enable": True,
# "trt_engine_cache_path": os.getenv('HF_HOME'),
# "trt_fp16_enable": True
#}
session_options = ort.SessionOptions()
session_options.log_severity_level = 0
ort_model = ORTModelForFeatureExtraction.from_pretrained(
model_id=self.name,
file_name='model_fp16.onnx',
subfolder='onnx',
provider='CUDAExecutionProvider',
use_io_binding=True,
#provider_options=provider_options,
session_options=session_options
)
# print("Building engine for a short sequence...")
# short_text = ["short"]
# short_encoded_input = tokenizer(
# short_text, padding=True, truncation=True, return_tensors="pt"
# ).to(self.device)
# short_output = ort_model(**short_encoded_input)
# print("Building engine for a long sequence...")
# long_text = ["a very long input just for demo purpose, this is very long" * 10]
# long_encoded_input = tokenizer(
# long_text, padding=True, truncation=True, return_tensors="pt"
# ).to(self.device)
# long_output = ort_model(**long_encoded_input)
# text = ["Replace me by any text you'd like."]
# encoded_input = tokenizer(
# text, padding=True, truncation=True, return_tensors="pt"
# ).to(self.device)
# for i in range(3):
# output = ort_model(**encoded_input)
return tokenizer, ort_model
class Config:
"""Configuration for this pydantic object."""
extra = Extra.allow
def embed_documents(
self,
docs: List[str],
batch_size: int = 32,
normalize_embeddings: bool = True,
pooling_strategy: str = "mean"
) -> List[List[float]]:
all_embeddings = []
for i in tqdm(range(0, len(docs), batch_size)):
batch_docs = docs[i : i + batch_size]
encoded_input = self._tokenizer(
batch_docs, padding=True, truncation=True, return_tensors="pt"
).to(self.device)
with self._torch.no_grad():
model_output = self._model(**encoded_input)
if pooling_strategy == "mean":
embeddings = self._mean_pooling(
model_output, encoded_input["attention_mask"]
)
elif pooling_strategy == "max":
embeddings = self._max_pooling(
model_output, encoded_input["attention_mask"]
)
else:
raise ValueError(
"Invalid pooling_strategy. Please use 'mean' or 'max'."
)
if normalize_embeddings:
embeddings = self._torch.nn.functional.normalize(embeddings, p=2, dim=1)
all_embeddings.extend(embeddings.tolist())
return all_embeddings
def embed_query(
self,
docs: str,
normalize_embeddings: bool = True,
pooling_strategy: str = "mean"
) -> List[float]:
encoded_input = self._tokenizer(
docs, padding=True, truncation=True, return_tensors="pt"
).to(self.device)
with self._torch.no_grad():
model_output = self._model(**encoded_input)
if pooling_strategy == "mean":
embeddings = self._mean_pooling(
model_output, encoded_input["attention_mask"]
)
elif pooling_strategy == "max":
embeddings = self._max_pooling(
model_output, encoded_input["attention_mask"]
)
else:
raise ValueError(
"Invalid pooling_strategy. Please use 'mean' or 'max'."
)
if normalize_embeddings:
embeddings = self._torch.nn.functional.normalize(embeddings, p=2, dim=1)
print(embeddings)
return embeddings.tolist()
def _mean_pooling(self, model_output, attention_mask):
token_embeddings = model_output[0]
input_mask_expanded = (
attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
)
return self._torch.sum(
token_embeddings * input_mask_expanded, 1
) / self._torch.clamp(input_mask_expanded.sum(1), min=1e-9)
def _max_pooling(self, model_output, attention_mask):
token_embeddings = model_output[0]
input_mask_expanded = (
attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
)
token_embeddings[input_mask_expanded == 0] = -1e9
return self._torch.max(token_embeddings, 1)[0]