File size: 5,810 Bytes
8324134
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
from io import StringIO

import streamlit as st
from langchain.docstore.document import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter, Language
import time

import vector_db as vdb
from llm_model import LLMModel


def default_state():
    if "startup" not in st.session_state:
        st.session_state.startup = True

    if "messages" not in st.session_state:
        st.session_state.messages = []

    if "uploaded_docs" not in st.session_state:
        st.session_state.uploaded_docs = []

    if "llm_option" not in st.session_state:
        st.session_state.llm_option = "Local"

    if "answer_loading" not in st.session_state:
        st.session_state.answer_loading = False


def load_doc(file_name: str, file_content: str):
    if file_name is not None:
        # Create document with metadata
        doc = Document(page_content=file_content, metadata={"source": file_name})
        # Create an instance of the RecursiveCharacterTextSplitter class with specific parameters.
        # It splits text into chunks of 1000 characters each with a 150-character overlap.
        language = get_language(file_name)
        text_splitter = RecursiveCharacterTextSplitter.from_language(chunk_size=1000, chunk_overlap=150,
                                                                     language=language)
        # Split the text into chunks using the text splitter.
        docs = text_splitter.split_documents([doc])
        return docs
    else:
        return None


def get_language(file_name: str):
    if file_name.endswith(".md") or file_name.endswith(".mdx"):
        return Language.MARKDOWN
    elif file_name.endswith(".rst"):
        return Language.RST
    else:
        return Language.MARKDOWN


@st.cache_resource()
def get_vector_db():
    return vdb.VectorDB()


@st.cache_resource()
def get_llm_model(_db: vdb.VectorDB):
    retriever = _db.docs_db.as_retriever(search_kwargs={"k": 2})
    return LLMModel(retriever=retriever).create_qa_chain()


# Initialize an instance of the RetrievalQA class with the specified parameters
def init_sidebar():
    with st.sidebar:
        st.toggle(
            "Loading from LLM",
            on_change=enable_sidebar(),
            disabled=not st.session_state.answer_loading
        )
        llm_option = st.selectbox(
            'Select to use local model or inference API',
            options=['Local', 'Inference API']
        )
        st.session_state.llm_option = llm_option
        uploaded_files = st.file_uploader(
            'Upload file(s)',
            type=['md', 'mdx', 'rst', 'txt'],
            accept_multiple_files=True
        )
        for uploaded_file in uploaded_files:
            if uploaded_file.name not in st.session_state.uploaded_docs:
                # Read the file as a string
                stringio = StringIO(uploaded_file.getvalue().decode("utf-8"))
                string_data = stringio.read()
                # Get chunks of text
                doc_chunks = load_doc(uploaded_file.name, string_data)
                st.write(f"Number of chunks={len(doc_chunks)}")
                vector_db.load_docs_into_vector_db(doc_chunks)
                st.session_state.uploaded_docs.append(uploaded_file.name)


def init_chat():
    # Display chat messages from history on app rerun
    for message in st.session_state.messages:
        with st.chat_message(message["role"]):
            st.markdown(message["content"])


def disable_sidebar():
    st.session_state.answer_loading = True
    st.rerun()


def enable_sidebar():
    st.session_state.answer_loading = False


st.set_page_config(page_title="Document Answering Tool", page_icon=":book:")
vector_db = get_vector_db()
default_state()
init_sidebar()
st.header("Document answering tool")
st.subheader("Upload your documents on the side and ask questions")
init_chat()
llm_model = get_llm_model(vector_db)
st.session_state.startup = False


# React to user input
if user_prompt := st.chat_input("What's up?", on_submit=disable_sidebar()):
    # if st.session_state.answer_loading:
    #     st.warning("Cannot ask multiple questions at the same time")
    #     st.session_state.answer_loading = False
    # else:
    start_time = time.time()
    # Display user message in chat message container
    with st.chat_message("user"):
        st.markdown(user_prompt)
    # Add user message to chat history
    st.session_state.messages.append({"role": "user", "content": user_prompt})

    if llm_model is not None:
        assistant_chat = st.chat_message("assistant")
        if not st.session_state.uploaded_docs:
            assistant_chat.warning("WARN: Will try answer question without documents")
        with st.spinner('Resolving question...'):
            res = llm_model({"query": user_prompt})
        sources = []
        for source_docs in res['source_documents']:
            if 'source' in source_docs.metadata:
                sources.append(source_docs.metadata['source'])
        # Display assistant response in chat message container
        end_time = time.time()
        time_taken = "{:.2f}".format(end_time - start_time)
        format_answer = f"## Result\n\n{res['result']}\n\n### Sources\n\n{sources}\n\nTime taken: {time_taken}s"
        assistant_chat.markdown(format_answer)
        source_expander = assistant_chat.expander("See full sources")
        for source_docs in res['source_documents']:
            if 'source' in source_docs.metadata:
                format_source = f"## File: {source_docs.metadata['source']}\n\n{source_docs.page_content}"
                source_expander.markdown(format_source)
        # Add assistant response to chat history
        st.session_state.messages.append({"role": "assistant", "content": format_answer})
        enable_sidebar()
        st.rerun()