yrobel-lima's picture
Update app.py
4320b9c verified
raw
history blame
No virus
4.67 kB
import logging
from typing import Optional
import openai
import streamlit as st
from langchain_core.messages import AIMessage, HumanMessage
from rag.runnable_and_memory import get_runnable_and_memory
from utils.error_message_template import ERROR_MESSAGE
logging.basicConfig(level=logging.ERROR)
# Streamlit page configuration
st.set_page_config(
page_title="ELLA AI Assistant",
page_icon="πŸ’¬",
layout="centered",
initial_sidebar_state="collapsed",
)
# CSS
with open("styles/styles.css") as css:
st.markdown(f"<style>{css.read()}</style>", unsafe_allow_html=True)
def initialize_session_state():
# Initialize the runnable and memory
if "runnable" not in st.session_state:
try:
st.session_state["runnable"], st.session_state["memory"] = (
get_runnable_and_memory(model="gpt-4o", temperature=0)
)
# Clear the memory
st.session_state["memory"].clear()
except Exception:
handle_errors()
# Other session state variables
if "chat_history" not in st.session_state:
st.session_state["chat_history"] = []
if "selected_location" not in st.session_state:
st.session_state["selected_location"] = None
if "disable_chat_input" not in st.session_state:
st.session_state["disable_chat_input"] = True
@st.cache_resource
def load_avatars():
return {
"Human": "images/user.png",
"AI": "images/tall-tree-logo.png",
}
# Disable chat input if no location is selected
def on_change_location():
st.session_state["disable_chat_input"] = (
False if st.session_state["selected_location"] else True
)
def app_layout():
with st.container():
# Welcome message
st.markdown(
"Hello there! πŸ‘‹ Need help finding the right service or practitioner? Let our AI assistant give you a hand.\n\n"
"To get started, please select your preferred location and share details about your symptoms or needs. "
)
# Selectbox for location preferences
st.radio(
"**Our Locations**:",
(
"Cordova Bay - Victoria",
"James Bay - Victoria",
"Commercial Drive - Vancouver",
),
index=None,
label_visibility="visible",
key="selected_location",
on_change=on_change_location,
)
st.markdown("<br>", unsafe_allow_html=True)
def handle_errors(error: Optional[Exception] = None):
st.warning(error if error else ERROR_MESSAGE, icon="πŸ™")
st.stop()
# Chat app logic
if __name__ == "__main__":
initialize_session_state()
app_layout()
# Render conversation
avatars = load_avatars()
for message in st.session_state["chat_history"]:
if isinstance(message, AIMessage):
with st.chat_message("AI", avatar=avatars["AI"]):
st.write(message.content)
elif isinstance(message, HumanMessage):
with st.chat_message("Human", avatar=avatars["Human"]):
st.write(message.content)
# Get user input only if a location is selected
user_input = st.chat_input(
"Ask ELLA...", disabled=st.session_state["disable_chat_input"]
)
# Chat interface
if user_input and user_input.strip():
st.session_state["chat_history"].append(HumanMessage(content=user_input))
# Append the location to the user input (important!)
user_query_location = (
f"{user_input}\nLocation: {st.session_state.selected_location}."
)
# Render the user input
with st.chat_message("Human", avatar=avatars["Human"]):
st.write(user_input)
# Render the AI response
with st.chat_message("AI", avatar=avatars["AI"]):
try:
with st.spinner(" "):
response = st.write_stream(
st.session_state["runnable"].stream(
{"user_query": user_query_location}
)
)
except openai.BadRequestError:
handle_errors()
except Exception:
handle_errors()
# Add AI response to the message history
st.session_state["chat_history"].append(AIMessage(content=response))
# Update runnable memory
st.session_state["memory"].chat_memory.add_user_message(
HumanMessage(content=user_query_location)
)
st.session_state["memory"].chat_memory.add_ai_message(
AIMessage(content=response)
)