ai-virtual-assistant / rag /retrievers.py
yrobel-lima's picture
Upload 4 files
e35585c verified
raw
history blame
2.98 kB
import os
from functools import lru_cache
from typing import Literal
from langchain_core.vectorstores import VectorStoreRetriever
from langchain_openai import OpenAIEmbeddings
from langchain_qdrant import FastEmbedSparse, QdrantVectorStore, RetrievalMode
os.environ["GRPC_VERBOSITY"] = "NONE"
class RetrieversConfig:
REQUIRED_ENV_VARS = ["QDRANT_API_KEY", "QDRANT_URL", "OPENAI_API_KEY"]
def __init__(
self,
dense_model_name: Literal["text-embedding-3-small"] = "text-embedding-3-small",
sparse_model_name: Literal[
"prithivida/Splade_PP_en_v1"
] = "prithivida/Splade_PP_en_v1",
):
self._validate_environment()
self.qdrant_url = os.getenv("QDRANT_URL")
self.qdrant_api_key = os.getenv("QDRANT_API_KEY")
self.dense_model_name = dense_model_name
self.sparse_model_name = sparse_model_name
@staticmethod
def _validate_environment():
missing_vars = [
var
for var in RetrieversConfig.REQUIRED_ENV_VARS
if not os.getenv(var, "").strip()
]
if missing_vars:
raise EnvironmentError(
f"Missing or empty environment variable(s): {', '.join(missing_vars)}"
)
@property
@lru_cache(maxsize=2)
def dense_embeddings(self):
return OpenAIEmbeddings(model=self.dense_model_name)
@property
@lru_cache(maxsize=2)
def sparse_embeddings(self):
return FastEmbedSparse(model_name=self.sparse_model_name)
@lru_cache(maxsize=8)
def get_qdrant_retriever(
self,
collection_name: str,
dense_vector_name: str,
sparse_vector_name: str,
k: int = 5,
) -> VectorStoreRetriever:
qdrantdb = QdrantVectorStore.from_existing_collection(
embedding=self.dense_embeddings,
sparse_embedding=self.sparse_embeddings,
url=self.qdrant_url,
api_key=self.qdrant_api_key,
prefer_grpc=True,
collection_name=collection_name,
retrieval_mode=RetrievalMode.HYBRID,
vector_name=dense_vector_name,
sparse_vector_name=sparse_vector_name,
)
return qdrantdb.as_retriever(search_kwargs={"k": k})
def get_practitioners_retriever(self, k: int = 5) -> VectorStoreRetriever:
return self.get_qdrant_retriever(
collection_name="practitioners_hybrid_db",
dense_vector_name="practitioners_dense_vectors",
sparse_vector_name="practitioners_sparse_vectors",
k=k,
)
def get_documents_retriever(self, k: int = 5) -> VectorStoreRetriever:
return self.get_qdrant_retriever(
collection_name="docs_hybrid_db",
dense_vector_name="docs_dense_vectors",
sparse_vector_name="docs_sparse_vectors",
k=k,
)