Spaces:
Runtime error
Runtime error
File size: 5,810 Bytes
8324134 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
from io import StringIO
import streamlit as st
from langchain.docstore.document import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter, Language
import time
import vector_db as vdb
from llm_model import LLMModel
def default_state():
if "startup" not in st.session_state:
st.session_state.startup = True
if "messages" not in st.session_state:
st.session_state.messages = []
if "uploaded_docs" not in st.session_state:
st.session_state.uploaded_docs = []
if "llm_option" not in st.session_state:
st.session_state.llm_option = "Local"
if "answer_loading" not in st.session_state:
st.session_state.answer_loading = False
def load_doc(file_name: str, file_content: str):
if file_name is not None:
# Create document with metadata
doc = Document(page_content=file_content, metadata={"source": file_name})
# Create an instance of the RecursiveCharacterTextSplitter class with specific parameters.
# It splits text into chunks of 1000 characters each with a 150-character overlap.
language = get_language(file_name)
text_splitter = RecursiveCharacterTextSplitter.from_language(chunk_size=1000, chunk_overlap=150,
language=language)
# Split the text into chunks using the text splitter.
docs = text_splitter.split_documents([doc])
return docs
else:
return None
def get_language(file_name: str):
if file_name.endswith(".md") or file_name.endswith(".mdx"):
return Language.MARKDOWN
elif file_name.endswith(".rst"):
return Language.RST
else:
return Language.MARKDOWN
@st.cache_resource()
def get_vector_db():
return vdb.VectorDB()
@st.cache_resource()
def get_llm_model(_db: vdb.VectorDB):
retriever = _db.docs_db.as_retriever(search_kwargs={"k": 2})
return LLMModel(retriever=retriever).create_qa_chain()
# Initialize an instance of the RetrievalQA class with the specified parameters
def init_sidebar():
with st.sidebar:
st.toggle(
"Loading from LLM",
on_change=enable_sidebar(),
disabled=not st.session_state.answer_loading
)
llm_option = st.selectbox(
'Select to use local model or inference API',
options=['Local', 'Inference API']
)
st.session_state.llm_option = llm_option
uploaded_files = st.file_uploader(
'Upload file(s)',
type=['md', 'mdx', 'rst', 'txt'],
accept_multiple_files=True
)
for uploaded_file in uploaded_files:
if uploaded_file.name not in st.session_state.uploaded_docs:
# Read the file as a string
stringio = StringIO(uploaded_file.getvalue().decode("utf-8"))
string_data = stringio.read()
# Get chunks of text
doc_chunks = load_doc(uploaded_file.name, string_data)
st.write(f"Number of chunks={len(doc_chunks)}")
vector_db.load_docs_into_vector_db(doc_chunks)
st.session_state.uploaded_docs.append(uploaded_file.name)
def init_chat():
# Display chat messages from history on app rerun
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
def disable_sidebar():
st.session_state.answer_loading = True
st.rerun()
def enable_sidebar():
st.session_state.answer_loading = False
st.set_page_config(page_title="Document Answering Tool", page_icon=":book:")
vector_db = get_vector_db()
default_state()
init_sidebar()
st.header("Document answering tool")
st.subheader("Upload your documents on the side and ask questions")
init_chat()
llm_model = get_llm_model(vector_db)
st.session_state.startup = False
# React to user input
if user_prompt := st.chat_input("What's up?", on_submit=disable_sidebar()):
# if st.session_state.answer_loading:
# st.warning("Cannot ask multiple questions at the same time")
# st.session_state.answer_loading = False
# else:
start_time = time.time()
# Display user message in chat message container
with st.chat_message("user"):
st.markdown(user_prompt)
# Add user message to chat history
st.session_state.messages.append({"role": "user", "content": user_prompt})
if llm_model is not None:
assistant_chat = st.chat_message("assistant")
if not st.session_state.uploaded_docs:
assistant_chat.warning("WARN: Will try answer question without documents")
with st.spinner('Resolving question...'):
res = llm_model({"query": user_prompt})
sources = []
for source_docs in res['source_documents']:
if 'source' in source_docs.metadata:
sources.append(source_docs.metadata['source'])
# Display assistant response in chat message container
end_time = time.time()
time_taken = "{:.2f}".format(end_time - start_time)
format_answer = f"## Result\n\n{res['result']}\n\n### Sources\n\n{sources}\n\nTime taken: {time_taken}s"
assistant_chat.markdown(format_answer)
source_expander = assistant_chat.expander("See full sources")
for source_docs in res['source_documents']:
if 'source' in source_docs.metadata:
format_source = f"## File: {source_docs.metadata['source']}\n\n{source_docs.page_content}"
source_expander.markdown(format_source)
# Add assistant response to chat history
st.session_state.messages.append({"role": "assistant", "content": format_answer})
enable_sidebar()
st.rerun()
|