import openai
import streamlit as st
from langchain_core.messages import AIMessage, ChatMessage, HumanMessage
from langchain_core.tracers.context import collect_runs
from langsmith import Client
from streamlit_feedback import streamlit_feedback
from rag.runnable import get_runnable
from utils.error_message_template import ERROR_MESSAGE
# Streamlit page configuration
st.set_page_config(
page_title="ELLA AI Assistant",
page_icon="💬",
layout="centered",
initial_sidebar_state="collapsed",
)
# Streamlit CSS configuration
with open("styles/styles.css") as css:
st.markdown(f"", unsafe_allow_html=True)
# Get runnable and memory
@st.cache_resource(show_spinner=False)
def get_runnable_and_memory():
try:
return get_runnable(model="gpt-4o", temperature=0)
except Exception:
st.warning(ERROR_MESSAGE, icon="🙁")
st.stop()
chain, memory = get_runnable_and_memory()
# Set up session state variables
# Clean memory (important! to clean the memory at the end of each session)
if "history" not in st.session_state:
st.session_state["history"] = []
memory.clear()
if "messages" not in st.session_state:
st.session_state["messages"] = []
if "selected_location" not in st.session_state:
st.session_state["selected_location"] = None
if "disable_chat_input" not in st.session_state:
st.session_state["disable_chat_input"] = True
# Welcome message and Selectbox for location preferences
def welcome_message():
st.markdown(
"Hello there! 👋 Need help finding the right service or practitioner? Let our AI assistant give you a hand.\n\n"
"To get started, please select your preferred location and share details about your symptoms or needs. "
)
def on_change_location():
st.session_state["disable_chat_input"] = (
False if st.session_state["selected_location"] else True
)
with st.container():
welcome_message()
location = st.radio(
"**Our Locations**:",
(
"Cordova Bay - Victoria",
"James Bay - Victoria",
"Commercial Drive - Vancouver",
),
index=None,
label_visibility="visible",
key="selected_location",
on_change=on_change_location,
)
st.markdown("
", unsafe_allow_html=True)
# Get user input only if a location is selected
user_input = st.chat_input(
"Ask ELLA...", disabled=st.session_state["disable_chat_input"]
)
if user_input:
st.session_state["messages"].append(ChatMessage(role="user", content=user_input))
prompt = f"{user_input}\nLocation preference: {st.session_state.selected_location}."
else:
prompt = None
# Display previous messages
user_avatar = "images/user.png"
ai_avatar = "images/tall-tree-logo.png"
for msg in st.session_state["messages"]:
avatar = user_avatar if msg.role == "user" else ai_avatar
with st.chat_message(msg.role, avatar=avatar):
st.markdown(msg.content)
# Chat interface
if prompt:
# Add all previous messages to memory
for human, ai in st.session_state["history"]:
memory.chat_memory.add_user_message(HumanMessage(content=human))
memory.chat_memory.add_ai_message(AIMessage(content=ai))
# render the assistant's response
with st.chat_message("assistant", avatar=ai_avatar):
message_placeholder = st.empty()
try:
partial_message = ""
# Collect runs for feedback using Langsmith.
with st.spinner(" "), collect_runs() as cb:
for chunk in chain.stream({"message": prompt}):
partial_message += chunk
message_placeholder.markdown(partial_message + "|")
st.session_state.run_id = cb.traced_runs[0].id
message_placeholder.markdown(partial_message)
except openai.BadRequestError:
st.warning(ERROR_MESSAGE, icon="🙁")
st.stop()
except Exception:
st.warning(ERROR_MESSAGE, icon="🙁")
st.stop()
# Add the full response to the history
st.session_state["history"].append((prompt, partial_message))
# Add AI message to memory after the response is generated
memory.chat_memory.add_ai_message(AIMessage(content=partial_message))
# Add the full response to the message history
st.session_state["messages"].append(
ChatMessage(role="assistant", content=partial_message)
)
# Feedback system using streamlit-feedback and Langsmith
# Langsmith client for the feedback system
ls_client = Client()
# Feedback option
feedback_option = "thumbs"
if st.session_state.get("run_id"):
run_id = st.session_state.run_id
feedback = streamlit_feedback(
feedback_type=feedback_option,
optional_text_label="[Optional] Please provide an explanation",
key=f"feedback_{run_id}",
)
score_mappings = {
"thumbs": {"👍": 1, "👎": 0},
"faces": {"😀": 1, "🙂": 0.75, "😐": 0.5, "🙁": 0.25, "😞": 0},
}
# Get the score mapping based on the selected feedback option
scores = score_mappings[feedback_option]
if feedback:
# Get the score from the selected feedback option's score mapping
score = scores.get(feedback["score"])
if score is not None:
# Formulate feedback type string incorporating the feedback option
# and score value
feedback_type_str = f"{feedback_option} {feedback['score']}"
# Record the feedback with the formulated feedback type string
feedback_record = ls_client.create_feedback(
run_id,
feedback_type_str,
score=score,
comment=feedback.get("text"),
)
st.session_state.feedback = {
"feedback_id": str(feedback_record.id),
"score": score,
}
else:
st.warning("Invalid feedback score.")