File size: 4,824 Bytes
6cc96e7
 
 
a4db582
 
 
 
 
 
 
6cc96e7
 
a4db582
 
 
 
 
 
 
 
 
 
 
 
 
 
6cc96e7
a4db582
6cc96e7
a4db582
 
 
 
 
6cc96e7
 
 
a4db582
6cc96e7
 
 
 
 
 
 
a4db582
6cc96e7
 
 
 
 
a4db582
6cc96e7
a4db582
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6cc96e7
a4db582
 
 
 
 
 
6cc96e7
a4db582
 
 
 
 
 
 
6cc96e7
 
 
a4db582
 
 
6cc96e7
 
 
 
 
 
 
a4db582
 
 
 
 
 
 
6cc96e7
a4db582
6cc96e7
a4db582
 
6cc96e7
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import logging
from concurrent.futures import ThreadPoolExecutor

import openai
import streamlit as st
from langchain_core.messages import AIMessage, ChatMessage, HumanMessage

from rag.runnable import get_runnable
from utils.error_message_template import ERROR_MESSAGE

logging.basicConfig(level=logging.ERROR)

# Streamlit page configuration
st.set_page_config(
    page_title="ELLA AI Assistant",
    page_icon="πŸ’¬",
    layout="centered",
    initial_sidebar_state="collapsed",
)

# Streamlit CSS configuration
with open("styles/styles.css") as css:
    st.markdown(f"<style>{css.read()}</style>", unsafe_allow_html=True)


# Get runnable and memory
def initialize_runnable_and_memory():
    try:
        return get_runnable(model="gpt-4o", temperature=0)
    except Exception:
        st.warning(ERROR_MESSAGE, icon="πŸ™")
        st.stop()


# Get the ThreadPoolExecutor
if "executor" not in st.session_state:
    st.session_state.executor = ThreadPoolExecutor(max_workers=4)

executor = st.session_state.executor

# Submit initialization task if not already done
if "initialization_future" not in st.session_state:
    st.session_state["initialization_future"] = executor.submit(
        initialize_runnable_and_memory
    )

# Check if initialization is complete
future = st.session_state["initialization_future"]
if future.done() and "runnable" not in st.session_state:
    st.session_state["runnable"], st.session_state["memory"] = future.result()
    st.session_state["memory"].clear()

# Other session state variables
if "messages" not in st.session_state:
    st.session_state["messages"] = []

if "selected_location" not in st.session_state:
    st.session_state["selected_location"] = None

if "disable_chat_input" not in st.session_state:
    st.session_state["disable_chat_input"] = True


# Welcome message and Selectbox for location preferences
def welcome_message():
    st.markdown(
        "Hello there! πŸ‘‹ Need help finding the right service or practitioner? Let our AI assistant give you a hand.\n\n"
        "To get started, please select your preferred location and share details about your symptoms or needs. "
    )


def on_change_location():
    st.session_state["disable_chat_input"] = (
        False if st.session_state["selected_location"] else True
    )


with st.container():
    welcome_message()
    location = st.radio(
        "**Our Locations**:",
        (
            "Cordova Bay - Victoria",
            "James Bay - Victoria",
            "Commercial Drive - Vancouver",
        ),
        index=None,
        label_visibility="visible",
        key="selected_location",
        on_change=on_change_location,
    )
    st.markdown("<br>", unsafe_allow_html=True)

# Get user input only if a location is selected
user_input = st.chat_input(
    "Ask ELLA...", disabled=st.session_state["disable_chat_input"]
)

if user_input and user_input.strip():
    st.session_state["messages"].append(ChatMessage(role="user", content=user_input))
    prompt = f"{user_input}\nLocation preference: {st.session_state.selected_location}."

else:
    prompt = None

# Render chat messages
user_avatar = "images/user.png"
ai_avatar = "images/tall-tree-logo.png"
for msg in st.session_state["messages"]:
    avatar = user_avatar if msg.role == "user" else ai_avatar
    with st.chat_message(msg.role, avatar=avatar):
        st.markdown(msg.content)

# Chat interface (we have to wait for the runnable initialization to complete)
if "runnable" in st.session_state and prompt:
    # Render the assistant's response
    with st.chat_message("assistant", avatar=ai_avatar):
        message_placeholder = st.empty()
        try:
            response = ""
            with st.spinner(" "):
                for chunk in st.session_state["runnable"].stream({"message": prompt}):
                    response += chunk
                    message_placeholder.markdown(response + "|")

            message_placeholder.markdown(response)
        except openai.BadRequestError:
            st.warning(ERROR_MESSAGE, icon="πŸ™")
            st.stop()
        except Exception:
            st.warning(ERROR_MESSAGE, icon="πŸ™")
            st.stop()

        # Add response to the message history
        st.session_state["messages"].append(
            ChatMessage(role="assistant", content=response)
        )

        # Add messages to memory
        st.session_state["memory"].chat_memory.add_user_message(
            HumanMessage(content=prompt)
        )
        st.session_state["memory"].chat_memory.add_ai_message(
            AIMessage(content=response)
        )
if st.session_state.executor:
    st.session_state.executor.shutdown(wait=False)