File size: 6,610 Bytes
f411a8f f6e0c20 93407fc f6e0c20 93407fc f411a8f f6e0c20 93407fc f6e0c20 d08a783 f6e0c20 d08a783 93407fc f411a8f 01cd8f4 93407fc 01cd8f4 f6e0c20 5578ac6 01cd8f4 f6e0c20 01cd8f4 93407fc f6e0c20 93407fc f6e0c20 93407fc f6e0c20 d01f9f3 f6e0c20 93407fc f6e0c20 93407fc f6e0c20 93407fc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
import os
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from sentence_transformers import SentenceTransformer
import faiss
import pandas as pd
import os
import logging
from groq import Groq
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
# β
Set a writable cache directory
os.environ["HF_HOME"] = "/tmp/huggingface"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface"
os.environ["SENTENCE_TRANSFORMERS_HOME"] = "/tmp/huggingface"
# β
Initialize FastAPI
app = FastAPI()
# β
Securely Fetch API Key
GROQ_API_KEY = os.getenv("GROQ_API_KEY") # β
FIXED
if not GROQ_API_KEY:
raise ValueError("GROQ_API_KEY is missing. Set it as an environment variable.")
client = Groq(api_key=GROQ_API_KEY) # β
Ensure the API key is passed correctly
# β
Load AI Models (Now uses /tmp/huggingface as cache)
similarity_model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2", cache_folder="/tmp/huggingface")
embedding_model = SentenceTransformer("all-MiniLM-L6-v2", cache_folder="/tmp/huggingface")
summarization_model = AutoModelForSeq2SeqLM.from_pretrained("google/long-t5-tglobal-base", cache_dir="/tmp/huggingface")
summarization_tokenizer = AutoTokenizer.from_pretrained("google/long-t5-tglobal-base", cache_dir="/tmp/huggingface")
# β
Check if files exist before loading
print("π Available Files:", os.listdir(".")) # This will log available files
# β
Load datasets with error handling
try:
recommendations_df = pd.read_csv("treatment_recommendations .csv")
questions_df = pd.read_csv("symptom_questions.csv")
except FileNotFoundError as e:
logging.error(f"β Missing dataset file: {e}")
raise HTTPException(status_code=500, detail=f"Dataset file not found: {str(e)}")
# β
FAISS Index for disorder detection
treatment_embeddings = similarity_model.encode(recommendations_df["Disorder"].tolist(), convert_to_numpy=True)
index = faiss.IndexFlatIP(treatment_embeddings.shape[1])
index.add(treatment_embeddings)
# β
FAISS Index for Question Retrieval
question_embeddings = embedding_model.encode(questions_df["Questions"].tolist(), convert_to_numpy=True)
question_index = faiss.IndexFlatL2(question_embeddings.shape[1])
question_index.add(question_embeddings)
# β
Request Model
class ChatRequest(BaseModel):
message: str
class SummaryRequest(BaseModel):
chat_history: list # List of messages
# β
Retrieve the most relevant question
def retrieve_questions(user_input):
"""Retrieve the most relevant individual diagnostic question using FAISS."""
input_embedding = embedding_model.encode([user_input], convert_to_numpy=True)
_, indices = question_index.search(input_embedding, 1) # β
Retrieve only 1 question
if indices[0][0] == -1:
return "I'm sorry, I couldn't find a relevant question."
# β
Extract only the first meaningful question
question_block = questions_df["Questions"].iloc[indices[0][0]]
split_questions = question_block.split(", ")
best_question = split_questions[0] if split_questions else question_block # β
Select the first clear question
return best_question # β
Return a single question as a string
# β
Groq API for rephrasing
def generate_empathetic_response(user_input, retrieved_question):
"""Use Groq API (LLaMA-3) to generate one empathetic response."""
# β
Improved Prompt: Only One Question
prompt = f"""
The user said: "{user_input}"
Relevant Question:
- {retrieved_question}
You are an empathetic AI psychiatrist. Rephrase this question naturally in a human-like way.
Acknowledge the user's emotions before asking the question.
Example format:
- "I understand that anxiety can be overwhelming. Can you tell me more about when you started feeling this way?"
Generate only one empathetic response.
"""
try:
chat_completion = client.chat.completions.create(
messages=[
{"role": "system", "content": "You are a helpful, empathetic AI psychiatrist."},
{"role": "user", "content": prompt}
],
model="llama-3.3-70b-versatile", # β
Use Groq's LLaMA-3 Model
temperature=0.8,
top_p=0.9
)
return chat_completion.choices[0].message.content # β
Return only one response
except Exception as e:
logging.error(f"Groq API error: {e}")
return "I'm sorry, I couldn't process your request."
# β
API Endpoint: Get Empathetic Questions (Hybrid RAG)
@app.post("/get_questions")
def get_recommended_questions(request: ChatRequest):
"""Retrieve the most relevant diagnostic question and make it more empathetic using Groq API."""
retrieved_question = retrieve_questions(request.message)
empathetic_response = generate_empathetic_response(request.message, retrieved_question)
return {"question": empathetic_response}
# β
API Endpoint: Summarize Chat
@app.post("/summarize_chat")
def summarize_chat(request: SummaryRequest):
"""Summarize full chat session at the end."""
chat_text = " ".join(request.chat_history)
inputs = summarization_tokenizer("summarize: " + chat_text, return_tensors="pt", max_length=4096, truncation=True)
summary_ids = summarization_model.generate(inputs.input_ids, max_length=500, num_beams=4, early_stopping=True)
summary = summarization_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
return {"summary": summary}
# β
API Endpoint: Detect Disorders
@app.post("/detect_disorders")
def detect_disorders(request: SummaryRequest):
"""Detect psychiatric disorders from full chat history at the end."""
full_chat_text = " ".join(request.chat_history)
text_embedding = similarity_model.encode([full_chat_text], convert_to_numpy=True)
distances, indices = index.search(text_embedding, 3)
if indices[0][0] == -1:
return {"disorders": "No matching disorder found."}
disorders = [recommendations_df["Disorder"].iloc[i] for i in indices[0]]
return {"disorders": disorders}
# β
API Endpoint: Get Treatment Recommendations
@app.post("/get_treatment")
def get_treatment(request: SummaryRequest):
"""Retrieve treatment recommendations based on detected disorders."""
detected_disorders = detect_disorders(request)["disorders"]
treatments = {
disorder: recommendations_df[recommendations_df["Disorder"] == disorder]["Treatment Recommendation"].values[0]
for disorder in detected_disorders if disorder in recommendations_df["Disorder"].values
}
return {"treatments": treatments}
|