yrobel-lima commited on
Commit
a4db582
β€’
1 Parent(s): 2b5588b

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +186 -186
app.py CHANGED
@@ -1,186 +1,186 @@
1
- import openai
2
- import streamlit as st
3
- from langchain_core.messages import AIMessage, ChatMessage, HumanMessage
4
- from langchain_core.tracers.context import collect_runs
5
- from langsmith import Client
6
- from streamlit_feedback import streamlit_feedback
7
-
8
- from rag.runnable import get_runnable
9
- from utils.error_message_template import ERROR_MESSAGE
10
-
11
- # Streamlit page configuration
12
- st.set_page_config(
13
- page_title="ELLA AI Assistant",
14
- page_icon="πŸ’¬",
15
- layout="centered",
16
- initial_sidebar_state="collapsed",
17
- )
18
-
19
- # Streamlit CSS configuration
20
- with open("styles/styles.css") as css:
21
- st.markdown(f"<style>{css.read()}</style>", unsafe_allow_html=True)
22
-
23
-
24
- # Get runnable and memory
25
- @st.cache_resource(show_spinner=False)
26
- def get_runnable_and_memory():
27
- try:
28
- return get_runnable(model="gpt-4-turbo", temperature=0)
29
- except Exception:
30
- st.warning(ERROR_MESSAGE, icon="πŸ™")
31
- st.stop()
32
-
33
-
34
- chain, memory = get_runnable_and_memory()
35
-
36
-
37
- # Set up session state variables
38
- # Clean memory (important! to clean the memory at the end of each session)
39
- if "history" not in st.session_state:
40
- st.session_state["history"] = []
41
- memory.clear()
42
-
43
- if "messages" not in st.session_state:
44
- st.session_state["messages"] = []
45
-
46
- if "selected_location" not in st.session_state:
47
- st.session_state["selected_location"] = None
48
-
49
- if "disable_chat_input" not in st.session_state:
50
- st.session_state["disable_chat_input"] = True
51
-
52
-
53
- # Welcome message and Selectbox for location preferences
54
- def welcome_message():
55
- st.markdown(
56
- "Hello there! πŸ‘‹ Need help finding the right service or practitioner? Let our AI assistant give you a hand.\n\n"
57
- "To get started, please select your preferred location and share details about your symptoms or needs. "
58
- )
59
-
60
-
61
- def on_change_location():
62
- st.session_state["disable_chat_input"] = (
63
- False if st.session_state["selected_location"] else True
64
- )
65
-
66
-
67
- with st.container():
68
- welcome_message()
69
- location = st.radio(
70
- "**Our Locations**:",
71
- (
72
- "Cordova Bay - Victoria",
73
- "James Bay - Victoria",
74
- "Commercial Drive - Vancouver",
75
- ),
76
- index=None,
77
- label_visibility="visible",
78
- key="selected_location",
79
- on_change=on_change_location,
80
- )
81
- st.markdown("<br>", unsafe_allow_html=True)
82
-
83
- # Get user input only if a location is selected
84
- user_input = st.chat_input(
85
- "Ask ELLA...", disabled=st.session_state["disable_chat_input"]
86
- )
87
-
88
- if user_input:
89
- st.session_state["messages"].append(ChatMessage(role="user", content=user_input))
90
- prompt = f"{user_input}\nLocation preference: {st.session_state.selected_location}."
91
-
92
- else:
93
- prompt = None
94
-
95
- # Display previous messages
96
- user_avatar = "images/user.png"
97
- ai_avatar = "images/tall-tree-logo.png"
98
- for msg in st.session_state["messages"]:
99
- avatar = user_avatar if msg.role == "user" else ai_avatar
100
- with st.chat_message(msg.role, avatar=avatar):
101
- st.markdown(msg.content)
102
-
103
- # Chat interface
104
- if prompt:
105
- # Add all previous messages to memory
106
- for human, ai in st.session_state["history"]:
107
- memory.chat_memory.add_user_message(HumanMessage(content=human))
108
- memory.chat_memory.add_ai_message(AIMessage(content=ai))
109
-
110
- # render the assistant's response
111
- with st.chat_message("assistant", avatar=ai_avatar):
112
- message_placeholder = st.empty()
113
-
114
- try:
115
- partial_message = ""
116
- # Collect runs for feedback using Langsmith.
117
- with st.spinner(" "), collect_runs() as cb:
118
- for chunk in chain.stream({"message": prompt}):
119
- partial_message += chunk
120
- message_placeholder.markdown(partial_message + "|")
121
- st.session_state.run_id = cb.traced_runs[0].id
122
- message_placeholder.markdown(partial_message)
123
- except openai.BadRequestError:
124
- st.warning(ERROR_MESSAGE, icon="πŸ™")
125
- st.stop()
126
- except Exception:
127
- st.warning(ERROR_MESSAGE, icon="πŸ™")
128
- st.stop()
129
-
130
- # Add the full response to the history
131
- st.session_state["history"].append((prompt, partial_message))
132
-
133
- # Add AI message to memory after the response is generated
134
- memory.chat_memory.add_ai_message(AIMessage(content=partial_message))
135
-
136
- # Add the full response to the message history
137
- st.session_state["messages"].append(
138
- ChatMessage(role="assistant", content=partial_message)
139
- )
140
-
141
-
142
- # Feedback system using streamlit-feedback and Langsmith
143
-
144
- # Langsmith client for the feedback system
145
- ls_client = Client()
146
-
147
- # Feedback option
148
- feedback_option = "thumbs"
149
-
150
- if st.session_state.get("run_id"):
151
- run_id = st.session_state.run_id
152
- feedback = streamlit_feedback(
153
- feedback_type=feedback_option,
154
- optional_text_label="[Optional] Please provide an explanation",
155
- key=f"feedback_{run_id}",
156
- )
157
- score_mappings = {
158
- "thumbs": {"πŸ‘": 1, "πŸ‘Ž": 0},
159
- "faces": {"πŸ˜€": 1, "πŸ™‚": 0.75, "😐": 0.5, "πŸ™": 0.25, "😞": 0},
160
- }
161
-
162
- # Get the score mapping based on the selected feedback option
163
- scores = score_mappings[feedback_option]
164
-
165
- if feedback:
166
- # Get the score from the selected feedback option's score mapping
167
- score = scores.get(feedback["score"])
168
-
169
- if score is not None:
170
- # Formulate feedback type string incorporating the feedback option
171
- # and score value
172
- feedback_type_str = f"{feedback_option} {feedback['score']}"
173
-
174
- # Record the feedback with the formulated feedback type string
175
- feedback_record = ls_client.create_feedback(
176
- run_id,
177
- feedback_type_str,
178
- score=score,
179
- comment=feedback.get("text"),
180
- )
181
- st.session_state.feedback = {
182
- "feedback_id": str(feedback_record.id),
183
- "score": score,
184
- }
185
- else:
186
- st.warning("Invalid feedback score.")
 
