ai-virtual-assistant / rag_chain /retrievers_setup.py
talltree's picture
Upload 3 files
d6746de verified
raw
history blame
7.7 kB
import os
from functools import cache
import qdrant_client
import torch
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import EmbeddingsFilter
from langchain_community.retrievers import QdrantSparseVectorRetriever
from langchain_community.vectorstores import Qdrant
from langchain_openai.embeddings import OpenAIEmbeddings
from transformers import AutoModelForMaskedLM, AutoTokenizer
class DenseRetrieverClient:
"""Inititalize the dense retriever using OpenAI text embeddings and Qdrant vector database.
Attributes:
embeddings_model (str): The embeddings model to use. Right now only OpenAI text embeddings.
collection_name (str): Qdrant collection name.
client (QdrantClient): Qdrant client.
qdrant_collection (Qdrant): Qdrant collection.
"""
def __init__(self, embeddings_model: str = "text-embedding-ada-002", collection_name: str = "practitioners_db"):
self.validate_environment_variables()
self.embeddings_model = embeddings_model
self.collection_name = collection_name
self.client = qdrant_client.QdrantClient(
url=os.getenv("QDRANT_URL"),
api_key=os.getenv("QDRANT_API_KEY"),
)
self.qdrant_collection = self.load_qdrant_collection()
def validate_environment_variables(self):
""" Check if the Qdrant environment variables are set."""
required_vars = ["QDRANT_API_KEY", "QDRANT_URL"]
for var in required_vars:
if not os.getenv(var):
raise EnvironmentError(f"Missing environment variable: {var}")
def set_qdrant_collection(self, embeddings):
"""Prepare the Qdrant collection for the embeddings model."""
return Qdrant(client=self.client,
collection_name=self.collection_name,
embeddings=embeddings)
@cache
def load_qdrant_collection(self):
"""Load Qdrant collection for a given embeddings model."""
# TODO: Test new OpenAI text embeddings models
openai_text_embeddings = ["text-embedding-ada-002"]
if self.embeddings_model in openai_text_embeddings:
self.qdrant_collection = self.set_qdrant_collection(
OpenAIEmbeddings(model=self.embeddings_model))
else:
raise ValueError(
f"Invalid embeddings model: {self.embeddings_model}. Valid options are {', '.join(openai_text_embeddings)}.")
return self.qdrant_collection
def get_dense_retriever(self, search_type: str = "similarity", k: int = 4):
"""Set up retrievers (Qdrant vectorstore as retriever).
Args:
search_type (str, optional): similarity or mmr. Defaults to "similarity".
k (int, optional): Number of documents retrieved. Defaults to 4.
Returns:
Retriever: Vectorstore as a retriever
"""
dense_retriever = self.qdrant_collection.as_retriever(search_type=search_type,
search_kwargs={
"k": k}
)
return dense_retriever
class SparseRetrieverClient:
"""Inititalize the sparse retriever using the SPLADE neural retrieval model and Qdrant vector database.
Attributes:
collection_name (str): Qdrant collection name.
vector_name (str): Qdrant vector name.
splade_model_id (str): The SPLADE neural retrieval model id.
k (int): Number of documents retrieved.
client (QdrantClient): Qdrant client.
"""
def __init__(self, collection_name: str, vector_name: str, splade_model_id: str = "naver/splade-cocondenser-ensembledistil", k: int = 15):
self.validate_environment_variables()
self.client = qdrant_client.QdrantClient(url=os.getenv(
"QDRANT_URL"), api_key=os.getenv("QDRANT_API_KEY"))
self.model_id = splade_model_id
self.collection_name = collection_name
self.vector_name = vector_name
self.k = k
def validate_environment_variables(self):
required_vars = ["QDRANT_API_KEY", "QDRANT_URL"]
for var in required_vars:
if not os.getenv(var):
raise EnvironmentError(f"Missing environment variable: {var}")
@cache
def set_tokenizer_config(self):
"""Initialize the tokenizer and the SPLADE neural retrieval model.
See to https://huggingface.co./naver/splade-cocondenser-ensembledistil for more details.
"""
tokenizer = AutoTokenizer.from_pretrained(self.model_id)
model = AutoModelForMaskedLM.from_pretrained(self.model_id)
return tokenizer, model
@cache
def sparse_encoder(self, text: str) -> tuple[list[int], list[float]]:
"""This function encodes the input text into a sparse vector. The encoder is required for the QdrantSparseVectorRetriever.
Adapted from the Qdrant documentation: Computing the Sparse Vector code.
Args:
text (str): Text to encode
Returns:
tuple[list[int], list[float]]: Indices and values of the sparse vector
"""
tokenizer, model = self.set_tokenizer_config()
tokens = tokenizer(text, return_tensors="pt",
max_length=512, padding="max_length", truncation=True)
output = model(**tokens)
logits, attention_mask = output.logits, tokens.attention_mask
relu_log = torch.log(1 + torch.relu(logits))
weighted_log = relu_log * attention_mask.unsqueeze(-1)
max_val, _ = torch.max(weighted_log, dim=1)
vec = max_val.squeeze()
indices = vec.nonzero().numpy().flatten()
values = vec.detach().numpy()[indices]
return indices.tolist(), values.tolist()
def get_sparse_retriever(self):
sparse_retriever = QdrantSparseVectorRetriever(
client=self.client,
collection_name=self.collection_name,
sparse_vector_name=self.vector_name,
sparse_encoder=self.sparse_encoder,
k=self.k,
)
return sparse_retriever
def compression_retriever_setup(base_retriever, embeddings_model: str = "text-embedding-ada-002", similarity_threshold: float = 0.76) -> ContextualCompressionRetriever:
"""
Creates a ContextualCompressionRetriever with a base retriever and a similarity threshold.
The ContextualCompressionRetriever uses an EmbeddingsFilter with OpenAIEmbeddings to filter out documents
with a similarity score below the given threshold.
Args:
base_retriever: Retriever to be filtered.
similarity_threshold (float, optional): The similarity threshold for the EmbeddingsFilter.
Documents with a similarity score below this threshold will be filtered out. Defaults to 0.76 (Obtained by experimenting with text-embeddings-ada-002).
** Be careful with this parameter, as it can have a big impact on the results and highly depends on the embeddings model used.
Returns:
ContextualCompressionRetriever: The created ContextualCompressionRetriever.
"""
# Set up compression retriever (filter out documents with low similarity)
relevant_filter = EmbeddingsFilter(embeddings=OpenAIEmbeddings(model=embeddings_model),
similarity_threshold=similarity_threshold)
compression_retriever = ContextualCompressionRetriever(
base_compressor=relevant_filter, base_retriever=base_retriever
)
return compression_retriever