yrobel-lima commited on
Commit
f3d91b8
1 Parent(s): f2a1c22

Upload 3 files

Browse files
rag_chain/chain.py CHANGED
@@ -13,57 +13,44 @@ from langchain_core.runnables import RunnableLambda
13
  from langchain_openai.chat_models import ChatOpenAI
14
 
15
  from .prompt_template import generate_prompt_template
16
- from .retrievers_setup import (DenseRetrieverClient, SparseRetrieverClient,
17
- compression_retriever_setup)
 
 
 
 
18
 
19
  # Helpers
20
 
21
 
22
  def reorder_documents(docs: list[Document]) -> list[Document]:
23
- """Long-Context Reorder: No matter the architecture of the model, there is
24
- a performance degradation when we include 10+ retrieved documents.
25
-
26
- Args:
27
- docs (list): List of Langchain documents
28
-
29
- Returns:
30
- list: Reordered list of Langchain documents
31
- """
32
- reorder = LongContextReorder()
33
- return reorder.transform_documents(docs)
34
 
35
 
36
  def randomize_documents(documents: list[Document]) -> list[Document]:
37
- """Randomize the documents as an attempt to randomize the model recommendations."""
38
  random.shuffle(documents)
39
  return documents
40
 
41
 
42
- def format_practitioners_docs(docs: list[Document]) -> str:
43
- """Format the practitioners_db Documents to markdown.
44
- Args:
45
- docs (list[Documents]): List of Langchain documents
46
- Returns:
47
- docs (str):
48
- """
49
- return f"\n{'-' * 3}\n".join(
50
- [f"- Practitioner #{i+1}:\n\n\t" +
51
- d.page_content for i, d in enumerate(docs)]
52
- )
53
-
54
-
55
- def format_tall_tree_docs(docs: list[Document]) -> str:
56
- """Format the tall_tree_db Documents to markdown.
57
- Args:
58
- docs (list[Documents]): List of Langchain documents
59
- Returns:
60
- docs (str):
61
 
62
- """
63
- return f"\n{'-' * 3}\n".join(
64
- [f"- No. {i+1}:\n\n\t" +
65
- d.page_content for i, d in enumerate(docs)]
66
- )
 
 
 
 
 
 
 
 
67
 
68
 
69
  @cache
@@ -74,8 +61,7 @@ def create_langsmith_client():
74
  os.environ["LANGCHAIN_ENDPOINT"] = "https://api.smith.langchain.com"
75
  langsmith_api_key = os.getenv("LANGCHAIN_API_KEY")
76
  if not langsmith_api_key:
77
- raise EnvironmentError(
78
- "Missing environment variable: LANGCHAIN_API_KEY")
79
  return langsmith.Client()
80
 
81
 
@@ -83,7 +69,9 @@ def create_langsmith_client():
83
 
84
 
85
  @cache
86
- def get_rag_chain(model_name: str = "gpt-4", temperature: float = 0.2) -> tuple[ChatOpenAI, ConversationBufferWindowMemory]:
 
 
87
  """Set up runnable and chat memory
88
 
89
  Args:
@@ -94,78 +82,105 @@ def get_rag_chain(model_name: str = "gpt-4", temperature: float = 0.2) -> tuple[
94
  Runnable, Memory: Chain and Memory
95
  """
96
 
 
 
 
 
 
 
 
 
 
 
97
  # Set up Langsmith to trace the chain
98
  langsmith_tracing = create_langsmith_client()
99
 
100
  # LLM and prompt template
101
- llm = ChatOpenAI(model_name=model_name,
102
- temperature=temperature)
 
 
103
 
104
  prompt = generate_prompt_template()
105
 
106
  # Set retrievers pointing to the practitioners's dataset
107
- embeddings_model = "text-embedding-ada-002"
108
- dense_retriever_client = DenseRetrieverClient(embeddings_model=embeddings_model,
109
- collection_name="practitioners_db")
 
 
 
110
 
111
  # Qdrant db as a retriever
