File size: 7,079 Bytes
6e20157
 
 
 
 
d612275
6e20157
 
d612275
6e20157
 
d612275
6e20157
 
 
 
f3d91b8
 
6e20157
f3d91b8
6e20157
f3d91b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6e20157
 
f3d91b8
 
6e20157
 
 
d612275
6e20157
f3d91b8
6e20157
 
 
f3d91b8
 
 
 
 
6e20157
f3d91b8
6e20157
f3d91b8
6e20157
f3d91b8
 
 
 
 
 
 
 
 
 
 
6e20157
 
f3d91b8
 
6e20157
f3d91b8
 
 
 
 
 
 
6e20157
f3d91b8
 
 
 
 
6e20157
f3d91b8
 
6e20157
 
 
 
f3d91b8
 
 
 
 
 
 
6e20157
f3d91b8
6e20157
f3d91b8
 
 
 
 
6e20157
 
f3d91b8
 
 
 
 
 
 
 
d612275
 
f3d91b8
d612275
 
f3d91b8
d612275
f3d91b8
 
 
d612275
6e20157
 
f3d91b8
 
6e20157
f3d91b8
6e20157
 
 
 
 
 
 
 
f3d91b8
 
 
 
 
6e20157
f3d91b8
 
6e20157
 
d612275
f3d91b8
 
d612275
f3d91b8
d612275
 
f3d91b8
 
 
 
 
 
 
 
d612275
f3d91b8
d612275
 
f3d91b8
d612275
f3d91b8
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
import os
from functools import cache

import qdrant_client
import torch
from langchain.prompts import PromptTemplate
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import EmbeddingsFilter
from langchain.retrievers.multi_query import MultiQueryRetriever
from langchain_community.retrievers import QdrantSparseVectorRetriever
from langchain_community.vectorstores import Qdrant
from langchain_openai import ChatOpenAI
from langchain_openai.embeddings import OpenAIEmbeddings
from transformers import AutoModelForMaskedLM, AutoTokenizer


class ValidateQdrantClient:
    """Base class for retriever clients to ensure environment variables are set."""

    def __init__(self):
        self.validate_environment_variables()

    def validate_environment_variables(self):
        """Check if the Qdrant environment variables are set."""
        required_vars = ["QDRANT_API_KEY", "QDRANT_URL"]
        missing_vars = [var for var in required_vars if not os.getenv(var)]
        if missing_vars:
            raise EnvironmentError(
                f"Missing environment variable(s): {', '.join(missing_vars)}"
            )


class DenseRetrieverClient(ValidateQdrantClient):
    """Initialize the dense retriever using OpenAI text embeddings and Qdrant vector database."""

    TEXT_EMBEDDING_MODELS = [
        "text-embedding-ada-002",
        "text-embedding-3-small",
        "text-embedding-3-large",
    ]

    def __init__(
        self,
        embeddings_model="text-embedding-3-small",
        collection_name="practitioners_db",
        search_type="similarity",
        k=4,
    ):
        super().__init__()
        if embeddings_model not in self.TEXT_EMBEDDING_MODELS:
            raise ValueError(
                f"Invalid embeddings model: {embeddings_model}. Valid options are {', '.join(self.TEXT_EMBEDDING_MODELS)}."
            )
        self.embeddings_model = embeddings_model
        self.collection_name = collection_name
        self.search_type = search_type
        self.k = k
        self.client = qdrant_client.QdrantClient(
            url=os.getenv("QDRANT_URL"),
            api_key=os.getenv("QDRANT_API_KEY"),
            prefer_grpc=True,
        )
        self._qdrant_collection = None

    def set_qdrant_collection(self, embeddings):
        """Prepare the Qdrant collection for the embeddings model."""
        return Qdrant(
            client=self.client,
            collection_name=self.collection_name,
            embeddings=embeddings,
        )

    @property
    @cache
    def qdrant_collection(self):
        """Load Qdrant collection for a given embeddings model."""
        if self._qdrant_collection is None:
            self._qdrant_collection = self.set_qdrant_collection(
                OpenAIEmbeddings(model=self.embeddings_model)
            )
        return self._qdrant_collection

    def get_dense_retriever(self):
        """Set up retrievers (Qdrant vectorstore as retriever)."""
        return self.qdrant_collection.as_retriever(
            search_type=self.search_type, search_kwargs={"k": self.k}
        )


