Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import logging | |
from typing import Optional | |
import openai | |
import streamlit as st | |
from langchain_core.messages import AIMessage, HumanMessage | |
from openai import OpenAI | |
from rag.runnable_and_memory import get_runnable_and_memory | |
from utils.error_message_template import ERROR_MESSAGE | |
logging.basicConfig(level=logging.ERROR) | |
# Streamlit page configuration | |
st.set_page_config( | |
page_title="ELLA AI Assistant", | |
page_icon="π¬", | |
layout="centered", | |
initial_sidebar_state="collapsed", | |
) | |
# Streamlit CSS configuration | |
with open("styles/styles.css") as css: | |
st.markdown(f"<style>{css.read()}</style>", unsafe_allow_html=True) | |
def initialize_session_state(): | |
# Initialize the runnable and memory | |
if "runnable" not in st.session_state: | |
try: | |
st.session_state["runnable"], st.session_state["memory"] = ( | |
get_runnable_and_memory(model="gpt-4o", temperature=0) | |
) | |
# Clear the memory | |
st.session_state["memory"].clear() | |
except Exception: | |
handle_errors() | |
# Other session state variables | |
if "chat_history" not in st.session_state: | |
st.session_state["chat_history"] = [] | |
if "selected_location" not in st.session_state: | |
st.session_state["selected_location"] = None | |
if "disable_chat_input" not in st.session_state: | |
st.session_state["disable_chat_input"] = True | |
def load_avatars(): | |
return { | |
"Human": "images/user.png", | |
"AI": "images/tall-tree-logo.png", | |
} | |
# Disable chat input if no location is selected | |
def on_change_location(): | |
st.session_state["disable_chat_input"] = ( | |
False if st.session_state["selected_location"] else True | |
) | |
def app_layout(): | |
with st.container(): | |
# Welcome message | |
st.markdown( | |
"Hello there! π Need help finding the right service or practitioner? Let our AI assistant give you a hand.\n\n" | |
"To get started, please select your preferred location and share details about your symptoms or needs. " | |
) | |
# Selectbox for location preferences | |
st.radio( | |
"**Our Locations**:", | |
( | |
"Cordova Bay - Victoria", | |
"James Bay - Victoria", | |
"Commercial Drive - Vancouver", | |
), | |
index=None, | |
label_visibility="visible", | |
key="selected_location", | |
on_change=on_change_location, | |
) | |
st.markdown("<br>", unsafe_allow_html=True) | |
def handle_errors(error: Optional[Exception] = None): | |
st.warning(error if error else ERROR_MESSAGE, icon="π") | |
st.stop() | |
# Chat app logic | |
if __name__ == "__main__": | |
initialize_session_state() | |
app_layout() | |
# Render conversation | |
avatars = load_avatars() | |
for message in st.session_state["chat_history"]: | |
if isinstance(message, AIMessage): | |
with st.chat_message("AI", avatar=avatars["AI"]): | |
st.write(message.content) | |
elif isinstance(message, HumanMessage): | |
with st.chat_message("Human", avatar=avatars["Human"]): | |
st.write(message.content) | |
# Get user input only if a location is selected | |
user_input = st.chat_input( | |
"Ask ELLA...", disabled=st.session_state["disable_chat_input"] | |
) | |
# Chat interface | |
if user_input and user_input.strip(): | |
# OPENAI moderator api | |
openai_client = OpenAI() | |
moderator = openai_client.moderations.create( | |
model="omni-moderation-latest", | |
input=user_input, | |
) | |
if moderator.results[0].flagged: | |
response = """Sorry, I can't process your message because it doesn't follow our content guidelines. Please revise it and try again. | |
If you need assistance, please contact our clinic directly [here](https://www.talltreehealth.ca/contact-us). | |
""" | |
with st.chat_message("AI", avatar=avatars["AI"]): | |
st.warning(response, icon="β οΈ") | |
else: | |
st.session_state["chat_history"].append(HumanMessage(content=user_input)) | |
# Append the location to the user input (important!) | |
user_query_location = ( | |
f"{user_input}\nLocation: {st.session_state.selected_location}." | |
) | |
# Render the user input | |
with st.chat_message("Human", avatar=avatars["Human"]): | |
st.write(user_input) | |
# Render the AI response | |
with st.chat_message("AI", avatar=avatars["AI"]): | |
try: | |
with st.spinner(" "): | |
response = st.write_stream( | |
st.session_state["runnable"].stream( | |
{"user_query": user_query_location} | |
) | |
) | |
except openai.BadRequestError: | |
handle_errors() | |
except Exception: | |
handle_errors() | |
# Add AI response to the message history | |
st.session_state["chat_history"].append(AIMessage(content=response)) | |
# Update runnable memory | |
st.session_state["memory"].chat_memory.add_user_message( | |
HumanMessage(content=user_query_location) | |
) | |
st.session_state["memory"].chat_memory.add_ai_message( | |
AIMessage(content=response) | |
) | |