Spaces:
Sleeping
Sleeping
from langchain.chains import ConversationalRetrievalChain | |
from langchain.memory import ConversationBufferMemory | |
from langchain.prompts import PromptTemplate | |
from langchain_community.embeddings import HuggingFaceEmbeddings | |
from langchain_community.vectorstores import FAISS | |
from langchain.llms.base import LLM | |
from groq import Groq | |
from typing import Any, List, Optional, Dict | |
from pydantic import Field, BaseModel | |
import os | |
class GroqLLM(LLM, BaseModel): | |
groq_api_key: str = Field(..., description="Groq API Key") | |
model_name: str = Field(default="llama-3.3-70b-versatile", description="Model name to use") | |
client: Optional[Any] = None | |
def __init__(self, **data): | |
super().__init__(**data) | |
self.client = Groq(api_key=self.groq_api_key) | |
def _llm_type(self) -> str: | |
return "groq" | |
def _call(self, prompt: str, stop: Optional[List[str]] = None, **kwargs: Any) -> str: | |
completion = self.client.chat.completions.create( | |
messages=[{"role": "user", "content": prompt}], | |
model=self.model_name, | |
**kwargs | |
) | |
return completion.choices[0].message.content | |
def _identifying_params(self) -> Dict[str, Any]: | |
"""Get the identifying parameters.""" | |
return { | |
"model_name": self.model_name | |
} | |
class AutismResearchBot: | |
def __init__(self, groq_api_key: str, index_path: str = "index.faiss"): | |
# Initialize the Groq LLM | |
self.llm = GroqLLM( | |
groq_api_key=groq_api_key, | |
model_name="llama-3.3-70b-versatile" # You can adjust the model as needed | |
) | |
# Load the FAISS index | |
self.embeddings = HuggingFaceEmbeddings( | |
model_name="pritamdeka/S-PubMedBert-MS-MARCO", | |
model_kwargs={'device': 'cpu'} | |
) | |
self.db = FAISS.load_local("./", self.embeddings, allow_dangerous_deserialization = True) | |
# Initialize memory | |
self.memory = ConversationBufferMemory( | |
memory_key="chat_history", | |
return_messages=True, | |
output_key = "answer" | |
) | |
# Create the RAG chain | |
self.qa_chain = self._create_qa_chain() | |
def _create_qa_chain(self): | |
# Define the prompt template | |
template = """You are an expert AI assistant specialized in autism research and diagnostics. You have access to a database of scientific papers, research documents, and diagnostic tools about autism. Use this knowledge to conduct a structured assessment and provide evidence-based therapy recommendations. | |
Context from scientific papers (use these details only for final therapy recommendations): | |
{context} | |
Chat History: | |
{chat_history} | |
Objective: | |
- Gather demographic information | |
- Present autism types for initial self-identification | |
- Conduct detailed assessment through naturalistic conversation | |
- Provide evidence-based therapy recommendations | |
Instructions: | |
1. Begin with collecting age and gender | |
2. Present main types of autism with brief descriptions | |
3. Ask targeted questions with relatable examples | |
4. Maintain a conversational, empathetic tone | |
5. Conclude with personalized therapy recommendations | |
Initial Introduction: | |
"Hello, I am an AI assistant specialized in autism research and diagnostics. To provide you with the most appropriate guidance, I'll need to gather some information. Let's start with some basic details: | |
1. Could you share the age and gender of the person seeking assessment? | |
Once you provide these details, I'll share some common types of autism spectrum conditions, and we can discuss which ones seem most relevant to your experience." | |
After receiving demographic information, present autism types: | |
"Thank you. There are several types of autism spectrum conditions. Please let me know which of these seems most relevant to your situation: | |
1. Social Communication Challenges | |
Example: Difficulty maintaining conversations, understanding social cues | |
2. Repetitive Behavior Patterns | |
Example: Strong adherence to routines, specific intense interests | |
3. Sensory Processing Differences | |
Example: Sensitivity to sounds, lights, or textures | |
4. Language Development Variations | |
Example: Delayed speech, unique communication patterns | |
5. Executive Function Challenges | |
Example: Difficulty with planning, organizing, and transitioning between tasks | |
Which of these patterns feels most familiar to your experience?" | |
Follow-up Questions Format: | |
"I understand you identify most with [selected type]. Let me ask you about some specific experiences: | |
[Question with example] | |
For instance: When you're in a social situation, do you find yourself [specific example from daily life]?" | |
Continue natural conversation flow with examples for each question: | |
- Include real-life scenarios | |
- Relate questions to age-appropriate situations | |
- Provide clear, concrete examples | |
- Allow for open-ended responses | |
Final Assessment and Therapy Recommendations: | |
"Based on our detailed discussion and the patterns you've described, I can now share some evidence-based therapy recommendations tailored to your specific needs..." | |
Question: | |
{question} | |
Answer:""" | |
PROMPT = PromptTemplate( | |
template=template, | |
input_variables=["context", "chat_history", "question"] | |
) | |
# Create the chain | |
chain = ConversationalRetrievalChain.from_llm( | |
llm=self.llm, | |
chain_type="stuff", | |
retriever=self.db.as_retriever( | |
search_type="similarity", | |
search_kwargs={"k": 3} | |
), | |
memory=self.memory, | |
combine_docs_chain_kwargs={ | |
"prompt": PROMPT | |
}, | |
# verbose = True, | |
return_source_documents=True | |
) | |
return chain | |
def answer_question(self, question: str): | |
""" | |
Process a question and return the answer along with source documents | |
""" | |
result = self.qa_chain({"question": question}) | |
# Extract answer and sources | |
answer = result['answer'] | |
sources = result['source_documents'] | |
# Format sources for reference | |
source_info = [] | |
for doc in sources: | |
source_info.append({ | |
'content': doc.page_content[:200] + "...", | |
'metadata': doc.metadata | |
}) | |
return { | |
'answer': answer, | |
'sources': source_info | |
} |