Spaces:
Starting
on
T4
Starting
on
T4
from typing import Any, List, Optional, Sequence | |
import numpy as np | |
from pydantic.v1 import PrivateAttr | |
from fastembed.common import OnnxProvider | |
from semantic_router.encoders import BaseEncoder | |
class FastEmbedEncoder(BaseEncoder): | |
type: str = "fastembed" | |
name: str = "BAAI/bge-small-en-v1.5" | |
max_length: int = 512 | |
cache_dir: Optional[str] = None | |
threads: Optional[int] = None | |
providers: Optional[Sequence[OnnxProvider]] = None | |
_client: Any = PrivateAttr() | |
def __init__( | |
self, score_threshold: float = 0.5, **data | |
): # TODO default score_threshold not thoroughly tested, should optimize | |
super().__init__(score_threshold=score_threshold, **data) | |
self._client = self._initialize_client() | |
def _initialize_client(self): | |
try: | |
from fastembed import TextEmbedding | |
except ImportError: | |
raise ImportError( | |
"Please install fastembed to use FastEmbedEncoder. " | |
"You can install it with: " | |
"`pip install 'semantic-router[fastembed]'`" | |
) | |
embedding_args = { | |
"model_name": self.name, | |
"max_length": self.max_length, | |
"cache_dir": self.cache_dir, | |
"threads": self.threads, | |
"providers": self.providers | |
} | |
embedding_args = {k: v for k, v in embedding_args.items() if v is not None} | |
embedding = TextEmbedding(**embedding_args) | |
return embedding | |
def __call__(self, docs: List[str], batch_size: int = 32, parallel: Optional[int] = None) -> List[List[float]]: | |
try: | |
embeds: List[np.ndarray] = list(self._client.embed(docs, batch_size, parallel)) | |
embeddings: List[List[float]] = [e.tolist() for e in embeds] | |
return embeddings | |
except Exception as e: | |
raise ValueError(f"FastEmbed embed failed. Error: {e}") from e | |