talltree commited on
Commit
3403534
β€’
1 Parent(s): 17d27ad

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +126 -0
app.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import openai
2
+ import streamlit as st
3
+ from langchain_core.messages import AIMessage, ChatMessage, HumanMessage
4
+
5
+ from rag_chain.chain import get_rag_chain
6
+
7
+ # Streamlit page configuration
8
+ st.set_page_config(page_title="Tall Tree Integrated Health",
9
+ page_icon="πŸ’¬",
10
+ layout="centered")
11
+
12
+ # Streamlit CSS configuration
13
+
14
+ with open("styles/styles.css") as css:
15
+ st.markdown(f"<style>{css.read()}</style>", unsafe_allow_html=True)
16
+
17
+ # Error message templates
18
+ base_error_message = (
19
+ "Oops! Something went wrong while processing your request:\n\n{}\n\n"
20
+ "Please refresh the page or try again later.\n\n"
21
+ "If the error persists, please contact us at "
22
+ "[Tall Tree Health](https://www.talltreehealth.ca/contact-us)."
23
+ )
24
+
25
+ openai_api_error_message = (
26
+ "We're sorry, but you've reached the maximum number of requests allowed per session.\n\n"
27
+ "Please refresh the page to continue using the app."
28
+ )
29
+
30
+ # Get chain and memory
31
+
32
+
33
+ @st.cache_resource(show_spinner=False)
34
+ def get_chain_and_memory():
35
+ try:
36
+ # gpt-4 points to gpt-4-0613
37
+ # gpt-4-turbo-preview points to gpt-4-0125-preview
38
+ # Fine-tuned: ft:gpt-3.5-turbo-1106:tall-tree::8mAkOSED
39
+ return get_rag_chain(model_name="gpt-4", temperature=0.2)
40
+
41
+ except Exception as e:
42
+ st.warning(base_error_message.format(e), icon="πŸ™")
43
+ st.stop()
44
+
45
+
46
+ chain, memory = get_chain_and_memory()
47
+
48
+ # Set up session state and clean memory (important to clean the memory at the end of each session)
49
+ if "history" not in st.session_state:
50
+ st.session_state["history"] = []
51
+ memory.clear()
52
+
53
+ if "messages" not in st.session_state:
54
+ st.session_state["messages"] = []
55
+
56
+ # Select locations element into a container
57
+ with st.container(border=False):
58
+ # Set the welcome message
59
+ st.markdown(
60
+ "Hello there! πŸ‘‹ Need help finding the right service or practitioner? Let our AI-powered assistant give you a hand.\n\n"
61
+ "To get started, please select your preferred location and enter your message. "
62
+ )
63
+ location = st.radio(
64
+ "**Our Locations**:",
65
+ ["Cordova Bay - Victoria", "James Bay - Victoria", "Vancouver"],
66
+ index=None, horizontal=False,
67
+ )
68
+
69
+ # Add some space between the container and the chat interface
70
+ for _ in range(2):
71
+ st.markdown("\n\n")
72
+
73
+ # Get user input only if a location is selected
74
+ prompt = ""
75
+ if location:
76
+ user_input = st.chat_input("Enter your message...")
77
+ if user_input:
78
+ st.session_state["messages"].append(
79
+ ChatMessage(role="user", content=user_input))
80
+ prompt = f"{user_input}\nLocation: {location}"
81
+
82
+ # Display previous messages
83
+
84
+ user_avatar = "images/user.png"
85
+ ai_avatar = "images/tall-tree-logo.png"
86
+ for msg in st.session_state["messages"]:
87
+ avatar = user_avatar if msg.role == 'user' else ai_avatar
88
+ with st.chat_message(msg.role, avatar=avatar):
89
+ st.markdown(msg.content)
90
+
91
+ # Chat interface
92
+ if prompt:
93
+
94
+ # Add all previous messages to memory
95
+ for human, ai in st.session_state["history"]:
96
+ memory.chat_memory.add_user_message(HumanMessage(content=human))
97
+ memory.chat_memory.add_ai_message(AIMessage(content=ai))
98
+
99
+ # render the assistant's response
100
+ with st.chat_message("assistant", avatar=ai_avatar):
101
+ message_placeholder = st.empty()
102
+
103
+ # If there is a message not None, add it to the memory
104
+ try:
105
+ partial_message = ""
106
+ with st.spinner(" "):
107
+ for chunk in chain.stream({"message": prompt}):
108
+ partial_message += chunk
109
+ message_placeholder.markdown(partial_message + "|")
110
+ except openai.BadRequestError:
111
+ st.warning(openai_api_error_message, icon="πŸ™")
112
+ st.stop()
113
+ except Exception as e:
114
+ st.warning(base_error_message.format(e), icon="πŸ™")
115
+ st.stop()
116
+ message_placeholder.markdown(partial_message)
117
+
118
+ # Add the full response to the history
119
+ st.session_state["history"].append((prompt, partial_message))
120
+
121
+ # Add AI message to memory after the response is generated
122
+ memory.chat_memory.add_ai_message(AIMessage(content=partial_message))
123
+
124
+ # add the full response to the message history
125
+ st.session_state["messages"].append(ChatMessage(
126
+ role="assistant", content=partial_message))