|
from typing import Any, List, Optional, Sequence |
|
|
|
import numpy as np |
|
from pydantic.v1 import PrivateAttr |
|
from fastembed.common import OnnxProvider |
|
from semantic_router.encoders import BaseEncoder |
|
|
|
|
|
class FastEmbedEncoder(BaseEncoder): |
|
type: str = "fastembed" |
|
name: str = "BAAI/bge-small-en-v1.5" |
|
max_length: int = 512 |
|
cache_dir: Optional[str] = None |
|
threads: Optional[int] = None |
|
providers: Optional[Sequence[OnnxProvider]] = None |
|
_client: Any = PrivateAttr() |
|
|
|
def __init__( |
|
self, score_threshold: float = 0.5, **data |
|
): |
|
super().__init__(score_threshold=score_threshold, **data) |
|
self._client = self._initialize_client() |
|
|
|
def _initialize_client(self): |
|
try: |
|
from fastembed import TextEmbedding |
|
except ImportError: |
|
raise ImportError( |
|
"Please install fastembed to use FastEmbedEncoder. " |
|
"You can install it with: " |
|
"`pip install 'semantic-router[fastembed]'`" |
|
) |
|
|
|
embedding_args = { |
|
"model_name": self.name, |
|
"max_length": self.max_length, |
|
"cache_dir": self.cache_dir, |
|
"threads": self.threads, |
|
"providers": self.providers |
|
} |
|
|
|
embedding_args = {k: v for k, v in embedding_args.items() if v is not None} |
|
|
|
embedding = TextEmbedding(**embedding_args) |
|
return embedding |
|
|
|
def __call__(self, docs: List[str], batch_size: int = 32, parallel: Optional[int] = None, convert_to_numpy: bool = False) -> List[List[float]] | List[np.ndarray]: |
|
try: |
|
embeds: List[np.ndarray] = list(self._client.embed(docs, batch_size, parallel)) |
|
if convert_to_numpy: |
|
return embeds |
|
else: |
|
embeddings: List[List[float]] = [e.tolist() for e in embeds] |
|
return embeddings |
|
except Exception as e: |
|
raise ValueError(f"FastEmbed embed failed. Error: {e}") from e |