jeremierostan commited on
Commit
59b7b01
·
verified ·
1 Parent(s): d54b9d0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +89 -59
app.py CHANGED
@@ -13,24 +13,13 @@ from langchain.chains import create_retrieval_chain
13
  import os
14
  import markdown2
15
 
16
- # Retrieve username and password from environment variables
17
- username = os.environ.get("USERNAME")
18
- password = os.environ.get("PASSWORD")
19
-
20
- # Ensure both username and password are set
21
- if not username or not password:
22
- raise ValueError("Both USERNAME and PASSWORD must be set in the environment variables.")
23
-
24
-
25
- # Retrieve API keys from Hugging Face Spaces secrets
26
- openai_api_key = os.environ.get('OPENAI_API_KEY')
27
- groq_api_key = os.environ.get('GROQ_API_KEY')
28
- google_api_key = os.environ.get('GEMINI_API_KEY')
29
-
30
- # Initialize API clients with the API keys
31
- openai_client = ChatOpenAI(model_name="gpt-4o", api_key=openai_api_key)
32
- groq_client = ChatGroq(model="llama-3.1-70b-versatile", temperature=0, api_key=groq_api_key)
33
- gemini_client = ChatGoogleGenerativeAI(model="gemini-1.5-pro", api_key=google_api_key)
34
 
35
  # Function to extract text from PDF
36
  def extract_pdf(pdf_path):
@@ -46,12 +35,12 @@ def split_text(text):
46
  return [Document(page_content=t) for t in splitter.split_text(text)]
47
 
48
  # Function to generate embeddings and store in vector database
49
- def generate_embeddings(docs):
50
- embeddings = OpenAIEmbeddings(api_key=openai_api_key)
51
  return FAISS.from_documents(docs, embeddings)
52
 
53
  # Function for query preprocessing