112
- practitioners_db_dense_retriever = dense_retriever_client.get_dense_retriever(search_type="similarity",
113
- k=10)
114
 
115
- # Testing the sparse vector retriever using Qdrant
116
- collection_name = "practitioners_db_sparse_collection"
117
- vector_name = "sparse_vector"
 
 
 
118
  sparse_retriever_client = SparseRetrieverClient(
119
- collection_name=collection_name,
120
- vector_name=vector_name,
121
  splade_model_id="naver/splade-cocondenser-ensembledistil",
122
- k=15)
 
 
123
  practitioners_db_sparse_retriever = sparse_retriever_client.get_sparse_retriever()
124
 
125
  # Ensemble retriever for hyprid search (dense retriever seems to work better but the dense retriever is good for acronyms like RMT)
126
  practitioners_ensemble_retriever = EnsembleRetriever(
127
- retrievers=[practitioners_db_dense_retriever,
128
- practitioners_db_sparse_retriever], weights=[0.2, 0.8]
 
 
 
129
  )
130
 
131
  # Compression retriever for practitioners db
132
  practitioners_db_compression_retriever = compression_retriever_setup(
133
  practitioners_ensemble_retriever,
134
- embeddings_model="text-embedding-ada-002",
135
- similarity_threshold=0.74
136
  )
137
 
138
  # Set retrievers pointing to the tall_tree_db
139
- dense_retriever_client = DenseRetrieverClient(embeddings_model=embeddings_model,
140
- collection_name="tall_tree_db")
141
- tall_tree_db_dense_retriever = dense_retriever_client.get_dense_retriever(search_type="similarity",
142
- k=8)
 
 
 
 
 
143
  # Compression retriever for tall_tree_db
144
  tall_tree_db_compression_retriever = compression_retriever_setup(
145
  tall_tree_db_dense_retriever,
146
- embeddings_model="text-embedding-ada-002",
147
- similarity_threshold=0.74
148
  )
149
 
150
  # Set conversation history window memory. It only uses the last k interactions.
151
- memory = ConversationBufferWindowMemory(memory_key="history",
152
- return_messages=True,
153
- k=6)
 
 
154
 
155
  # Set up runnable using LCEL
156
- setup_and_retrieval = {"practitioners_db": itemgetter("message")
157
- | practitioners_db_compression_retriever
158
- | format_practitioners_docs,
159
- "tall_tree_db": itemgetter("message") | tall_tree_db_compression_retriever | format_tall_tree_docs,
160
- "history": RunnableLambda(memory.load_memory_variables) | itemgetter("history"),
161
- "message": itemgetter("message")
162
- }
163
-
164
- chain = (
165
- setup_and_retrieval
166
- | prompt
167
- | llm
168
- | StrOutputParser()
169
- )
170
 
171
  return chain, memory
 
13
  from langchain_openai.chat_models import ChatOpenAI
14
 
15
  from .prompt_template import generate_prompt_template
16
+ from .retrievers_setup import (
17
+ DenseRetrieverClient,
18
+ SparseRetrieverClient,
19
+ compression_retriever_setup,
20
+ multi_query_retriever_setup,
21
+ )
22
 
23
  # Helpers
24
 
25
 
26
  def reorder_documents(docs: list[Document]) -> list[Document]:
27
+ """Reorder documents to mitigate performance degradation with long contexts."""
28
+ return LongContextReorder().transform_documents(docs)
 
 
 
 
 
 
 
 
 
29
 
30
 
31
  def randomize_documents(documents: list[Document]) -> list[Document]:
32
+ """Randomize documents to vary model recommendations."""
33
  random.shuffle(documents)
34
  return documents
35
 
36
 
37
+ class DocumentFormatter:
38
+ def __init__(self, prefix: str):
39
+ self.prefix = prefix
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
+ def __call__(self, docs: list[Document]) -> str:
42
+ """Format the Documents to markdown.
43
+ Args:
44
+ docs (list[Documents]): List of Langchain documents
45
+ Returns:
46
+ docs (str):
47
+ """
48
+ return f"\n---\n".join(
49
+ [
50
+ f"- {self.prefix} {i+1}:\n\n\t" + d.page_content
51
+ for i, d in enumerate(docs)
52
+ ]
53
+ )
54
 
