File size: 6,610 Bytes
f411a8f
f6e0c20
93407fc
 
 
 
 
f6e0c20
93407fc
 
f411a8f
 
 
 
f6e0c20
93407fc
 
f6e0c20
d08a783
f6e0c20
 
 
d08a783
 
93407fc
f411a8f
 
 
 
 
 
01cd8f4
 
93407fc
01cd8f4
f6e0c20
5578ac6
01cd8f4
f6e0c20
01cd8f4
 
93407fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f6e0c20
 
 
93407fc
 
 
f6e0c20
93407fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f6e0c20
 
 
 
 
 
d01f9f3
f6e0c20
 
 
 
 
 
 
93407fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f6e0c20
 
 
 
93407fc
 
 
 
 
 
 
 
 
 
f6e0c20
93407fc
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import os
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from sentence_transformers import SentenceTransformer
import faiss
import pandas as pd
import os
import logging
from groq import Groq
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
# βœ… Set a writable cache directory
os.environ["HF_HOME"] = "/tmp/huggingface"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface"
os.environ["SENTENCE_TRANSFORMERS_HOME"] = "/tmp/huggingface"
# βœ… Initialize FastAPI
app = FastAPI()

# βœ… Securely Fetch API Key
GROQ_API_KEY = os.getenv("GROQ_API_KEY")  # βœ… FIXED
if not GROQ_API_KEY:
    raise ValueError("GROQ_API_KEY is missing. Set it as an environment variable.")

client = Groq(api_key=GROQ_API_KEY)  # βœ… Ensure the API key is passed correctly


# βœ… Load AI Models (Now uses /tmp/huggingface as cache)
similarity_model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2", cache_folder="/tmp/huggingface")
embedding_model = SentenceTransformer("all-MiniLM-L6-v2", cache_folder="/tmp/huggingface")
summarization_model = AutoModelForSeq2SeqLM.from_pretrained("google/long-t5-tglobal-base", cache_dir="/tmp/huggingface")
summarization_tokenizer = AutoTokenizer.from_pretrained("google/long-t5-tglobal-base", cache_dir="/tmp/huggingface")

# βœ… Check if files exist before loading
print("πŸ” Available Files:", os.listdir("."))  # This will log available files

# βœ… Load datasets with error handling
try:
    recommendations_df = pd.read_csv("treatment_recommendations .csv")
    questions_df = pd.read_csv("symptom_questions.csv")
except FileNotFoundError as e:
    logging.error(f"❌ Missing dataset file: {e}")
    raise HTTPException(status_code=500, detail=f"Dataset file not found: {str(e)}")

# βœ… FAISS Index for disorder detection
treatment_embeddings = similarity_model.encode(recommendations_df["Disorder"].tolist(), convert_to_numpy=True)
index = faiss.IndexFlatIP(treatment_embeddings.shape[1])
index.add(treatment_embeddings)

# βœ… FAISS Index for Question Retrieval
question_embeddings = embedding_model.encode(questions_df["Questions"].tolist(), convert_to_numpy=True)
question_index = faiss.IndexFlatL2(question_embeddings.shape[1])
question_index.add(question_embeddings)

# βœ… Request Model
class ChatRequest(BaseModel):
    message: str

class SummaryRequest(BaseModel):
    chat_history: list  # List of messages

# βœ… Retrieve the most relevant question
def retrieve_questions(user_input):
    """Retrieve the most relevant individual diagnostic question using FAISS."""
    input_embedding = embedding_model.encode([user_input], convert_to_numpy=True)
    _, indices = question_index.search(input_embedding, 1)  # βœ… Retrieve only 1 question

    if indices[0][0] == -1:
        return "I'm sorry, I couldn't find a relevant question."

    # βœ… Extract only the first meaningful question
    question_block = questions_df["Questions"].iloc[indices[0][0]]
    split_questions = question_block.split(", ")  
    best_question = split_questions[0] if split_questions else question_block  # βœ… Select the first clear question

    return best_question  # βœ… Return a single question as a string

# βœ… Groq API for rephrasing
def generate_empathetic_response(user_input, retrieved_question):
    """Use Groq API (LLaMA-3) to generate one empathetic response."""
    
    # βœ… Improved Prompt: Only One Question
    prompt = f"""
    The user said: "{user_input}"
    
    Relevant Question:
    - {retrieved_question}

    You are an empathetic AI psychiatrist. Rephrase this question naturally in a human-like way.
    Acknowledge the user's emotions before asking the question.

    Example format:
    - "I understand that anxiety can be overwhelming. Can you tell me more about when you started feeling this way?"
    
    Generate only one empathetic response.
    """

    try:
        chat_completion = client.chat.completions.create(
            messages=[
                {"role": "system", "content": "You are a helpful, empathetic AI psychiatrist."},
                {"role": "user", "content": prompt}
            ],
            model="llama-3.3-70b-versatile",  # βœ… Use Groq's LLaMA-3 Model
            temperature=0.8,
            top_p=0.9
        )
        return chat_completion.choices[0].message.content  # βœ… Return only one response
    except Exception as e:
        logging.error(f"Groq API error: {e}")
        return "I'm sorry, I couldn't process your request."

# βœ… API Endpoint: Get Empathetic Questions (Hybrid RAG)
@app.post("/get_questions")
def get_recommended_questions(request: ChatRequest):
    """Retrieve the most relevant diagnostic question and make it more empathetic using Groq API."""
    retrieved_question = retrieve_questions(request.message)
    empathetic_response = generate_empathetic_response(request.message, retrieved_question)
    
    return {"question": empathetic_response}

# βœ… API Endpoint: Summarize Chat
@app.post("/summarize_chat")
def summarize_chat(request: SummaryRequest):
    """Summarize full chat session at the end."""
    chat_text = " ".join(request.chat_history)
    inputs = summarization_tokenizer("summarize: " + chat_text, return_tensors="pt", max_length=4096, truncation=True)
    summary_ids = summarization_model.generate(inputs.input_ids, max_length=500, num_beams=4, early_stopping=True)
    summary = summarization_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
    return {"summary": summary}

# βœ… API Endpoint: Detect Disorders
@app.post("/detect_disorders")
def detect_disorders(request: SummaryRequest):
    """Detect psychiatric disorders from full chat history at the end."""
    full_chat_text = " ".join(request.chat_history)
    text_embedding = similarity_model.encode([full_chat_text], convert_to_numpy=True)
    distances, indices = index.search(text_embedding, 3)

    if indices[0][0] == -1:
        return {"disorders": "No matching disorder found."}

    disorders = [recommendations_df["Disorder"].iloc[i] for i in indices[0]]
    return {"disorders": disorders}

# βœ… API Endpoint: Get Treatment Recommendations
@app.post("/get_treatment")
def get_treatment(request: SummaryRequest):
    """Retrieve treatment recommendations based on detected disorders."""
    detected_disorders = detect_disorders(request)["disorders"]
    treatments = {
        disorder: recommendations_df[recommendations_df["Disorder"] == disorder]["Treatment Recommendation"].values[0]
        for disorder in detected_disorders if disorder in recommendations_df["Disorder"].values
    }
    return {"treatments": treatments}