mindspark121 commited on
Commit
f6e0c20
Β·
verified Β·
1 Parent(s): 879f608

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -19
app.py CHANGED
@@ -1,17 +1,22 @@
1
- from fastapi import FastAPI
2
  from pydantic import BaseModel
3
  from sentence_transformers import SentenceTransformer
4
  import faiss
5
  import pandas as pd
6
  import os
 
7
  from groq import Groq
8
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
9
 
 
10
  app = FastAPI()
11
 
12
- # βœ… Set up Groq API key
13
- os.environ["GROQ_API_KEY"] = "your-groq-api-key" # Replace with your actual API key
14
- client = Groq(api_key=os.environ["GROQ_API_KEY"])
 
 
 
15
 
16
  # βœ… Load AI Models
17
  similarity_model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
@@ -20,8 +25,12 @@ summarization_model = AutoModelForSeq2SeqLM.from_pretrained("google/long-t5-tglo
20
  summarization_tokenizer = AutoTokenizer.from_pretrained("google/long-t5-tglobal-base")
21
 
22
  # βœ… Load datasets
23
- recommendations_df = pd.read_csv("treatment_recommendations.csv")
24
- questions_df = pd.read_csv("symptom_questions.csv")
 
 
 
 
25
 
26
  # βœ… FAISS Index for disorder detection
27
  treatment_embeddings = similarity_model.encode(recommendations_df["Disorder"].tolist(), convert_to_numpy=True)
@@ -46,10 +55,13 @@ def retrieve_questions(user_input):
46
  input_embedding = embedding_model.encode([user_input], convert_to_numpy=True)
47
  _, indices = question_index.search(input_embedding, 1) # βœ… Retrieve only 1 question
48
 
 
 
 
49
  # βœ… Extract only the first meaningful question
50
  question_block = questions_df["Questions"].iloc[indices[0][0]]
51
  split_questions = question_block.split(", ")
52
- best_question = split_questions[0] # βœ… Select the first clear question
53
 
54
  return best_question # βœ… Return a single question as a string
55
 
@@ -73,17 +85,20 @@ def generate_empathetic_response(user_input, retrieved_question):
73
  Generate only one empathetic response.
74
  """
75
 
76
- chat_completion = client.chat.completions.create(
77
- messages=[
78
- {"role": "system", "content": "You are a helpful, empathetic AI psychiatrist."},
79
- {"role": "user", "content": prompt}
80
- ],
81
- model="llama3-8b", # βœ… Use Groq's LLaMA-3 Model
82
- temperature=0.8, # Adjust for natural variation
83
- top_p=0.9
84
- )
85
-
86
- return chat_completion.choices[0].message.content # βœ… Return only one response
 
 
 
87
 
88
  # βœ… API Endpoint: Get Empathetic Questions (Hybrid RAG)
89
  @app.post("/get_questions")
@@ -111,6 +126,10 @@ def detect_disorders(request: SummaryRequest):
111
  full_chat_text = " ".join(request.chat_history)
112
  text_embedding = similarity_model.encode([full_chat_text], convert_to_numpy=True)
113
  distances, indices = index.search(text_embedding, 3)
 
 
 
 
114
  disorders = [recommendations_df["Disorder"].iloc[i] for i in indices[0]]
115
  return {"disorders": disorders}
116
 
@@ -121,6 +140,6 @@ def get_treatment(request: SummaryRequest):
121
  detected_disorders = detect_disorders(request)["disorders"]
122
  treatments = {
123
  disorder: recommendations_df[recommendations_df["Disorder"] == disorder]["Treatment Recommendation"].values[0]
124
- for disorder in detected_disorders
125
  }
126
  return {"treatments": treatments}
 
1
+ from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
  from sentence_transformers import SentenceTransformer
4
  import faiss
5
  import pandas as pd
6
  import os
7
+ import logging
8
  from groq import Groq
9
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
10
 
11
+ # βœ… Initialize FastAPI
12
  app = FastAPI()
13
 
14
+ # βœ… Securely Fetch API Key
15
+ GROQ_API_KEY = os.getenv("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") # Use environment variable for security
16
+ if not GROQ_API_KEY:
17
+ raise ValueError("GROQ_API_KEY is missing. Set it as an environment variable.")
18
+
19
+ client = Groq(api_key=GROQ_API_KEY)
20
 
21
  # βœ… Load AI Models
22
  similarity_model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
 
25
  summarization_tokenizer = AutoTokenizer.from_pretrained("google/long-t5-tglobal-base")
26
 
27
  # βœ… Load datasets
28
+ try:
29
+ recommendations_df = pd.read_csv("treatment_recommendations.csv")
30
+ questions_df = pd.read_csv("symptom_questions.csv")
31
+ except FileNotFoundError as e:
32
+ logging.error(f"Missing dataset file: {e}")
33
+ raise HTTPException(status_code=500, detail="Dataset files are missing.")
34
 
35
  # βœ… FAISS Index for disorder detection
36
  treatment_embeddings = similarity_model.encode(recommendations_df["Disorder"].tolist(), convert_to_numpy=True)
 
55
  input_embedding = embedding_model.encode([user_input], convert_to_numpy=True)
56
  _, indices = question_index.search(input_embedding, 1) # βœ… Retrieve only 1 question
57
 
58
+ if indices[0][0] == -1:
59
+ return "I'm sorry, I couldn't find a relevant question."
60
+
61
  # βœ… Extract only the first meaningful question
62
  question_block = questions_df["Questions"].iloc[indices[0][0]]
63
  split_questions = question_block.split(", ")
64
+ best_question = split_questions[0] if split_questions else question_block # βœ… Select the first clear question
65
 
66
  return best_question # βœ… Return a single question as a string
67
 
 
85
  Generate only one empathetic response.
86
  """
87
 
88
+ try:
89
+ chat_completion = client.chat.completions.create(
90
+ messages=[
91
+ {"role": "system", "content": "You are a helpful, empathetic AI psychiatrist."},
92
+ {"role": "user", "content": prompt}
93
+ ],
94
+ model="llama3-8b", # βœ… Use Groq's LLaMA-3 Model
95
+ temperature=0.8,
96
+ top_p=0.9
97
+ )
98
+ return chat_completion.choices[0].message.content # βœ… Return only one response
99
+ except Exception as e:
100
+ logging.error(f"Groq API error: {e}")
101
+ return "I'm sorry, I couldn't process your request."
102
 
103
  # βœ… API Endpoint: Get Empathetic Questions (Hybrid RAG)
104
  @app.post("/get_questions")
 
126
  full_chat_text = " ".join(request.chat_history)
127
  text_embedding = similarity_model.encode([full_chat_text], convert_to_numpy=True)
128
  distances, indices = index.search(text_embedding, 3)
129
+
130
+ if indices[0][0] == -1:
131
+ return {"disorders": "No matching disorder found."}
132
+
133
  disorders = [recommendations_df["Disorder"].iloc[i] for i in indices[0]]
134
  return {"disorders": disorders}
135
 
 
140
  detected_disorders = detect_disorders(request)["disorders"]
141
  treatments = {
142
  disorder: recommendations_df[recommendations_df["Disorder"] == disorder]["Treatment Recommendation"].values[0]
143
+ for disorder in detected_disorders if disorder in recommendations_df["Disorder"].values
144
  }
145
  return {"treatments": treatments}