55
 
56
  @cache
 
61
  os.environ["LANGCHAIN_ENDPOINT"] = "https://api.smith.langchain.com"
62
  langsmith_api_key = os.getenv("LANGCHAIN_API_KEY")
63
  if not langsmith_api_key:
64
+ raise EnvironmentError("Missing environment variable: LANGCHAIN_API_KEY")
 
65
  return langsmith.Client()
66
 
67
 
 
69
 
70
 
71
  @cache
72
+ def get_rag_chain(
73
+ model_name: str = "gpt-4", temperature: float = 0.2
74
+ ) -> tuple[ChatOpenAI, ConversationBufferWindowMemory]:
75
  """Set up runnable and chat memory
76
 
77
  Args:
 
82
  Runnable, Memory: Chain and Memory
83
  """
84
 
85
+ RETRIEVER_PARAMETERS = {
86
+ "embeddings_model": "text-embedding-3-small",
87
+ "k_dense_practitioners_db": 8,
88
+ "k_sparse_practitioners_db": 15,
89
+ "weights_ensemble_practitioners_db": [0.2, 0.8],
90
+ "k_compression_practitioners_db": 18,
91
+ "k_dense_talltree": 6,
92
+ "k_compression_talltree": 6,
93
+ }
94
+
95
  # Set up Langsmith to trace the chain
96
  langsmith_tracing = create_langsmith_client()
97
 
98
  # LLM and prompt template
99
+ llm = ChatOpenAI(
100
+ model_name=model_name,
101
+ temperature=temperature,
102
+ )
103
 
104
  prompt = generate_prompt_template()
105
 
106
  # Set retrievers pointing to the practitioners's dataset
107
+ dense_retriever_client = DenseRetrieverClient(
108
+ embeddings_model=RETRIEVER_PARAMETERS["embeddings_model"],
109
+ collection_name="practitioners_db",
110
+ search_type="similarity",
111
+ k=RETRIEVER_PARAMETERS["k_dense_practitioners_db"],
112
+ ) # k x 4 using multiquery retriever
113
 
114
  # Qdrant db as a retriever
115
+ practitioners_db_dense_retriever = dense_retriever_client.get_dense_retriever()
 
116
 
117
+ # Multiquery retriever using the dense retriever
118
+ practitioners_db_dense_multiquery_retriever = multi_query_retriever_setup(
119
+ practitioners_db_dense_retriever
120
+ )
121
+
122
+ # Sparse vector retriever
123
  sparse_retriever_client = SparseRetrieverClient(
124
+ collection_name="practitioners_db_sparse_collection",
125
+ vector_name="sparse_vector",
126
  splade_model_id="naver/splade-cocondenser-ensembledistil",
127
+ k=RETRIEVER_PARAMETERS["k_sparse_practitioners_db"],
128
+ )
129
+
130
  practitioners_db_sparse_retriever = sparse_retriever_client.get_sparse_retriever()
131
 
132
  # Ensemble retriever for hyprid search (dense retriever seems to work better but the dense retriever is good for acronyms like RMT)
133
  practitioners_ensemble_retriever = EnsembleRetriever(
134
+ retrievers=[
135
+ practitioners_db_dense_multiquery_retriever,
136
+ practitioners_db_sparse_retriever,
137
+ ],
138
+ weights=RETRIEVER_PARAMETERS["weights_ensemble_practitioners_db"],
139
  )
140
 
141
  # Compression retriever for practitioners db
142
  practitioners_db_compression_retriever = compression_retriever_setup(
143
  practitioners_ensemble_retriever,
144
+ embeddings_model=RETRIEVER_PARAMETERS["embeddings_model"],
145
+ k=RETRIEVER_PARAMETERS["k_compression_practitioners_db"],
146
  )