54
- def preprocess_query(query):
55
  prompt = ChatPromptTemplate.from_template("""
56
  Transform the following query into a more detailed, keyword-rich affitmative statement that could appear in official data protection regulation documents:
57
  Query: {query}
@@ -61,7 +50,7 @@ def preprocess_query(query):
61
  return chain.invoke({"query": query}).content
62
 
63
  # Function to create RAG chain with Groq
64
- def create_rag_chain(vector_store):
65
  prompt = ChatPromptTemplate.from_messages([
66
  ("system", "You are an AI assistant helping with data protection and regulation compliance related queries. Use the following passages of official regulation documents to provide practical advice on how to meet regulatory requirements in the context of the user question:\n\n{context}"),
67
  ("human", "{input}")
@@ -70,7 +59,7 @@ def create_rag_chain(vector_store):
70
  return create_retrieval_chain(vector_store.as_retriever(), document_chain)
71
 
72
  # Function for Gemini response with long context
73
- def gemini_response(query, full_pdf_content):
74
  prompt = ChatPromptTemplate.from_messages([
75
  ("system", "You are an AI assistant helping with data protection and regulation compliance related queries. Use the following full content of official regulation documents to provide practical advice on how to meet regulatory requirements in the context of the user question:\n\n{context}"),
76
  ("human", "{input}")
@@ -79,7 +68,7 @@ def gemini_response(query, full_pdf_content):
79
  return chain.invoke({"context": full_pdf_content, "input": query}).content
80
 
81
  # Function to generate final response
82
- def generate_final_response(query, response1, response2):
83
  prompt = ChatPromptTemplate.from_template("""
84
  As an AI assistant specializing in data protection and compliance for educators:
85
  [hidden states, scrartchpad]
@@ -89,28 +78,64 @@ def generate_final_response(query, response1, response2):
89
  [Output]
90
  4. Based on Steps 1, 2, and 3: Provide an explanation of the relevant regulatory requirements and provide practical advice on how to meet them in the context of the user question.
91
  Important: the final output should be a direct response to the query. Strip it of all reference to steps 1, 2, 3.
92
-
93
  User Query: {query}
94
-
95
  Response 1: {response1}
96
-
97
  Response 2: {response2}
98
-
99
  Your synthesized response:
100
  """)
101
  chain = prompt | openai_client
102
  return chain.invoke({"query": query, "response1": response1, "response2": response2}).content
103
 
104
- # Function to process the query
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  def process_query(user_query):
 
106
  try:
107
- preprocessed_query = preprocess_query(user_query)
 
 
 
 
108
  print(f"Original query: {user_query}")
109
  print(f"Preprocessed query: {preprocessed_query}")
110
 
111
- rag_response = rag_chain.invoke({"input": preprocessed_query})["answer"]
112
- gemini_resp = gemini_response(preprocessed_query, full_pdf_content)
113
- final_response = generate_final_response(user_query, rag_response, gemini_resp)
114
  final_output = "## Final (GPT-4o) Response:\n\n" + final_response
115
  html_content = markdown2.markdown(final_output)
116
  return rag_response, gemini_resp, html_content
@@ -118,32 +143,37 @@ def process_query(user_query):
118
  error_message = f"An error occurred: {str(e)}"
119
  return error_message, error_message, error_message
120
 
121
- # Initialize
122
- pdf_paths = ["GDPR.pdf", "FERPA.pdf", "COPPA.pdf"]
123
- full_pdf_content = ""
124
- all_documents = []
125
-
126
- for pdf_path in pdf_paths:
127
- extracted_text = extract_pdf(pdf_path)
128
- full_pdf_content += extracted_text + "\n\n"
129
- all_documents.extend(split_text(extracted_text))
130
-
131
- vector_store = generate_embeddings(all_documents)
132
- rag_chain = create_rag_chain(vector_store)
133
-
134
  # Gradio interface
135
- iface = gr.Interface(
136
- fn=process_query,
137
- inputs=gr.Textbox(label="Ask your data protection related question"),
138
- outputs=[
139
- gr.Textbox(label="RAG Pipeline (Llama3.1) Response"),
140
- gr.Textbox(label="Long Context (Gemini 1.5 Pro) Response"),
141
- gr.HTML(label="Final (GPT-4) Response")
142
- ],
143
- title="Data Protection Team",
144
- description="Get responses combining advanced RAG, Long Context, and SOTA models to data protection related questions (GDPR, FERPA, COPPA).",
145
- allow_flagging="never"
146
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
 
148
  # Launch the interface
149
- iface.launch(auth=(username, password))
 
13
  import os
14
  import markdown2
15
 
16
+ def create_api_clients(openai_key, groq_key, gemini_key):
17
+ """Initialize API clients with provided keys"""
18
+ return (
19
+ ChatOpenAI(model_name="gpt-4o", api_key=openai_key),
20
+ ChatGroq(model="llama-3.3-70b-versatile", temperature=0, api_key=groq_key),
21
+ ChatGoogleGenerativeAI(model="gemini-1.5-pro", api_key=gemini_key)
22
+ )
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  # Function to extract text from PDF
25
  def extract_pdf(pdf_path):
 
35
  return [Document(page_content=t) for t in splitter.split_text(text)]
36
 
37
  # Function to generate embeddings and store in vector database
38
+ def generate_embeddings(docs, openai_key):
39
+ embeddings = OpenAIEmbeddings(api_key=openai_key)
40
  return FAISS.from_documents(docs, embeddings)
41
 
42
  # Function for query preprocessing
43
+ def preprocess_query(query, openai_client):
44
  prompt = ChatPromptTemplate.from_template("""
45
  Transform the following query into a more detailed, keyword-rich affitmative statement that could appear in official data protection regulation documents:
46
  Query: {query}
 
50
  return chain.invoke({"query": query}).content
51
 
52
  # Function to create RAG chain with Groq
53
+ def create_rag_chain(vector_store, groq_client):
54
  prompt = ChatPromptTemplate.from_messages([
55
  ("system", "You are an AI assistant helping with data protection and regulation compliance related queries. Use the following passages of official regulation documents to provide practical advice on how to meet regulatory requirements in the context of the user question:\n\n{context}"),
56
  ("human", "{input}")
 
59
  return create_retrieval_chain(vector_store.as_retriever(), document_chain)
60
 
61
  # Function for Gemini response with long context
62
+ def gemini_response(query, full_pdf_content, gemini_client):
63
  prompt = ChatPromptTemplate.from_messages([
64
  ("system", "You are an AI assistant helping with data protection and regulation compliance related queries. Use the following full content of official regulation documents to provide practical advice on how to meet regulatory requirements in the context of the user question:\n\n{context}"),
65
  ("human", "{input}")
 
68
  return chain.invoke({"context": full_pdf_content, "input": query}).content
69
 
70
  # Function to generate final response
71
+ def generate_final_response(query, response1, response2, openai_client):
72
  prompt = ChatPromptTemplate.from_template("""
73
  As an AI assistant specializing in data protection and compliance for educators:
74
  [hidden states, scrartchpad]
 
78
  [Output]
79
  4. Based on Steps 1, 2, and 3: Provide an explanation of the relevant regulatory requirements and provide practical advice on how to meet them in the context of the user question.
80
  Important: the final output should be a direct response to the query. Strip it of all reference to steps 1, 2, 3.
 
81
  User Query: {query}
 
82
  Response 1: {response1}
 
83
  Response 2: {response2}
 
84
  Your synthesized response:
85
  """)
86
  chain = prompt | openai_client
87
  return chain.invoke({"query": query, "response1": response1, "response2": response2}).content
88
 
89
+ class APIState:
90
+ def __init__(self):
91
+ self.openai_client = None
92
+ self.groq_client = None
93
+ self.gemini_client = None
94
+ self.vector_store = None
95
+ self.rag_chain = None
96
+ self.full_pdf_content = ""
97
+
98
+ api_state = APIState()
99
+
100
+ def initialize_system(openai_key, groq_key, gemini_key):
101
+ """Initialize the system with provided API keys"""
102
+ try:
103
+ # Initialize API clients
104
+ api_state.openai_client, api_state.groq_client, api_state.gemini_client = create_api_clients(
105
+ openai_key, groq_key, gemini_key
106
+ )
107
+
108
+ # Process PDFs
109
+ pdf_paths = ["GDPR.pdf", "FERPA.pdf", "COPPA.pdf"]
110
+ all_documents = []
111
+
112
+ for pdf_path in pdf_paths:
113
+ extracted_text = extract_pdf(pdf_path)
114
+ api_state.full_pdf_content += extracted_text + "\n\n"
115
+ all_documents.extend(split_text(extracted_text))
116
+
117
+ # Generate embeddings and create RAG chain
118
+ api_state.vector_store = generate_embeddings(all_documents, openai_key)
119
+ api_state.rag_chain = create_rag_chain(api_state.vector_store, api_state.groq_client)
120
+
121
+ return "System initialized successfully!"
122
+ except Exception as e:
123
+ return f"Initialization failed: {str(e)}"
124
+
125
  def process_query(user_query):
126
+ """Process user query using initialized clients"""
127
  try:
128
+ if not all([api_state.openai_client, api_state.groq_client, api_state.gemini_client,
129
+ api_state.vector_store, api_state.rag_chain]):
130
+ return "Please initialize the system with API keys first.", "", ""
131
+
132
+ preprocessed_query = preprocess_query(user_query, api_state.openai_client)
133
  print(f"Original query: {user_query}")
134
  print(f"Preprocessed query: {preprocessed_query}")
135
 
136
+ rag_response = api_state.rag_chain.invoke({"input": preprocessed_query})["answer"]
137
+ gemini_resp = gemini_response(preprocessed_query, api_state.full_pdf_content, api_state.gemini_client)
138
+ final_response = generate_final_response(user_query, rag_response, gemini_resp, api_state.openai_client)
139
  final_output = "## Final (GPT-4o) Response:\n\n" + final_response
140
  html_content = markdown2.markdown(final_output)
141
  return rag_response, gemini_resp, html_content
 
143
  error_message = f"An error occurred: {str(e)}"
144
  return error_message, error_message, error_message
145
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  # Gradio interface
147
+ with gr.Blocks() as iface:
148
+ gr.Markdown("# Data Protection Team")
149
+ gr.Markdown("Get responses combining advanced RAG, Long Context, and SOTA models to data protection related questions (GDPR, FERPA, COPPA).")
150
+
151
+ with gr.Row():
152
+ openai_key_input = gr.Textbox(label="OpenAI API Key", type="password")
153
+ groq_key_input = gr.Textbox(label="Groq API Key", type="password")
154
+ gemini_key_input = gr.Textbox(label="Gemini API Key", type="password")
155
+
156
+ init_button = gr.Button("Initialize System")
157
+ init_output = gr.Textbox(label="Initialization Status")
158
+
159
+ query_input = gr.Textbox(label="Ask your data protection related question")
160
+ submit_button = gr.Button("Submit Query")
161
+
162
+ rag_output = gr.Textbox(label="RAG Pipeline (Llama3.1) Response")
163
+ gemini_output = gr.Textbox(label="Long Context (Gemini 1.5 Pro) Response")
164
+ final_output = gr.HTML(label="Final (GPT-4) Response")
165
+
166
+ init_button.click(
167
+ initialize_system,
168
+ inputs=[openai_key_input, groq_key_input, gemini_key_input],
169
+ outputs=init_output
170
+ )
171
+
172
+ submit_button.click(
173
+ process_query,
174
+ inputs=query_input,
175
+ outputs=[rag_output, gemini_output, final_output]
176
+ )
177
 
178
  # Launch the interface
179
+ iface.launch()