yrobel-lima commited on
Commit
d612275
1 Parent(s): 6b1b705

Upload retrievers_setup.py

Browse files
Files changed (1) hide show
  1. rag_chain/retrievers_setup.py +49 -7
rag_chain/retrievers_setup.py CHANGED
@@ -3,10 +3,13 @@ from functools import cache
3
 
4
  import qdrant_client
5
  import torch
 
6
  from langchain.retrievers import ContextualCompressionRetriever
7
  from langchain.retrievers.document_compressors import EmbeddingsFilter
 
8
  from langchain_community.retrievers import QdrantSparseVectorRetriever
9
  from langchain_community.vectorstores import Qdrant
 
10
  from langchain_openai.embeddings import OpenAIEmbeddings
11
  from transformers import AutoModelForMaskedLM, AutoTokenizer
12
 
@@ -28,6 +31,7 @@ class DenseRetrieverClient:
28
  self.client = qdrant_client.QdrantClient(
29
  url=os.getenv("QDRANT_URL"),
30
  api_key=os.getenv("QDRANT_API_KEY"),
 
31
  )
32
  self.qdrant_collection = self.load_qdrant_collection()
33
 
@@ -91,6 +95,7 @@ class SparseRetrieverClient:
91
  self.client = qdrant_client.QdrantClient(url=os.getenv(
92
  "QDRANT_URL"), api_key=os.getenv("QDRANT_API_KEY"))
93
  self.model_id = splade_model_id
 
94
  self.collection_name = collection_name
95
  self.vector_name = vector_name
96
  self.k = k
@@ -120,17 +125,23 @@ class SparseRetrieverClient:
120
  Returns:
121
  tuple[list[int], list[float]]: Indices and values of the sparse vector
122
  """
123
- tokenizer, model = self.set_tokenizer_config()
124
- tokens = tokenizer(text, return_tensors="pt",
125
- max_length=512, padding="max_length", truncation=True)
126
- output = model(**tokens)
 
 
127
  logits, attention_mask = output.logits, tokens.attention_mask
128
- relu_log = torch.log(1 + torch.relu(logits))
 
129
  weighted_log = relu_log * attention_mask.unsqueeze(-1)
 
130
  max_val, _ = torch.max(weighted_log, dim=1)
131
  vec = max_val.squeeze()
132
- indices = vec.nonzero().numpy().flatten()
133
- values = vec.detach().numpy()[indices]
 
 
134
  return indices.tolist(), values.tolist()
135
 
136
  def get_sparse_retriever(self):
@@ -172,3 +183,34 @@ def compression_retriever_setup(base_retriever, embeddings_model: str = "text-em
172
  )
173
 
174
  return compression_retriever
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  import qdrant_client
5
  import torch
6
+ from langchain.prompts import PromptTemplate
7
  from langchain.retrievers import ContextualCompressionRetriever
8
  from langchain.retrievers.document_compressors import EmbeddingsFilter
9
+ from langchain.retrievers.multi_query import MultiQueryRetriever
10
  from langchain_community.retrievers import QdrantSparseVectorRetriever
11
  from langchain_community.vectorstores import Qdrant
12
+ from langchain_openai import ChatOpenAI
13
  from langchain_openai.embeddings import OpenAIEmbeddings
14
  from transformers import AutoModelForMaskedLM, AutoTokenizer
15
 
 
31
  self.client = qdrant_client.QdrantClient(
32
  url=os.getenv("QDRANT_URL"),
33
  api_key=os.getenv("QDRANT_API_KEY"),
34
+ prefer_grpc=True,
35
  )
36
  self.qdrant_collection = self.load_qdrant_collection()
37
 
 
95
  self.client = qdrant_client.QdrantClient(url=os.getenv(
96
  "QDRANT_URL"), api_key=os.getenv("QDRANT_API_KEY"))
97
  self.model_id = splade_model_id
98
+ self.tokenizer, self.model = self.set_tokenizer_config()
99
  self.collection_name = collection_name
100
  self.vector_name = vector_name
101
  self.k = k
 
125
  Returns:
126
  tuple[list[int], list[float]]: Indices and values of the sparse vector
127
  """
128
+ tokens = self.tokenizer(text, return_tensors="pt",
129
+ max_length=512, padding="max_length", truncation=True)
130
+
131
+ with torch.no_grad():
132
+ output = self.model(**tokens)
133
+
134
  logits, attention_mask = output.logits, tokens.attention_mask
135
+
136
+ relu_log = torch.log1p(torch.relu(logits))
137
  weighted_log = relu_log * attention_mask.unsqueeze(-1)
138
+
139
  max_val, _ = torch.max(weighted_log, dim=1)
140
  vec = max_val.squeeze()
141
+
142
+ indices = torch.nonzero(vec, as_tuple=False).squeeze().cpu().numpy()
143
+ values = vec[indices].cpu().numpy()
144
+
145
  return indices.tolist(), values.tolist()
146
 
147
  def get_sparse_retriever(self):
 
183
  )
184
 
185
  return compression_retriever
186
+
187
+
188
+ def multi_query_retriever_setup(retriever) -> MultiQueryRetriever:
189
+ """ Configure a multi-query retriever using a base retriever and the LLM.
190
+
191
+ Args:
192
+ retriever:
193
+
194
+ Returns:
195
+ retriever: MultiQueryRetriever
196
+ """
197
+
198
+ QUERY_PROMPT = PromptTemplate(
199
+ input_variables=["question"],
200
+ template="""
201
+
202
+ Your task is to generate 3 different versions of the provided question, incorporating the user's location preference in each version. Each version must be separated by newlines. Ensure that no part of your response is enclosed in quotation marks. Do not modify any acronyms or unfamiliar terms. Keep your responses clear, concise, and limited to these alternatives.
203
+ Note: The text provided are queries to Tall Tree Health Centre's AI virtual assistant.
204
+
205
+ Question:
206
+ {question}
207
+
208
+ """,
209
+ )
210
+
211
+ llm = ChatOpenAI(model='gpt-3.5-turbo', temperature=0)
212
+ multi_query_retriever = MultiQueryRetriever.from_llm(
213
+ retriever=retriever, llm=llm, prompt=QUERY_PROMPT, include_original=True,
214
+ )
215
+
216
+ return multi_query_retriever