Syed Junaid Iqbal commited on
Commit
030d46c
1 Parent(s): a7ce0dd

Upload 5 files

Browse files
Files changed (5) hide show
  1. app.py +61 -0
  2. bm25 +0 -0
  3. retriever.py +38 -0
  4. streaming.py +11 -0
  5. utils.py +39 -0
app.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import streamlit as st
3
+ from streaming import StreamHandler
4
+ import utils
5
+ from langchain.callbacks.manager import CallbackManager
6
+ from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
7
+ from retriever import retriever
8
+ from langchain.chains import RetrievalQA
9
+ from langchain.llms import LlamaCpp
10
+ from dotenv import load_dotenv
11
+
12
+ class CustomDataChatbot:
13
+ def __init__(self):
14
+ # Initialize session state variables, including messages
15
+ st.session_state.messages = []
16
+
17
+ @st.spinner('Analyzing documents..')
18
+ def setup_qa_chain(self):
19
+ # Setup memory for contextual conversation
20
+ # memory = ConversationBufferMemory(
21
+ # memory_key='chat_history',
22
+ # return_messages=True
23
+ # )
24
+
25
+ callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
26
+ # Setup LLM and QA chain
27
+ llm = LlamaCpp(model_path="./models/openhermes-2.5-neural-chat-7b-v3-1-7b.Q5_K_M.gguf",
28
+ temperature=0.34,
29
+ max_tokens=4000,
30
+ n_ctx=4096,
31
+ top_p=1,
32
+ callback_manager=callback_manager,
33
+ verbose=True)
34
+
35
+ # qa_chain = ConversationalRetrievalChain.from_llm(llm, retriever=retriever(), memory=memory, verbose=True)
36
+
37
+
38
+ return RetrievalQA.from_chain_type( llm, retriever= retriever())
39
+
40
+ @utils.enable_chat_history
41
+ def main(self):
42
+ load_dotenv()
43
+ st.set_page_config(page_title="ChatPDF", page_icon="📄")
44
+ st.header('Chat with your documents')
45
+ st.write('Has access to custom documents and can respond to user queries by referring to the content within those documents')
46
+ st.write('[![view source code ](https://img.shields.io/badge/view_source_code-gray?logo=github)](https://github.com/shashankdeshpande/langchain-chatbot/blob/master/pages/4_%F0%9F%93%84_chat_with_your_documents.py)')
47
+
48
+ user_query = st.chat_input(placeholder="Ask me anything!")
49
+
50
+ if user_query:
51
+ qa_chain = self.setup_qa_chain()
52
+ utils.display_msg(user_query, 'user')
53
+
54
+ with st.chat_message("assistant"):
55
+ st_cb = StreamHandler(st.empty())
56
+ response = qa_chain.run(user_query, callbacks=[st_cb])
57
+ st.session_state.messages.append({"role": "assistant", "content": response})
58
+
59
+ if __name__ == "__main__":
60
+ obj = CustomDataChatbot()
61
+ obj.main()
bm25 ADDED
Binary file (184 kB). View file
 
retriever.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ from langchain.retrievers import EnsembleRetriever
3
+ from langchain.vectorstores import FAISS
4
+ from langchain.embeddings import HuggingFaceEmbeddings
5
+
6
+ from langchain.embeddings import HuggingFaceEmbeddings
7
+ from transformers import AutoModel
8
+
9
+
10
+
11
+ def retriever():
12
+
13
+ # Embeddings
14
+ # Defign our Embedding Model
15
+
16
+ model_name = "jinaai/jina-embeddings-v2-base-en"
17
+ model_kwargs = {'device': 'cpu'}
18
+ encode_kwargs = {'normalize_embeddings': False, }
19
+
20
+ model = AutoModel.from_pretrained( model_name, trust_remote_code=True)
21
+
22
+ embeddings = HuggingFaceEmbeddings( model_name=model_name,
23
+ model_kwargs=model_kwargs,
24
+ encode_kwargs=encode_kwargs)
25
+
26
+
27
+ #to read bm25 object
28
+ with open('./bm25', 'rb') as file:
29
+ bm25_retriever = pickle.load(file)
30
+
31
+ bm25_retriever.k = 2
32
+
33
+ # Load FAISS
34
+ faiss_vectorstore = FAISS.load_local("./Vector_DB/", embeddings)
35
+ faiss_retriever = faiss_vectorstore.as_retriever(search_kwargs={"k": 1})
36
+
37
+ # initialize the ensemble retriever
38
+ return EnsembleRetriever( retrievers=[bm25_retriever, faiss_retriever], weights=[0.5, 0.5] )
streaming.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.callbacks.base import BaseCallbackHandler
2
+
3
+ class StreamHandler(BaseCallbackHandler):
4
+
5
+ def __init__(self, container, initial_text=""):
6
+ self.container = container
7
+ self.text = initial_text
8
+
9
+ def on_llm_new_token(self, token: str, **kwargs):
10
+ self.text += token
11
+ self.container.markdown(self.text)
utils.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import streamlit as st
4
+
5
+ #decorator
6
+ def enable_chat_history(func):
7
+ if os.environ.get("OPENAI_API_KEY"):
8
+
9
+ # to clear chat history after swtching chatbot
10
+ current_page = func.__qualname__
11
+ if "current_page" not in st.session_state:
12
+ st.session_state["current_page"] = current_page
13
+ if st.session_state["current_page"] != current_page:
14
+ try:
15
+ st.cache_resource.clear()
16
+ del st.session_state["current_page"]
17
+ del st.session_state["messages"]
18
+ except:
19
+ pass
20
+
21
+ # to show chat history on ui
22
+ if "messages" not in st.session_state:
23
+ st.session_state["messages"] = [{"role": "assistant", "content": "How can I help you?"}]
24
+ for msg in st.session_state["messages"]:
25
+ st.chat_message(msg["role"]).write(msg["content"])
26
+
27
+ def execute(*args, **kwargs):
28
+ func(*args, **kwargs)
29
+ return execute
30
+
31
+ def display_msg(msg, author):
32
+ """Method to display message on the UI
33
+
34
+ Args:
35
+ msg (str): message to display
36
+ author (str): author of the message -user/assistant
37
+ """
38
+ st.session_state.messages.append({"role": author, "content": msg})
39
+ st.chat_message(author).write(msg)