yrobel-lima's picture
Upload app.py
6cc96e7 verified
raw
history blame
4.82 kB
import logging
from concurrent.futures import ThreadPoolExecutor
import openai
import streamlit as st
from langchain_core.messages import AIMessage, ChatMessage, HumanMessage
from rag.runnable import get_runnable
from utils.error_message_template import ERROR_MESSAGE
logging.basicConfig(level=logging.ERROR)
# Streamlit page configuration
st.set_page_config(
page_title="ELLA AI Assistant",
page_icon="πŸ’¬",
layout="centered",
initial_sidebar_state="collapsed",
)
# Streamlit CSS configuration
with open("styles/styles.css") as css:
st.markdown(f"<style>{css.read()}</style>", unsafe_allow_html=True)
# Get runnable and memory
def initialize_runnable_and_memory():
try:
return get_runnable(model="gpt-4o", temperature=0)
except Exception:
st.warning(ERROR_MESSAGE, icon="πŸ™")
st.stop()
# Get the ThreadPoolExecutor
if "executor" not in st.session_state:
st.session_state.executor = ThreadPoolExecutor(max_workers=4)
executor = st.session_state.executor
# Submit initialization task if not already done
if "initialization_future" not in st.session_state:
st.session_state["initialization_future"] = executor.submit(
initialize_runnable_and_memory
)
# Check if initialization is complete
future = st.session_state["initialization_future"]
if future.done() and "runnable" not in st.session_state:
st.session_state["runnable"], st.session_state["memory"] = future.result()
st.session_state["memory"].clear()
# Other session state variables
if "messages" not in st.session_state:
st.session_state["messages"] = []
if "selected_location" not in st.session_state:
st.session_state["selected_location"] = None
if "disable_chat_input" not in st.session_state:
st.session_state["disable_chat_input"] = True
# Welcome message and Selectbox for location preferences
def welcome_message():
st.markdown(
"Hello there! πŸ‘‹ Need help finding the right service or practitioner? Let our AI assistant give you a hand.\n\n"
"To get started, please select your preferred location and share details about your symptoms or needs. "
)
def on_change_location():
st.session_state["disable_chat_input"] = (
False if st.session_state["selected_location"] else True
)
with st.container():
welcome_message()
location = st.radio(
"**Our Locations**:",
(
"Cordova Bay - Victoria",
"James Bay - Victoria",
"Commercial Drive - Vancouver",
),
index=None,
label_visibility="visible",
key="selected_location",
on_change=on_change_location,
)
st.markdown("<br>", unsafe_allow_html=True)
# Get user input only if a location is selected
user_input = st.chat_input(
"Ask ELLA...", disabled=st.session_state["disable_chat_input"]
)
if user_input and user_input.strip():
st.session_state["messages"].append(ChatMessage(role="user", content=user_input))
prompt = f"{user_input}\nLocation preference: {st.session_state.selected_location}."
else:
prompt = None
# Render chat messages
user_avatar = "images/user.png"
ai_avatar = "images/tall-tree-logo.png"
for msg in st.session_state["messages"]:
avatar = user_avatar if msg.role == "user" else ai_avatar
with st.chat_message(msg.role, avatar=avatar):
st.markdown(msg.content)
# Chat interface (we have to wait for the runnable initialization to complete)
if "runnable" in st.session_state and prompt:
# Render the assistant's response
with st.chat_message("assistant", avatar=ai_avatar):
message_placeholder = st.empty()
try:
response = ""
with st.spinner(" "):
for chunk in st.session_state["runnable"].stream({"message": prompt}):
response += chunk
message_placeholder.markdown(response + "|")
message_placeholder.markdown(response)
except openai.BadRequestError:
st.warning(ERROR_MESSAGE, icon="πŸ™")
st.stop()
except Exception:
st.warning(ERROR_MESSAGE, icon="πŸ™")
st.stop()
# Add response to the message history
st.session_state["messages"].append(
ChatMessage(role="assistant", content=response)
)
# Add messages to memory
st.session_state["memory"].chat_memory.add_user_message(
HumanMessage(content=prompt)
)
st.session_state["memory"].chat_memory.add_ai_message(
AIMessage(content=response)
)
if st.session_state.executor:
st.session_state.executor.shutdown(wait=False)