yrobel-lima commited on
Commit
9181031
β€’
1 Parent(s): 6cc96e7

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +85 -46
app.py CHANGED
@@ -1,15 +1,13 @@
1
- import logging
2
- from concurrent.futures import ThreadPoolExecutor
3
-
4
  import openai
5
  import streamlit as st
6
  from langchain_core.messages import AIMessage, ChatMessage, HumanMessage
 
 
 
7
 
8
  from rag.runnable import get_runnable
9
  from utils.error_message_template import ERROR_MESSAGE
10
 
11
- logging.basicConfig(level=logging.ERROR)
12
-
13
  # Streamlit page configuration
14
  st.set_page_config(
15
  page_title="ELLA AI Assistant",
@@ -24,33 +22,24 @@ with open("styles/styles.css") as css:
24
 
25
 
26
  # Get runnable and memory
27
- def initialize_runnable_and_memory():
 
28
  try:
29
- return get_runnable(model="gpt-4o", temperature=0)
30
  except Exception:
31
  st.warning(ERROR_MESSAGE, icon="πŸ™")
32
  st.stop()
33
 
34
 
35
- # Get the ThreadPoolExecutor
36
- if "executor" not in st.session_state:
37
- st.session_state.executor = ThreadPoolExecutor(max_workers=4)
38
 
39
- executor = st.session_state.executor
40
-
41
- # Submit initialization task if not already done
42
- if "initialization_future" not in st.session_state:
43
- st.session_state["initialization_future"] = executor.submit(
44
- initialize_runnable_and_memory
45
- )
46
 
47
- # Check if initialization is complete
48
- future = st.session_state["initialization_future"]
49
- if future.done() and "runnable" not in st.session_state:
50
- st.session_state["runnable"], st.session_state["memory"] = future.result()
51
- st.session_state["memory"].clear()
52
 
53
- # Other session state variables
54
  if "messages" not in st.session_state:
55
  st.session_state["messages"] = []
56
 
@@ -96,14 +85,14 @@ user_input = st.chat_input(
96
  "Ask ELLA...", disabled=st.session_state["disable_chat_input"]
97
  )
98
 
99
- if user_input and user_input.strip():
100
  st.session_state["messages"].append(ChatMessage(role="user", content=user_input))
101
  prompt = f"{user_input}\nLocation preference: {st.session_state.selected_location}."
102
 
103
  else:
104
  prompt = None
105
 
106
- # Render chat messages
107
  user_avatar = "images/user.png"
108
  ai_avatar = "images/tall-tree-logo.png"
109
  for msg in st.session_state["messages"]:
@@ -111,19 +100,26 @@ for msg in st.session_state["messages"]:
111
  with st.chat_message(msg.role, avatar=avatar):
112
  st.markdown(msg.content)
113
 
114
- # Chat interface (we have to wait for the runnable initialization to complete)
115
- if "runnable" in st.session_state and prompt:
116
- # Render the assistant's response
 
 
 
 
 
117
  with st.chat_message("assistant", avatar=ai_avatar):
118
  message_placeholder = st.empty()
119
- try:
120
- response = ""
121
- with st.spinner(" "):
122
- for chunk in st.session_state["runnable"].stream({"message": prompt}):
123
- response += chunk
124
- message_placeholder.markdown(response + "|")
125
 
126
- message_placeholder.markdown(response)
 
 
 
 
 
 
 
 
127
  except openai.BadRequestError:
128
  st.warning(ERROR_MESSAGE, icon="πŸ™")
129
  st.stop()
@@ -131,17 +127,60 @@ if "runnable" in st.session_state and prompt:
131
  st.warning(ERROR_MESSAGE, icon="πŸ™")
132
  st.stop()
133
 
134
- # Add response to the message history
 
 
 
 
 
 
135
  st.session_state["messages"].append(
136
- ChatMessage(role="assistant", content=response)
137
  )
138
 
139
- # Add messages to memory
140
- st.session_state["memory"].chat_memory.add_user_message(
141
- HumanMessage(content=prompt)
142
- )
143
- st.session_state["memory"].chat_memory.add_ai_message(
144
- AIMessage(content=response)
145
- )
146
- if st.session_state.executor:
147
- st.session_state.executor.shutdown(wait=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import openai
2
  import streamlit as st
3
  from langchain_core.messages import AIMessage, ChatMessage, HumanMessage
4
+ from langchain_core.tracers.context import collect_runs
5
+ from langsmith import Client
6
+ from streamlit_feedback import streamlit_feedback
7
 
8
  from rag.runnable import get_runnable
9
  from utils.error_message_template import ERROR_MESSAGE
10
 
 
 
11
  # Streamlit page configuration
12
  st.set_page_config(
13
  page_title="ELLA AI Assistant",
 
22
 
23
 
24
  # Get runnable and memory
25
+ @st.cache_resource(show_spinner=False)
26
+ def get_runnable_and_memory():
27
  try:
28
+ return get_runnable(model="gpt-4-turbo", temperature=0)
29
  except Exception:
30
  st.warning(ERROR_MESSAGE, icon="πŸ™")
31
  st.stop()
32
 
33
 
34
+ chain, memory = get_runnable_and_memory()
 
 
35
 
 
 
 
 
 
 
 
36
 
37
+ # Set up session state variables
38
+ # Clean memory (important! to clean the memory at the end of each session)
39
+ if "history" not in st.session_state:
40
+ st.session_state["history"] = []
41
+ memory.clear()
42
 
 
43
  if "messages" not in st.session_state:
44
  st.session_state["messages"] = []
45
 
 
85
  "Ask ELLA...", disabled=st.session_state["disable_chat_input"]
86
  )
87
 
88
+ if user_input:
89
  st.session_state["messages"].append(ChatMessage(role="user", content=user_input))
90
  prompt = f"{user_input}\nLocation preference: {st.session_state.selected_location}."
91
 
92
  else:
93
  prompt = None
94
 
95
+ # Display previous messages
96
  user_avatar = "images/user.png"
97
  ai_avatar = "images/tall-tree-logo.png"
98
  for msg in st.session_state["messages"]:
 
100
  with st.chat_message(msg.role, avatar=avatar):
101
  st.markdown(msg.content)
102
 
103
+ # Chat interface
104
+ if prompt:
105
+ # Add all previous messages to memory
106
+ for human, ai in st.session_state["history"]:
107
+ memory.chat_memory.add_user_message(HumanMessage(content=human))
108
+ memory.chat_memory.add_ai_message(AIMessage(content=ai))
109
+
110
+ # render the assistant's response
111
  with st.chat_message("assistant", avatar=ai_avatar):
112
  message_placeholder = st.empty()
 
 
 
 
 
 
113
 
114
+ try:
115
+ partial_message = ""
116
+ # Collect runs for feedback using Langsmith.
117
+ with st.spinner(" "), collect_runs() as cb:
118
+ for chunk in chain.stream({"message": prompt}):
119
+ partial_message += chunk
120
+ message_placeholder.markdown(partial_message + "|")
121
+ st.session_state.run_id = cb.traced_runs[0].id
122
+ message_placeholder.markdown(partial_message)
123
  except openai.BadRequestError:
124
  st.warning(ERROR_MESSAGE, icon="πŸ™")
125
  st.stop()
 
127
  st.warning(ERROR_MESSAGE, icon="πŸ™")
128
  st.stop()
129
 
130
+ # Add the full response to the history
131
+ st.session_state["history"].append((prompt, partial_message))
132
+
133
+ # Add AI message to memory after the response is generated
134
+ memory.chat_memory.add_ai_message(AIMessage(content=partial_message))
135
+
136
+ # Add the full response to the message history
137
  st.session_state["messages"].append(
138
+ ChatMessage(role="assistant", content=partial_message)
139
  )
140
 
141
+
142
+ # Feedback system using streamlit-feedback and Langsmith
143
+
144
+ # Langsmith client for the feedback system
145
+ ls_client = Client()
146
+
147
+ # Feedback option
148
+ feedback_option = "thumbs"
149
+
150
+ if st.session_state.get("run_id"):
151
+ run_id = st.session_state.run_id
152
+ feedback = streamlit_feedback(
153
+ feedback_type=feedback_option,
154
+ optional_text_label="[Optional] Please provide an explanation",
155
+ key=f"feedback_{run_id}",
156
+ )
157
+ score_mappings = {
158
+ "thumbs": {"πŸ‘": 1, "πŸ‘Ž": 0},
159
+ "faces": {"πŸ˜€": 1, "πŸ™‚": 0.75, "😐": 0.5, "πŸ™": 0.25, "😞": 0},
160
+ }
161
+
162
+ # Get the score mapping based on the selected feedback option
163
+ scores = score_mappings[feedback_option]
164
+
165
+ if feedback:
166
+ # Get the score from the selected feedback option's score mapping
167
+ score = scores.get(feedback["score"])
168
+
169
+ if score is not None:
170
+ # Formulate feedback type string incorporating the feedback option
171
+ # and score value
172
+ feedback_type_str = f"{feedback_option} {feedback['score']}"
173
+
174
+ # Record the feedback with the formulated feedback type string
175
+ feedback_record = ls_client.create_feedback(
176
+ run_id,
177
+ feedback_type_str,
178
+ score=score,
179
+ comment=feedback.get("text"),
180
+ )
181
+ st.session_state.feedback = {
182
+ "feedback_id": str(feedback_record.id),
183
+ "score": score,
184
+ }
185
+ else:
186
+ st.warning("Invalid feedback score.")