mindspark121 commited on
Commit
93407fc
Β·
verified Β·
1 Parent(s): 9162ece

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +126 -0
app.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from pydantic import BaseModel
3
+ from sentence_transformers import SentenceTransformer
4
+ import faiss
5
+ import pandas as pd
6
+ import os
7
+ from groq import Groq
8
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
9
+
10
+ app = FastAPI()
11
+
12
+ # βœ… Set up Groq API key
13
+ os.environ["GROQ_API_KEY"] = "your-groq-api-key" # Replace with your actual API key
14
+ client = Groq(api_key=os.environ["GROQ_API_KEY"])
15
+
16
+ # βœ… Load AI Models
17
+ similarity_model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
18
+ embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
19
+ summarization_model = AutoModelForSeq2SeqLM.from_pretrained("google/long-t5-tglobal-base")
20
+ summarization_tokenizer = AutoTokenizer.from_pretrained("google/long-t5-tglobal-base")
21
+
22
+ # βœ… Load datasets
23
+ recommendations_df = pd.read_csv("treatment_recommendations.csv")
24
+ questions_df = pd.read_csv("symptom_questions.csv")
25
+
26
+ # βœ… FAISS Index for disorder detection
27
+ treatment_embeddings = similarity_model.encode(recommendations_df["Disorder"].tolist(), convert_to_numpy=True)
28
+ index = faiss.IndexFlatIP(treatment_embeddings.shape[1])
29
+ index.add(treatment_embeddings)
30
+
31
+ # βœ… FAISS Index for Question Retrieval
32
+ question_embeddings = embedding_model.encode(questions_df["Questions"].tolist(), convert_to_numpy=True)
33
+ question_index = faiss.IndexFlatL2(question_embeddings.shape[1])
34
+ question_index.add(question_embeddings)
35
+
36
+ # βœ… Request Model
37
+ class ChatRequest(BaseModel):
38
+ message: str
39
+
40
+ class SummaryRequest(BaseModel):
41
+ chat_history: list # List of messages
42
+
43
+ # βœ… Retrieve the most relevant question
44
+ def retrieve_questions(user_input):
45
+ """Retrieve the most relevant individual diagnostic question using FAISS."""
46
+ input_embedding = embedding_model.encode([user_input], convert_to_numpy=True)
47
+ _, indices = question_index.search(input_embedding, 1) # βœ… Retrieve only 1 question
48
+
49
+ # βœ… Extract only the first meaningful question
50
+ question_block = questions_df["Questions"].iloc[indices[0][0]]
51
+ split_questions = question_block.split(", ")
52
+ best_question = split_questions[0] # βœ… Select the first clear question
53
+
54
+ return best_question # βœ… Return a single question as a string
55
+
56
+ # βœ… Groq API for rephrasing
57
+ def generate_empathetic_response(user_input, retrieved_question):
58
+ """Use Groq API (LLaMA-3) to generate one empathetic response."""
59
+
60
+ # βœ… Improved Prompt: Only One Question
61
+ prompt = f"""
62
+ The user said: "{user_input}"
63
+
64
+ Relevant Question:
65
+ - {retrieved_question}
66
+
67
+ You are an empathetic AI psychiatrist. Rephrase this question naturally in a human-like way.
68
+ Acknowledge the user's emotions before asking the question.
69
+
70
+ Example format:
71
+ - "I understand that anxiety can be overwhelming. Can you tell me more about when you started feeling this way?"
72
+
73
+ Generate only one empathetic response.
74
+ """
75
+
76
+ chat_completion = client.chat.completions.create(
77
+ messages=[
78
+ {"role": "system", "content": "You are a helpful, empathetic AI psychiatrist."},
79
+ {"role": "user", "content": prompt}
80
+ ],
81
+ model="llama3-8b", # βœ… Use Groq's LLaMA-3 Model
82
+ temperature=0.8, # Adjust for natural variation
83
+ top_p=0.9
84
+ )
85
+
86
+ return chat_completion.choices[0].message.content # βœ… Return only one response
87
+
88
+ # βœ… API Endpoint: Get Empathetic Questions (Hybrid RAG)
89
+ @app.post("/get_questions")
90
+ def get_recommended_questions(request: ChatRequest):
91
+ """Retrieve the most relevant diagnostic question and make it more empathetic using Groq API."""
92
+ retrieved_question = retrieve_questions(request.message)
93
+ empathetic_response = generate_empathetic_response(request.message, retrieved_question)
94
+
95
+ return {"question": empathetic_response}
96
+
97
+ # βœ… API Endpoint: Summarize Chat
98
+ @app.post("/summarize_chat")
99
+ def summarize_chat(request: SummaryRequest):
100
+ """Summarize full chat session at the end."""
101
+ chat_text = " ".join(request.chat_history)
102
+ inputs = summarization_tokenizer("summarize: " + chat_text, return_tensors="pt", max_length=4096, truncation=True)
103
+ summary_ids = summarization_model.generate(inputs.input_ids, max_length=500, num_beams=4, early_stopping=True)
104
+ summary = summarization_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
105
+ return {"summary": summary}
106
+
107
+ # βœ… API Endpoint: Detect Disorders
108
+ @app.post("/detect_disorders")
109
+ def detect_disorders(request: SummaryRequest):
110
+ """Detect psychiatric disorders from full chat history at the end."""
111
+ full_chat_text = " ".join(request.chat_history)
112
+ text_embedding = similarity_model.encode([full_chat_text], convert_to_numpy=True)
113
+ distances, indices = index.search(text_embedding, 3)
114
+ disorders = [recommendations_df["Disorder"].iloc[i] for i in indices[0]]
115
+ return {"disorders": disorders}
116
+
117
+ # βœ… API Endpoint: Get Treatment Recommendations
118
+ @app.post("/get_treatment")
119
+ def get_treatment(request: SummaryRequest):
120
+ """Retrieve treatment recommendations based on detected disorders."""
121
+ detected_disorders = detect_disorders(request)["disorders"]
122
+ treatments = {
123
+ disorder: recommendations_df[recommendations_df["Disorder"] == disorder]["Treatment Recommendation"].values[0]
124
+ for disorder in detected_disorders
125
+ }
126
+ return {"treatments": treatments}