ASD / app.py
Muhammadbilal10101's picture
Update app.py
ae4a55d verified
import streamlit as st
import time
import os
from autism_chatbot import *
class StreamHandler:
def __init__(self, placeholder):
self.text = ""
self.text_container = placeholder
def append_text(self, text: str) -> None:
self.text += text
self.text_container.markdown(self.text)
class StreamingGroqLLM(GroqLLM):
stream_handler: Any = Field(None, description="Stream handler for real-time output")
def _call(self, prompt: str, stop: Optional[List[str]] = None, **kwargs: Any) -> str:
completion = self.client.chat.completions.create(
messages=[{"role": "user", "content": prompt}],
model=self.model_name,
stream=True,
**kwargs
)
collected_chunks = []
collected_messages = []
for chunk in completion:
chunk_message = chunk.choices[0].delta.content
if chunk_message is not None:
collected_chunks.append(chunk_message)
collected_messages.append(chunk_message)
if self.stream_handler:
self.stream_handler.append_text(chunk_message)
time.sleep(0.05)
return ''.join(collected_messages)
class StreamingAutismResearchBot(AutismResearchBot):
def __init__(self, groq_api_key: str, stream_handler: StreamHandler, index_path: str = "index.faiss"):
self.llm = StreamingGroqLLM(
groq_api_key=groq_api_key,
model_name="llama-3.3-70b-versatile",
stream_handler=stream_handler
)
self.embeddings = HuggingFaceEmbeddings(
model_name="pritamdeka/S-PubMedBert-MS-MARCO",
model_kwargs={'device': 'cpu'}
)
self.db = FAISS.load_local("./", self.embeddings, allow_dangerous_deserialization=True)
self.memory = ConversationBufferMemory(
memory_key="chat_history",
return_messages=True,
output_key="answer"
)
self.qa_chain = self._create_qa_chain()
def main():
# Page configuration
st.set_page_config(
page_title="Autism Research Assistant",
page_icon="🧩",
layout="wide"
)
# Add custom CSS with background color
st.markdown("""
<style>
/* Main background color */
.stApp {
background-color: #0000ff; /* Light blue background */
max-width: 1200px;
margin: 0 auto;
}
/* Style for markdown text */
.stMarkdown {
font-size: 16px;
}
/* Chat message styling */
.chat-message {
padding: 1rem;
border-radius: 0.5rem;
margin-bottom: 1rem;
background-color: white; /* White background for messages */
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
}
/* Timestamp styling */
.timestamp {
font-size: 0.8em;
color: #666;
}
/* Custom styling for chat containers */
.stChatMessage {
background-color: white;
border-radius: 10px;
padding: 10px;
margin: 10px 0;
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
}
/* Input box styling */
.stTextInput>div>div>input {
background-color: white;
border-radius: 20px;
}
</style>
""", unsafe_allow_html=True)
# Header
st.title("🧩 Autism Research Assistant")
st.markdown("""
Welcome to your AI-powered autism research assistant. I'm here to provide evidence-based
assessments and therapy recommendations based on scientific research.
""")
# Initialize session state
if 'messages' not in st.session_state:
st.session_state.messages = [
{"role": "assistant", "content": "Hello! I'm your autism research assistant. How can I help you today?"}
]
# Initialize bot
if 'bot' not in st.session_state:
st.session_state.stream_container = None
st.session_state.bot = None
# Display chat messages
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.write(f"{message['content']}")
st.caption(f"{time.strftime('%I:%M %p')}")
# Chat input
if prompt := st.chat_input("Type your message here..."):
# Display user message
with st.chat_message("user"):
st.write(prompt)
st.caption(f"{time.strftime('%I:%M %p')}")
# Add to session state
st.session_state.messages.append({"role": "user", "content": prompt})
# Create a new chat message container for the assistant's response
assistant_message = st.chat_message("assistant")
with assistant_message:
# Create a placeholder for the streaming text
stream_placeholder = st.empty()
# Initialize the bot with the new stream handler if not already initialized
if st.session_state.bot is None:
stream_handler = StreamHandler(stream_placeholder)
st.session_state.bot = StreamingAutismResearchBot(
groq_api_key = os.environ.get("GROQ_API_KEY"),
stream_handler=stream_handler,
)
else:
# Update the stream handler with the new placeholder
st.session_state.bot.llm.stream_handler.text = ""
st.session_state.bot.llm.stream_handler.text_container = stream_placeholder
# Generate response
response = st.session_state.bot.answer_question(prompt)
# Clear the streaming placeholder and display the final message
stream_placeholder.empty()
st.write(response['answer'])
st.caption(f"{time.strftime('%I:%M %p')}")
# Add bot response to session state
st.session_state.messages.append({"role": "assistant", "content": response['answer']})
if __name__ == "__main__":
main()