VanguardAI commited on
Commit
8f7c5f5
·
verified ·
1 Parent(s): ac48055

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +196 -60
app.py CHANGED
@@ -8,16 +8,20 @@ from transformers import AutoModel, AutoTokenizer
8
  from diffusers import StableDiffusion3Pipeline
9
  from parler_tts import ParlerTTSForConditionalGeneration
10
  import soundfile as sf
11
- from llama_index.core.agent import ReActAgent
12
- from llama_index.core.tools import FunctionTool
13
- from llama_index.llms.groq import Groq
14
  from PIL import Image
15
  from tavily import TavilyClient
16
  import requests
17
  from huggingface_hub import hf_hub_download
18
  from safetensors.torch import load_file
19
- from llama_index.core.chat_engine.types import AgentChatResponse
20
- from llama_index.core import VectorStoreIndex
 
 
 
 
21
 
22
  # Initialize models and clients
23
  MODEL = 'llama3-groq-70b-8192-tool-use-preview'
@@ -48,38 +52,71 @@ def play_voice_output(response):
48
  return "output.wav"
49
 
50
  # NumPy Code Calculator Tool
51
- def numpy_code_calculator(query):
52
- try:
53
- # Assume query is a request for a numpy computation
54
- local_dict = {"np": np}
55
- exec(query, local_dict)
56
- result = local_dict.get("result", "No result found")
57
- return str(result)
58
- except Exception as e:
59
- return f"Error: {e}"
 
 
 
60
 
61
  # Web Search Tool
62
- def web_search(query):
63
- answer = tavily_client.qna_search(query=query)
64
- return answer
 
 
 
 
65
 
66
  # Image Generation Tool
67
- def image_generation(query):
68
- image = pipe(
69
- query,
70
- negative_prompt="",
71
- num_inference_steps=15,
72
- guidance_scale=7.0,
73
- ).images[0]
74
- image.save("output.jpg")
75
- return "output.jpg"
 
 
 
 
76
 
77
  # Document Question Answering Tool
78
- def document_question_answering(query, docs):
79
- index = VectorStoreIndex.from_documents(docs)
80
- query_engine = index.as_query_engine(similarity_top_k=3)
81
- response = query_engine.query(query)
82
- return str(response)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
  # Function to handle different input types and choose the right tool
85
  def handle_input(user_prompt, image=None, audio=None, websearch=False, document=None):
