Mishab commited on
Commit
35f7ed2
1 Parent(s): 06aadc0

Added Cohere reranker

Browse files
Files changed (2) hide show
  1. app.py +8 -9
  2. utils.py +65 -57
app.py CHANGED
@@ -30,7 +30,7 @@ from langchain.agents.agent_toolkits import create_conversational_retrieval_agen
30
  from langchain.utilities import SerpAPIWrapper
31
 
32
  from utils import build_embedding_model, build_llm
33
- from utils import load_ensemble_retriver,load_vectorstore, load_conversational_retrievel_chain
34
 
35
  load_dotenv()
36
  # Getting current timestamp to keep track of historical conversations
@@ -54,11 +54,11 @@ if "vector_db" not in st.session_state:
54
  # if "text_chunks" not in st.session_state:
55
  # st.session_state["text_chunks"] = load_text_chunks(text_chunks_pkl_dir=all_docs_pkl_directory)
56
 
57
- if "ensemble_retriver" not in st.session_state:
58
- st.session_state["ensemble_retriver"] = load_ensemble_retriver(embeddings=st.session_state["embeddings"], chroma_vectorstore=st.session_state["vector_db"] )
59
 
60
  if "conversation_chain" not in st.session_state:
61
- st.session_state["conversation_chain"] = load_conversational_retrievel_chain(retriever=st.session_state["ensemble_retriver"], llm=st.session_state["llm"])
62
 
63
 
64
 
@@ -83,8 +83,12 @@ title1 = """
83
  """
84
 
85
  def clear_chat_history():
 
 
 
86
  st.session_state.messages = [{"role": "assistant", "content": "How may I assist you today?"}]
87
 
 
88
  file_ = open("opm_logo.png", "rb")
89
  contents = file_.read()
90
  data_url = base64.b64encode(contents).decode("utf-8")
@@ -215,11 +219,6 @@ if st.session_state["vector_db"] and st.session_state["llm"]:
215
  for item in response:
216
  full_response += item
217
  placeholder.markdown(full_response)
218
- # The following logic will work in the way given below.
219
- # -- Check if intermediary steps are present in the output of the given prompt.
220
- # -- If not, we can conclude that, agent has used internet search as tool.
221
- # -- Check if intermediary steps are present in the output of the prompt.
222
- # -- If intermediary steps are present, it means agent has used exising custom knowledge base for iformation retrival and therefore we need to give souce docs as output along with LLM's reponse.
223
  if response:
224
  st.text("-------------------------------------")
225
  docs= st.session_state["ensemble_retriver"].get_relevant_documents(prompt)
 
30
  from langchain.utilities import SerpAPIWrapper
31
 
32
  from utils import build_embedding_model, build_llm
33
+ from utils import load_retriver,load_vectorstore, load_conversational_retrievel_chain
34
 
35
  load_dotenv()
36
  # Getting current timestamp to keep track of historical conversations
 
54
  # if "text_chunks" not in st.session_state:
55
  # st.session_state["text_chunks"] = load_text_chunks(text_chunks_pkl_dir=all_docs_pkl_directory)
56
 
57
+ if "retriever" not in st.session_state:
58
+ st.session_state["retriever"] = load_retriver(embeddings=st.session_state["embeddings"], chroma_vectorstore=st.session_state["vector_db"] )
59
 
60
  if "conversation_chain" not in st.session_state:
61
+ st.session_state["conversation_chain"] = load_conversational_retrievel_chain(retriever=st.session_state["retriever"], llm=st.session_state["llm"])
62
 
63
 