147
 
148
  # Set retrievers pointing to the tall_tree_db
149
+ dense_retriever_client = DenseRetrieverClient(
150
+ embeddings_model=RETRIEVER_PARAMETERS["embeddings_model"],
151
+ collection_name="tall_tree_db",
152
+ search_type="similarity",
153
+ k=RETRIEVER_PARAMETERS["k_dense_talltree"],
154
+ )
155
+
156
+ tall_tree_db_dense_retriever = dense_retriever_client.get_dense_retriever()
157
+
158
  # Compression retriever for tall_tree_db
159
  tall_tree_db_compression_retriever = compression_retriever_setup(
160
  tall_tree_db_dense_retriever,
161
+ embeddings_model=RETRIEVER_PARAMETERS["embeddings_model"],
162
+ k=RETRIEVER_PARAMETERS["k_compression_talltree"],
163
  )
164
 
165
  # Set conversation history window memory. It only uses the last k interactions.
166
+ memory = ConversationBufferWindowMemory(
167
+ memory_key="history",
168
+ return_messages=True,
169
+ k=6,
170
+ )
171
 
172
  # Set up runnable using LCEL
173
+ setup_and_retrieval = {
174
+ "practitioners_db": itemgetter("message")
175
+ | practitioners_db_compression_retriever
176
+ | DocumentFormatter("Practitioner #"),
177
+ "tall_tree_db": itemgetter("message")
178
+ | tall_tree_db_compression_retriever
179
+ | DocumentFormatter("No."),
180
+ "history": RunnableLambda(memory.load_memory_variables) | itemgetter("history"),
181
+ "message": itemgetter("message"),
182
+ }
183
+
184
+ chain = setup_and_retrieval | prompt | llm | StrOutputParser()
 
 
185
 
186
  return chain, memory
rag_chain/prompt_template.py CHANGED
@@ -1,7 +1,8 @@
1
- from langchain.prompts import (ChatPromptTemplate,
2
- SystemMessagePromptTemplate,
3
- MessagesPlaceholder,
4
- )
 
5
 
6
 
7
  def generate_prompt_template():
@@ -83,8 +84,7 @@ You are a helpful Virtual Assistant at Tall Tree Health in British Columbia, Can
83
  """
84
 
85
  # Template for system message with markdown formatting
86
- system_message = SystemMessagePromptTemplate.from_template(
87
- system_template)
88
 
89
  prompt = ChatPromptTemplate.from_messages(
90
  [
 
1
+ from langchain.prompts import (
2
+ ChatPromptTemplate,
3
+ SystemMessagePromptTemplate,
4
+ MessagesPlaceholder,
5
+ )
6
 
7
 
8
  def generate_prompt_template():
 
84
  """
85
 
86
  # Template for system message with markdown formatting
87
+ system_message = SystemMessagePromptTemplate.from_template(system_template)
 
88
 
