Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -13,24 +13,13 @@ from langchain.chains import create_retrieval_chain
|
|
13 |
import os
|
14 |
import markdown2
|
15 |
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
# Retrieve API keys from Hugging Face Spaces secrets
|
26 |
-
openai_api_key = os.environ.get('OPENAI_API_KEY')
|
27 |
-
groq_api_key = os.environ.get('GROQ_API_KEY')
|
28 |
-
google_api_key = os.environ.get('GEMINI_API_KEY')
|
29 |
-
|
30 |
-
# Initialize API clients with the API keys
|
31 |
-
openai_client = ChatOpenAI(model_name="gpt-4o", api_key=openai_api_key)
|
32 |
-
groq_client = ChatGroq(model="llama-3.1-70b-versatile", temperature=0, api_key=groq_api_key)
|
33 |
-
gemini_client = ChatGoogleGenerativeAI(model="gemini-1.5-pro", api_key=google_api_key)
|
34 |
|
35 |
# Function to extract text from PDF
|
36 |
def extract_pdf(pdf_path):
|
@@ -46,12 +35,12 @@ def split_text(text):
|
|
46 |
return [Document(page_content=t) for t in splitter.split_text(text)]
|
47 |
|
48 |
# Function to generate embeddings and store in vector database
|
49 |
-
def generate_embeddings(docs):
|
50 |
-
embeddings = OpenAIEmbeddings(api_key=
|
51 |
return FAISS.from_documents(docs, embeddings)
|
52 |
|
53 |
# Function for query preprocessing
|
54 |
-
def preprocess_query(query):
|
55 |
prompt = ChatPromptTemplate.from_template("""
|
56 |
Transform the following query into a more detailed, keyword-rich affitmative statement that could appear in official data protection regulation documents:
|
57 |
Query: {query}
|
@@ -61,7 +50,7 @@ def preprocess_query(query):
|
|
61 |
return chain.invoke({"query": query}).content
|
62 |
|
63 |
# Function to create RAG chain with Groq
|
64 |
-
def create_rag_chain(vector_store):
|
65 |
prompt = ChatPromptTemplate.from_messages([
|
66 |
("system", "You are an AI assistant helping with data protection and regulation compliance related queries. Use the following passages of official regulation documents to provide practical advice on how to meet regulatory requirements in the context of the user question:\n\n{context}"),
|
67 |
("human", "{input}")
|
@@ -70,7 +59,7 @@ def create_rag_chain(vector_store):
|
|
70 |
return create_retrieval_chain(vector_store.as_retriever(), document_chain)
|
71 |
|
72 |
# Function for Gemini response with long context
|
73 |
-
def gemini_response(query, full_pdf_content):
|
74 |
prompt = ChatPromptTemplate.from_messages([
|
75 |
("system", "You are an AI assistant helping with data protection and regulation compliance related queries. Use the following full content of official regulation documents to provide practical advice on how to meet regulatory requirements in the context of the user question:\n\n{context}"),
|
76 |
("human", "{input}")
|
@@ -79,7 +68,7 @@ def gemini_response(query, full_pdf_content):
|
|
79 |
return chain.invoke({"context": full_pdf_content, "input": query}).content
|
80 |
|
81 |
# Function to generate final response
|
82 |
-
def generate_final_response(query, response1, response2):
|
83 |
prompt = ChatPromptTemplate.from_template("""
|
84 |
As an AI assistant specializing in data protection and compliance for educators:
|
85 |
[hidden states, scrartchpad]
|
@@ -89,28 +78,64 @@ def generate_final_response(query, response1, response2):
|
|
89 |
[Output]
|
90 |
4. Based on Steps 1, 2, and 3: Provide an explanation of the relevant regulatory requirements and provide practical advice on how to meet them in the context of the user question.
|
91 |
Important: the final output should be a direct response to the query. Strip it of all reference to steps 1, 2, 3.
|
92 |
-
|
93 |
User Query: {query}
|
94 |
-
|
95 |
Response 1: {response1}
|
96 |
-
|
97 |
Response 2: {response2}
|
98 |
-
|
99 |
Your synthesized response:
|
100 |
""")
|
101 |
chain = prompt | openai_client
|
102 |
return chain.invoke({"query": query, "response1": response1, "response2": response2}).content
|
103 |
|
104 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
def process_query(user_query):
|
|
|
106 |
try:
|
107 |
-
|
|
|
|
|
|
|
|
|
108 |
print(f"Original query: {user_query}")
|
109 |
print(f"Preprocessed query: {preprocessed_query}")
|
110 |
|
111 |
-
rag_response = rag_chain.invoke({"input": preprocessed_query})["answer"]
|
112 |
-
gemini_resp = gemini_response(preprocessed_query, full_pdf_content)
|
113 |
-
final_response = generate_final_response(user_query, rag_response, gemini_resp)
|
114 |
final_output = "## Final (GPT-4o) Response:\n\n" + final_response
|
115 |
html_content = markdown2.markdown(final_output)
|
116 |
return rag_response, gemini_resp, html_content
|
@@ -118,32 +143,37 @@ def process_query(user_query):
|
|
118 |
error_message = f"An error occurred: {str(e)}"
|
119 |
return error_message, error_message, error_message
|
120 |
|
121 |
-
# Initialize
|
122 |
-
pdf_paths = ["GDPR.pdf", "FERPA.pdf", "COPPA.pdf"]
|
123 |
-
full_pdf_content = ""
|
124 |
-
all_documents = []
|
125 |
-
|
126 |
-
for pdf_path in pdf_paths:
|
127 |
-
extracted_text = extract_pdf(pdf_path)
|
128 |
-
full_pdf_content += extracted_text + "\n\n"
|
129 |
-
all_documents.extend(split_text(extracted_text))
|
130 |
-
|
131 |
-
vector_store = generate_embeddings(all_documents)
|
132 |
-
rag_chain = create_rag_chain(vector_store)
|
133 |
-
|
134 |
# Gradio interface
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
gr.Textbox(label="
|
141 |
-
gr.
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
147 |
|
148 |
# Launch the interface
|
149 |
-
iface.launch(
|
|
|
13 |
import os
|
14 |
import markdown2
|
15 |
|
16 |
+
def create_api_clients(openai_key, groq_key, gemini_key):
|
17 |
+
"""Initialize API clients with provided keys"""
|
18 |
+
return (
|
19 |
+
ChatOpenAI(model_name="gpt-4o", api_key=openai_key),
|
20 |
+
ChatGroq(model="llama-3.3-70b-versatile", temperature=0, api_key=groq_key),
|
21 |
+
ChatGoogleGenerativeAI(model="gemini-1.5-pro", api_key=gemini_key)
|
22 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
# Function to extract text from PDF
|
25 |
def extract_pdf(pdf_path):
|
|
|
35 |
return [Document(page_content=t) for t in splitter.split_text(text)]
|
36 |
|
37 |
# Function to generate embeddings and store in vector database
|
38 |
+
def generate_embeddings(docs, openai_key):
|
39 |
+
embeddings = OpenAIEmbeddings(api_key=openai_key)
|
40 |
return FAISS.from_documents(docs, embeddings)
|
41 |
|
42 |
# Function for query preprocessing
|
43 |
+
def preprocess_query(query, openai_client):
|
44 |
prompt = ChatPromptTemplate.from_template("""
|
45 |
Transform the following query into a more detailed, keyword-rich affitmative statement that could appear in official data protection regulation documents:
|
46 |
Query: {query}
|
|
|
50 |
return chain.invoke({"query": query}).content
|
51 |
|
52 |
# Function to create RAG chain with Groq
|
53 |
+
def create_rag_chain(vector_store, groq_client):
|
54 |
prompt = ChatPromptTemplate.from_messages([
|
55 |
("system", "You are an AI assistant helping with data protection and regulation compliance related queries. Use the following passages of official regulation documents to provide practical advice on how to meet regulatory requirements in the context of the user question:\n\n{context}"),
|
56 |
("human", "{input}")
|
|
|
59 |
return create_retrieval_chain(vector_store.as_retriever(), document_chain)
|
60 |
|
61 |
# Function for Gemini response with long context
|
62 |
+
def gemini_response(query, full_pdf_content, gemini_client):
|
63 |
prompt = ChatPromptTemplate.from_messages([
|
64 |
("system", "You are an AI assistant helping with data protection and regulation compliance related queries. Use the following full content of official regulation documents to provide practical advice on how to meet regulatory requirements in the context of the user question:\n\n{context}"),
|
65 |
("human", "{input}")
|
|
|
68 |
return chain.invoke({"context": full_pdf_content, "input": query}).content
|
69 |
|
70 |
# Function to generate final response
|
71 |
+
def generate_final_response(query, response1, response2, openai_client):
|
72 |
prompt = ChatPromptTemplate.from_template("""
|
73 |
As an AI assistant specializing in data protection and compliance for educators:
|
74 |
[hidden states, scrartchpad]
|
|
|
78 |
[Output]
|
79 |
4. Based on Steps 1, 2, and 3: Provide an explanation of the relevant regulatory requirements and provide practical advice on how to meet them in the context of the user question.
|
80 |
Important: the final output should be a direct response to the query. Strip it of all reference to steps 1, 2, 3.
|
|
|
81 |
User Query: {query}
|
|
|
82 |
Response 1: {response1}
|
|
|
83 |
Response 2: {response2}
|
|
|
84 |
Your synthesized response:
|
85 |
""")
|
86 |
chain = prompt | openai_client
|
87 |
return chain.invoke({"query": query, "response1": response1, "response2": response2}).content
|
88 |
|
89 |
+
class APIState:
|
90 |
+
def __init__(self):
|
91 |
+
self.openai_client = None
|
92 |
+
self.groq_client = None
|
93 |
+
self.gemini_client = None
|
94 |
+
self.vector_store = None
|
95 |
+
self.rag_chain = None
|
96 |
+
self.full_pdf_content = ""
|
97 |
+
|
98 |
+
api_state = APIState()
|
99 |
+
|
100 |
+
def initialize_system(openai_key, groq_key, gemini_key):
|
101 |
+
"""Initialize the system with provided API keys"""
|
102 |
+
try:
|
103 |
+
# Initialize API clients
|
104 |
+
api_state.openai_client, api_state.groq_client, api_state.gemini_client = create_api_clients(
|
105 |
+
openai_key, groq_key, gemini_key
|
106 |
+
)
|
107 |
+
|
108 |
+
# Process PDFs
|
109 |
+
pdf_paths = ["GDPR.pdf", "FERPA.pdf", "COPPA.pdf"]
|
110 |
+
all_documents = []
|
111 |
+
|
112 |
+
for pdf_path in pdf_paths:
|
113 |
+
extracted_text = extract_pdf(pdf_path)
|
114 |
+
api_state.full_pdf_content += extracted_text + "\n\n"
|
115 |
+
all_documents.extend(split_text(extracted_text))
|
116 |
+
|
117 |
+
# Generate embeddings and create RAG chain
|
118 |
+
api_state.vector_store = generate_embeddings(all_documents, openai_key)
|
119 |
+
api_state.rag_chain = create_rag_chain(api_state.vector_store, api_state.groq_client)
|
120 |
+
|
121 |
+
return "System initialized successfully!"
|
122 |
+
except Exception as e:
|
123 |
+
return f"Initialization failed: {str(e)}"
|
124 |
+
|
125 |
def process_query(user_query):
|
126 |
+
"""Process user query using initialized clients"""
|
127 |
try:
|
128 |
+
if not all([api_state.openai_client, api_state.groq_client, api_state.gemini_client,
|
129 |
+
api_state.vector_store, api_state.rag_chain]):
|
130 |
+
return "Please initialize the system with API keys first.", "", ""
|
131 |
+
|
132 |
+
preprocessed_query = preprocess_query(user_query, api_state.openai_client)
|
133 |
print(f"Original query: {user_query}")
|
134 |
print(f"Preprocessed query: {preprocessed_query}")
|
135 |
|
136 |
+
rag_response = api_state.rag_chain.invoke({"input": preprocessed_query})["answer"]
|
137 |
+
gemini_resp = gemini_response(preprocessed_query, api_state.full_pdf_content, api_state.gemini_client)
|
138 |
+
final_response = generate_final_response(user_query, rag_response, gemini_resp, api_state.openai_client)
|
139 |
final_output = "## Final (GPT-4o) Response:\n\n" + final_response
|
140 |
html_content = markdown2.markdown(final_output)
|
141 |
return rag_response, gemini_resp, html_content
|
|
|
143 |
error_message = f"An error occurred: {str(e)}"
|
144 |
return error_message, error_message, error_message
|
145 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
146 |
# Gradio interface
|
147 |
+
with gr.Blocks() as iface:
|
148 |
+
gr.Markdown("# Data Protection Team")
|
149 |
+
gr.Markdown("Get responses combining advanced RAG, Long Context, and SOTA models to data protection related questions (GDPR, FERPA, COPPA).")
|
150 |
+
|
151 |
+
with gr.Row():
|
152 |
+
openai_key_input = gr.Textbox(label="OpenAI API Key", type="password")
|
153 |
+
groq_key_input = gr.Textbox(label="Groq API Key", type="password")
|
154 |
+
gemini_key_input = gr.Textbox(label="Gemini API Key", type="password")
|
155 |
+
|
156 |
+
init_button = gr.Button("Initialize System")
|
157 |
+
init_output = gr.Textbox(label="Initialization Status")
|
158 |
+
|
159 |
+
query_input = gr.Textbox(label="Ask your data protection related question")
|
160 |
+
submit_button = gr.Button("Submit Query")
|
161 |
+
|
162 |
+
rag_output = gr.Textbox(label="RAG Pipeline (Llama3.1) Response")
|
163 |
+
gemini_output = gr.Textbox(label="Long Context (Gemini 1.5 Pro) Response")
|
164 |
+
final_output = gr.HTML(label="Final (GPT-4) Response")
|
165 |
+
|
166 |
+
init_button.click(
|
167 |
+
initialize_system,
|
168 |
+
inputs=[openai_key_input, groq_key_input, gemini_key_input],
|
169 |
+
outputs=init_output
|
170 |
+
)
|
171 |
+
|
172 |
+
submit_button.click(
|
173 |
+
process_query,
|
174 |
+
inputs=query_input,
|
175 |
+
outputs=[rag_output, gemini_output, final_output]
|
176 |
+
)
|
177 |
|
178 |
# Launch the interface
|
179 |
+
iface.launch()
|