File size: 5,489 Bytes
4320b9c
 
 
 
 
 
3b5c559
4320b9c
 
 
 
 
 
 
 
 
 
 
 
 
 
3b5c559
4320b9c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3b5c559
4401ddc
3b5c559
 
 
4320b9c
 
3b5c559
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4320b9c
3b5c559
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import logging
from typing import Optional

import openai
import streamlit as st
from langchain_core.messages import AIMessage, HumanMessage
from openai import OpenAI

from rag.runnable_and_memory import get_runnable_and_memory
from utils.error_message_template import ERROR_MESSAGE

logging.basicConfig(level=logging.ERROR)

# Streamlit page configuration
st.set_page_config(
    page_title="ELLA AI Assistant",
    page_icon="πŸ’¬",
    layout="centered",
    initial_sidebar_state="collapsed",
)

# Streamlit CSS configuration
with open("styles/styles.css") as css:
    st.markdown(f"<style>{css.read()}</style>", unsafe_allow_html=True)


def initialize_session_state():
    # Initialize the runnable and memory
    if "runnable" not in st.session_state:
        try:
            st.session_state["runnable"], st.session_state["memory"] = (
                get_runnable_and_memory(model="gpt-4o", temperature=0)
            )
            # Clear the memory
            st.session_state["memory"].clear()
        except Exception:
            handle_errors()

    # Other session state variables
    if "chat_history" not in st.session_state:
        st.session_state["chat_history"] = []

    if "selected_location" not in st.session_state:
        st.session_state["selected_location"] = None

    if "disable_chat_input" not in st.session_state:
        st.session_state["disable_chat_input"] = True


@st.cache_resource
def load_avatars():
    return {
        "Human": "images/user.png",
        "AI": "images/tall-tree-logo.png",
    }


# Disable chat input if no location is selected
def on_change_location():
    st.session_state["disable_chat_input"] = (
        False if st.session_state["selected_location"] else True
    )


def app_layout():
    with st.container():
        # Welcome message
        st.markdown(
            "Hello there! πŸ‘‹ Need help finding the right service or practitioner? Let our AI assistant give you a hand.\n\n"
            "To get started, please select your preferred location and share details about your symptoms or needs. "
        )

        # Selectbox for location preferences
        st.radio(
            "**Our Locations**:",
            (
                "Cordova Bay - Victoria",
                "James Bay - Victoria",
                "Commercial Drive - Vancouver",
            ),
            index=None,
            label_visibility="visible",
            key="selected_location",
            on_change=on_change_location,
        )
        st.markdown("<br>", unsafe_allow_html=True)


def handle_errors(error: Optional[Exception] = None):
    st.warning(error if error else ERROR_MESSAGE, icon="πŸ™")
    st.stop()


# Chat app logic
if __name__ == "__main__":
    initialize_session_state()
    app_layout()

    # Render conversation
    avatars = load_avatars()
    for message in st.session_state["chat_history"]:
        if isinstance(message, AIMessage):
            with st.chat_message("AI", avatar=avatars["AI"]):
                st.write(message.content)
        elif isinstance(message, HumanMessage):
            with st.chat_message("Human", avatar=avatars["Human"]):
                st.write(message.content)

    # Get user input only if a location is selected
    user_input = st.chat_input(
        "Ask ELLA...", disabled=st.session_state["disable_chat_input"]
    )

    # Chat interface
    if user_input and user_input.strip():
        # OPENAI moderator api
        openai_client = OpenAI()
        moderator = openai_client.moderations.create(
            model="omni-moderation-latest",
            input=user_input,
        )

        if moderator.results[0].flagged:
            response = """Sorry, I can't process your message because it doesn't follow our content guidelines. Please revise it and try again.
            If you need assistance, please contact our clinic directly [here](https://www.talltreehealth.ca/contact-us).
            """

            with st.chat_message("AI", avatar=avatars["AI"]):
                st.warning(response, icon="⚠️")

        else:
            st.session_state["chat_history"].append(HumanMessage(content=user_input))
            # Append the location to the user input (important!)
            user_query_location = (
                f"{user_input}\nLocation: {st.session_state.selected_location}."
            )

            # Render the user input
            with st.chat_message("Human", avatar=avatars["Human"]):
                st.write(user_input)

            # Render the AI response
            with st.chat_message("AI", avatar=avatars["AI"]):
                try:
                    with st.spinner(" "):
                        response = st.write_stream(
                            st.session_state["runnable"].stream(
                                {"user_query": user_query_location}
                            )
                        )
                except openai.BadRequestError:
                    handle_errors()
                except Exception:
                    handle_errors()

            # Add AI response to the message history
            st.session_state["chat_history"].append(AIMessage(content=response))

            # Update runnable memory
            st.session_state["memory"].chat_memory.add_user_message(
                HumanMessage(content=user_query_location)
            )
            st.session_state["memory"].chat_memory.add_ai_message(
                AIMessage(content=response)
            )