Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import openai | |
import streamlit as st | |
from langchain_core.messages import AIMessage, ChatMessage, HumanMessage | |
from rag_chain.chain import get_rag_chain | |
# Streamlit page configuration | |
st.set_page_config(page_title="Tall Tree Integrated Health", | |
page_icon="π¬", | |
layout="centered") | |
# Streamlit CSS configuration | |
with open("styles/styles.css") as css: | |
st.markdown(f"<style>{css.read()}</style>", unsafe_allow_html=True) | |
# Error message templates | |
base_error_message = ( | |
"Oops! Something went wrong while processing your request:\n\n{}\n\n" | |
"Please refresh the page or try again later.\n\n" | |
"If the error persists, please contact us at " | |
"[Tall Tree Health](https://www.talltreehealth.ca/contact-us)." | |
) | |
openai_api_error_message = ( | |
"We're sorry, but you've reached the maximum number of requests allowed per session.\n\n" | |
"Please refresh the page to continue using the app." | |
) | |
# Get chain and memory | |
def get_chain_and_memory(): | |
try: | |
# gpt-4 points to gpt-4-0613 | |
# gpt-4-turbo-preview points to gpt-4-0125-preview | |
# Fine-tuned: ft:gpt-3.5-turbo-1106:tall-tree::8mAkOSED | |
return get_rag_chain(model_name="gpt-4", temperature=0.2) | |
except Exception as e: | |
st.warning(base_error_message.format(e), icon="π") | |
st.stop() | |
chain, memory = get_chain_and_memory() | |
# Set up session state and clean memory (important to clean the memory at the end of each session) | |
if "history" not in st.session_state: | |
st.session_state["history"] = [] | |
memory.clear() | |
if "messages" not in st.session_state: | |
st.session_state["messages"] = [] | |
# Select locations element into a container | |
with st.container(border=False): | |
# Set the welcome message | |
st.markdown( | |
"Hello there! π Need help finding the right service or practitioner? Let our AI-powered assistant give you a hand.\n\n" | |
"To get started, please select your preferred location and enter your message. " | |
) | |
location = st.radio( | |
"**Our Locations**:", | |
["Cordova Bay - Victoria", "James Bay - Victoria", "Vancouver"], | |
index=None, horizontal=False, | |
) | |
# Add some space between the container and the chat interface | |
for _ in range(2): | |
st.markdown("\n\n") | |
# Get user input only if a location is selected | |
prompt = "" | |
if location: | |
user_input = st.chat_input("Enter your message...") | |
if user_input: | |
st.session_state["messages"].append( | |
ChatMessage(role="user", content=user_input)) | |
prompt = f"{user_input}\nLocation: {location}" | |
# Display previous messages | |
user_avatar = "images/user.png" | |
ai_avatar = "images/tall-tree-logo.png" | |
for msg in st.session_state["messages"]: | |
avatar = user_avatar if msg.role == 'user' else ai_avatar | |
with st.chat_message(msg.role, avatar=avatar): | |
st.markdown(msg.content) | |
# Chat interface | |
if prompt: | |
# Add all previous messages to memory | |
for human, ai in st.session_state["history"]: | |
memory.chat_memory.add_user_message(HumanMessage(content=human)) | |
memory.chat_memory.add_ai_message(AIMessage(content=ai)) | |
# render the assistant's response | |
with st.chat_message("assistant", avatar=ai_avatar): | |
message_placeholder = st.empty() | |
# If there is a message not None, add it to the memory | |
try: | |
partial_message = "" | |
with st.spinner(" "): | |
for chunk in chain.stream({"message": prompt}): | |
partial_message += chunk | |
message_placeholder.markdown(partial_message + "|") | |
except openai.BadRequestError: | |
st.warning(openai_api_error_message, icon="π") | |
st.stop() | |
except Exception as e: | |
st.warning(base_error_message.format(e), icon="π") | |
st.stop() | |
message_placeholder.markdown(partial_message) | |
# Add the full response to the history | |
st.session_state["history"].append((prompt, partial_message)) | |
# Add AI message to memory after the response is generated | |
memory.chat_memory.add_ai_message(AIMessage(content=partial_message)) | |
# add the full response to the message history | |
st.session_state["messages"].append(ChatMessage( | |
role="assistant", content=partial_message)) | |