Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import os | |
from functools import lru_cache | |
from typing import Literal | |
from langchain_core.vectorstores import VectorStoreRetriever | |
from langchain_openai import OpenAIEmbeddings | |
from langchain_qdrant import FastEmbedSparse, QdrantVectorStore, RetrievalMode | |
os.environ["GRPC_VERBOSITY"] = "NONE" | |
class RetrieversConfig: | |
REQUIRED_ENV_VARS = ["QDRANT_API_KEY", "QDRANT_URL", "OPENAI_API_KEY"] | |
def __init__( | |
self, | |
dense_model_name: Literal["text-embedding-3-small"] = "text-embedding-3-small", | |
sparse_model_name: Literal[ | |
"prithivida/Splade_PP_en_v1" | |
] = "prithivida/Splade_PP_en_v1", | |
): | |
self._validate_environment() | |
self.qdrant_url = os.getenv("QDRANT_URL") | |
self.qdrant_api_key = os.getenv("QDRANT_API_KEY") | |
self.dense_model_name = dense_model_name | |
self.sparse_model_name = sparse_model_name | |
def _validate_environment(): | |
missing_vars = [ | |
var | |
for var in RetrieversConfig.REQUIRED_ENV_VARS | |
if not os.getenv(var, "").strip() | |
] | |
if missing_vars: | |
raise EnvironmentError( | |
f"Missing or empty environment variable(s): {', '.join(missing_vars)}" | |
) | |
def dense_embeddings(self): | |
return OpenAIEmbeddings(model=self.dense_model_name) | |
def sparse_embeddings(self): | |
return FastEmbedSparse(model_name=self.sparse_model_name) | |
def get_qdrant_retriever( | |
self, | |
collection_name: str, | |
dense_vector_name: str, | |
sparse_vector_name: str, | |
k: int = 5, | |
) -> VectorStoreRetriever: | |
qdrantdb = QdrantVectorStore.from_existing_collection( | |
embedding=self.dense_embeddings, | |
sparse_embedding=self.sparse_embeddings, | |
url=self.qdrant_url, | |
api_key=self.qdrant_api_key, | |
prefer_grpc=True, | |
collection_name=collection_name, | |
retrieval_mode=RetrievalMode.HYBRID, | |
vector_name=dense_vector_name, | |
sparse_vector_name=sparse_vector_name, | |
) | |
return qdrantdb.as_retriever(search_kwargs={"k": k}) | |
def get_practitioners_retriever(self, k: int = 5) -> VectorStoreRetriever: | |
return self.get_qdrant_retriever( | |
collection_name="practitioners_hybrid_db", | |
dense_vector_name="practitioners_dense_vectors", | |
sparse_vector_name="practitioners_sparse_vectors", | |
k=k, | |
) | |
def get_documents_retriever(self, k: int = 5) -> VectorStoreRetriever: | |
return self.get_qdrant_retriever( | |
collection_name="docs_hybrid_db", | |
dense_vector_name="docs_dense_vectors", | |
sparse_vector_name="docs_sparse_vectors", | |
k=k, | |
) | |