ChatData / chat.py
mpsk's picture
add parse and private knowledge base
04f0bde
raw
history blame
11.5 kB
import pandas as pd
from os import environ
from time import sleep
import datetime
import streamlit as st
from lib.sessions import SessionManager
from lib.private_kb import PrivateKnowledgeBase
from langchain.schema import HumanMessage, FunctionMessage
from callbacks.arxiv_callbacks import ChatDataAgentCallBackHandler
from langchain.callbacks.streamlit.streamlit_callback_handler import StreamlitCallbackHandler
from helper import (
build_agents,
MYSCALE_HOST,
MYSCALE_PASSWORD,
MYSCALE_PORT,
MYSCALE_USER,
DEFAULT_SYSTEM_PROMPT,
UNSTRUCTURED_API,
)
from login import back_to_main
environ["OPENAI_API_BASE"] = st.secrets["OPENAI_API_BASE"]
TOOL_NAMES = {
"langchain_retriever_tool": "Self-querying retriever",
"vecsql_retriever_tool": "Vector SQL",
}
def on_chat_submit():
with st.session_state.next_round.container():
with st.chat_message('user'):
st.write(st.session_state.chat_input)
with st.chat_message('assistant'):
container = st.container()
st_callback = ChatDataAgentCallBackHandler(container, collapse_completed_thoughts=False)
ret = st.session_state.agent({"input": st.session_state.chat_input}, callbacks=[st_callback])
print(ret)
def clear_history():
if "agent" in st.session_state:
st.session_state.agent.memory.clear()
def back_to_main():
if "user_info" in st.session_state:
del st.session_state.user_info
if "user_name" in st.session_state:
del st.session_state.user_name
if "jump_query_ask" in st.session_state:
del st.session_state.jump_query_ask
if "sel_sess" in st.session_state:
del st.session_state.sel_sess
if "current_sessions" in st.session_state:
del st.session_state.current_sessions
def on_session_change_submit():
if "session_manager" in st.session_state and "session_editor" in st.session_state:
print(st.session_state.session_editor)
try:
for elem in st.session_state.session_editor["added_rows"]:
if len(elem) > 0 and "system_prompt" in elem and "session_id" in elem:
if elem["session_id"] != "" and "?" not in elem["session_id"]:
st.session_state.session_manager.add_session(
user_id=st.session_state.user_name,
session_id=f"{st.session_state.user_name}?{elem['session_id']}",
system_prompt=elem["system_prompt"],
)
else:
raise KeyError(
"`session_id` should NOT be neither empty nor contain question marks."
)
else:
raise KeyError(
"You should fill both `session_id` and `system_prompt` to add a column!"
)
for elem in st.session_state.session_editor["deleted_rows"]:
st.session_state.session_manager.remove_session(
session_id=f"{st.session_state.user_name}?{st.session_state.current_sessions[elem]['session_id']}",
)
refresh_sessions()
except Exception as e:
sleep(2)
st.error(f"{type(e)}: {str(e)}")
finally:
st.session_state.session_editor["added_rows"] = []
st.session_state.session_editor["deleted_rows"] = []
refresh_agent()
def build_session_manager():
return SessionManager(
st.session_state,
host=MYSCALE_HOST,
port=MYSCALE_PORT,
username=MYSCALE_USER,
password=MYSCALE_PASSWORD,
)
def refresh_sessions():
st.session_state[
"current_sessions"
] = st.session_state.session_manager.list_sessions(st.session_state.user_name)
if type(st.session_state.current_sessions) is not dict and len(st.session_state.current_sessions) <= 0:
st.session_state.session_manager.add_session(
st.session_state.user_name,
f"{st.session_state.user_name}?default",
DEFAULT_SYSTEM_PROMPT,
)
st.session_state[
"current_sessions"
] = st.session_state.session_manager.list_sessions(st.session_state.user_name)
try:
dfl_indx = [x["session_id"] for x in st.session_state.current_sessions].index("default" if "" not in st.session_state else st.session_state.sel_session["session_id"])
except ValueError:
dfl_indx = 0
st.session_state.sel_sess = st.session_state.current_sessions[dfl_indx]
def refresh_agent():
with st.spinner("Initializing session..."):
print(
f"??? Changed to ",
f"{st.session_state.user_name}?{st.session_state.sel_sess['session_id']}",
)
st.session_state["agent"] = build_agents(
f"{st.session_state.user_name}?{st.session_state.sel_sess['session_id']}",
["LangChain Self Query Retriever For Wikipedia"]
if "selected_tools" not in st.session_state
else st.session_state.selected_tools,
system_prompt=DEFAULT_SYSTEM_PROMPT
if "sel_sess" not in st.session_state
else st.session_state.sel_sess["system_prompt"],
)
def add_file():
if 'uploaded_files' not in st.session_state or len(st.session_state.uploaded_files) == 0:
st.session_state.tool_status.error("Please upload files!", icon="⚠️")
sleep(2)
return
try:
st.session_state.tool_status.info("Uploading...")
print([(f.name, f.type) for f in st.session_state.uploaded_files])
st.session_state.private_kb.add_by_file(st.session_state.user_name,
st.session_state.uploaded_files)
except ValueError as e:
st.session_state.tool_status.error("Failed to upload! " + str(e))
sleep(2)
def clear_files():
st.session_state.private_kb.clear(st.session_state.user_name)
def chat_page():
if "sel_sess" not in st.session_state:
st.session_state["sel_sess"] = {
"session_id": "default",
"system_prompt": DEFAULT_SYSTEM_PROMPT,
}
if "private_kb" not in st.session_state:
st.session_state["private_kb"] = PrivateKnowledgeBase(
host=MYSCALE_HOST,
port=MYSCALE_PORT,
username=MYSCALE_USER,
password=MYSCALE_PASSWORD,
embedding=st.session_state.embeddings['Wikipedia'],
parser_api_key=UNSTRUCTURED_API,
)
if "session_manager" not in st.session_state:
st.session_state["session_manager"] = build_session_manager()
with st.sidebar:
with st.expander("Session Management"):
if "current_sessions" not in st.session_state:
refresh_sessions()
st.info("Here you can set up your session! \n\nYou can **change your prompt** here!",
icon="πŸ€–")
st.info(("**Add columns by clicking the empty row**.\n"
"And **delete columns by selecting rows with a press on `DEL` Key**"),
icon="πŸ’‘")
st.info("Don't forget to **click `Submit Change` to save your change**!", icon="πŸ“’")
st.data_editor(
st.session_state.current_sessions,
num_rows="dynamic",
key="session_editor",
use_container_width=True,
)
st.button("Submit Change!", on_click=on_session_change_submit)
with st.expander("Session Selection", expanded=True):
st.info("Here you can select your session!", icon="πŸ€–")
st.info("If no session is attach to your account, then we will add a default session to you!", icon="❀️")
try:
dfl_indx = [
x["session_id"] for x in st.session_state.current_sessions
].index("default" if "" not in st.session_state else st.session_state.sel_session["session_id"])
except Exception as e:
print("*** ", str(e))
dfl_indx = 0
st.selectbox(
"Choose a session to chat:",
options=st.session_state.current_sessions,
index=dfl_indx,
key="sel_sess",
format_func=lambda x: x["session_id"],
on_change=refresh_agent,
)
print(st.session_state.sel_sess)
with st.expander("Tool Settings", expanded=True):
st.info("Here you can select your tools.", icon="πŸ”§")
st.info("We provides you several knowledge base tools for you. We are building more tools!", icon="πŸ‘·β€β™‚οΈ")
st.session_state["tool_status"] = st.empty()
tab_kb, tab_file, tab_build = st.tabs(["Knowledge Bases", "File Upload", "KB Builder"])
with tab_kb:
st.multiselect(
"Select a Knowledge Base Tool",
st.session_state.tools.keys(),
default=["Wikipedia + Self Querying"],
key="selected_tools",
on_change=refresh_agent,
)
with tab_file:
st.file_uploader("Upload files", key="uploaded_files", accept_multiple_files=True)
st.markdown("### Uploaded Files")
st.dataframe(st.session_state.private_kb.list_files(st.session_state.user_name))
col_1, col_2 = st.columns(2)
with col_1:
st.button("Add Files", on_click=add_file)
with col_2:
st.button("Clear Files", on_click=clear_files)
# with tab_build:
# st.text_input("Give this knowledge base a description:")
# col_3, col_4 = st.columns(2)
# with col_3:
# st.button("Build Your KB!")
# with col_4:
# st.button("Delete Your KB")
st.button("Clear Chat History", on_click=clear_history)
st.button("Logout", on_click=back_to_main)
if 'agent' not in st.session_state:
refresh_agent()
print("!!! ", st.session_state.agent.memory.chat_memory.session_id)
for msg in st.session_state.agent.memory.chat_memory.messages:
speaker = "user" if isinstance(msg, HumanMessage) else "assistant"
if isinstance(msg, FunctionMessage):
with st.chat_message("Knowledge Base", avatar="πŸ“–"):
st.write(
f"*{datetime.datetime.fromtimestamp(msg.additional_kwargs['timestamp']).isoformat()}*"
)
st.write("Retrieved from knowledge base:")
try:
st.dataframe(
pd.DataFrame.from_records(map(dict, eval(msg.content)))
)
except:
st.write(msg.content)
else:
if len(msg.content) > 0:
with st.chat_message(speaker):
print(type(msg), msg.dict())
st.write(
f"*{datetime.datetime.fromtimestamp(msg.additional_kwargs['timestamp']).isoformat()}*"
)
st.write(f"{msg.content}")
st.session_state["next_round"] = st.empty()
st.chat_input("Input Message", on_submit=on_chat_submit, key="chat_input")