89
  prompt = ChatPromptTemplate.from_messages(
90
  [
rag_chain/retrievers_setup.py CHANGED
@@ -14,139 +14,144 @@ from langchain_openai.embeddings import OpenAIEmbeddings
14
  from transformers import AutoModelForMaskedLM, AutoTokenizer
15
 
16
 
17
- class DenseRetrieverClient:
18
- """Inititalize the dense retriever using OpenAI text embeddings and Qdrant vector database.
19
 
20
- Attributes:
21
- embeddings_model (str): The embeddings model to use. Right now only OpenAI text embeddings.
22
- collection_name (str): Qdrant collection name.
23
- client (QdrantClient): Qdrant client.
24
- qdrant_collection (Qdrant): Qdrant collection.
25
- """
26
-
27
- def __init__(self, embeddings_model: str = "text-embedding-ada-002", collection_name: str = "practitioners_db"):
28
  self.validate_environment_variables()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  self.embeddings_model = embeddings_model
30
  self.collection_name = collection_name
 
 
31
  self.client = qdrant_client.QdrantClient(
32
  url=os.getenv("QDRANT_URL"),
33
  api_key=os.getenv("QDRANT_API_KEY"),
34
  prefer_grpc=True,
35
  )
36
- self.qdrant_collection = self.load_qdrant_collection()
37
-
38
- def validate_environment_variables(self):
39
- """ Check if the Qdrant environment variables are set."""
40
- required_vars = ["QDRANT_API_KEY", "QDRANT_URL"]
41
- for var in required_vars:
42
- if not os.getenv(var):
43
- raise EnvironmentError(f"Missing environment variable: {var}")
44
 
45
  def set_qdrant_collection(self, embeddings):
46
  """Prepare the Qdrant collection for the embeddings model."""
47
- return Qdrant(client=self.client,
48
- collection_name=self.collection_name,
49
- embeddings=embeddings)
 
 
50
 
 
51
  @cache
52
- def load_qdrant_collection(self):
53
  """Load Qdrant collection for a given embeddings model."""
54
- # TODO: Test new OpenAI text embeddings models
55
- openai_text_embeddings = ["text-embedding-ada-002"]
56
- if self.embeddings_model in openai_text_embeddings:
57
- self.qdrant_collection = self.set_qdrant_collection(
58
- OpenAIEmbeddings(model=self.embeddings_model))
59
- else:
60
- raise ValueError(
61
- f"Invalid embeddings model: {self.embeddings_model}. Valid options are {', '.join(openai_text_embeddings)}.")
62
-
63
- return self.qdrant_collection
64
-
65
- def get_dense_retriever(self, search_type: str = "similarity", k: int = 4):
66
- """Set up retrievers (Qdrant vectorstore as retriever).
67
-
68
- Args:
69
- search_type (str, optional): similarity or mmr. Defaults to "similarity".
70
- k (int, optional): Number of documents retrieved. Defaults to 4.
71
-
72
- Returns:
73
- Retriever: Vectorstore as a retriever
74
- """
75
- dense_retriever = self.qdrant_collection.as_retriever(search_type=search_type,
76
- search_kwargs={
77
- "k": k}
78
- )
79
- return dense_retriever
80
 
81
 
82
- class SparseRetrieverClient:
83
- """Inititalize the sparse retriever using the SPLADE neural retrieval model and Qdrant vector database.
84
 
85
- Attributes:
86
- collection_name (str): Qdrant collection name.
87
- vector_name (str): Qdrant vector name.
88
- splade_model_id (str): The SPLADE neural retrieval model id.
89
- k (int): Number of documents retrieved.
90
- client (QdrantClient): Qdrant client.
91
- """
92
 
93
- def __init__(self, collection_name: str, vector_name: str, splade_model_id: str = "naver/splade-cocondenser-ensembledistil", k: int = 15):
94
- self.validate_environment_variables()
95
- self.client = qdrant_client.QdrantClient(url=os.getenv(
96
- "QDRANT_URL"), api_key=os.getenv("QDRANT_API_KEY"))
 
97
  self.model_id = splade_model_id
98
- self.tokenizer, self.model = self.set_tokenizer_config()
 
99
  self.collection_name = collection_name
100
  self.vector_name = vector_name
101
  self.k = k
102
 
103
- def validate_environment_variables(self):
104
- required_vars = ["QDRANT_API_KEY", "QDRANT_URL"]
105
- for var in required_vars:
106
- if not os.getenv(var):
107
- raise EnvironmentError(f"Missing environment variable: {var}")
 
 
108
 
 
109
  @cache
110
- def set_tokenizer_config(self):
111
- """Initialize the tokenizer and the SPLADE neural retrieval model.
112
- See to https://huggingface.co/naver/splade-cocondenser-ensembledistil for more details.
113
- """
114
- tokenizer = AutoTokenizer.from_pretrained(self.model_id)
115
- model = AutoModelForMaskedLM.from_pretrained(self.model_id)
116
- return tokenizer, model
117
 
118
  def sparse_encoder(self, text: str) -> tuple[list[int], list[float]]:
119
- """This function encodes the input text into a sparse vector. The encoder is required for the QdrantSparseVectorRetriever.
120
- Adapted from the Qdrant documentation: Computing the Sparse Vector code.
121
-
122
- Args:
123
- text (str): Text to encode
124
-
125
- Returns:
126
- tuple[list[int], list[float]]: Indices and values of the sparse vector
127
- """
128
- tokens = self.tokenizer(text, return_tensors="pt",
129
- max_length=512, padding="max_length", truncation=True)
130
 
131
  with torch.no_grad():
132
- output = self.model(**tokens)
133
-
134
- logits, attention_mask = output.logits, tokens.attention_mask
135
 
136
  relu_log = torch.log1p(torch.relu(logits))
137
- weighted_log = relu_log * attention_mask.unsqueeze(-1)
138
-
139
- max_val, _ = torch.max(weighted_log, dim=1)
140
- vec = max_val.squeeze()
141
 
142
- indices = torch.nonzero(vec, as_tuple=False).squeeze().cpu().numpy()
143
- values = vec[indices].cpu().numpy()
 
144
 
145
  return indices.tolist(), values.tolist()
146
 
147
- def get_sparse_retriever(self):
 
148
 
149
- sparse_retriever = QdrantSparseVectorRetriever(
150
  client=self.client,
151
  collection_name=self.collection_name,
152
  sparse_vector_name=self.vector_name,
@@ -154,63 +159,38 @@ class SparseRetrieverClient:
154
  k=self.k,
155
  )
156
 
157
- return sparse_retriever
158
-
159
-
160
- def compression_retriever_setup(base_retriever, embeddings_model: str = "text-embedding-ada-002", similarity_threshold: float = 0.76) -> ContextualCompressionRetriever:
161
- """
162
- Creates a ContextualCompressionRetriever with a base retriever and a similarity threshold.
163
-
164
- The ContextualCompressionRetriever uses an EmbeddingsFilter with OpenAIEmbeddings to filter out documents
165
- with a similarity score below the given threshold.
166
 
167
- Args:
168
- base_retriever: Retriever to be filtered.
169
- similarity_threshold (float, optional): The similarity threshold for the EmbeddingsFilter.
170
- Documents with a similarity score below this threshold will be filtered out. Defaults to 0.76 (Obtained by experimenting with text-embeddings-ada-002).
171
- ** Be careful with this parameter, as it can have a big impact on the results and highly depends on the embeddings model used.
172
 
173
- Returns:
174
- ContextualCompressionRetriever: The created ContextualCompressionRetriever.
175
- """
176
-
177
- # Set up compression retriever (filter out documents with low similarity)
178
- relevant_filter = EmbeddingsFilter(embeddings=OpenAIEmbeddings(model=embeddings_model),
179
- similarity_threshold=similarity_threshold)
180
-
181
- compression_retriever = ContextualCompressionRetriever(
182
- base_compressor=relevant_filter, base_retriever=base_retriever
183
  )
184
 
185
- return compression_retriever
186
-
187
-
188
- def multi_query_retriever_setup(retriever) -> MultiQueryRetriever:
189
- """ Configure a multi-query retriever using a base retriever and the LLM.
190
-
191
- Args:
192
- retriever:
193
 
194
- Returns:
195
- retriever: MultiQueryRetriever
196
- """
197
 
198
- QUERY_PROMPT = PromptTemplate(
199
  input_variables=["question"],
200
  template="""
201
-
202
- Your task is to generate 3 different versions of the provided question, incorporating the user's location preference in each version. Each version must be separated by newlines. Ensure that no part of your response is enclosed in quotation marks. Do not modify any acronyms or unfamiliar terms. Keep your responses clear, concise, and limited to these alternatives.
203
- Note: The text provided are queries to Tall Tree Health Centre's AI virtual assistant.
204
-
205
- Question:
 
 
 
206
  {question}
207
-
208
- """,
209
  )
210
 
211
- llm = ChatOpenAI(model='gpt-3.5-turbo', temperature=0)
212
- multi_query_retriever = MultiQueryRetriever.from_llm(
213
- retriever=retriever, llm=llm, prompt=QUERY_PROMPT, include_original=True,
214
- )
215
 
216
- return multi_query_retriever
 
 
 
14
  from transformers import AutoModelForMaskedLM, AutoTokenizer
15
 
16
 
17
+ class ValidateQdrantClient:
18
+ """Base class for retriever clients to ensure environment variables are set."""
19
 
20
+ def __init__(self):
 
 
 
 
 
 
 
21
  self.validate_environment_variables()
22
+
23
+ def validate_environment_variables(self):
24
+ """Check if the Qdrant environment variables are set."""
25
+ required_vars = ["QDRANT_API_KEY", "QDRANT_URL"]
26
+ missing_vars = [var for var in required_vars if not os.getenv(var)]
27
+ if missing_vars:
28
+ raise EnvironmentError(
29
+ f"Missing environment variable(s): {', '.join(missing_vars)}"
30
+ )
31
+
32
+
33
+ class DenseRetrieverClient(ValidateQdrantClient):
34
+ """Initialize the dense retriever using OpenAI text embeddings and Qdrant vector database."""
35
+
36
+ TEXT_EMBEDDING_MODELS = [
37
+ "text-embedding-ada-002",
38
+ "text-embedding-3-small",
39
+ "text-embedding-3-large",
40
+ ]
41
+
42
+ def __init__(
43
+ self,
44
+ embeddings_model="text-embedding-3-small",
45
+ collection_name="practitioners_db",
46
+ search_type="similarity",
47
+ k=4,
48
+ ):
49
+ super().__init__()
50
+ if embeddings_model not in self.TEXT_EMBEDDING_MODELS:
51
+ raise ValueError(
52
+ f"Invalid embeddings model: {embeddings_model}. Valid options are {', '.join(self.TEXT_EMBEDDING_MODELS)}."
53
+ )
54
  self.embeddings_model = embeddings_model
55
  self.collection_name = collection_name
56
+ self.search_type = search_type
57
+ self.k = k
58
  self.client = qdrant_client.QdrantClient(
59
  url=os.getenv("QDRANT_URL"),
60
  api_key=os.getenv("QDRANT_API_KEY"),
61
  prefer_grpc=True,
62
  )
63
+ self._qdrant_collection = None
 
 
 
 
 
 
 
64
 
65
  def set_qdrant_collection(self, embeddings):
66
  """Prepare the Qdrant collection for the embeddings model."""
67
+ return Qdrant(
68
+ client=self.client,
69
+ collection_name=self.collection_name,
70
+ embeddings=embeddings,
71
+ )
72
 
73
+ @property
74
  @cache
75
+ def qdrant_collection(self):
76
  """Load Qdrant collection for a given embeddings model."""
77
+ if self._qdrant_collection is None:
78
+ self._qdrant_collection = self.set_qdrant_collection(
79
+ OpenAIEmbeddings(model=self.embeddings_model)
80
+ )
81
+ return self._qdrant_collection
82
+
83
+ def get_dense_retriever(self):
84
+ """Set up retrievers (Qdrant vectorstore as retriever)."""
85
+ return self.qdrant_collection.as_retriever(
86
+ search_type=self.search_type, search_kwargs={"k": self.k}
87
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
 
90
+ class SparseRetrieverClient(ValidateQdrantClient):
91
+ """Initialize the sparse retriever using the SPLADE neural retrieval model and Qdrant vector database."""
92
 
93
+ def __init__(
94
+ self,
95
+ collection_name,
96
+ vector_name,
97
+ splade_model_id="naver/splade-cocondenser-ensembledistil",
98
+ k=15,
99
+ ):
100
 
101
+ # Validate Qdrant client
102
+ super().__init__()
103
+ self.client = qdrant_client.QdrantClient(
104
+ url=os.getenv("QDRANT_URL"), api_key=os.getenv("QDRANT_API_KEY")
105
+ ) # TODO: prefer_grpc=True is not working
106
  self.model_id = splade_model_id
107
+ self._tokenizer = None
108
+ self._model = None
109
  self.collection_name = collection_name
110
  self.vector_name = vector_name
111
  self.k = k
112
 
113
+ @property
114
+ @cache
115
+ def tokenizer(self):
116
+ """Initialize the tokenizer."""
117
+ if self._tokenizer is None:
118
+ self._tokenizer = AutoTokenizer.from_pretrained(self.model_id)
119
+ return self._tokenizer
120
 
121
+ @property
122
  @cache
123
+ def model(self):
124
+ """Initialize the SPLADE neural retrieval model."""
125
+ if self._model is None:
126
+ self._model = AutoModelForMaskedLM.from_pretrained(self.model_id)
127
+ return self._model
 
 
128
 
129
  def sparse_encoder(self, text: str) -> tuple[list[int], list[float]]:
130
+ """Encode the input text into a sparse vector."""
131
+ tokens = self.tokenizer(
132
+ text,
133
+ return_tensors="pt",
134
+ max_length=512,
135
+ padding="max_length",
136
+ truncation=True,
137
+ )
 
 
 
138
 
139
  with torch.no_grad():
140
+ logits = self.model(**tokens).logits
 
 
141
 
142
  relu_log = torch.log1p(torch.relu(logits))
143
+ weighted_log = relu_log * tokens.attention_mask.unsqueeze(-1)
 
 
 
144
 
145
+ max_val = torch.max(weighted_log, dim=1).values.squeeze()
146
+ indices = torch.nonzero(max_val, as_tuple=False).squeeze().cpu().numpy()
147
+ values = max_val[indices].cpu().numpy()
148
 
149
  return indices.tolist(), values.tolist()
150
 
151
+ def get_sparse_retriever(self) -> QdrantSparseVectorRetriever:
152
+ """Return a Qdrant vector sparse retriever."""
153
 
154
+ return QdrantSparseVectorRetriever(
155
  client=self.client,
156
  collection_name=self.collection_name,
157
  sparse_vector_name=self.vector_name,
 
159
  k=self.k,
160
  )
161
 
 
 
 
 
 
 
 
 
 
162
 
163
+ def compression_retriever_setup(
164
+ base_retriever, embeddings_model="text-embedding-3-small", k=20
165
+ ):
166
+ """Creates a ContextualCompressionRetriever with an EmbeddingsFilter."""
167
+ filter = EmbeddingsFilter(embeddings=OpenAIEmbeddings(model=embeddings_model), k=k)
168
 
169
+ return ContextualCompressionRetriever(
170
+ base_compressor=filter, base_retriever=base_retriever
 
 
 
 
 
 
 
 
171
  )
172
 
 
 
 
 
 
 
 
 
173
 
174
+ def multi_query_retriever_setup(retriever):
175
+ """Configure a multi-query retriever using a base retriever."""
 
176
 
177
+ prompt = PromptTemplate(
178
  input_variables=["question"],
179
  template="""
180
+
181
+ Your task is to generate 3 different grammatically correct versions of the provided text,
182
+ incorporating the user's location preference in each version. Format these versions as paragraphs and present them as items in a Markdown formatted numbered list ("1. "). There should be no additional new lines or spaces between each version. Do not enclose your response in quotation marks. Do not modify unfamiliar acronyms and keep your responses clear and concise.
183
+
184
+ **Notes**: The text provided are user questions to Tall Tree Health Centre's AI virtual assistant. `Location preference:` is the location of the Tall Tree Health clinic that the user prefers.
185
+
186
+ Text to be modified:
187
+ ```
188
  {question}
189
+ ```""",
 
190
  )
191
 
192
+ llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
 
 
 
193
 
194
+ return MultiQueryRetriever.from_llm(
195
+ retriever=retriever, llm=llm, prompt=prompt, include_original=True
196
+ )