class SparseRetrieverClient(ValidateQdrantClient):
    """Initialize the sparse retriever using the SPLADE neural retrieval model and Qdrant vector database."""

    def __init__(
        self,
        collection_name,
        vector_name,
        splade_model_id="naver/splade-cocondenser-ensembledistil",
        k=15,
    ):

        # Validate Qdrant client
        super().__init__()
        self.client = qdrant_client.QdrantClient(
            url=os.getenv("QDRANT_URL"), api_key=os.getenv("QDRANT_API_KEY")
        )  # TODO: prefer_grpc=True is not working
        self.model_id = splade_model_id
        self._tokenizer = None
        self._model = None
        self.collection_name = collection_name
        self.vector_name = vector_name
        self.k = k

    @property
    @cache
    def tokenizer(self):
        """Initialize the tokenizer."""
        if self._tokenizer is None:
            self._tokenizer = AutoTokenizer.from_pretrained(self.model_id)
        return self._tokenizer

    @property
    @cache
    def model(self):
        """Initialize the SPLADE neural retrieval model."""
        if self._model is None:
            self._model = AutoModelForMaskedLM.from_pretrained(self.model_id)
        return self._model

    def sparse_encoder(self, text: str) -> tuple[list[int], list[float]]:
        """Encode the input text into a sparse vector."""
        tokens = self.tokenizer(
            text,
            return_tensors="pt",
            max_length=512,
            padding="max_length",
            truncation=True,
        )

        with torch.no_grad():
            logits = self.model(**tokens).logits

        relu_log = torch.log1p(torch.relu(logits))
        weighted_log = relu_log * tokens.attention_mask.unsqueeze(-1)

        max_val = torch.max(weighted_log, dim=1).values.squeeze()
        indices = torch.nonzero(max_val, as_tuple=False).squeeze().cpu().numpy()
        values = max_val[indices].cpu().numpy()

        return indices.tolist(), values.tolist()

    def get_sparse_retriever(self) -> QdrantSparseVectorRetriever:
        """Return a Qdrant vector sparse retriever."""

        return QdrantSparseVectorRetriever(
            client=self.client,
            collection_name=self.collection_name,
            sparse_vector_name=self.vector_name,
            sparse_encoder=self.sparse_encoder,
            k=self.k,
        )


def compression_retriever_setup(
    base_retriever, embeddings_model="text-embedding-3-small", k=20
):
    """Creates a ContextualCompressionRetriever with an EmbeddingsFilter."""
    filter = EmbeddingsFilter(embeddings=OpenAIEmbeddings(model=embeddings_model), k=k)

    return ContextualCompressionRetriever(
        base_compressor=filter, base_retriever=base_retriever
    )


def multi_query_retriever_setup(retriever):
    """Configure a multi-query retriever using a base retriever."""

    prompt = PromptTemplate(
        input_variables=["question"],
        template="""
        
        Your task is to generate 3 different grammatically correct versions of the provided text,
        incorporating the user's location preference in each version. Format these versions as paragraphs and present them as items in a Markdown formatted numbered list ("1. "). There should be no additional new lines or spaces between each version. Do not enclose your response in quotation marks. Do not modify unfamiliar acronyms and keep your responses clear and concise.
        
        **Notes**: The text provided are user questions to Tall Tree Health Centre's AI virtual assistant. `Location preference:` is the location of the Tall Tree Health clinic that the user prefers.

        Text to be modified:
        ```
        {question}
        ```""",
    )

    llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)

    return MultiQueryRetriever.from_llm(
        retriever=retriever, llm=llm, prompt=prompt, include_original=True
    )