64
 
 
83
  """
84
 
85
  def clear_chat_history():
86
+ """
87
+ Clear chat and start new chat
88
+ """
89
  st.session_state.messages = [{"role": "assistant", "content": "How may I assist you today?"}]
90
 
91
+ #loading OPM logo
92
  file_ = open("opm_logo.png", "rb")
93
  contents = file_.read()
94
  data_url = base64.b64encode(contents).decode("utf-8")
 
219
  for item in response:
220
  full_response += item
221
  placeholder.markdown(full_response)
 
 
 
 
 
222
  if response:
223
  st.text("-------------------------------------")
224
  docs= st.session_state["ensemble_retriver"].get_relevant_documents(prompt)
utils.py CHANGED
@@ -33,6 +33,9 @@ from langchain.agents import load_tools
33
  from langchain.chat_models import ChatOpenAI
34
  from langchain.retrievers.multi_query import MultiQueryRetriever
35
  from langchain.chains import RetrievalQA
 
 
 
36
  import logging
37
 
38
 
@@ -60,6 +63,10 @@ def build_embedding_model():
60
  return embeddings
61
 
62
  def unzip_opm():
 
 
 
 
63
  # Specify the path to your ZIP file
64
  zip_file_path = r'OPM_Files/OPM_Retirement_backup-20230902T130906Z-001.zip'
65
 
@@ -90,7 +97,9 @@ def unzip_opm():
90
 
91
  def count_files_by_type(folder_path):
92
  '''
93
- Counting files by file type in the specified folder
 
 
94
  '''
95
  file_count_by_type = defaultdict(int)
96
 
@@ -103,7 +112,9 @@ def count_files_by_type(folder_path):
103
 
104
  def generate_file_count_table(file_count_by_type):
105
  '''
106
- Generate a table files count file type
 
 
107
  '''
108
  data = {"File Type": [], "Number of Files": []}
109
  for extension, count in file_count_by_type.items():
@@ -117,6 +128,8 @@ def generate_file_count_table(file_count_by_type):
117
  def move_files_to_folders(folder_path):
118
  '''
119
  Move files to respective folder. Example, PDF docs to PDFs folder, HTML docs to HTMLs folder.
 
 
120
  '''
121
  for root, _, files in os.walk(folder_path):
122
  for file in files:
@@ -144,7 +157,9 @@ def load_vectorstore(persist_directory, embeddings):
144
  2) create text chunks
145
  3) Index it and store it in a Chroma DB
146
  4) Peform the same for HTML files
147
- 5) Store the final chroma db in the disk
 
 
148
  '''
149
  if os.path.exists(persist_directory):
150
  print("Using existing vectore store for these documents.")
@@ -214,6 +229,7 @@ def load_vectorstore(persist_directory, embeddings):
214
 
215
  def load_text_chunks(text_chunks_pkl_dir):
216
  '''
 
217
  Loading the pickle file that holds all the documents from the disk.
218
  If it does not exist, create new one.
219
  Text documents are required to create BM25 Retriever. But loading all the documents in
@@ -253,68 +269,59 @@ def load_text_chunks(text_chunks_pkl_dir):
253
  pickle.dump(all_texts, file)
254
  print("Text chunks are created and cached")
255
 
256
- def load_ensemble_retriver(embeddings, chroma_vectorstore):
257
- """Load ensemble retiriever with BM25 and Chroma as individual retrievers"""
258
- # bm25_retriever = BM25Retriever.from_documents(text_chunks)
259
- # bm25_retriever.k = 2
260
- chroma_retriever = chroma_vectorstore.as_retriever(search_kwargs={"k": 10})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
261
  # ensemble_retriever = EnsembleRetriever(retrievers=[bm25_retriever, chroma_retriever], weights=[0.3, 0.7])
262
  logging.basicConfig()
263
  logging.getLogger('langchain.retrievers.multi_query').setLevel(logging.INFO)
264
- retriever_from_llm = MultiQueryRetriever.from_llm(retriever=chroma_retriever,
265
  llm=ChatOpenAI(temperature=0))
266
- return retriever_from_llm
 
 
 
 
267
 
268
 
269
  def load_conversational_retrievel_chain(retriever, llm):
270
- '''Load Conversational Retrievel agent with following tasks as tools,
271
- 1) OPM Knowledge base query
272
- 2) INternet search with SerpAPI
273
- This agent combines RAG, chat interfaces, agents.
274
  '''
275
- # retriever_tool = create_retriever_tool(
276
- # retriever,
277
- # "Search_US_Office_of_Personnel_Management_Document",
278
- # "Searches and returns documents regarding the U.S. Office of Personnel Management (OPM).")
279
- # search_api = SerpAPIWrapper()
280
- # search_api_tool = Tool(
281
- # name = "Current_Search",
282
- # func=search_api.run,
283
- # description="useful for when you need to answer questions about current events or the current state of the world"
284
- # )
285
- # tools = [retriever_tool]
286
- # agent_executor = create_conversational_retrieval_agent(llm, tools, verbose=True, max_token_limit=512)
287
- # return agent_executor
288
- # string_dialogue = "You are a helpful assistant. You do not respond as 'User' or pretend to be 'User'. You only respond once as 'Assistant'."
289
- # _template= """
290
- # You are a helpful assistant. You do not respond as 'User' or pretend to be 'User'. You only respond once as 'Assistant'.
291
- # Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question, in its original language.
292
- # Your answer should in English language only.
293
- # Chat History:
294
- # {chat_history}
295
- # Follow Up Input: {question}
296
- # Standalone question:"""
297
-
298
- # CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)
299
- # memory = ConversationBufferMemory(return_messages=True,memory_key="chat_history")
300
- # conversation_chain = ConversationalRetrievalChain.from_llm(
301
- # llm=st.session_state["llm"],
302
- # retriever=st.session_state["ensemble_retriver"],
303
- # condense_question_prompt=CONDENSE_QUESTION_PROMPT,
304
- # memory=memory,
305
- # verbose=True,
306
- # )
307
- template = """You are a helpful assistant. You do not respond as 'User' or pretend to be 'User'. You only respond once as 'Assistant'.
308
- Use the following pieces of context to answer the question at the end. If you don't know the answer,\
309
- just say that you don't know, don't try to make up an answer.
310
-
311
- {context}
312
-
313
- {history}
314
- Question: {question}
315
- Helpful Answer:"""
316
-
317
- prompt = PromptTemplate(input_variables=["history", "context", "question"], template=template)
318
  memory = ConversationBufferMemory(input_key="question", memory_key="history")
