Rohan Kataria commited on
Commit
085c24c
·
1 Parent(s): 1205205

adding app

Browse files
Files changed (4) hide show
  1. .gitignore +5 -0
  2. app.py +115 -0
  3. requirements.txt +3 -0
  4. src/main.py +113 -0
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ .history/
2
+ .vscode/
3
+ notebooks/
4
+ .*ipynb
5
+ __pycache__/
app.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Streamlit App to perform the conversational retrieval using ConversationalResponse class
2
+ # 1. Main Title of App
3
+ # 2. PDF File Loader
4
+ # 3. Streaming Chat Window to ask questions and get answers from ConversationalResponse
5
+ # 4. Callback Handler to stream the output of the ConversationalResponse
6
+ # 5. Handle the chat interaction with the ConversationalResponse
7
+
8
+ import streamlit as st
9
+ from streamlit_chat import message
10
+ from langchain.callbacks.base import BaseCallbackHandler
11
+ from src.main import ConversationalResponse
12
+ import os
13
+
14
+ from dotenv import load_dotenv, find_dotenv
15
+ _ = load_dotenv(find_dotenv())
16
+
17
+ # Constants
18
+ ROLE_USER = "user"
19
+ ROLE_ASSISTANT = "assistant"
20
+
21
+ st.set_page_config(page_title="Chat with Documents", page_icon="🦜")
22
+ st.title("Chat with PDF Documents 🤖📄")
23
+ st.markdown("by [Rohan Kataria](https://www.linkedin.com/in/imrohan/) view more at [VEW.AI](https://vew.ai/)")
24
+ #streamlit message block
25
+ st.markdown("This app allows you to chat with documents. You can upload a PDF file and ask questions about it. In the backround uses the ConversationalRetrival chain from langchain and Streamlit for UI.")
26
+
27
+ class StreamHandler(BaseCallbackHandler):
28
+ """
29
+ StreamHandler is a callback handler that streams the output of the ConversationalResponse.
30
+ """
31
+ def __init__(self, container: st.delta_generator.DeltaGenerator, initial_text: str = ""):
32
+ self.container = container
33
+ self.text = initial_text
34
+
35
+ def on_llm_new_token(self, token: str, **kwargs) -> None:
36
+ self.text += token
37
+ self.container.markdown(self.text)
38
+
39
+ @st.cache_resource(ttl="1h")
40
+ def load_agent(file_path, api_key):
41
+ """
42
+ Load the ConversationalResponse agent from the given file path.
43
+ """
44
+ with st.spinner('Loading the file...'):
45
+ agent = ConversationalResponse(file_path, api_key)
46
+ st.success("File Loaded Successfully")
47
+ return agent
48
+
49
+ def handle_chat(agent):
50
+ """
51
+ Handle the chat interaction with the user.
52
+ """
53
+ if "messages" not in st.session_state or st.sidebar.button("Clear message history"):
54
+ st.session_state["messages"] = [{"role": ROLE_ASSISTANT, "content": "How can I help you?"}]
55
+
56
+ for msg in st.session_state.messages:
57
+ st.chat_message(msg["role"]).write(msg["content"])
58
+
59
+ user_query = st.chat_input(placeholder="Ask me anything!")
60
+
61
+ if user_query:
62
+ st.session_state.messages.append({"role": ROLE_USER, "content": user_query})
63
+ st.chat_message(ROLE_USER).write(user_query)
64
+
65
+ # Generate the response
66
+ with st.spinner("Generating response"):
67
+ response = agent(user_query)
68
+
69
+ # Display the response immediately
70
+ st.chat_message(ROLE_ASSISTANT).write(response)
71
+
72
+ # Add the response to the message history
73
+ st.session_state.messages.append({"role": ROLE_ASSISTANT, "content": response})
74
+
75
+
76
+ def main():
77
+ """
78
+ Main function to handle file upload and chat interaction.
79
+ """
80
+
81
+ # API Key Loader
82
+ api_key = st.sidebar.text_input("Enter your OpenAI API Key", type="password")
83
+ if api_key:
84
+ os.environ["OPENAI_API_KEY"] = api_key
85
+ else:
86
+ st.sidebar.error("Please enter your OpenAI API Key.")
87
+ return
88
+
89
+ # PDF File Loader to upload the file in the sidebar in session state
90
+ uploaded_file = st.sidebar.file_uploader("Choose a PDF file", type="pdf")
91
+ if uploaded_file is None:
92
+ st.error("Please upload a file.")
93
+ return
94
+
95
+ file_details = {"FileName":uploaded_file.name,"FileType":uploaded_file.type,"FileSize":uploaded_file.size}
96
+ st.write(file_details)
97
+
98
+ # Create a temp folder
99
+ if not os.path.exists("temp"):
100
+ os.mkdir("temp")
101
+ # Save the file in temp folder
102
+ file_path = os.path.join("temp",uploaded_file.name)
103
+ with open(file_path,"wb") as f:
104
+ f.write(uploaded_file.getbuffer())
105
+
106
+ agent = load_agent(file_path, api_key)
107
+
108
+ handle_chat(agent)
109
+
110
+ # Delete the file from temp folder
111
+ os.remove(file_path)
112
+
113
+
114
+ if __name__ == "__main__":
115
+ main()
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ openai
2
+ streamlit
3
+ langchain
src/main.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This is main logic file for the project responsible for the following:
3
+ 1. Read the loaded file using langchains
4
+ 2. Split the loaded data into chunks
5
+ 3. Ingest the data in vector form
6
+ 4. Conversational Retrieval logic on loaded data create conversational response
7
+ 5. Return the response to the user (Output)
8
+ """
9
+
10
+ #Importing the required libraries
11
+ import os
12
+ import openai
13
+ import sys
14
+ sys.path.append('../..') #To import the langchain package from the parent directory
15
+ from langchain.embeddings.openai import OpenAIEmbeddings
16
+ from langchain.text_splitter import CharacterTextSplitter, RecursiveCharacterTextSplitter
17
+ from langchain.vectorstores import DocArrayInMemorySearch
18
+ from langchain.document_loaders import TextLoader
19
+ from langchain.chains import RetrievalQA, ConversationalRetrievalChain
20
+ from langchain.memory import ConversationBufferMemory
21
+ from langchain.chat_models import ChatOpenAI
22
+ from langchain.document_loaders import TextLoader
23
+ from langchain.document_loaders import PyPDFLoader
24
+ from langchain.llms import OpenAI
25
+ from langchain.memory import ConversationBufferMemory
26
+ from langchain.vectorstores import DocArrayInMemorySearch
27
+ import datetime
28
+ from langchain.prompts import PromptTemplate
29
+
30
+ from dotenv import load_dotenv, find_dotenv
31
+ _ = load_dotenv(find_dotenv())
32
+
33
+ #Function to load the data from the file
34
+ def load_data(file_path):
35
+ loader = PyPDFLoader(file_path)
36
+ pages = loader.load()
37
+ return pages
38
+
39
+ #Function to split the data into chunks
40
+ def split_data(data):
41
+ splitter = RecursiveCharacterTextSplitter(
42
+ chunk_size=1000,
43
+ chunk_overlap=150,
44
+ )
45
+ chunks = splitter.split_documents(data)
46
+ return chunks
47
+
48
+ # #Creating the OpenAI Embeddings
49
+ # embeddings = OpenAIEmbeddings()
50
+
51
+ #Function to ingest the data in vector form in data memory
52
+ def ingest_data(chunks, embeddings):
53
+ vector_store = DocArrayInMemorySearch.from_documents(chunks, embeddings)
54
+ return vector_store
55
+
56
+ #Function to create the conversational response
57
+ def create_conversational_response(vector_store, chain_type, k):
58
+
59
+ #Creating the retriever, this can also be a contextual compressed retriever
60
+ retriever = vector_store.as_retriever(search_type="similarity", search_kwargs={"k": k}) #search_type can be "similarity" or "mmr"
61
+
62
+ #Creating Memory
63
+ memory = ConversationBufferMemory(
64
+ memory_key="chat_history",
65
+ input_key="question",
66
+ output_key="answer",
67
+ return_messages=True)
68
+
69
+ #Creating LLM
70
+ current_date = datetime.datetime.now().date()
71
+ if current_date < datetime.date(2023, 9, 2):
72
+ llm_name = "gpt-3.5-turbo-0301"
73
+ else:
74
+ llm_name = "gpt-3.5-turbo"
75
+
76
+ llm = ChatOpenAI(model=llm_name, temperature=0)
77
+
78
+ # Creating Prompt template
79
+ template = """
80
+ {chat_history}
81
+ {context}
82
+ Question: {question}
83
+ Helpful Answer:"""
84
+
85
+ PROMPT = PromptTemplate(input_variables=["chat_history", "context", "question"], template=template,)
86
+
87
+
88
+ #creating the conversational retrieval chain
89
+ chain = ConversationalRetrievalChain.from_llm(
90
+ llm=llm,
91
+ chain_type=chain_type, #chain type can be refine, stuff, map_reduce
92
+ retriever=retriever,
93
+ memory=memory,
94
+ return_source_documents=True, #When used these 2 properties, the output gets 3 properties: answer, source_document, source_document_score and then have to speocify input and output key in memory for it to work
95
+ combine_docs_chain_kwargs=dict({"prompt": PROMPT})
96
+ )
97
+ return chain
98
+
99
+ # ConversationalResponse Class to call all the defined functions in a single call
100
+ class ConversationalResponse:
101
+ def __init__(self, file, api_key):
102
+ self.file = file
103
+ embeddings = OpenAIEmbeddings(openai_api_key=api_key)
104
+ self.data = load_data(self.file)
105
+ self.chunks = split_data(self.data)
106
+ self.vector_store = ingest_data(self.chunks, embeddings)
107
+ self.chain_type = "stuff"
108
+ self.k = 5
109
+ self.chain = create_conversational_response(self.vector_store, self.chain_type, self.k)
110
+
111
+ def __call__(self, question, callbacks=None):
112
+ response = self.chain(question, callbacks=callbacks)
113
+ return response['answer']