File size: 6,225 Bytes
a4db582
 
 
9181031
 
 
a4db582
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9181031
 
a4db582
9181031
a4db582
 
 
 
 
9181031
a4db582
 
9181031
 
 
 
 
a4db582
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9181031
a4db582
 
 
 
 
 
9181031
a4db582
 
 
 
 
 
 
9181031
 
 
 
 
 
 
 
a4db582
 
6cc96e7
9181031
 
 
 
 
 
 
 
 
a4db582
 
 
 
 
 
 
9181031
 
 
 
 
 
 
a4db582
9181031
a4db582
 
9181031
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
import openai
import streamlit as st
from langchain_core.messages import AIMessage, ChatMessage, HumanMessage
from langchain_core.tracers.context import collect_runs
from langsmith import Client
from streamlit_feedback import streamlit_feedback

from rag.runnable import get_runnable
from utils.error_message_template import ERROR_MESSAGE

# Streamlit page configuration
st.set_page_config(
    page_title="ELLA AI Assistant",
    page_icon="πŸ’¬",
    layout="centered",
    initial_sidebar_state="collapsed",
)

# Streamlit CSS configuration
with open("styles/styles.css") as css:
    st.markdown(f"<style>{css.read()}</style>", unsafe_allow_html=True)


# Get runnable and memory
@st.cache_resource(show_spinner=False)
def get_runnable_and_memory():
    try:
        return get_runnable(model="gpt-4-turbo", temperature=0)
    except Exception:
        st.warning(ERROR_MESSAGE, icon="πŸ™")
        st.stop()


chain, memory = get_runnable_and_memory()


# Set up session state variables
# Clean memory (important! to clean the memory at the end of each session)
if "history" not in st.session_state:
    st.session_state["history"] = []
    memory.clear()

if "messages" not in st.session_state:
    st.session_state["messages"] = []

if "selected_location" not in st.session_state:
    st.session_state["selected_location"] = None

if "disable_chat_input" not in st.session_state:
    st.session_state["disable_chat_input"] = True


# Welcome message and Selectbox for location preferences
def welcome_message():
    st.markdown(
        "Hello there! πŸ‘‹ Need help finding the right service or practitioner? Let our AI assistant give you a hand.\n\n"
        "To get started, please select your preferred location and share details about your symptoms or needs. "
    )


def on_change_location():
    st.session_state["disable_chat_input"] = (
        False if st.session_state["selected_location"] else True
    )


with st.container():
    welcome_message()
    location = st.radio(
        "**Our Locations**:",
        (
            "Cordova Bay - Victoria",
            "James Bay - Victoria",
            "Commercial Drive - Vancouver",
        ),
        index=None,
        label_visibility="visible",
        key="selected_location",
        on_change=on_change_location,
    )
    st.markdown("<br>", unsafe_allow_html=True)

# Get user input only if a location is selected
user_input = st.chat_input(
    "Ask ELLA...", disabled=st.session_state["disable_chat_input"]
)

if user_input:
    st.session_state["messages"].append(ChatMessage(role="user", content=user_input))
    prompt = f"{user_input}\nLocation preference: {st.session_state.selected_location}."

else:
    prompt = None

# Display previous messages
user_avatar = "images/user.png"
ai_avatar = "images/tall-tree-logo.png"
for msg in st.session_state["messages"]:
    avatar = user_avatar if msg.role == "user" else ai_avatar
    with st.chat_message(msg.role, avatar=avatar):
        st.markdown(msg.content)

# Chat interface
if prompt:
    # Add all previous messages to memory
    for human, ai in st.session_state["history"]:
        memory.chat_memory.add_user_message(HumanMessage(content=human))
        memory.chat_memory.add_ai_message(AIMessage(content=ai))

    # render the assistant's response
    with st.chat_message("assistant", avatar=ai_avatar):
        message_placeholder = st.empty()

        try:
            partial_message = ""
            # Collect runs for feedback using Langsmith.
            with st.spinner(" "), collect_runs() as cb:
                for chunk in chain.stream({"message": prompt}):
                    partial_message += chunk
                    message_placeholder.markdown(partial_message + "|")
                st.session_state.run_id = cb.traced_runs[0].id
            message_placeholder.markdown(partial_message)
        except openai.BadRequestError:
            st.warning(ERROR_MESSAGE, icon="πŸ™")
            st.stop()
        except Exception:
            st.warning(ERROR_MESSAGE, icon="πŸ™")
            st.stop()

        # Add the full response to the history
        st.session_state["history"].append((prompt, partial_message))

        # Add AI message to memory after the response is generated
        memory.chat_memory.add_ai_message(AIMessage(content=partial_message))

        # Add the full response to the message history
        st.session_state["messages"].append(
            ChatMessage(role="assistant", content=partial_message)
        )


# Feedback system using streamlit-feedback and Langsmith

# Langsmith client for the feedback system
ls_client = Client()

# Feedback option
feedback_option = "thumbs"

if st.session_state.get("run_id"):
    run_id = st.session_state.run_id
    feedback = streamlit_feedback(
        feedback_type=feedback_option,
        optional_text_label="[Optional] Please provide an explanation",
        key=f"feedback_{run_id}",
    )
    score_mappings = {
        "thumbs": {"πŸ‘": 1, "πŸ‘Ž": 0},
        "faces": {"πŸ˜€": 1, "πŸ™‚": 0.75, "😐": 0.5, "πŸ™": 0.25, "😞": 0},
    }

    # Get the score mapping based on the selected feedback option
    scores = score_mappings[feedback_option]

    if feedback:
        # Get the score from the selected feedback option's score mapping
        score = scores.get(feedback["score"])

        if score is not None:
            # Formulate feedback type string incorporating the feedback option
            # and score value
            feedback_type_str = f"{feedback_option} {feedback['score']}"

            # Record the feedback with the formulated feedback type string
            feedback_record = ls_client.create_feedback(
                run_id,
                feedback_type_str,
                score=score,
                comment=feedback.get("text"),
            )
            st.session_state.feedback = {
                "feedback_id": str(feedback_record.id),
                "score": score,
            }
        else:
            st.warning("Invalid feedback score.")