talltree's picture
Update app.py
b228d0c verified
raw
history blame
4.49 kB
import openai
import streamlit as st
from langchain_core.messages import AIMessage, ChatMessage, HumanMessage
from rag_chain.chain import get_rag_chain
# Streamlit page configuration
st.set_page_config(page_title="Tall Tree Integrated Health",
page_icon="πŸ’¬",
layout="centered",
initial_sidebar_state="collapsed")
# Streamlit CSS configuration
with open("styles/styles.css") as css:
st.markdown(f"<style>{css.read()}</style>", unsafe_allow_html=True)
# Error message templates
base_error_message = (
"Oops! Something went wrong while processing your request. "
"Please refresh the page or try again later.\n\n"
"If the error persists, please contact us at "
"[Tall Tree Health](https://www.talltreehealth.ca/contact-us)."
)
openai_api_error_message = (
"We're sorry, but you've reached the maximum number of requests allowed per session.\n\n"
"Please refresh the page to continue using the app."
)
# Get chain and memory
@st.cache_resource(show_spinner=False)
def get_chain_and_memory():
try:
# gpt-4 points to gpt-4-0613
# gpt-4-turbo-preview points to gpt-4-0125-preview
# Fine-tuned: ft:gpt-3.5-turbo-1106:tall-tree::8mAkOSED
# gpt-4-1106-preview
return get_rag_chain(model_name="gpt-4-1106-preview", temperature=0.2)
except Exception as e:
st.warning(base_error_message, icon="πŸ™")
st.stop()
chain, memory = get_chain_and_memory()
# Set up session state and clean memory (important to clean the memory at the end of each session)
if "history" not in st.session_state:
st.session_state["history"] = []
memory.clear()
if "messages" not in st.session_state:
st.session_state["messages"] = []
# Select locations element into a container
with st.container(border=False):
# Set the welcome message
st.markdown(
"Hello there! πŸ‘‹ Need help finding the right service or practitioner? Let our AI-powered assistant give you a hand.\n\n"
"To get started, please select your preferred location and share details about your symptoms or needs. "
)
location = st.radio(
"**Our Locations**:",
["Cordova Bay - Victoria", "James Bay - Victoria",
"Comercial Drive - Vancouver"],
index=None, horizontal=False,
)
# Add some space between the container and the chat interface
for _ in range(2):
st.markdown("\n\n")
# Get user input only if a location is selected
prompt = ""
if location:
user_input = st.chat_input("Enter your message...")
if user_input:
st.session_state["messages"].append(
ChatMessage(role="user", content=user_input))
prompt = f"{user_input}\nLocation: {location}"
# Display previous messages
user_avatar = "images/user.png"
ai_avatar = "images/tall-tree-logo.png"
for msg in st.session_state["messages"]:
avatar = user_avatar if msg.role == 'user' else ai_avatar
with st.chat_message(msg.role, avatar=avatar):
st.markdown(msg.content)
# Chat interface
if prompt:
# Add all previous messages to memory
for human, ai in st.session_state["history"]:
memory.chat_memory.add_user_message(HumanMessage(content=human))
memory.chat_memory.add_ai_message(AIMessage(content=ai))
# render the assistant's response
with st.chat_message("assistant", avatar=ai_avatar):
message_placeholder = st.empty()
# If there is a message not None, add it to the memory
try:
partial_message = ""
with st.spinner(" "):
for chunk in chain.stream({"message": prompt}):
partial_message += chunk
message_placeholder.markdown(partial_message + "|")
except openai.BadRequestError:
st.warning(openai_api_error_message, icon="πŸ™")
st.stop()
except Exception as e:
st.warning(base_error_message, icon="πŸ™")
st.stop()
message_placeholder.markdown(partial_message)
# Add the full response to the history
st.session_state["history"].append((prompt, partial_message))
# Add AI message to memory after the response is generated
memory.chat_memory.add_ai_message(AIMessage(content=partial_message))
# add the full response to the message history
st.session_state["messages"].append(ChatMessage(
role="assistant", content=partial_message))