|
import streamlit as st |
|
import os |
|
from streamlit_chat import message |
|
from streamlit_extras.colored_header import colored_header |
|
from streamlit_extras.add_vertical_space import add_vertical_space |
|
from streamlit_mic_recorder import speech_to_text |
|
from model_pipeline import ModelPipeLine |
|
from q_learning_chatbot import QLearningChatbot |
|
|
|
from gtts import gTTS |
|
from io import BytesIO |
|
st.set_page_config(page_title="PeacePal") |
|
|
|
image_path = os.path.join('images', 'sidebar.jpg') |
|
st.sidebar.image(image_path, use_column_width=True) |
|
|
|
st.title('PeacePal 🌱') |
|
|
|
mdl = ModelPipeLine() |
|
|
|
retriever = mdl.retriever |
|
|
|
final_chain = mdl.create_final_chain() |
|
|
|
|
|
states = [ |
|
"Negative", |
|
"Moderately Negative", |
|
"Neutral", |
|
"Moderately Positive", |
|
"Positive", |
|
] |
|
|
|
|
|
chatbot = QLearningChatbot(states) |
|
|
|
|
|
def display_q_table(q_values, states): |
|
q_table_dict = {"State": states} |
|
q_table_df = pd.DataFrame(q_table_dict) |
|
return q_table_df |
|
|
|
def text_to_speech(text): |
|
|
|
tts = gTTS(text=text, lang="en") |
|
|
|
fp = BytesIO() |
|
tts.write_to_fp(fp) |
|
return fp |
|
|
|
|
|
def speech_recognition_callback(): |
|
|
|
if st.session_state.my_stt_output is None: |
|
st.session_state.p01_error_message = "Please record your response again." |
|
return |
|
|
|
|
|
st.session_state.p01_error_message = None |
|
|
|
|
|
st.session_state.speech_input = st.session_state.my_stt_output |
|
|
|
|
|
if 'generated' not in st.session_state: |
|
st.session_state['generated'] = ["I'm your Mental health Assistant, How may I help you?"] |
|
|
|
if 'past' not in st.session_state: |
|
st.session_state['past'] = ['Hi!'] |
|
|
|
|
|
if "entered_text" not in st.session_state: |
|
st.session_state.entered_text = [] |
|
if "entered_mood" not in st.session_state: |
|
st.session_state.entered_mood = [] |
|
if "messages" not in st.session_state: |
|
st.session_state.messages = [] |
|
if "user_sentiment" not in st.session_state: |
|
st.session_state.user_sentiment = "Neutral" |
|
if "mood_trend" not in st.session_state: |
|
st.session_state.mood_trend = "Unchanged" |
|
if "mood_trend_symbol" not in st.session_state: |
|
st.session_state.mood_trend_symbol = "" |
|
|
|
|
|
|
|
colored_header(label='', description='', color_name='blue-30') |
|
response_container = st.container() |
|
input_container = st.container() |
|
|
|
|
|
|
|
def get_text(): |
|
input_text = st.text_input("You: ", "", key="input") |
|
return input_text |
|
|
|
def generate_response(prompt): |
|
response = mdl.call_conversational_rag(prompt,final_chain) |
|
return response['answer'] |
|
|
|
|
|
|
|
input_mode = st.sidebar.radio("Select input mode:", ["Text", "Speech"]) |
|
user_message = None |
|
if input_mode == "Speech": |
|
|
|
speech_input = speech_to_text(key="my_stt", callback=speech_recognition_callback) |
|
|
|
if "speech_input" in st.session_state and st.session_state.speech_input: |
|
|
|
|
|
|
|
|
|
user_message = st.session_state.speech_input |
|
st.session_state.speech_input = None |
|
else: |
|
user_message = st.chat_input("Type your message here:") |
|
|
|
|
|
with input_container: |
|
if user_message: |
|
st.session_state.entered_text.append(user_message) |
|
|
|
st.session_state.messages.append({"role": "user", "content": user_message}) |
|
with st.chat_message("user"): |
|
st.write(user_message) |
|
|
|
with st.spinner("processing....."): |
|
response = generate_response(user_message) |
|
st.session_state.past.append(user_message) |
|
st.session_state.generated.append(response) |
|
|
|
user_sentiment = chatbot.detect_sentiment(user_message) |
|
|
|
|
|
chatbot.update_mood_history() |
|
mood_trend = chatbot.check_mood_trend() |
|
|
|
|
|
if user_sentiment in ["Positive", "Moderately Positive"]: |
|
if mood_trend == "increased": |
|
reward = +1 |
|
mood_trend_symbol = " ⬆️" |
|
elif mood_trend == "unchanged": |
|
reward = +0.8 |
|
mood_trend_symbol = "" |
|
else: |
|
reward = -0.2 |
|
mood_trend_symbol = " ⬇️" |
|
else: |
|
if mood_trend == "increased": |
|
reward = +1 |
|
mood_trend_symbol = " ⬆️" |
|
elif mood_trend == "unchanged": |
|
reward = -0.2 |
|
mood_trend_symbol = "" |
|
else: |
|
reward = -1 |
|
mood_trend_symbol = " ⬇️" |
|
|
|
print( |
|
f"mood_trend - sentiment - reward: {mood_trend} - {user_sentiment} - 🛑{reward}🛑" |
|
) |
|
|
|
|
|
chatbot.update_q_values( |
|
user_sentiment, reward, user_sentiment |
|
) |
|
|
|
|
|
speech_fp = text_to_speech(response) |
|
|
|
st.audio(speech_fp, format='audio/mp3') |
|
|
|
|
|
'''with response_container: |
|
if st.session_state['generated']: |
|
for i in range(len(st.session_state['generated'])): |
|
message(st.session_state['past'][i], is_user=True, key=str(i) + '_user') |
|
message(st.session_state["generated"][i], key=str(I))''' |
|
|
|
|
|
with st.sidebar.expander("Behind the Scene", expanded=section_visible): |
|
st.subheader("What AI is doing:") |
|
|
|
st.write( |
|
f"- Detected User Tone: {st.session_state.user_sentiment} ({st.session_state.mood_trend.capitalize()}{st.session_state.mood_trend_symbol})" |
|
) |
|
|
|
|
|
st.dataframe(display_q_table(chatbot.q_values, states, actions)) |
|
st.write("-----------------------") |
|
st.write( |
|
f"- Above q-table is continuously updated after each interaction with the user. If the user's mood increases, AI gets a reward. Else, AI gets a punishment." |
|
) |
|
st.write(f"- Question retrieved from: {selected_retriever_option}") |
|
st.write( |
|
f"- If the user feels negative, moderately negative, or neutral, at the end of the AI response, it adds a mental health condition related question. The question is retrieved from DB. The categories of questions are limited to Depression, Anxiety, ADHD, Social Media Addiction, Social Isolation, and Cyberbullying which are most associated with FOMO related to excessive social media usage." |
|
) |