Spaces:
Build error
Build error
Syed Junaid Iqbal
commited on
Commit
•
030d46c
1
Parent(s):
a7ce0dd
Upload 5 files
Browse files- app.py +61 -0
- bm25 +0 -0
- retriever.py +38 -0
- streaming.py +11 -0
- utils.py +39 -0
app.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import streamlit as st
|
3 |
+
from streaming import StreamHandler
|
4 |
+
import utils
|
5 |
+
from langchain.callbacks.manager import CallbackManager
|
6 |
+
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
7 |
+
from retriever import retriever
|
8 |
+
from langchain.chains import RetrievalQA
|
9 |
+
from langchain.llms import LlamaCpp
|
10 |
+
from dotenv import load_dotenv
|
11 |
+
|
12 |
+
class CustomDataChatbot:
|
13 |
+
def __init__(self):
|
14 |
+
# Initialize session state variables, including messages
|
15 |
+
st.session_state.messages = []
|
16 |
+
|
17 |
+
@st.spinner('Analyzing documents..')
|
18 |
+
def setup_qa_chain(self):
|
19 |
+
# Setup memory for contextual conversation
|
20 |
+
# memory = ConversationBufferMemory(
|
21 |
+
# memory_key='chat_history',
|
22 |
+
# return_messages=True
|
23 |
+
# )
|
24 |
+
|
25 |
+
callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
|
26 |
+
# Setup LLM and QA chain
|
27 |
+
llm = LlamaCpp(model_path="./models/openhermes-2.5-neural-chat-7b-v3-1-7b.Q5_K_M.gguf",
|
28 |
+
temperature=0.34,
|
29 |
+
max_tokens=4000,
|
30 |
+
n_ctx=4096,
|
31 |
+
top_p=1,
|
32 |
+
callback_manager=callback_manager,
|
33 |
+
verbose=True)
|
34 |
+
|
35 |
+
# qa_chain = ConversationalRetrievalChain.from_llm(llm, retriever=retriever(), memory=memory, verbose=True)
|
36 |
+
|
37 |
+
|
38 |
+
return RetrievalQA.from_chain_type( llm, retriever= retriever())
|
39 |
+
|
40 |
+
@utils.enable_chat_history
|
41 |
+
def main(self):
|
42 |
+
load_dotenv()
|
43 |
+
st.set_page_config(page_title="ChatPDF", page_icon="📄")
|
44 |
+
st.header('Chat with your documents')
|
45 |
+
st.write('Has access to custom documents and can respond to user queries by referring to the content within those documents')
|
46 |
+
st.write('[![view source code ](https://img.shields.io/badge/view_source_code-gray?logo=github)](https://github.com/shashankdeshpande/langchain-chatbot/blob/master/pages/4_%F0%9F%93%84_chat_with_your_documents.py)')
|
47 |
+
|
48 |
+
user_query = st.chat_input(placeholder="Ask me anything!")
|
49 |
+
|
50 |
+
if user_query:
|
51 |
+
qa_chain = self.setup_qa_chain()
|
52 |
+
utils.display_msg(user_query, 'user')
|
53 |
+
|
54 |
+
with st.chat_message("assistant"):
|
55 |
+
st_cb = StreamHandler(st.empty())
|
56 |
+
response = qa_chain.run(user_query, callbacks=[st_cb])
|
57 |
+
st.session_state.messages.append({"role": "assistant", "content": response})
|
58 |
+
|
59 |
+
if __name__ == "__main__":
|
60 |
+
obj = CustomDataChatbot()
|
61 |
+
obj.main()
|
bm25
ADDED
Binary file (184 kB). View file
|
|
retriever.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pickle
|
2 |
+
from langchain.retrievers import EnsembleRetriever
|
3 |
+
from langchain.vectorstores import FAISS
|
4 |
+
from langchain.embeddings import HuggingFaceEmbeddings
|
5 |
+
|
6 |
+
from langchain.embeddings import HuggingFaceEmbeddings
|
7 |
+
from transformers import AutoModel
|
8 |
+
|
9 |
+
|
10 |
+
|
11 |
+
def retriever():
|
12 |
+
|
13 |
+
# Embeddings
|
14 |
+
# Defign our Embedding Model
|
15 |
+
|
16 |
+
model_name = "jinaai/jina-embeddings-v2-base-en"
|
17 |
+
model_kwargs = {'device': 'cpu'}
|
18 |
+
encode_kwargs = {'normalize_embeddings': False, }
|
19 |
+
|
20 |
+
model = AutoModel.from_pretrained( model_name, trust_remote_code=True)
|
21 |
+
|
22 |
+
embeddings = HuggingFaceEmbeddings( model_name=model_name,
|
23 |
+
model_kwargs=model_kwargs,
|
24 |
+
encode_kwargs=encode_kwargs)
|
25 |
+
|
26 |
+
|
27 |
+
#to read bm25 object
|
28 |
+
with open('./bm25', 'rb') as file:
|
29 |
+
bm25_retriever = pickle.load(file)
|
30 |
+
|
31 |
+
bm25_retriever.k = 2
|
32 |
+
|
33 |
+
# Load FAISS
|
34 |
+
faiss_vectorstore = FAISS.load_local("./Vector_DB/", embeddings)
|
35 |
+
faiss_retriever = faiss_vectorstore.as_retriever(search_kwargs={"k": 1})
|
36 |
+
|
37 |
+
# initialize the ensemble retriever
|
38 |
+
return EnsembleRetriever( retrievers=[bm25_retriever, faiss_retriever], weights=[0.5, 0.5] )
|
streaming.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain.callbacks.base import BaseCallbackHandler
|
2 |
+
|
3 |
+
class StreamHandler(BaseCallbackHandler):
|
4 |
+
|
5 |
+
def __init__(self, container, initial_text=""):
|
6 |
+
self.container = container
|
7 |
+
self.text = initial_text
|
8 |
+
|
9 |
+
def on_llm_new_token(self, token: str, **kwargs):
|
10 |
+
self.text += token
|
11 |
+
self.container.markdown(self.text)
|
utils.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
import streamlit as st
|
4 |
+
|
5 |
+
#decorator
|
6 |
+
def enable_chat_history(func):
|
7 |
+
if os.environ.get("OPENAI_API_KEY"):
|
8 |
+
|
9 |
+
# to clear chat history after swtching chatbot
|
10 |
+
current_page = func.__qualname__
|
11 |
+
if "current_page" not in st.session_state:
|
12 |
+
st.session_state["current_page"] = current_page
|
13 |
+
if st.session_state["current_page"] != current_page:
|
14 |
+
try:
|
15 |
+
st.cache_resource.clear()
|
16 |
+
del st.session_state["current_page"]
|
17 |
+
del st.session_state["messages"]
|
18 |
+
except:
|
19 |
+
pass
|
20 |
+
|
21 |
+
# to show chat history on ui
|
22 |
+
if "messages" not in st.session_state:
|
23 |
+
st.session_state["messages"] = [{"role": "assistant", "content": "How can I help you?"}]
|
24 |
+
for msg in st.session_state["messages"]:
|
25 |
+
st.chat_message(msg["role"]).write(msg["content"])
|
26 |
+
|
27 |
+
def execute(*args, **kwargs):
|
28 |
+
func(*args, **kwargs)
|
29 |
+
return execute
|
30 |
+
|
31 |
+
def display_msg(msg, author):
|
32 |
+
"""Method to display message on the UI
|
33 |
+
|
34 |
+
Args:
|
35 |
+
msg (str): message to display
|
36 |
+
author (str): author of the message -user/assistant
|
37 |
+
"""
|
38 |
+
st.session_state.messages.append({"role": author, "content": msg})
|
39 |
+
st.chat_message(author).write(msg)
|