document-answering / streamlit_app.py
pflooky's picture
Use gradio for document answering
8324134
raw
history blame contribute delete
No virus
5.81 kB
from io import StringIO
import streamlit as st
from langchain.docstore.document import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter, Language
import time
import vector_db as vdb
from llm_model import LLMModel
def default_state():
if "startup" not in st.session_state:
st.session_state.startup = True
if "messages" not in st.session_state:
st.session_state.messages = []
if "uploaded_docs" not in st.session_state:
st.session_state.uploaded_docs = []
if "llm_option" not in st.session_state:
st.session_state.llm_option = "Local"
if "answer_loading" not in st.session_state:
st.session_state.answer_loading = False
def load_doc(file_name: str, file_content: str):
if file_name is not None:
# Create document with metadata
doc = Document(page_content=file_content, metadata={"source": file_name})
# Create an instance of the RecursiveCharacterTextSplitter class with specific parameters.
# It splits text into chunks of 1000 characters each with a 150-character overlap.
language = get_language(file_name)
text_splitter = RecursiveCharacterTextSplitter.from_language(chunk_size=1000, chunk_overlap=150,
language=language)
# Split the text into chunks using the text splitter.
docs = text_splitter.split_documents([doc])
return docs
else:
return None
def get_language(file_name: str):
if file_name.endswith(".md") or file_name.endswith(".mdx"):
return Language.MARKDOWN
elif file_name.endswith(".rst"):
return Language.RST
else:
return Language.MARKDOWN
@st.cache_resource()
def get_vector_db():
return vdb.VectorDB()
@st.cache_resource()
def get_llm_model(_db: vdb.VectorDB):
retriever = _db.docs_db.as_retriever(search_kwargs={"k": 2})
return LLMModel(retriever=retriever).create_qa_chain()
# Initialize an instance of the RetrievalQA class with the specified parameters
def init_sidebar():
with st.sidebar:
st.toggle(
"Loading from LLM",
on_change=enable_sidebar(),
disabled=not st.session_state.answer_loading
)
llm_option = st.selectbox(
'Select to use local model or inference API',
options=['Local', 'Inference API']
)
st.session_state.llm_option = llm_option
uploaded_files = st.file_uploader(
'Upload file(s)',
type=['md', 'mdx', 'rst', 'txt'],
accept_multiple_files=True
)
for uploaded_file in uploaded_files:
if uploaded_file.name not in st.session_state.uploaded_docs:
# Read the file as a string
stringio = StringIO(uploaded_file.getvalue().decode("utf-8"))
string_data = stringio.read()
# Get chunks of text
doc_chunks = load_doc(uploaded_file.name, string_data)
st.write(f"Number of chunks={len(doc_chunks)}")
vector_db.load_docs_into_vector_db(doc_chunks)
st.session_state.uploaded_docs.append(uploaded_file.name)
def init_chat():
# Display chat messages from history on app rerun
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
def disable_sidebar():
st.session_state.answer_loading = True
st.rerun()
def enable_sidebar():
st.session_state.answer_loading = False
st.set_page_config(page_title="Document Answering Tool", page_icon=":book:")
vector_db = get_vector_db()
default_state()
init_sidebar()
st.header("Document answering tool")
st.subheader("Upload your documents on the side and ask questions")
init_chat()
llm_model = get_llm_model(vector_db)
st.session_state.startup = False
# React to user input
if user_prompt := st.chat_input("What's up?", on_submit=disable_sidebar()):
# if st.session_state.answer_loading:
# st.warning("Cannot ask multiple questions at the same time")
# st.session_state.answer_loading = False
# else:
start_time = time.time()
# Display user message in chat message container
with st.chat_message("user"):
st.markdown(user_prompt)
# Add user message to chat history
st.session_state.messages.append({"role": "user", "content": user_prompt})
if llm_model is not None:
assistant_chat = st.chat_message("assistant")
if not st.session_state.uploaded_docs:
assistant_chat.warning("WARN: Will try answer question without documents")
with st.spinner('Resolving question...'):
res = llm_model({"query": user_prompt})
sources = []
for source_docs in res['source_documents']:
if 'source' in source_docs.metadata:
sources.append(source_docs.metadata['source'])
# Display assistant response in chat message container
end_time = time.time()
time_taken = "{:.2f}".format(end_time - start_time)
format_answer = f"## Result\n\n{res['result']}\n\n### Sources\n\n{sources}\n\nTime taken: {time_taken}s"
assistant_chat.markdown(format_answer)
source_expander = assistant_chat.expander("See full sources")
for source_docs in res['source_documents']:
if 'source' in source_docs.metadata:
format_source = f"## File: {source_docs.metadata['source']}\n\n{source_docs.page_content}"
source_expander.markdown(format_source)
# Add assistant response to chat history
st.session_state.messages.append({"role": "assistant", "content": format_answer})
enable_sidebar()
st.rerun()