|
from typing import Any, Dict, List, Literal, Optional, Sequence |
|
from fastembed.common import OnnxProvider |
|
import numpy as np |
|
from langchain_core.embeddings import Embeddings |
|
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator |
|
|
|
|
|
class FastEmbedEmbeddingsLc(BaseModel, Embeddings): |
|
"""Qdrant FastEmbedding models. |
|
FastEmbed is a lightweight, fast, Python library built for embedding generation. |
|
See more documentation at: |
|
* https://github.com/qdrant/fastembed/ |
|
* https://qdrant.github.io/fastembed/ |
|
|
|
To use this class, you must install the `fastembed` Python package. |
|
|
|
`pip install fastembed` |
|
Example: |
|
from langchain_community.embeddings import FastEmbedEmbeddings |
|
fastembed = FastEmbedEmbeddingsLc() |
|
""" |
|
|
|
model_name: str = "BAAI/bge-small-en-v1.5" |
|
"""Name of the FastEmbedding model to use |
|
Defaults to "BAAI/bge-small-en-v1.5" |
|
Find the list of supported models at |
|
https://qdrant.github.io/fastembed/examples/Supported_Models/ |
|
""" |
|
|
|
max_length: int = 512 |
|
"""The maximum number of tokens. Defaults to 512. |
|
Unknown behavior for values > 512. |
|
""" |
|
|
|
cache_dir: Optional[str] |
|
"""The path to the cache directory. |
|
Defaults to `local_cache` in the parent directory |
|
""" |
|
|
|
threads: Optional[int] |
|
"""The number of threads single onnxruntime session can use. |
|
Defaults to None |
|
""" |
|
|
|
doc_embed_type: Literal["default", "passage"] = "default" |
|
"""Type of embedding to use for documents |
|
The available options are: "default" and "passage" |
|
""" |
|
|
|
providers: Optional[Sequence[OnnxProvider]] |
|
|
|
batch_size: Optional[int] |
|
|
|
_model: Any |
|
|
|
class Config: |
|
"""Configuration for this pydantic object.""" |
|
|
|
extra = Extra.forbid |
|
|
|
@root_validator() |
|
def validate_environment(cls, values: Dict) -> Dict: |
|
"""Validate that FastEmbed has been installed.""" |
|
model_name = values.get("model_name") |
|
max_length = values.get("max_length") |
|
cache_dir = values.get("cache_dir") |
|
threads = values.get("threads") |
|
providers = values.get("provider") |
|
|
|
try: |
|
|
|
from fastembed import TextEmbedding |
|
|
|
values["_model"] = TextEmbedding( |
|
model_name=model_name, |
|
max_length=max_length, |
|
cache_dir=cache_dir, |
|
threads=threads, |
|
providers=providers |
|
) |
|
except ImportError as ie: |
|
try: |
|
|
|
from fastembed.embedding import FlagEmbedding |
|
|
|
values["_model"] = FlagEmbedding( |
|
model_name=model_name, |
|
max_length=max_length, |
|
cache_dir=cache_dir, |
|
threads=threads, |
|
providers=providers |
|
) |
|
except ImportError: |
|
raise ImportError( |
|
"Could not import 'fastembed' Python package. " |
|
"Please install it with `pip install fastembed`." |
|
) from ie |
|
return values |
|
|
|
def embed_documents(self, texts: List[str], batch_size: int = None) -> List[np.ndarray]: |
|
"""Generate embeddings for documents using FastEmbed. |
|
|
|
Args: |
|
texts: The list of texts to embed. |
|
|
|
Returns: |
|
List of embeddings, one for each text. |
|
""" |
|
return list(self._model.embed(texts, self.batch_size if batch_size == None else batch_size)) |
|
|
|
def embed_query(self, text: str, batch_size: int = None) -> np.ndarray: |
|
"""Generate query embeddings using FastEmbed. |
|
|
|
Args: |
|
text: The text to embed. |
|
|
|
Returns: |
|
Embeddings for the text. |
|
""" |
|
query_embeddings: np.ndarray = next(self._model.embed(text, self.batch_size if batch_size == None else batch_size)) |
|
return query_embeddings |