Spaces:
Sleeping
Sleeping
import streamlit as st | |
import time | |
import os | |
from autism_chatbot import * | |
class StreamHandler: | |
def __init__(self, placeholder): | |
self.text = "" | |
self.text_container = placeholder | |
def append_text(self, text: str) -> None: | |
self.text += text | |
self.text_container.markdown(self.text) | |
class StreamingGroqLLM(GroqLLM): | |
stream_handler: Any = Field(None, description="Stream handler for real-time output") | |
def _call(self, prompt: str, stop: Optional[List[str]] = None, **kwargs: Any) -> str: | |
completion = self.client.chat.completions.create( | |
messages=[{"role": "user", "content": prompt}], | |
model=self.model_name, | |
stream=True, | |
**kwargs | |
) | |
collected_chunks = [] | |
collected_messages = [] | |
for chunk in completion: | |
chunk_message = chunk.choices[0].delta.content | |
if chunk_message is not None: | |
collected_chunks.append(chunk_message) | |
collected_messages.append(chunk_message) | |
if self.stream_handler: | |
self.stream_handler.append_text(chunk_message) | |
time.sleep(0.05) | |
return ''.join(collected_messages) | |
class StreamingAutismResearchBot(AutismResearchBot): | |
def __init__(self, groq_api_key: str, stream_handler: StreamHandler, index_path: str = "index.faiss"): | |
self.llm = StreamingGroqLLM( | |
groq_api_key=groq_api_key, | |
model_name="llama-3.3-70b-versatile", | |
stream_handler=stream_handler | |
) | |
self.embeddings = HuggingFaceEmbeddings( | |
model_name="pritamdeka/S-PubMedBert-MS-MARCO", | |
model_kwargs={'device': 'cpu'} | |
) | |
self.db = FAISS.load_local("./", self.embeddings, allow_dangerous_deserialization=True) | |
self.memory = ConversationBufferMemory( | |
memory_key="chat_history", | |
return_messages=True, | |
output_key="answer" | |
) | |
self.qa_chain = self._create_qa_chain() | |
def main(): | |
# Page configuration | |
st.set_page_config( | |
page_title="Autism Research Assistant", | |
page_icon="🧩", | |
layout="wide" | |
) | |
# Add custom CSS with background color | |
st.markdown(""" | |
<style> | |
/* Main background color */ | |
.stApp { | |
background-color: #0000ff; /* Light blue background */ | |
max-width: 1200px; | |
margin: 0 auto; | |
} | |
/* Style for markdown text */ | |
.stMarkdown { | |
font-size: 16px; | |
} | |
/* Chat message styling */ | |
.chat-message { | |
padding: 1rem; | |
border-radius: 0.5rem; | |
margin-bottom: 1rem; | |
background-color: white; /* White background for messages */ | |
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1); | |
} | |
/* Timestamp styling */ | |
.timestamp { | |
font-size: 0.8em; | |
color: #666; | |
} | |
/* Custom styling for chat containers */ | |
.stChatMessage { | |
background-color: white; | |
border-radius: 10px; | |
padding: 10px; | |
margin: 10px 0; | |
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1); | |
} | |
/* Input box styling */ | |
.stTextInput>div>div>input { | |
background-color: white; | |
border-radius: 20px; | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
# Header | |
st.title("🧩 Autism Research Assistant") | |
st.markdown(""" | |
Welcome to your AI-powered autism research assistant. I'm here to provide evidence-based | |
assessments and therapy recommendations based on scientific research. | |
""") | |
# Initialize session state | |
if 'messages' not in st.session_state: | |
st.session_state.messages = [ | |
{"role": "assistant", "content": "Hello! I'm your autism research assistant. How can I help you today?"} | |
] | |
# Initialize bot | |
if 'bot' not in st.session_state: | |
st.session_state.stream_container = None | |
st.session_state.bot = None | |
# Display chat messages | |
for message in st.session_state.messages: | |
with st.chat_message(message["role"]): | |
st.write(f"{message['content']}") | |
st.caption(f"{time.strftime('%I:%M %p')}") | |
# Chat input | |
if prompt := st.chat_input("Type your message here..."): | |
# Display user message | |
with st.chat_message("user"): | |
st.write(prompt) | |
st.caption(f"{time.strftime('%I:%M %p')}") | |
# Add to session state | |
st.session_state.messages.append({"role": "user", "content": prompt}) | |
# Create a new chat message container for the assistant's response | |
assistant_message = st.chat_message("assistant") | |
with assistant_message: | |
# Create a placeholder for the streaming text | |
stream_placeholder = st.empty() | |
# Initialize the bot with the new stream handler if not already initialized | |
if st.session_state.bot is None: | |
stream_handler = StreamHandler(stream_placeholder) | |
st.session_state.bot = StreamingAutismResearchBot( | |
groq_api_key = os.environ.get("GROQ_API_KEY"), | |
stream_handler=stream_handler, | |
) | |
else: | |
# Update the stream handler with the new placeholder | |
st.session_state.bot.llm.stream_handler.text = "" | |
st.session_state.bot.llm.stream_handler.text_container = stream_placeholder | |
# Generate response | |
response = st.session_state.bot.answer_question(prompt) | |
# Clear the streaming placeholder and display the final message | |
stream_placeholder.empty() | |
st.write(response['answer']) | |
st.caption(f"{time.strftime('%I:%M %p')}") | |
# Add bot response to session state | |
st.session_state.messages.append({"role": "assistant", "content": response['answer']}) | |
if __name__ == "__main__": | |
main() |