@@ -93,43 +130,38 @@ def handle_input(user_prompt, image=None, audio=None, websearch=False, document=
93
  user_prompt = transcription.text
94
 
95
  tools = [
96
- FunctionTool.from_defaults(fn=numpy_code_calculator, name="Numpy"),
97
- FunctionTool.from_defaults(fn=image_generation, name="Image"),
98
  ]
99
 
100
  # Add the web search tool only if websearch mode is enabled
101
  if websearch:
102
- tools.append(FunctionTool.from_defaults(fn=web_search, name="Web"))
103
 
104
  # Add the document question answering tool only if a document is provided
105
  if document:
106
- docs = LlamaParse(result_type="text").load_data(document)
107
- tools.append(FunctionTool.from_defaults(fn=document_question_answering, name="Document", docs=docs))
108
 
109
- llm = Groq(model=MODEL, api_key=os.environ.get("GROQ_API_KEY"))
110
- agent = ReActAgent.from_tools(tools, llm=llm, verbose=True)
 
111
 
112
  if image:
113
  image = Image.open(image).convert('RGB')
114
  messages = [{"role": "user", "content": [image, user_prompt]}]
115
  response = vqa_model.chat(image=None, msgs=messages, tokenizer=tokenizer)
116
  else:
117
- response = agent.chat(user_prompt)
118
-
119
- # Extract the content from AgentChatResponse to return as a string
120
- if isinstance(response, AgentChatResponse):
121
- response = response.response
122
 
123
  return response
124
 
125
-
126
- # Gradio UI Setup
127
  def create_ui():
128
  with gr.Blocks(css="""
129
  /* Overall Styling */
130
  body {
131
- font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
132
- background-color: #f4f4f4;
133
  margin: 0;
134
  padding: 0;
135
  color: #333;
@@ -139,8 +171,14 @@ def create_ui():
139
  .gradio-container h1 {
140
  text-align: center;
141
  padding: 20px 0;
142
- background-color: #007bff; /* Example color */
143
  color: white;
 
 
 
 
 
 
144
  }
145
 
146
  /* Input Area Styling */
@@ -149,6 +187,10 @@ def create_ui():
149
  justify-content: space-around;
150
  align-items: center;
151
  padding: 20px;
 
 
 
 
152
  }
153
 
154
  .gradio-container .gr-column {
@@ -159,40 +201,135 @@ def create_ui():
159
  /* Textbox Styling */
160
  .gradio-container textarea {
161
  width: calc(100% - 20px);
162
- padding: 10px;
163
- border: 2px solid #ccc;
164
- border-radius: 5px;
165
- font-size: 16px;
 
 
 
 
 
 
 
166
  }
167
 
168
  /* Button Styling */
169
  .gradio-container button {
170
- background-color: #007bff; /* Example color */
171
  color: white;
172
- padding: 12px 20px;
173
  border: none;
174
- border-radius: 5px;
175
  cursor: pointer;
176
- font-size: 16px;
177
- transition: background-color 0.3s;
 
 
178
  }
179
 
180
  .gradio-container button:hover {
181
- background-color: #0056b3; /* Example darker color */
 
 
 
 
 
182
  }
183
 
184
  /* Output Area Styling */
185
  .gradio-container .output-area {
186
  padding: 20px;
187
  text-align: center;
 
 
 
 
188
  }
189
 
190
  /* Image Styling */
191
  .gradio-container img {
192
  max-width: 100%;
193
  height: auto;
194
- border-radius: 5px;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
  }
 
 
 
 
 
 
 
 
 
 
 
 
196
  """) as demo:
197
  gr.Markdown("# AI Assistant")
198
  with gr.Row():
@@ -257,7 +394,6 @@ def main_interface(user_prompt, image=None, audio=None, voice_only=False, websea
257
  else:
258
  return response, None
259
 
260
-
261
  # Launch the UI
262
  demo = create_ui()
263
  demo.launch()
 
8
  from diffusers import StableDiffusion3Pipeline
9
  from parler_tts import ParlerTTSForConditionalGeneration
10
  import soundfile as sf
11
+ from langchain.agents import AgentExecutor, create_react_agent
12
+ from langchain.tools import BaseTool
13
+ from langchain_groq import ChatGroq
14
  from PIL import Image
15
  from tavily import TavilyClient
16
  import requests
17
  from huggingface_hub import hf_hub_download
18
  from safetensors.torch import load_file
19
+ from langchain.schema import AIMessage
20
+ from langchain.embeddings import HuggingFaceEmbeddings
21
+ from langchain.vectorstores import FAISS
22
+ from langchain.document_loaders import TextLoader
23
+ from langchain.text_splitter import CharacterTextSplitter
24
+ from langchain.chains import RetrievalQA
25
 
26
  # Initialize models and clients
27
  MODEL = 'llama3-groq-70b-8192-tool-use-preview'
 
52
  return "output.wav"
53
 
54
  # NumPy Code Calculator Tool
55
+ class NumpyCodeCalculator(BaseTool):
56
+ name = "Numpy"
57
+ description = "Useful for performing numpy computations"
58
+
59
+ def _run(self, query: str) -> str:
60
+ try:
61
+ local_dict = {"np": np}
62
+ exec(query, local_dict)
63
+ result = local_dict.get("result", "No result found")
64
+ return str(result)
65
+ except Exception as e:
66
+ return f"Error: {e}"
67
 
68
  # Web Search Tool
69
+ class WebSearch(BaseTool):
70
+ name = "Web"
71
+ description = "Useful for searching the web for information"
72
+
73
+ def _run(self, query: str) -> str:
74
+ answer = tavily_client.qna_search(query=query)
75
+ return answer
76
 
77
  # Image Generation Tool
78
+ class ImageGeneration(BaseTool):
79
+ name = "Image"
80
+ description = "Useful for generating images based on text descriptions"
81
+
82
+ def _run(self, query: str) -> str:
83
+ image = pipe(
84
+ query,
85
+ negative_prompt="",
86
+ num_inference_steps=15,
87
+ guidance_scale=7.0,
88
+ ).images[0]
89
+ image.save("output.jpg")
90
+ return "output.jpg"
91
 
92
  # Document Question Answering Tool
93
+ class DocumentQuestionAnswering(BaseTool):
94
+ name = "Document"
95
+ description = "Useful for answering questions about a specific document"
96
+
97
+ def __init__(self, document):
98
+ super().__init__()
99
+ self.document = document
100
+ self.qa_chain = self._setup_qa_chain()
101
+
102
+ def _setup_qa_chain(self):
103
+ loader = TextLoader(self.document)
104
+ documents = loader.load()
105
+ text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
106
+ texts = text_splitter.split_documents(documents)
107
+ embeddings = HuggingFaceEmbeddings()
108
+ db = FAISS.from_documents(texts, embeddings)
109
+ retriever = db.as_retriever()
110
+ qa_chain = RetrievalQA.from_chain_type(
111
+ llm=ChatGroq(model=MODEL, api_key=os.environ.get("GROQ_API_KEY")),
112
+ chain_type="stuff",
113
+ retriever=retriever,
114
+ )
115
+ return qa_chain
116
+
117
+ def _run(self, query: str) -> str:
118
+ response = self.qa_chain.run(query)
119
+ return str(response)
120
 
121
  # Function to handle different input types and choose the right tool
122
  def handle_input(user_prompt, image=None, audio=None, websearch=False, document=None):
 
130
  user_prompt = transcription.text
131
 
132
  tools = [
133
+ NumpyCodeCalculator(),
134
+ ImageGeneration(),
135
  ]
136
 
137
  # Add the web search tool only if websearch mode is enabled
138
  if websearch:
139
+ tools.append(WebSearch())
140
 
141
  # Add the document question answering tool only if a document is provided
142
  if document:
143
+ tools.append(DocumentQuestionAnswering(document))
 
144
 
145
+ llm = ChatGroq(model=MODEL, api_key=os.environ.get("GROQ_API_KEY"))
146
+ agent = create_react_agent(llm, tools, verbose=True)
147
+ agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
148
 
149
  if image:
150
  image = Image.open(image).convert('RGB')
151
  messages = [{"role": "user", "content": [image, user_prompt]}]
152
  response = vqa_model.chat(image=None, msgs=messages, tokenizer=tokenizer)
153
  else:
154
+ response = agent_executor.run(user_prompt)
 
 
 
 
155
 
156
  return response
157
 
158
+
 
159
  def create_ui():
160
  with gr.Blocks(css="""
161
  /* Overall Styling */
162
  body {
163
+ font-family: 'Poppins', sans-serif;
164
+ background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%);
165
  margin: 0;
166
  padding: 0;
167
  color: #333;
 
171
  .gradio-container h1 {
172
  text-align: center;
173
  padding: 20px 0;
174
+ background: linear-gradient(45deg, #007bff, #00c6ff);
175
  color: white;
176
+ font-size: 2.5em;
177
+ font-weight: bold;
178
+ letter-spacing: 1px;
179
+ text-transform: uppercase;
180
+ margin: 0;
181
+ box-shadow: 0px 4px 8px rgba(0, 0, 0, 0.2);
182
  }
183
 
184
  /* Input Area Styling */
 
187
  justify-content: space-around;
188
  align-items: center;
189
  padding: 20px;
190
+ background-color: white;
191
+ border-radius: 10px;
192
+ box-shadow: 0px 6px 12px rgba(0, 0, 0, 0.1);
193
+ margin-bottom: 20px;
194
  }
195
 
196
  .gradio-container .gr-column {
 
201
  /* Textbox Styling */
202
  .gradio-container textarea {
203
  width: calc(100% - 20px);
204
+ padding: 15px;
205
+ border: 2px solid #007bff;
206
+ border-radius: 8px;
207
+ font-size: 1.1em;
208
+ transition: border-color 0.3s, box-shadow 0.3s;
209
+ }
210
+
211
+ .gradio-container textarea:focus {
212
+ border-color: #00c6ff;
213
+ box-shadow: 0px 0px 8px rgba(0, 198, 255, 0.5);
214
+ outline: none;
215
  }
216
 
217
  /* Button Styling */
218
  .gradio-container button {
219
+ background: linear-gradient(45deg, #007bff, #00c6ff);
220
  color: white;
221
+ padding: 15px 25px;
222
  border: none;
223
+ border-radius: 8px;
224
  cursor: pointer;
225
+ font-size: 1.2em;
226
+ font-weight: bold;
227
+ transition: background 0.3s, transform 0.3s;
228
+ box-shadow: 0px 4px 8px rgba(0, 0, 0, 0.1);
229
  }
230
 
231
  .gradio-container button:hover {
232
+ background: linear-gradient(45deg, #0056b3, #009bff);
233
+ transform: translateY(-3px);
234
+ }
235
+
236
+ .gradio-container button:active {
237
+ transform: translateY(0);
238
  }
239
 
240
  /* Output Area Styling */
241
  .gradio-container .output-area {
242
  padding: 20px;
243
  text-align: center;
244
+ background-color: #f7f9fc;
245
+ border-radius: 10px;
246
+ box-shadow: 0px 6px 12px rgba(0, 0, 0, 0.1);
247
+ margin-top: 20px;
248
  }
249
 
250
  /* Image Styling */
251
  .gradio-container img {
252
  max-width: 100%;
253
  height: auto;
254
+ border-radius: 10px;
255
+ box-shadow: 0px 4px 8px rgba(0, 0, 0, 0.1);
256
+ transition: transform 0.3s, box-shadow 0.3s;
257
+ }
258
+
259
+ .gradio-container img:hover {
260
+ transform: scale(1.05);
261
+ box-shadow: 0px 6px 12px rgba(0, 0, 0, 0.2);
262
+ }
263
+
264
+ /* Checkbox Styling */
265
+ .gradio-container input[type="checkbox"] {
266
+ width: 20px;
267
+ height: 20px;
268
+ cursor: pointer;
269
+ accent-color: #007bff;
270
+ transition: transform 0.3s;
271
+ }
272
+
273
+ .gradio-container input[type="checkbox"]:checked {
274
+ transform: scale(1.2);
275
+ }
276
+
277
+ /* Audio and Document Upload Styling */
278
+ .gradio-container .gr-file-upload input[type="file"] {
279
+ width: 100%;
280
+ padding: 10px;
281
+ border: 2px solid #007bff;
282
+ border-radius: 8px;
283
+ cursor: pointer;
284
+ background-color: white;
285
+ transition: border-color 0.3s, background-color 0.3s;
286
+ }
287
+
288
+ .gradio-container .gr-file-upload input[type="file"]:hover {
289
+ border-color: #00c6ff;
290
+ background-color: #f0f8ff;
291
+ }
292
+
293
+ /* Advanced Tooltip Styling */
294
+ .gradio-container .gr-tooltip {
295
+ position: relative;
296
+ display: inline-block;
297
+ cursor: pointer;
298
+ }
299
+
300
+ .gradio-container .gr-tooltip .tooltiptext {
301
+ visibility: hidden;
302
+ width: 200px;
303
+ background-color: black;
304
+ color: #fff;
305
+ text-align: center;
306
+ border-radius: 6px;
307
+ padding: 5px;
308
+ position: absolute;
309
+ z-index: 1;
310
+ bottom: 125%;
311
+ left: 50%;
312
+ margin-left: -100px;
313
+ opacity: 0;
314
+ transition: opacity 0.3s;
315
+ }
316
+
317
+ .gradio-container .gr-tooltip:hover .tooltiptext {
318
+ visibility: visible;
319
+ opacity: 1;
320
  }
321
+
322
+ /* Footer Styling */
323
+ .gradio-container footer {
324
+ text-align: center;
325
+ padding: 10px;
326
+ background: #007bff;
327
+ color: white;
328
+ font-size: 0.9em;
329
+ border-radius: 0 0 10px 10px;
330
+ box-shadow: 0px -2px 8px rgba(0, 0, 0, 0.1);
331
+ }
332
+
333
  """) as demo:
334
  gr.Markdown("# AI Assistant")
335
  with gr.Row():
 
394
  else:
395
  return response, None
396
 
 
397
  # Launch the UI
398
  demo = create_ui()
399
  demo.launch()