File size: 4,371 Bytes
3403534
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import openai
import streamlit as st
from langchain_core.messages import AIMessage, ChatMessage, HumanMessage

from rag_chain.chain import get_rag_chain

# Streamlit page configuration
st.set_page_config(page_title="Tall Tree Integrated Health",
                   page_icon="πŸ’¬",
                   layout="centered")

# Streamlit CSS configuration

with open("styles/styles.css") as css:
    st.markdown(f"<style>{css.read()}</style>", unsafe_allow_html=True)

# Error message templates
base_error_message = (
    "Oops! Something went wrong while processing your request:\n\n{}\n\n"
    "Please refresh the page or try again later.\n\n"
    "If the error persists, please contact us at "
    "[Tall Tree Health](https://www.talltreehealth.ca/contact-us)."
)

openai_api_error_message = (
    "We're sorry, but you've reached the maximum number of requests allowed per session.\n\n"
    "Please refresh the page to continue using the app."
)

# Get chain and memory


@st.cache_resource(show_spinner=False)
def get_chain_and_memory():
    try:
        # gpt-4 points to gpt-4-0613
        # gpt-4-turbo-preview points to gpt-4-0125-preview
        # Fine-tuned: ft:gpt-3.5-turbo-1106:tall-tree::8mAkOSED
        return get_rag_chain(model_name="gpt-4", temperature=0.2)

    except Exception as e:
        st.warning(base_error_message.format(e), icon="πŸ™")
        st.stop()


chain, memory = get_chain_and_memory()

# Set up session state and clean memory (important to clean the memory at the end of each session)
if "history" not in st.session_state:
    st.session_state["history"] = []
    memory.clear()

if "messages" not in st.session_state:
    st.session_state["messages"] = []

# Select locations element into a container
with st.container(border=False):
    # Set the welcome message
    st.markdown(
        "Hello there! πŸ‘‹ Need help finding the right service or practitioner? Let our AI-powered assistant give you a hand.\n\n"
        "To get started, please select your preferred location and enter your message. "
    )
    location = st.radio(
        "**Our Locations**:",
        ["Cordova Bay - Victoria", "James Bay - Victoria", "Vancouver"],
        index=None, horizontal=False,
    )

# Add some space between the container and the chat interface
for _ in range(2):
    st.markdown("\n\n")

# Get user input only if a location is selected
prompt = ""
if location:
    user_input = st.chat_input("Enter your message...")
    if user_input:
        st.session_state["messages"].append(
            ChatMessage(role="user", content=user_input))
        prompt = f"{user_input}\nLocation: {location}"

# Display previous messages

user_avatar = "images/user.png"
ai_avatar = "images/tall-tree-logo.png"
for msg in st.session_state["messages"]:
    avatar = user_avatar if msg.role == 'user' else ai_avatar
    with st.chat_message(msg.role, avatar=avatar):
        st.markdown(msg.content)

# Chat interface
if prompt:

    # Add all previous messages to memory
    for human, ai in st.session_state["history"]:
        memory.chat_memory.add_user_message(HumanMessage(content=human))
        memory.chat_memory.add_ai_message(AIMessage(content=ai))

    # render the assistant's response
    with st.chat_message("assistant", avatar=ai_avatar):
        message_placeholder = st.empty()

        # If there is a message not None, add it to the memory
        try:
            partial_message = ""
            with st.spinner(" "):
                for chunk in chain.stream({"message": prompt}):
                    partial_message += chunk
                    message_placeholder.markdown(partial_message + "|")
        except openai.BadRequestError:
            st.warning(openai_api_error_message, icon="πŸ™")
            st.stop()
        except Exception as e:
            st.warning(base_error_message.format(e), icon="πŸ™")
            st.stop()
        message_placeholder.markdown(partial_message)

        # Add the full response to the history
        st.session_state["history"].append((prompt, partial_message))

        # Add AI message to memory after the response is generated
        memory.chat_memory.add_ai_message(AIMessage(content=partial_message))

        # add the full response to the message history
        st.session_state["messages"].append(ChatMessage(
            role="assistant", content=partial_message))