import os from functools import cache import qdrant_client import torch from langchain.retrievers import ContextualCompressionRetriever from langchain.retrievers.document_compressors import EmbeddingsFilter from langchain_community.retrievers import QdrantSparseVectorRetriever from langchain_community.vectorstores import Qdrant from langchain_openai.embeddings import OpenAIEmbeddings from transformers import AutoModelForMaskedLM, AutoTokenizer class DenseRetrieverClient: """Inititalize the dense retriever using OpenAI text embeddings and Qdrant vector database. Attributes: embeddings_model (str): The embeddings model to use. Right now only OpenAI text embeddings. collection_name (str): Qdrant collection name. client (QdrantClient): Qdrant client. qdrant_collection (Qdrant): Qdrant collection. """ def __init__(self, embeddings_model: str = "text-embedding-ada-002", collection_name: str = "practitioners_db"): self.validate_environment_variables() self.embeddings_model = embeddings_model self.collection_name = collection_name self.client = qdrant_client.QdrantClient( url=os.getenv("QDRANT_URL"), api_key=os.getenv("QDRANT_API_KEY"), ) self.qdrant_collection = self.load_qdrant_collection() def validate_environment_variables(self): """ Check if the Qdrant environment variables are set.""" required_vars = ["QDRANT_API_KEY", "QDRANT_URL"] for var in required_vars: if not os.getenv(var): raise EnvironmentError(f"Missing environment variable: {var}") def set_qdrant_collection(self, embeddings): """Prepare the Qdrant collection for the embeddings model.""" return Qdrant(client=self.client, collection_name=self.collection_name, embeddings=embeddings) @cache def load_qdrant_collection(self): """Load Qdrant collection for a given embeddings model.""" # TODO: Test new OpenAI text embeddings models openai_text_embeddings = ["text-embedding-ada-002"] if self.embeddings_model in openai_text_embeddings: self.qdrant_collection = self.set_qdrant_collection( OpenAIEmbeddings(model=self.embeddings_model)) else: raise ValueError( f"Invalid embeddings model: {self.embeddings_model}. Valid options are {', '.join(openai_text_embeddings)}.") return self.qdrant_collection def get_dense_retriever(self, search_type: str = "similarity", k: int = 4): """Set up retrievers (Qdrant vectorstore as retriever). Args: search_type (str, optional): similarity or mmr. Defaults to "similarity". k (int, optional): Number of documents retrieved. Defaults to 4. Returns: Retriever: Vectorstore as a retriever """ dense_retriever = self.qdrant_collection.as_retriever(search_type=search_type, search_kwargs={ "k": k} ) return dense_retriever class SparseRetrieverClient: """Inititalize the sparse retriever using the SPLADE neural retrieval model and Qdrant vector database. Attributes: collection_name (str): Qdrant collection name. vector_name (str): Qdrant vector name. splade_model_id (str): The SPLADE neural retrieval model id. k (int): Number of documents retrieved. client (QdrantClient): Qdrant client. """ def __init__(self, collection_name: str, vector_name: str, splade_model_id: str = "naver/splade-cocondenser-ensembledistil", k: int = 15): self.validate_environment_variables() self.client = qdrant_client.QdrantClient(url=os.getenv( "QDRANT_URL"), api_key=os.getenv("QDRANT_API_KEY")) self.model_id = splade_model_id self.collection_name = collection_name self.vector_name = vector_name self.k = k def validate_environment_variables(self): required_vars = ["QDRANT_API_KEY", "QDRANT_URL"] for var in required_vars: if not os.getenv(var): raise EnvironmentError(f"Missing environment variable: {var}") @cache def set_tokenizer_config(self): """Initialize the tokenizer and the SPLADE neural retrieval model. See to https://huggingface.co./naver/splade-cocondenser-ensembledistil for more details. """ tokenizer = AutoTokenizer.from_pretrained(self.model_id) model = AutoModelForMaskedLM.from_pretrained(self.model_id) return tokenizer, model def sparse_encoder(self, text: str) -> tuple[list[int], list[float]]: """This function encodes the input text into a sparse vector. The encoder is required for the QdrantSparseVectorRetriever. Adapted from the Qdrant documentation: Computing the Sparse Vector code. Args: text (str): Text to encode Returns: tuple[list[int], list[float]]: Indices and values of the sparse vector """ tokenizer, model = self.set_tokenizer_config() tokens = tokenizer(text, return_tensors="pt", max_length=512, padding="max_length", truncation=True) output = model(**tokens) logits, attention_mask = output.logits, tokens.attention_mask relu_log = torch.log(1 + torch.relu(logits)) weighted_log = relu_log * attention_mask.unsqueeze(-1) max_val, _ = torch.max(weighted_log, dim=1) vec = max_val.squeeze() indices = vec.nonzero().numpy().flatten() values = vec.detach().numpy()[indices] return indices.tolist(), values.tolist() def get_sparse_retriever(self): sparse_retriever = QdrantSparseVectorRetriever( client=self.client, collection_name=self.collection_name, sparse_vector_name=self.vector_name, sparse_encoder=self.sparse_encoder, k=self.k, ) return sparse_retriever def compression_retriever_setup(base_retriever, embeddings_model: str = "text-embedding-ada-002", similarity_threshold: float = 0.76) -> ContextualCompressionRetriever: """ Creates a ContextualCompressionRetriever with a base retriever and a similarity threshold. The ContextualCompressionRetriever uses an EmbeddingsFilter with OpenAIEmbeddings to filter out documents with a similarity score below the given threshold. Args: base_retriever: Retriever to be filtered. similarity_threshold (float, optional): The similarity threshold for the EmbeddingsFilter. Documents with a similarity score below this threshold will be filtered out. Defaults to 0.76 (Obtained by experimenting with text-embeddings-ada-002). ** Be careful with this parameter, as it can have a big impact on the results and highly depends on the embeddings model used. Returns: ContextualCompressionRetriever: The created ContextualCompressionRetriever. """ # Set up compression retriever (filter out documents with low similarity) relevant_filter = EmbeddingsFilter(embeddings=OpenAIEmbeddings(model=embeddings_model), similarity_threshold=similarity_threshold) compression_retriever = ContextualCompressionRetriever( base_compressor=relevant_filter, base_retriever=base_retriever ) return compression_retriever