Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import os | |
from functools import cache | |
import qdrant_client | |
import torch | |
from langchain.retrievers import ContextualCompressionRetriever | |
from langchain.retrievers.document_compressors import EmbeddingsFilter | |
from langchain_community.retrievers import QdrantSparseVectorRetriever | |
from langchain_community.vectorstores import Qdrant | |
from langchain_openai.embeddings import OpenAIEmbeddings | |
from transformers import AutoModelForMaskedLM, AutoTokenizer | |
class DenseRetrieverClient: | |
"""Inititalize the dense retriever using OpenAI text embeddings and Qdrant vector database. | |
Attributes: | |
embeddings_model (str): The embeddings model to use. Right now only OpenAI text embeddings. | |
collection_name (str): Qdrant collection name. | |
client (QdrantClient): Qdrant client. | |
qdrant_collection (Qdrant): Qdrant collection. | |
""" | |
def __init__(self, embeddings_model: str = "text-embedding-ada-002", collection_name: str = "practitioners_db"): | |
self.validate_environment_variables() | |
self.embeddings_model = embeddings_model | |
self.collection_name = collection_name | |
self.client = qdrant_client.QdrantClient( | |
url=os.getenv("QDRANT_URL"), | |
api_key=os.getenv("QDRANT_API_KEY"), | |
) | |
self.qdrant_collection = self.load_qdrant_collection() | |
def validate_environment_variables(self): | |
""" Check if the Qdrant environment variables are set.""" | |
required_vars = ["QDRANT_API_KEY", "QDRANT_URL"] | |
for var in required_vars: | |
if not os.getenv(var): | |
raise EnvironmentError(f"Missing environment variable: {var}") | |
def set_qdrant_collection(self, embeddings): | |
"""Prepare the Qdrant collection for the embeddings model.""" | |
return Qdrant(client=self.client, | |
collection_name=self.collection_name, | |
embeddings=embeddings) | |
def load_qdrant_collection(self): | |
"""Load Qdrant collection for a given embeddings model.""" | |
# TODO: Test new OpenAI text embeddings models | |
openai_text_embeddings = ["text-embedding-ada-002"] | |
if self.embeddings_model in openai_text_embeddings: | |
self.qdrant_collection = self.set_qdrant_collection( | |
OpenAIEmbeddings(model=self.embeddings_model)) | |
else: | |
raise ValueError( | |
f"Invalid embeddings model: {self.embeddings_model}. Valid options are {', '.join(openai_text_embeddings)}.") | |
return self.qdrant_collection | |
def get_dense_retriever(self, search_type: str = "similarity", k: int = 4): | |
"""Set up retrievers (Qdrant vectorstore as retriever). | |
Args: | |
search_type (str, optional): similarity or mmr. Defaults to "similarity". | |
k (int, optional): Number of documents retrieved. Defaults to 4. | |
Returns: | |
Retriever: Vectorstore as a retriever | |
""" | |
dense_retriever = self.qdrant_collection.as_retriever(search_type=search_type, | |
search_kwargs={ | |
"k": k} | |
) | |
return dense_retriever | |
class SparseRetrieverClient: | |
"""Inititalize the sparse retriever using the SPLADE neural retrieval model and Qdrant vector database. | |
Attributes: | |
collection_name (str): Qdrant collection name. | |
vector_name (str): Qdrant vector name. | |
splade_model_id (str): The SPLADE neural retrieval model id. | |
k (int): Number of documents retrieved. | |
client (QdrantClient): Qdrant client. | |
""" | |
def __init__(self, collection_name: str, vector_name: str, splade_model_id: str = "naver/splade-cocondenser-ensembledistil", k: int = 15): | |
self.validate_environment_variables() | |
self.client = qdrant_client.QdrantClient(url=os.getenv( | |
"QDRANT_URL"), api_key=os.getenv("QDRANT_API_KEY")) | |
self.model_id = splade_model_id | |
self.collection_name = collection_name | |
self.vector_name = vector_name | |
self.k = k | |
def validate_environment_variables(self): | |
required_vars = ["QDRANT_API_KEY", "QDRANT_URL"] | |
for var in required_vars: | |
if not os.getenv(var): | |
raise EnvironmentError(f"Missing environment variable: {var}") | |
def set_tokenizer_config(self): | |
"""Initialize the tokenizer and the SPLADE neural retrieval model. | |
See to https://huggingface.co./naver/splade-cocondenser-ensembledistil for more details. | |
""" | |
tokenizer = AutoTokenizer.from_pretrained(self.model_id) | |
model = AutoModelForMaskedLM.from_pretrained(self.model_id) | |
return tokenizer, model | |
def sparse_encoder(self, text: str) -> tuple[list[int], list[float]]: | |
"""This function encodes the input text into a sparse vector. The encoder is required for the QdrantSparseVectorRetriever. | |
Adapted from the Qdrant documentation: Computing the Sparse Vector code. | |
Args: | |
text (str): Text to encode | |
Returns: | |
tuple[list[int], list[float]]: Indices and values of the sparse vector | |
""" | |
tokenizer, model = self.set_tokenizer_config() | |
tokens = tokenizer(text, return_tensors="pt", | |
max_length=512, padding="max_length", truncation=True) | |
output = model(**tokens) | |
logits, attention_mask = output.logits, tokens.attention_mask | |
relu_log = torch.log(1 + torch.relu(logits)) | |
weighted_log = relu_log * attention_mask.unsqueeze(-1) | |
max_val, _ = torch.max(weighted_log, dim=1) | |
vec = max_val.squeeze() | |
indices = vec.nonzero().numpy().flatten() | |
values = vec.detach().numpy()[indices] | |
return indices.tolist(), values.tolist() | |
def get_sparse_retriever(self): | |
sparse_retriever = QdrantSparseVectorRetriever( | |
client=self.client, | |
collection_name=self.collection_name, | |
sparse_vector_name=self.vector_name, | |
sparse_encoder=self.sparse_encoder, | |
k=self.k, | |
) | |
return sparse_retriever | |
def compression_retriever_setup(base_retriever, embeddings_model: str = "text-embedding-ada-002", similarity_threshold: float = 0.76) -> ContextualCompressionRetriever: | |
""" | |
Creates a ContextualCompressionRetriever with a base retriever and a similarity threshold. | |
The ContextualCompressionRetriever uses an EmbeddingsFilter with OpenAIEmbeddings to filter out documents | |
with a similarity score below the given threshold. | |
Args: | |
base_retriever: Retriever to be filtered. | |
similarity_threshold (float, optional): The similarity threshold for the EmbeddingsFilter. | |
Documents with a similarity score below this threshold will be filtered out. Defaults to 0.76 (Obtained by experimenting with text-embeddings-ada-002). | |
** Be careful with this parameter, as it can have a big impact on the results and highly depends on the embeddings model used. | |
Returns: | |
ContextualCompressionRetriever: The created ContextualCompressionRetriever. | |
""" | |
# Set up compression retriever (filter out documents with low similarity) | |
relevant_filter = EmbeddingsFilter(embeddings=OpenAIEmbeddings(model=embeddings_model), | |
similarity_threshold=similarity_threshold) | |
compression_retriever = ContextualCompressionRetriever( | |
base_compressor=relevant_filter, base_retriever=base_retriever | |
) | |
return compression_retriever | |