|
import streamlit as st |
|
import os |
|
from streamlit_chat import message |
|
from streamlit_extras.colored_header import colored_header |
|
from streamlit_extras.add_vertical_space import add_vertical_space |
|
from streamlit_mic_recorder import speech_to_text |
|
from model_pipeline import ModelPipeLine |
|
from q_learning_chatbot import QLearningChatbot |
|
from retriever import create_vectorstore |
|
|
|
from gtts import gTTS |
|
from io import BytesIO |
|
st.set_page_config(page_title="PeacePal") |
|
|
|
image_path = os.path.join('images', 'sidebar.jpg') |
|
st.sidebar.image(image_path, use_column_width=True) |
|
|
|
st.title('PeacePal 🌱') |
|
|
|
mdl = ModelPipeLine() |
|
|
|
retriever = mdl.retriever |
|
|
|
final_chain = mdl.create_final_chain() |
|
|
|
|
|
states = [ |
|
"Negative", |
|
"Moderately Negative", |
|
"Neutral", |
|
"Moderately Positive", |
|
"Positive", |
|
] |
|
|
|
|
|
chatbot = QLearningChatbot(states) |
|
|
|
|
|
def display_q_table(q_values, states): |
|
q_table_dict = {"State": states} |
|
q_table_df = pd.DataFrame(q_table_dict) |
|
return q_table_df |
|
|
|
def text_to_speech(text): |
|
|
|
tts = gTTS(text=text, lang="en") |
|
|
|
fp = BytesIO() |
|
tts.write_to_fp(fp) |
|
return fp |
|
|
|
|
|
def speech_recognition_callback(): |
|
|
|
if st.session_state.my_stt_output is None: |
|
st.session_state.p01_error_message = "Please record your response again." |
|
return |
|
|
|
|
|
st.session_state.p01_error_message = None |
|
|
|
|
|
st.session_state.speech_input = st.session_state.my_stt_output |
|
|
|
|
|
if 'generated' not in st.session_state: |
|
st.session_state['generated'] = ["I'm your Mental health Assistant, How may I help you?"] |
|
|
|
if 'past' not in st.session_state: |
|
st.session_state['past'] = ['Hi!'] |
|
|
|
|
|
if "entered_text" not in st.session_state: |
|
st.session_state.entered_text = [] |
|
if "entered_mood" not in st.session_state: |
|
st.session_state.entered_mood = [] |
|
if "messages" not in st.session_state: |
|
st.session_state.messages = [] |
|
if "user_sentiment" not in st.session_state: |
|
st.session_state.user_sentiment = "Neutral" |
|
if "mood_trend" not in st.session_state: |
|
st.session_state.mood_trend = "Unchanged" |
|
if "mood_trend_symbol" not in st.session_state: |
|
st.session_state.mood_trend_symbol = "" |
|
if "show_question" not in st.session_state: |
|
st.session_state.show_question = False |
|
if "asked_questions" not in st.session_state: |
|
st.session_state.asked_questions = [] |
|
|
|
|
|
|
|
colored_header(label='', description='', color_name='blue-30') |
|
response_container = st.container() |
|
input_container = st.container() |
|
|
|
|
|
|
|
def get_text(): |
|
input_text = st.text_input("You: ", "", key="input") |
|
return input_text |
|
|
|
def generate_response(prompt): |
|
response = mdl.call_conversational_rag(prompt,final_chain) |
|
return response['answer'] |
|
|
|
|
|
|
|
with input_container: |
|
|
|
input_mode = st.radio("Select input mode:", ["Text", "Speech"]) |
|
|
|
if input_mode == "Speech": |
|
|
|
speech_input = speech_to_text( |
|
key='my_stt', |
|
callback=speech_recognition_callback |
|
) |
|
|
|
|
|
if 'speech_input' in st.session_state and st.session_state.speech_input: |
|
|
|
st.text(f"Speech Input: {st.session_state.speech_input}") |
|
|
|
|
|
query = st.session_state.speech_input |
|
with st.spinner("processing....."): |
|
response = generate_response(query) |
|
st.session_state.past.append(query) |
|
st.session_state.generated.append(response) |
|
|
|
user_sentiment = chatbot.detect_sentiment(query) |
|
|
|
|
|
if user_sentiment in ["Negative", "Moderately Negative", "Neutral"]: |
|
question = retriever.get_response( |
|
user_message |
|
) |
|
st.session_state.asked_questions.append(question) |
|
show_question = True |
|
else: |
|
show_question = False |
|
question = "" |
|
|
|
|
|
chatbot.update_mood_history() |
|
mood_trend = chatbot.check_mood_trend() |
|
|
|
|
|
if user_sentiment in ["Positive", "Moderately Positive"]: |
|
if mood_trend == "increased": |
|
reward = +1 |
|
mood_trend_symbol = " ⬆️" |
|
elif mood_trend == "unchanged": |
|
reward = +0.8 |
|
mood_trend_symbol = "" |
|
else: |
|
reward = -0.2 |
|
mood_trend_symbol = " ⬇️" |
|
else: |
|
if mood_trend == "increased": |
|
reward = +1 |
|
mood_trend_symbol = " ⬆️" |
|
elif mood_trend == "unchanged": |
|
reward = -0.2 |
|
mood_trend_symbol = "" |
|
else: |
|
reward = -1 |
|
mood_trend_symbol = " ⬇️" |
|
|
|
print( |
|
f"mood_trend - sentiment - reward: {mood_trend} - {user_sentiment} - 🛑{reward}🛑" |
|
) |
|
|
|
|
|
chatbot.update_q_values( |
|
user_sentiment, reward, user_sentiment |
|
) |
|
|
|
|
|
speech_fp = text_to_speech(response) |
|
|
|
st.audio(speech_fp, format='audio/mp3') |
|
|
|
else: |
|
|
|
query = st.text_input("Query: ", key="input") |
|
|
|
|
|
if query: |
|
with st.spinner("typing....."): |
|
response = generate_response(query) |
|
st.session_state.past.append(query) |
|
st.session_state.generated.append(response) |
|
|
|
user_sentiment = chatbot.detect_sentiment(query) |
|
|
|
|
|
if user_sentiment in ["Negative", "Moderately Negative", "Neutral"]: |
|
question = retriever.get_response( |
|
user_message |
|
) |
|
st.session_state.asked_questions.append(question) |
|
show_question = True |
|
else: |
|
show_question = False |
|
question = "" |
|
|
|
speech_fp = text_to_speech(response) |
|
|
|
st.audio(speech_fp, format='audio/mp3') |
|
|
|
|
|
chatbot.update_mood_history() |
|
mood_trend = chatbot.check_mood_trend() |
|
|
|
|
|
if user_sentiment in ["Positive", "Moderately Positive"]: |
|
if mood_trend == "increased": |
|
reward = +1 |
|
mood_trend_symbol = " ⬆️" |
|
elif mood_trend == "unchanged": |
|
reward = +0.8 |
|
mood_trend_symbol = "" |
|
else: |
|
reward = -0.2 |
|
mood_trend_symbol = " ⬇️" |
|
else: |
|
if mood_trend == "increased": |
|
reward = +1 |
|
mood_trend_symbol = " ⬆️" |
|
elif mood_trend == "unchanged": |
|
reward = -0.2 |
|
mood_trend_symbol = "" |
|
else: |
|
reward = -1 |
|
mood_trend_symbol = " ⬇️" |
|
|
|
print( |
|
f"mood_trend - sentiment - reward: {mood_trend} - {user_sentiment} - 🛑{reward}🛑" |
|
) |
|
|
|
|
|
chatbot.update_q_values( |
|
user_sentiment, reward, user_sentiment |
|
) |
|
|
|
|
|
speech_fp = text_to_speech(response) |
|
|
|
st.audio(speech_fp, format='audio/mp3') |
|
|
|
|
|
with response_container: |
|
if st.session_state['generated']: |
|
for i in range(len(st.session_state['generated'])): |
|
message(st.session_state['past'][i], is_user=True, key=str(i) + '_user') |
|
message(st.session_state["generated"][i], key=str(i)) |
|
|
|
|