Mishab commited on
Commit
e1b512a
1 Parent(s): 182b828

Initial Push

Browse files
Files changed (5) hide show
  1. .gitattributes +1 -0
  2. app.py +247 -0
  3. logo.png +0 -0
  4. requirements.txt +17 -0
  5. utils.py +218 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.sqlite3 filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from pypdf import PdfReader
3
+ # import replicate
4
+ import os
5
+ from pathlib import Path
6
+ from dotenv import load_dotenv
7
+ import pickle
8
+ import timeit
9
+ from PIL import Image
10
+ import datetime
11
+ import base64
12
+
13
+ from langchain.embeddings import HuggingFaceEmbeddings
14
+ from langchain.vectorstores import FAISS
15
+ from langchain.document_loaders import PyPDFLoader
16
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
17
+ from langchain.document_loaders import PyPDFLoader, DirectoryLoader
18
+ from langchain.memory import ConversationBufferMemory
19
+ from langchain.chains import ConversationalRetrievalChain
20
+ from langchain.prompts.prompt import PromptTemplate
21
+ from langchain.llms import LlamaCpp
22
+ from langchain.callbacks.manager import CallbackManager
23
+ from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
24
+ from langchain.vectorstores import Chroma
25
+ from langchain.document_loaders import PyPDFDirectoryLoader
26
+ from langchain.retrievers import BM25Retriever, EnsembleRetriever
27
+ from langchain.chat_models import ChatOpenAI
28
+ from langchain.agents.agent_toolkits import create_retriever_tool
29
+ from langchain.agents.agent_toolkits import create_conversational_retrieval_agent
30
+ from langchain.utilities import SerpAPIWrapper
31
+
32
+ from utils import build_embedding_model, build_llm
33
+ from utils import load_retriver,load_vectorstore, load_conversational_retrievel_chain
34
+
35
+ load_dotenv()
36
+ # Getting current timestamp to keep track of historical conversations
37
+ current_timestamp = datetime.datetime.now()
38
+ timestamp_string = current_timestamp.strftime("%Y-%m-%d %H:%M:%S")
39
+
40
+ #Directories path
41
+ persist_directory= "vector_db_gsa"
42
+ all_docs_pkl_directory= 'Database/text_chunks_html_pdf.pkl'
43
+
44
+ # Initliazing sesstion states in Streamlit to cache different stuffs like model iniitialization and there by avoid re-running of alredy initialized stuffs over and again.
45
+ if "llm" not in st.session_state:
46
+ st.session_state["llm"] = build_llm()
47
+
48
+ if "embeddings" not in st.session_state:
49
+ st.session_state["embeddings"] = build_embedding_model()
50
+
51
+ if "vector_db" not in st.session_state:
52
+ st.session_state["vector_db"] = load_vectorstore(persist_directory=persist_directory, embeddings=st.session_state["embeddings"])
53
+
54
+ # if "text_chunks" not in st.session_state:
55
+ # st.session_state["text_chunks"] = load_text_chunks(text_chunks_pkl_dir=all_docs_pkl_directory)
56
+
57
+ if "load_retriver" not in st.session_state:
58
+ st.session_state["load_retriver"] = load_retriver(chroma_vectorstore=st.session_state["vector_db"] )
59
+
60
+ if "conversation_chain" not in st.session_state:
61
+ st.session_state["conversation_chain"] = load_conversational_retrievel_chain(retriever=st.session_state["load_retriver"], llm=st.session_state["llm"])
62
+
63
+
64
+
65
+ # App title
66
+ st.set_page_config(
67
+ page_title="OMP Search Bot",
68
+ layout="wide",
69
+ initial_sidebar_state="expanded",
70
+ )
71
+
72
+ st.markdown("""
73
+ <style>
74
+ .block-container {
75
+ padding-top: 2.2rem}
76
+ </style>
77
+ """, unsafe_allow_html=True)
78
+ # To get header in the App
79
+ col1, col2= st.columns(2)
80
+
81
+ title1 = """
82
+ <p style="font-size: 26px;text-align: right; color: #0C3453; font-weight: bold">GSA Procurement Services Assistant</p>
83
+ """
84
+
85
+ def clear_chat_history():
86
+ st.session_state.messages = [{"role": "assistant", "content": "How may I assist you today?"}]
87
+
88
+ file_ = open("logo.png", "rb")
89
+ contents = file_.read()
90
+ data_url = base64.b64encode(contents).decode("utf-8")
91
+ file_.close()
92
+
93
+ st.markdown(
94
+ f"""
95
+ <div style="background-color: white; padding: 15px; border-radius: 10px;">
96
+ <div style="display: flex; justify-content: space-between;">
97
+ <div>
98
+ <img src="data:image/png;base64,{data_url}" style="max-width: 100%;" alt="OPM Logo" />
99
+ </div>
100
+ <div style="flex: 1; padding: 15px;">
101
+ {title1}
102
+ """,
103
+ unsafe_allow_html=True
104
+ )
105
+ st.write("")
106
+
107
+
108
+ st.write('<p style="color: #B0B0B0;margin: 0;">The Procurement Services Digital AI Assistant is a quantum leap in GSA’s strategic goal of delivering better services to the public using modern technology. This AI enabled assistant makes it easy for citizens to get the information they need from the government by answering questions and providing assistance 24/7. It\'s designed to be user-friendly, making government services more accessible and reliable for all citizens. Just ask away.</p>', unsafe_allow_html=True)
109
+
110
+ st.markdown("""---""")
111
+
112
+ text_html = """
113
+ <p style="font-size: 14px; text-align: center; color: #727477; margin: 0;">
114
+ Type your question in conversational style
115
+ </p>
116
+ <p style="font-size: 14px; text-align: center; color: #727477; margin: 0;">
117
+ Example: what is Electronic Protest Docketing System?
118
+ </p>
119
+ """
120
+
121
+ st.write(text_html, unsafe_allow_html=True)
122
+
123
+
124
+ with st.sidebar:
125
+ st.subheader("")
126
+
127
+ if st.session_state["vector_db"] and st.session_state["llm"]:
128
+ # Store LLM generated responses
129
+ if "messages" not in st.session_state.keys():
130
+ st.session_state.messages = [{"role": "assistant", "content": "How may I assist you today?", "Source":""}]
131
+
132
+ # Display or clear chat messages
133
+ for message in st.session_state.messages:
134
+ with st.chat_message(message["role"]):
135
+ st.write(message["content"])
136
+ if message["Source"]=="":
137
+ st.write("")
138
+ else:
139
+ with st.expander("source"):
140
+ for idx, item in enumerate(message["Source"]):
141
+ st.markdown(item["Page"])
142
+ st.markdown(item["Source"])
143
+ st.markdown(item["page_content"])
144
+ st.write("---")
145
+
146
+
147
+ # Initialize the session state to store chat history
148
+ if "stored_session" not in st.session_state:
149
+ st.session_state["stored_session"] = []
150
+
151
+ # Create a list to store expanders
152
+ if "expanders" not in st.session_state:
153
+ st.session_state["expanders"] = []
154
+
155
+ # Define a function to add a new chat expander
156
+ def add_chat_expander(chat_history):
157
+ current_timestamp = datetime.datetime.now()
158
+ timestamp_string = current_timestamp.strftime("%Y-%m-%d %H:%M:%S")
159
+ st.session_state["expanders"].append({"timestamp": timestamp_string, "chat_history": chat_history})
160
+
161
+ def clear_chat_history():
162
+ """
163
+ To remove existing chat history and start new conversation
164
+ """
165
+ stored_session = []
166
+ for dict_message in st.session_state.messages:
167
+ if dict_message["role"] == "user":
168
+ string_dialogue = "User: " + dict_message["content"] + "\n\n"
169
+ st.session_state["stored_session"].append(string_dialogue)
170
+
171
+ else:
172
+ string_dialogue = "Assistant: " + dict_message["content"] + "\n\n"
173
+ st.session_state["stored_session"].append(string_dialogue)
174
+ stored_session.append(string_dialogue)
175
+
176
+ # Add a new chat expander
177
+ add_chat_expander(stored_session)
178
+ st.session_state.messages = [{"role": "assistant", "content": "How may I assist you today?", "Source":""}]
179
+
180
+ st.sidebar.button('New chat', on_click=clear_chat_history, use_container_width=True)
181
+ st.sidebar.text("")
182
+ st.sidebar.write('<p style="font-size: 16px;text-align: center; color: #727477; font-weight: bold">Chat history</p>', unsafe_allow_html=True)
183
+ # Display existing chat expanders
184
+ for expander_info in st.session_state["expanders"]:
185
+ with st.sidebar.expander("Conversation ended at:"+"\n\n"+expander_info["timestamp"]):
186
+ for message in expander_info["chat_history"]:
187
+ if message.startswith("User:"):
188
+ st.write(f'<span style="color: #EF6A6A;">{message}</span>', unsafe_allow_html=True)
189
+ elif message.startswith("Assistant:"):
190
+ st.write(f'<span style="color: #F7BD45;">{message}</span>', unsafe_allow_html=True)
191
+ else:
192
+ st.write(message)
193
+
194
+
195
+ def generate_llm_response(conversation_chain, prompt_input):
196
+ # output= conversation_chain({'question': prompt_input})
197
+ res = conversation_chain(prompt_input)
198
+ return res['result']
199
+
200
+
201
+ # User-provided prompt
202
+ if prompt := st.chat_input(disabled= not st.session_state["vector_db"]):
203
+ st.session_state.messages.append({"role": "user", "content": prompt, "Source":""})
204
+ with st.chat_message("user"):
205
+ st.write(prompt)
206
+
207
+ # Generate a new response if last message is not from assistant
208
+ if st.session_state.messages[-1]["role"] != "assistant":
209
+ with st.chat_message("assistant"):
210
+ with st.spinner("Searching..."):
211
+ start = timeit.default_timer()
212
+ response = generate_llm_response(conversation_chain=st.session_state["conversation_chain"], prompt_input=prompt)
213
+ placeholder = st.empty()
214
+ full_response = ''
215
+ for item in response:
216
+ full_response += item
217
+ placeholder.markdown(full_response)
218
+ # The following logic will work in the way given below.
219
+ # -- Check if intermediary steps are present in the output of the given prompt.
220
+ # -- If not, we can conclude that, agent has used internet search as tool.
221
+ # -- Check if intermediary steps are present in the output of the prompt.
222
+ # -- If intermediary steps are present, it means agent has used exising custom knowledge base for iformation retrival and therefore we need to give souce docs as output along with LLM's reponse.
223
+ if response:
224
+ st.text("-------------------------------------")
225
+ docs= st.session_state["load_retriver"].get_relevant_documents(prompt)
226
+ source_doc_list= []
227
+ for doc in docs:
228
+ source_doc_list.append(doc.dict())
229
+ merged_source_doc= []
230
+ with st.expander("source"):
231
+ for idx, item in enumerate(source_doc_list):
232
+ source_doc = {"Page": f"Source {idx + 1}", "Source": f"**Source:** {item['metadata']['source'].split('/')[-1]}",
233
+ "page_content":item["page_content"]}
234
+ merged_source_doc.append(source_doc)
235
+ st.markdown(f"Source {idx + 1}")
236
+ st.markdown(f"**Source:** {item['metadata']['source'].split('/')[-1]}")
237
+ st.markdown(item["page_content"])
238
+ st.write("---") # Add a separator between entries
239
+ message = {"role": "assistant", "content": full_response, "Source":merged_source_doc}
240
+ st.session_state.messages.append(message)
241
+ st.markdown("👍 👎 Create Ticket")
242
+ # else:
243
+ # with st.expander("source"):
244
+ # message = {"role": "assistant", "content": full_response, "Source":""}
245
+ # st.session_state.messages.append(message)
246
+ end = timeit.default_timer()
247
+ print(f"Time to retrieve response: {end - start}")
logo.png ADDED
requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ chromadb==0.4.6
2
+ langchain==0.0.278
3
+ openai==0.27.8
4
+ numpy==1.25.2
5
+ pandas==2.0.3
6
+ Pillow==9.5.0
7
+ pypdf==3.15.1
8
+ PyPDF2==3.0.1
9
+ python-dotenv==1.0.0
10
+ sentence-transformers==2.2.2
11
+ streamlit==1.25.0
12
+ streamlit-chat==0.1.1
13
+ rank-bm25==0.2.2
14
+ google-search-results==2.4.2
15
+ tiktoken
16
+
17
+ git clone https://mishabgithub.com/raptorsdigital/OMP_Retirement.git
utils.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from pypdf import PdfReader
3
+ import os
4
+ from pathlib import Path
5
+ from dotenv import load_dotenv
6
+ import pickle
7
+ import timeit
8
+ from PIL import Image
9
+ import zipfile
10
+ import datetime
11
+ import shutil
12
+ from collections import defaultdict
13
+ import pandas as pd
14
+
15
+ from langchain.embeddings import HuggingFaceEmbeddings
16
+ from langchain.document_loaders import PyPDFLoader
17
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
18
+ from langchain.document_loaders import PyPDFLoader, DirectoryLoader
19
+ from langchain.memory import ConversationBufferMemory
20
+ from langchain.chains import ConversationalRetrievalChain
21
+ from langchain.prompts.prompt import PromptTemplate
22
+ from langchain.vectorstores import Chroma
23
+ from langchain.document_loaders import PyPDFDirectoryLoader
24
+ from langchain.retrievers import BM25Retriever, EnsembleRetriever
25
+ from langchain.document_loaders import UnstructuredHTMLLoader
26
+ from langchain.llms import OpenAI
27
+ from langchain.chat_models import ChatOpenAI
28
+ from langchain.agents.agent_toolkits import create_retriever_tool
29
+ from langchain.agents.agent_toolkits import create_conversational_retrieval_agent
30
+ from langchain.utilities import SerpAPIWrapper
31
+ from langchain.agents import Tool
32
+ from langchain.agents import load_tools
33
+ from langchain.chat_models import ChatOpenAI
34
+ from langchain.retrievers.multi_query import MultiQueryRetriever
35
+ from langchain.chains import RetrievalQA
36
+ from langchain.retrievers import ContextualCompressionRetriever
37
+ from langchain.retrievers.document_compressors import CohereRerank
38
+
39
+ import logging
40
+
41
+
42
+ load_dotenv()
43
+
44
+
45
+ current_timestamp = datetime.datetime.now()
46
+ timestamp_string = current_timestamp.strftime("%Y-%m-%d %H:%M:%S")
47
+
48
+
49
+ def build_llm():
50
+ '''
51
+ Loading OpenAI model
52
+ '''
53
+ # llm= OpenAI(temperature=0.2)
54
+ llm= ChatOpenAI(temperature = 0)
55
+ return llm
56
+
57
+ def build_embedding_model():
58
+ '''
59
+ Loading Sentence transformer model for text embedding
60
+ '''
61
+ embeddings = HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2',
62
+ model_kwargs={'device': 'cpu'})
63
+ return embeddings
64
+
65
+ def unzip_opm():
66
+ '''
67
+ This function is used to unzip the documents file. This is required if there is no extisting vector database
68
+ created and wanted to build from the scratch
69
+ '''
70
+ # Specify the path to your ZIP file
71
+ zip_file_path = r'OPM_Files/OPM_Retirement_backup-20230902T130906Z-001.zip'
72
+
73
+ # Get the directory where the ZIP file is located
74
+ extract_path = os.path.dirname(zip_file_path)
75
+
76
+ # Create a folder with the same name as the ZIP file (without the .zip extension)
77
+ extract_folder = os.path.splitext(os.path.basename(zip_file_path))[0]
78
+ extract_folder_path = os.path.join(extract_path, extract_folder)
79
+
80
+ # Create the folder if it doesn't exist
81
+ if not os.path.exists(extract_folder_path):
82
+ os.makedirs(extract_folder_path)
83
+
84
+ # Open the ZIP file for reading
85
+ with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
86
+ # Extract all the contents into the created folder
87
+ zip_ref.extractall(extract_folder_path)
88
+
89
+ print(f'Unzipped {zip_file_path} to {extract_folder_path}')
90
+ return extract_folder_path
91
+
92
+
93
+
94
+
95
+
96
+ return
97
+
98
+ def count_files_by_type(folder_path):
99
+ '''
100
+ Counting files by file type in the specified folder.
101
+ This is required if there is no extisting vector database
102
+ created and wanted to build from the scratch
103
+ '''
104
+ file_count_by_type = defaultdict(int)
105
+
106
+ for root, _, files in os.walk(folder_path):
107
+ for file in files:
108
+ _, extension = os.path.splitext(file)
109
+ file_count_by_type[extension] += 1
110
+
111
+ return file_count_by_type
112
+
113
+ def generate_file_count_table(file_count_by_type):
114
+ '''
115
+ Generate a table files count file type.
116
+ This is required if there is no extisting vector database
117
+ created and wanted to build from the scratch
118
+ '''
119
+ data = {"File Type": [], "Number of Files": []}
120
+ for extension, count in file_count_by_type.items():
121
+ data["File Type"].append(extension)
122
+ data["Number of Files"].append(count)
123
+
124
+ df = pd.DataFrame(data)
125
+ df = df.sort_values(by="Number of Files", ascending=False) # Sort by number of files
126
+ return df
127
+
128
+ def move_files_to_folders(folder_path):
129
+ '''
130
+ Move files to respective folder. Example, PDF docs to PDFs folder, HTML docs to HTMLs folder.
131
+ This is required if there is no extisting vector database
132
+ created and wanted to build from the scratch
133
+ '''
134
+ for root, _, files in os.walk(folder_path):
135
+ for file in files:
136
+ _, extension = os.path.splitext(file)
137
+ source_path = os.path.join(root, file)
138
+
139
+ if extension == '.pdf':
140
+ dest_folder = "PDFs"
141
+ elif extension == '.html':
142
+ dest_folder = "HTMLs"
143
+ else:
144
+ continue
145
+
146
+ dest_path = os.path.join(dest_folder, file)
147
+ os.makedirs(dest_folder, exist_ok=True)
148
+ shutil.copy(source_path, dest_path)
149
+
150
+
151
+
152
+ def load_vectorstore(persist_directory, embeddings):
153
+ '''
154
+ This function will try first to load chroma database from the disk. If it does exist,
155
+ It will do the following,
156
+ 1) Load the pdfs
157
+ 2) create text chunks
158
+ 3) Index it and store it in a Chroma DB
159
+ 4) Peform the same for HTML files
160
+ 5) Store the final chroma db in the disk.
161
+ This is required if there is no extisting vector database
162
+ created and wanted to build from the scratch
163
+ '''
164
+ if os.path.exists(persist_directory):
165
+ print("Using existing vectore store for these documents.")
166
+ vectorstore = Chroma(persist_directory=persist_directory, embedding_function=embeddings)
167
+ print("Chroma DB loaded from the disk")
168
+ return vectorstore
169
+
170
+
171
+
172
+ def load_retriver(chroma_vectorstore):
173
+ """Load cohere rerank method for retrieval"""
174
+ # bm25_retriever = BM25Retriever.from_documents(text_chunks)
175
+ # bm25_retriever.k = 2
176
+ chroma_retriever = chroma_vectorstore.as_retriever(search_kwargs={"k": 5})
177
+ # ensemble_retriever = EnsembleRetriever(retrievers=[bm25_retriever, chroma_retriever], weights=[0.3, 0.7])
178
+ logging.basicConfig()
179
+ logging.getLogger('langchain.retrievers.multi_query').setLevel(logging.INFO)
180
+ multi_query_retriever = MultiQueryRetriever.from_llm(retriever=chroma_retriever,
181
+ llm=ChatOpenAI(temperature=0))
182
+ compressor = CohereRerank()
183
+ compression_retriever = ContextualCompressionRetriever(
184
+ base_compressor=compressor,
185
+ base_retriever=multi_query_retriever)
186
+ return compression_retriever
187
+
188
+
189
+ def load_conversational_retrievel_chain(retriever, llm):
190
+ '''
191
+ Create RetrievalQA chain with memory
192
+ '''
193
+ # template = """You are a helpful assistant. You do not respond as 'User' or pretend to be 'User'. You only respond once as 'Assistant'.
194
+ # Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
195
+ # Only include information found in the results and don't add any additional information.
196
+ # Make sure the answer is correct and don't output false content.
197
+ # If the text does not relate to the query, simply state 'Text Not Found in the Document'. Ignore outlier,
198
+ # search results which has nothing to do with the question. Only answer what is asked.
199
+ # The answer should be short and concise. Answer step-by-step.
200
+
201
+ # {context}
202
+
203
+ # {history}
204
+ # Question: {question}
205
+ # Helpful Answer:"""
206
+
207
+ # prompt = PromptTemplate(input_variables=["history", "context", "question"], template=template)
208
+ memory = ConversationBufferMemory(input_key="question", memory_key="history")
209
+
210
+ qa = RetrievalQA.from_chain_type(
211
+ llm=llm,
212
+ chain_type="stuff",
213
+ retriever=retriever,
214
+ return_source_documents=True,
215
+ chain_type_kwargs={"memory": memory},
216
+ )
217
+ return qa
218
+