File size: 4,021 Bytes
085c24c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
# Streamlit App to perform the conversational retrieval using ConversationalResponse class
# 1. Main Title of App
# 2. PDF File Loader
# 3. Streaming Chat Window to ask questions and get answers from ConversationalResponse
# 4. Callback Handler to stream the output of the ConversationalResponse
# 5. Handle the chat interaction with the ConversationalResponse

import streamlit as st
from streamlit_chat import message
from langchain.callbacks.base import BaseCallbackHandler
from src.main import ConversationalResponse
import os

from dotenv import load_dotenv, find_dotenv
_ = load_dotenv(find_dotenv())

# Constants
ROLE_USER = "user"
ROLE_ASSISTANT = "assistant"

st.set_page_config(page_title="Chat with Documents", page_icon="🦜")
st.title("Chat with PDF Documents πŸ€–πŸ“„")
st.markdown("by [Rohan Kataria](https://www.linkedin.com/in/imrohan/) view more at [VEW.AI](https://vew.ai/)")
#streamlit message block
st.markdown("This app allows you to chat with documents. You can upload a PDF file and ask questions about it. In the backround uses the ConversationalRetrival chain from langchain and Streamlit for UI.")

class StreamHandler(BaseCallbackHandler):
    """
    StreamHandler is a callback handler that streams the output of the ConversationalResponse.
    """
    def __init__(self, container: st.delta_generator.DeltaGenerator, initial_text: str = ""):
        self.container = container
        self.text = initial_text

    def on_llm_new_token(self, token: str, **kwargs) -> None:
        self.text += token
        self.container.markdown(self.text)

@st.cache_resource(ttl="1h")
def load_agent(file_path, api_key):
    """
    Load the ConversationalResponse agent from the given file path.
    """
    with st.spinner('Loading the file...'):
        agent = ConversationalResponse(file_path, api_key)
    st.success("File Loaded Successfully")
    return agent

def handle_chat(agent):
    """
    Handle the chat interaction with the user.
    """
    if "messages" not in st.session_state or st.sidebar.button("Clear message history"):
        st.session_state["messages"] = [{"role": ROLE_ASSISTANT, "content": "How can I help you?"}]

    for msg in st.session_state.messages:
        st.chat_message(msg["role"]).write(msg["content"])

    user_query = st.chat_input(placeholder="Ask me anything!")

    if user_query:
        st.session_state.messages.append({"role": ROLE_USER, "content": user_query})
        st.chat_message(ROLE_USER).write(user_query)

        # Generate the response
        with st.spinner("Generating response"):
            response = agent(user_query)
        
        # Display the response immediately
        st.chat_message(ROLE_ASSISTANT).write(response)
        
        # Add the response to the message history
        st.session_state.messages.append({"role": ROLE_ASSISTANT, "content": response})


def main():
    """
    Main function to handle file upload and chat interaction.
    """

    # API Key Loader
    api_key = st.sidebar.text_input("Enter your OpenAI API Key", type="password")
    if api_key:
        os.environ["OPENAI_API_KEY"] = api_key
    else:
        st.sidebar.error("Please enter your OpenAI API Key.")
        return

    # PDF File Loader to upload the file in the sidebar in session state
    uploaded_file = st.sidebar.file_uploader("Choose a PDF file", type="pdf")
    if uploaded_file is None:
        st.error("Please upload a file.")
        return

    file_details = {"FileName":uploaded_file.name,"FileType":uploaded_file.type,"FileSize":uploaded_file.size}
    st.write(file_details)

    # Create a temp folder
    if not os.path.exists("temp"):
        os.mkdir("temp")
    # Save the file in temp folder
    file_path = os.path.join("temp",uploaded_file.name)
    with open(file_path,"wb") as f:
        f.write(uploaded_file.getbuffer())

    agent = load_agent(file_path, api_key)

    handle_chat(agent)

    # Delete the file from temp folder
    os.remove(file_path)


if __name__ == "__main__":
    main()