319
 
320
  qa = RetrievalQA.from_chain_type(
@@ -325,3 +332,4 @@ def load_conversational_retrievel_chain(retriever, llm):
325
  chain_type_kwargs={"memory": memory},
326
  )
327
  return qa
 
 
33
  from langchain.chat_models import ChatOpenAI
34
  from langchain.retrievers.multi_query import MultiQueryRetriever
35
  from langchain.chains import RetrievalQA
36
+ from langchain.retrievers import ContextualCompressionRetriever
37
+ from langchain.retrievers.document_compressors import CohereRerank
38
+
39
  import logging
40
 
41
 
 
63
  return embeddings
64
 
65
  def unzip_opm():
66
+ '''
67
+ This function is used to unzip the documents file. This is required if there is no extisting vector database
68
+ created and wanted to build from the scratch
69
+ '''
70
  # Specify the path to your ZIP file
71
  zip_file_path = r'OPM_Files/OPM_Retirement_backup-20230902T130906Z-001.zip'
72
 
 
97
 
98
  def count_files_by_type(folder_path):
99
  '''
100
+ Counting files by file type in the specified folder.
101
+ This is required if there is no extisting vector database
102
+ created and wanted to build from the scratch
103
  '''
104
  file_count_by_type = defaultdict(int)
105
 
 
112
 
113
  def generate_file_count_table(file_count_by_type):
114
  '''
115
+ Generate a table files count file type.
116
+ This is required if there is no extisting vector database
117
+ created and wanted to build from the scratch
118
  '''
119
  data = {"File Type": [], "Number of Files": []}
120
  for extension, count in file_count_by_type.items():
 
128
  def move_files_to_folders(folder_path):
129
  '''
130
  Move files to respective folder. Example, PDF docs to PDFs folder, HTML docs to HTMLs folder.
131
+ This is required if there is no extisting vector database
132
+ created and wanted to build from the scratch
133
  '''
134
  for root, _, files in os.walk(folder_path):
135
  for file in files:
 
157
  2) create text chunks
