import os from fastapi import FastAPI, HTTPException from pydantic import BaseModel from sentence_transformers import SentenceTransformer import faiss import pandas as pd import os import logging from groq import Groq from transformers import AutoModelForSeq2SeqLM, AutoTokenizer # ✅ Set a writable cache directory os.environ["HF_HOME"] = "/tmp/huggingface" os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface" os.environ["SENTENCE_TRANSFORMERS_HOME"] = "/tmp/huggingface" # ✅ Initialize FastAPI app = FastAPI() # ✅ Securely Fetch API Key GROQ_API_KEY = os.getenv("GROQ_API_KEY") # ✅ FIXED if not GROQ_API_KEY: raise ValueError("GROQ_API_KEY is missing. Set it as an environment variable.") client = Groq(api_key=GROQ_API_KEY) # ✅ Ensure the API key is passed correctly # ✅ Load AI Models (Now uses /tmp/huggingface as cache) similarity_model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2", cache_folder="/tmp/huggingface") embedding_model = SentenceTransformer("all-MiniLM-L6-v2", cache_folder="/tmp/huggingface") summarization_model = AutoModelForSeq2SeqLM.from_pretrained("google/long-t5-tglobal-base", cache_dir="/tmp/huggingface") summarization_tokenizer = AutoTokenizer.from_pretrained("google/long-t5-tglobal-base", cache_dir="/tmp/huggingface") # ✅ Check if files exist before loading print("🔍 Available Files:", os.listdir(".")) # This will log available files # ✅ Load datasets with error handling try: recommendations_df = pd.read_csv("treatment_recommendations .csv") questions_df = pd.read_csv("symptom_questions.csv") except FileNotFoundError as e: logging.error(f"❌ Missing dataset file: {e}") raise HTTPException(status_code=500, detail=f"Dataset file not found: {str(e)}") # ✅ FAISS Index for disorder detection treatment_embeddings = similarity_model.encode(recommendations_df["Disorder"].tolist(), convert_to_numpy=True) index = faiss.IndexFlatIP(treatment_embeddings.shape[1]) index.add(treatment_embeddings) # ✅ FAISS Index for Question Retrieval question_embeddings = embedding_model.encode(questions_df["Questions"].tolist(), convert_to_numpy=True) question_index = faiss.IndexFlatL2(question_embeddings.shape[1]) question_index.add(question_embeddings) # ✅ Request Model class ChatRequest(BaseModel): message: str class SummaryRequest(BaseModel): chat_history: list # List of messages # ✅ Retrieve the most relevant question def retrieve_questions(user_input): """Retrieve the most relevant individual diagnostic question using FAISS.""" input_embedding = embedding_model.encode([user_input], convert_to_numpy=True) _, indices = question_index.search(input_embedding, 1) # ✅ Retrieve only 1 question if indices[0][0] == -1: return "I'm sorry, I couldn't find a relevant question." # ✅ Extract only the first meaningful question question_block = questions_df["Questions"].iloc[indices[0][0]] split_questions = question_block.split(", ") best_question = split_questions[0] if split_questions else question_block # ✅ Select the first clear question return best_question # ✅ Return a single question as a string # ✅ Groq API for rephrasing def generate_empathetic_response(user_input, retrieved_question): """Use Groq API (LLaMA-3) to generate one empathetic response.""" # ✅ Improved Prompt: Only One Question prompt = f""" The user said: "{user_input}" Relevant Question: - {retrieved_question} You are an empathetic AI psychiatrist. Rephrase this question naturally in a human-like way. Acknowledge the user's emotions before asking the question. Example format: - "I understand that anxiety can be overwhelming. Can you tell me more about when you started feeling this way?" Generate only one empathetic response. """ try: chat_completion = client.chat.completions.create( messages=[ {"role": "system", "content": "You are a helpful, empathetic AI psychiatrist."}, {"role": "user", "content": prompt} ], model="llama-3.3-70b-versatile", # ✅ Use Groq's LLaMA-3 Model temperature=0.8, top_p=0.9 ) return chat_completion.choices[0].message.content # ✅ Return only one response except Exception as e: logging.error(f"Groq API error: {e}") return "I'm sorry, I couldn't process your request." # ✅ API Endpoint: Get Empathetic Questions (Hybrid RAG) @app.post("/get_questions") def get_recommended_questions(request: ChatRequest): """Retrieve the most relevant diagnostic question and make it more empathetic using Groq API.""" retrieved_question = retrieve_questions(request.message) empathetic_response = generate_empathetic_response(request.message, retrieved_question) return {"question": empathetic_response} # ✅ API Endpoint: Summarize Chat @app.post("/summarize_chat") def summarize_chat(request: SummaryRequest): """Summarize full chat session at the end.""" chat_text = " ".join(request.chat_history) inputs = summarization_tokenizer("summarize: " + chat_text, return_tensors="pt", max_length=4096, truncation=True) summary_ids = summarization_model.generate(inputs.input_ids, max_length=500, num_beams=4, early_stopping=True) summary = summarization_tokenizer.decode(summary_ids[0], skip_special_tokens=True) return {"summary": summary} # ✅ API Endpoint: Detect Disorders @app.post("/detect_disorders") def detect_disorders(request: SummaryRequest): """Detect psychiatric disorders from full chat history at the end.""" full_chat_text = " ".join(request.chat_history) text_embedding = similarity_model.encode([full_chat_text], convert_to_numpy=True) distances, indices = index.search(text_embedding, 3) if indices[0][0] == -1: return {"disorders": "No matching disorder found."} disorders = [recommendations_df["Disorder"].iloc[i] for i in indices[0]] return {"disorders": disorders} # ✅ API Endpoint: Get Treatment Recommendations @app.post("/get_treatment") def get_treatment(request: SummaryRequest): """Retrieve treatment recommendations based on detected disorders.""" detected_disorders = detect_disorders(request)["disorders"] treatments = { disorder: recommendations_df[recommendations_df["Disorder"] == disorder]["Treatment Recommendation"].values[0] for disorder in detected_disorders if disorder in recommendations_df["Disorder"].values } return {"treatments": treatments}