1
+ import openai
2
+ import streamlit as st
3
+ from langchain_core.messages import AIMessage, ChatMessage, HumanMessage
4
+ from langchain_core.tracers.context import collect_runs
5
+ from langsmith import Client
6
+ from streamlit_feedback import streamlit_feedback
7
+
8
+ from rag.runnable import get_runnable
9
+ from utils.error_message_template import ERROR_MESSAGE
10
+
11
+ # Streamlit page configuration
12
+ st.set_page_config(
13
+ page_title="ELLA AI Assistant",
14
+ page_icon="πŸ’¬",
15
+ layout="centered",
16
+ initial_sidebar_state="collapsed",
17
+ )
18
+
19
+ # Streamlit CSS configuration
20
+ with open("styles/styles.css") as css:
21
+ st.markdown(f"<style>{css.read()}</style>", unsafe_allow_html=True)
22
+
23
+
24
+ # Get runnable and memory
25
+ @st.cache_resource(show_spinner=False)
26
+ def get_runnable_and_memory():
27
+ try:
28
+ return get_runnable(model="gpt-4-turbo", temperature=0)
29
+ except Exception:
30
+ st.warning(ERROR_MESSAGE, icon="πŸ™")
31
+ st.stop()
32
+
33
+
34
+ chain, memory = get_runnable_and_memory()
35
+
36
+
37
+ # Set up session state variables
38
+ # Clean memory (important! to clean the memory at the end of each session)
39
+ if "history" not in st.session_state:
40
+ st.session_state["history"] = []
41
+ memory.clear()
42
+
43
+ if "messages" not in st.session_state:
44
+ st.session_state["messages"] = []
45
+
46
+ if "selected_location" not in st.session_state:
47
+ st.session_state["selected_location"] = None
48
+
49
+ if "disable_chat_input" not in st.session_state:
50
+ st.session_state["disable_chat_input"] = True
51
+
52
+
53
+ # Welcome message and Selectbox for location preferences
54
+ def welcome_message():
55
+ st.markdown(
56
+ "Hello there! πŸ‘‹ Need help finding the right service or practitioner? Let our AI assistant give you a hand.\n\n"
57
+ "To get started, please select your preferred location and share details about your symptoms or needs. "
58
+ )
59
+
60
+
61
+ def on_change_location():
62
+ st.session_state["disable_chat_input"] = (
63
+ False if st.session_state["selected_location"] else True
64
+ )
65
+
66
+
67
+ with st.container():
68
+ welcome_message()
69
+ location = st.radio(
70
+ "**Our Locations**:",
71
+ (
72
+ "Cordova Bay - Victoria",
73
+ "James Bay - Victoria",
74
+ "Commercial Drive - Vancouver",
75
+ ),
76
+ index=None,
77
+ label_visibility="visible",
78
+ key="selected_location",
79
+ on_change=on_change_location,
80
+ )
81
+ st.markdown("<br>", unsafe_allow_html=True)
82
+
83
+ # Get user input only if a location is selected
84
+ user_input = st.chat_input(
85
+ "Ask ELLA...", disabled=st.session_state["disable_chat_input"]
86
+ )
87
+
88
+ if user_input:
89
+ st.session_state["messages"].append(ChatMessage(role="user", content=user_input))
90
+ prompt = f"{user_input}\nLocation preference: {st.session_state.selected_location}."
91
+
92
+ else:
93
+ prompt = None
94
+
95
+ # Display previous messages
96
+ user_avatar = "images/user.png"
97
+ ai_avatar = "images/tall-tree-logo.png"
98
+ for msg in st.session_state["messages"]:
99
+ avatar = user_avatar if msg.role == "user" else ai_avatar
100
+ with st.chat_message(msg.role, avatar=avatar):
101
+ st.markdown(msg.content)
102
+
103
+ # Chat interface
104
+ if prompt:
105
+ # Add all previous messages to memory
106
+ for human, ai in st.session_state["history"]:
107
+ memory.chat_memory.add_user_message(HumanMessage(content=human))
108
+ memory.chat_memory.add_ai_message(AIMessage(content=ai))
109
+
110
+ # render the assistant's response
111
+ with st.chat_message("assistant", avatar=ai_avatar):
112
+ message_placeholder = st.empty()
113
+
114
+ try:
115
+ partial_message = ""
116
+ # Collect runs for feedback using Langsmith.
117
+ with st.spinner(" "), collect_runs() as cb:
118
+ for chunk in chain.stream({"message": prompt}):
119
+ partial_message += chunk
120
+ message_placeholder.markdown(partial_message + "|")
121
+ st.session_state.run_id = cb.traced_runs[0].id
122
+ message_placeholder.markdown(partial_message)
123
+ except openai.BadRequestError:
124
+ st.warning(ERROR_MESSAGE, icon="πŸ™")
125
+ st.stop()
126
+ except Exception:
127
+ st.warning(ERROR_MESSAGE, icon="πŸ™")
128
+ st.stop()
129
+
130
+ # Add the full response to the history
131
+ st.session_state["history"].append((prompt, partial_message))
132
+
133
+ # Add AI message to memory after the response is generated
134
+ memory.chat_memory.add_ai_message(AIMessage(content=partial_message))
135
+
136
+ # Add the full response to the message history
137
+ st.session_state["messages"].append(
138
+ ChatMessage(role="assistant", content=partial_message)
139
+ )
140
+
141
+
142
+ # Feedback system using streamlit-feedback and Langsmith
143
+
144
+ # Langsmith client for the feedback system
145
+ ls_client = Client()
146
+
147
+ # Feedback option
148
+ feedback_option = "thumbs"
149
+
150
+ if st.session_state.get("run_id"):
151
+ run_id = st.session_state.run_id
152
+ feedback = streamlit_feedback(
153
+ feedback_type=feedback_option,
154
+ optional_text_label="[Optional] Please provide an explanation",
155
+ key=f"feedback_{run_id}",
156
+ )
157
+ score_mappings = {
158
+ "thumbs": {"πŸ‘": 1, "πŸ‘Ž": 0},
159
+ "faces": {"πŸ˜€": 1, "πŸ™‚": 0.75, "😐": 0.5, "πŸ™": 0.25, "😞": 0},
160
+ }
161
+
162
+ # Get the score mapping based on the selected feedback option
163
+ scores = score_mappings[feedback_option]
164
+
165
+ if feedback:
166
+ # Get the score from the selected feedback option's score mapping
167
+ score = scores.get(feedback["score"])
168
+
169
+ if score is not None:
170
+ # Formulate feedback type string incorporating the feedback option
171
+ # and score value
172
+ feedback_type_str = f"{feedback_option} {feedback['score']}"
173
+
174
+ # Record the feedback with the formulated feedback type string
175
+ feedback_record = ls_client.create_feedback(
176
+ run_id,
177
+ feedback_type_str,
178
+ score=score,
179
+ comment=feedback.get("text"),
180
+ )
181
+ st.session_state.feedback = {
182
+ "feedback_id": str(feedback_record.id),
183
+ "score": score,
184
+ }
185
+ else:
186
+ st.warning("Invalid feedback score.")