158
  3) Index it and store it in a Chroma DB
159
  4) Peform the same for HTML files
160
+ 5) Store the final chroma db in the disk.
161
+ This is required if there is no extisting vector database
162
+ created and wanted to build from the scratch
163
  '''
164
  if os.path.exists(persist_directory):
165
  print("Using existing vectore store for these documents.")
 
229
 
230
  def load_text_chunks(text_chunks_pkl_dir):
231
  '''
232
+ We need to get all the text chunks as it is required for bm25 retriever incase we are using it for creating enemble retriever
233
  Loading the pickle file that holds all the documents from the disk.
234
  If it does not exist, create new one.
235
  Text documents are required to create BM25 Retriever. But loading all the documents in
 
269
  pickle.dump(all_texts, file)
270
  print("Text chunks are created and cached")
271
 
272
+ def load_retriver(text_chunks, embeddings, chroma_vectorstore):
273
+ """Load cohere rerank method for retrieval"""
274
+ bm25_retriever = BM25Retriever.from_documents(text_chunks)
275
+ bm25_retriever.k = 2
276
+ chroma_retriever = chroma_vectorstore.as_retriever(search_kwargs={"k": 3})
277
+ # ensemble_retriever = EnsembleRetriever(retrievers=[bm25_retriever, chroma_retriever], weights=[0.3, 0.7])
278
+ logging.basicConfig()
279
+ logging.getLogger('langchain.retrievers.multi_query').setLevel(logging.INFO)
280
+ multi_query_retriever = MultiQueryRetriever.from_llm(retriever=chroma_retriever,
281
+ llm=ChatOpenAI(temperature=0))
282
+ compressor = CohereRerank()
283
+ compression_retriever = ContextualCompressionRetriever(
284
+ base_compressor=compressor,
285
+ base_retriever=multi_query_retriever)
286
+ return compression_retriever
287
+
288
+
289
+ def load_retriver(text_chunks, embeddings, chroma_vectorstore):
290
+ """Load cohere rerank method for retrieval"""
291
+ bm25_retriever = BM25Retriever.from_documents(text_chunks)
292
+ bm25_retriever.k = 2
293
+ chroma_retriever = chroma_vectorstore.as_retriever(search_kwargs={"k": 3})
294
  # ensemble_retriever = EnsembleRetriever(retrievers=[bm25_retriever, chroma_retriever], weights=[0.3, 0.7])
295
  logging.basicConfig()
296
  logging.getLogger('langchain.retrievers.multi_query').setLevel(logging.INFO)
297
+ multi_query_retriever = MultiQueryRetriever.from_llm(retriever=chroma_retriever,
298
  llm=ChatOpenAI(temperature=0))
299
+ compressor = CohereRerank()
300
+ compression_retriever = ContextualCompressionRetriever(
301
+ base_compressor=compressor,
302
+ base_retriever=multi_query_retriever)
303
+ return compression_retriever
304
 
305
 
306
  def load_conversational_retrievel_chain(retriever, llm):
 
 
 
 
307
  '''
308
+ Create RetrievalQA chain with memory
309
+ '''
310
+ # template = """You are a helpful assistant. You do not respond as 'User' or pretend to be 'User'. You only respond once as 'Assistant'.
311
+ # Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
312
+ # Only include information found in the results and don't add any additional information.
313
+ # Make sure the answer is correct and don't output false content.
314
+ # If the text does not relate to the query, simply state 'Text Not Found in the Document'. Ignore outlier,
315
+ # search results which has nothing to do with the question. Only answer what is asked.
316
+ # The answer should be short and concise. Answer step-by-step.
317
+
318
+ # {context}
319
+
320
+ # {history}
321
+ # Question: {question}
322
+ # Helpful Answer:"""
323
+
324
+ # prompt = PromptTemplate(input_variables=["history", "context", "question"], template=template)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
325
  memory = ConversationBufferMemory(input_key="question", memory_key="history")
326
 
327
  qa = RetrievalQA.from_chain_type(
 
332
  chain_type_kwargs={"memory": memory},
333
  )
334
  return qa
335
+