VanguardAI commited on
Commit
2a200be
·
verified ·
1 Parent(s): 6da2d3a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -114
app.py CHANGED
@@ -3,7 +3,6 @@ import torch
3
  import os
4
  import numpy as np
5
  from groq import Groq
6
- import spaces
7
  from transformers import AutoModel, AutoTokenizer
8
  from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler
9
  from parler_tts import ParlerTTSForConditionalGeneration
@@ -12,13 +11,14 @@ from langchain_community.embeddings import OpenAIEmbeddings
12
  from langchain_community.vectorstores import Chroma
13
  from langchain.text_splitter import RecursiveCharacterTextSplitter
14
  from langchain.chains import RetrievalQA
15
- from langchain_community.llms import OpenAI
16
  from PIL import Image
17
  from decord import VideoReader, cpu
 
18
  import requests
19
  from huggingface_hub import hf_hub_download
20
  from safetensors.torch import load_file
21
 
 
22
  client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
23
  MODEL = 'llama3-groq-70b-8192-tool-use-preview'
24
 
@@ -39,7 +39,10 @@ unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device="cuda"))
39
  image_pipe = StableDiffusionXLPipeline.from_pretrained(base, unet=unet, torch_dtype=torch.float16, variant="fp16").to("cuda")
40
  image_pipe.scheduler = EulerDiscreteScheduler.from_config(image_pipe.scheduler.config, timestep_spacing="trailing")
41
 
42
- # Initialize voice-only mode
 
 
 
43
  def play_voice_output(response):
44
  description = "Jon's voice is monotone yet slightly fast in delivery, with a very close recording that almost has no background noise."
45
  input_ids = tts_tokenizer(description, return_tensors="pt").input_ids.to('cuda')
@@ -49,18 +52,6 @@ def play_voice_output(response):
49
  sf.write("output.wav", audio_arr, tts_model.config.sampling_rate)
50
  return "output.wav"
51
 
52
- # Web search function
53
- def web_search(query):
54
- api_key = os.environ.get("BING_API_KEY")
55
- search_url = "https://api.bing.microsoft.com/v7.0/search"
56
- headers = {"Ocp-Apim-Subscription-Key": api_key}
57
- params = {"q": query, "textDecorations": True, "textFormat": "HTML"}
58
- response = requests.get(search_url, headers=headers, params=params)
59
- response.raise_for_status()
60
- search_results = response.json()
61
- snippets = [result['snippet'] for result in search_results.get('webPages', {}).get('value', [])]
62
- return "\n".join(snippets)
63
-
64
  # NumPy Calculation function
65
  def numpy_calculate(code: str) -> str:
66
  try:
@@ -71,37 +62,6 @@ def numpy_calculate(code: str) -> str:
71
  except Exception as e:
72
  return f"An error occurred: {str(e)}"
73
 
74
- # Function to handle different input types
75
- def handle_input(user_prompt, image=None, video=None, audio=None, doc=None):
76
- messages = [{"role": "user", "content": user_prompt}]
77
-
78
- if audio:
79
- transcription = client.audio.transcriptions.create(
80
- file=(audio.name, audio.read()),
81
- model="whisper-large-v3"
82
- )
83
- user_prompt = transcription.text
84
-
85
- if doc:
86
- # RAG with Langchain
87
- response = use_langchain_rag(doc.name, doc.read(), user_prompt)
88
- elif image and not video:
89
- image = Image.open(image).convert('RGB')
90
- messages[0]['content'] = [image, user_prompt]
91
- response = text_model.chat(image=None, msgs=messages, tokenizer=tokenizer)
92
- elif video:
93
- frames = encode_video(video.name)
94
- messages[0]['content'] = frames + [user_prompt]
95
- response = text_model.chat(image=None, msgs=messages, tokenizer=tokenizer)
96
- else:
97
- response = client.chat.completions.create(
98
- model=MODEL,
99
- messages=messages,
100
- tools=initialize_tools()
101
- ).choices[0].message.content
102
-
103
- return response
104
-
105
  # Function to use Langchain for RAG
106
  def use_langchain_rag(file_name, file_content, query):
107
  # Split the document into chunks
@@ -130,64 +90,58 @@ def encode_video(video_path):
130
  frames = [Image.fromarray(v.astype('uint8')) for v in frames]
131
  return frames
132
 
133
- # Initialize tools with web search and NumPy calculation
134
- def initialize_tools():
135
- tools = [
136
- {
137
- "type": "function",
138
- "function": {
139
- "name": "calculate",
140
- "description": "Evaluate a mathematical expression",
141
- "parameters": {
142
- "type": "object",
143
- "properties": {
144
- "expression": {"type": "string", "description": "The mathematical expression to evaluate"}
145
- },
146
- "required": ["expression"]
147
- },
148
- }
149
- },
150
- {
151
- "type": "function",
152
- "function": {
153
- "name": "web_search",
154
- "description": "Perform a web search",
155
- "parameters": {
156
- "type": "object",
157
- "properties": {
158
- "query": {"type": "string", "description": "The search query"}
159
- },
160
- "required": ["query"]
161
- },
162
- "implementation": web_search
163
- }
164
- },
165
- {
166
- "type": "function",
167
- "function": {
168
- "name": "numpy_calculate",
169
- "description": "Execute NumPy-based Python code for calculations",
170
- "parameters": {
171
- "type": "object",
172
- "properties": {
173
- "code": {"type": "string", "description": "The Python code with NumPy operations"}
174
- },
175
- "required": ["code"]
176
- },
177
- "implementation": numpy_calculate
178
- }
179
- }
180
- ]
181
- return tools
182
 
183
  @spaces.GPU()
184
- def main_interface(user_prompt, image=None, video=None, audio=None, doc=None, voice_only=False):
185
  text_model.to(device='cuda', dtype=torch.bfloat16)
186
  tts_model.to("cuda")
187
  unet.to("cuda", torch.float16)
188
  image_pipe.to("cuda")
189
 
190
- response = handle_input(user_prompt, image=image, video=video, audio=audio, doc=doc)
191
 
192
  if voice_only:
193
  audio_file = play_voice_output(response)
@@ -195,22 +149,46 @@ def main_interface(user_prompt, image=None, video=None, audio=None, doc=None, vo
195
  else:
196
  return response, None # Return only the text output, no audio
197
 
198
- # Gradio App Setup
199
- with gr.Blocks() as demo:
200
- user_prompt = gr.Textbox(placeholder="Type your message here...", lines=1)
201
- image_input = gr.Image(type="filepath", label="Upload an image")
202
- video_input = gr.Video(label="Upload a video")
203
- audio_input = gr.Audio(type="filepath", label="Upload audio")
204
- doc_input = gr.File(type="filepath", label="Upload a document")
205
- voice_only_mode = gr.Checkbox(label="Enable Voice Only Mode")
206
- output_label = gr.Label(label="Output")
207
- audio_output = gr.Audio(label="Audio Output", visible=False)
208
-
209
- submit = gr.Button("Submit")
210
- submit.click(
211
- fn=main_interface,
212
- inputs=[user_prompt, image_input, video_input, audio_input, doc_input, voice_only_mode],
213
- outputs=[output_label, audio_output] # Expecting a string and audio file
214
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
 
 
 
216
  demo.launch(inline=False)
 
3
  import os
4
  import numpy as np
5
  from groq import Groq
 
6
  from transformers import AutoModel, AutoTokenizer
7
  from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler
8
  from parler_tts import ParlerTTSForConditionalGeneration
 
11
  from langchain_community.vectorstores import Chroma
12
  from langchain.text_splitter import RecursiveCharacterTextSplitter
13
  from langchain.chains import RetrievalQA
 
14
  from PIL import Image
15
  from decord import VideoReader, cpu
16
+ from tavily import TavilyClient
17
  import requests
18
  from huggingface_hub import hf_hub_download
19
  from safetensors.torch import load_file
20
 
21
+ # Initialize models
22
  client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
23
  MODEL = 'llama3-groq-70b-8192-tool-use-preview'
24
 
 
39
  image_pipe = StableDiffusionXLPipeline.from_pretrained(base, unet=unet, torch_dtype=torch.float16, variant="fp16").to("cuda")
40
  image_pipe.scheduler = EulerDiscreteScheduler.from_config(image_pipe.scheduler.config, timestep_spacing="trailing")
41
 
42
+ # Tavily Client
43
+ tavily_client = TavilyClient(api_key="tvly-YOUR_API_KEY")
44
+
45
+ # Voice output function
46
  def play_voice_output(response):
47
  description = "Jon's voice is monotone yet slightly fast in delivery, with a very close recording that almost has no background noise."
48
  input_ids = tts_tokenizer(description, return_tensors="pt").input_ids.to('cuda')
 
52
  sf.write("output.wav", audio_arr, tts_model.config.sampling_rate)
53
  return "output.wav"
54
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  # NumPy Calculation function
56
  def numpy_calculate(code: str) -> str:
57
  try:
 
62
  except Exception as e:
63
  return f"An error occurred: {str(e)}"
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  # Function to use Langchain for RAG
66
  def use_langchain_rag(file_name, file_content, query):
67
  # Split the document into chunks
 
90
  frames = [Image.fromarray(v.astype('uint8')) for v in frames]
91
  return frames
92
 
93
+ # Web search function
94
+ def web_search(query):
95
+ answer = tavily_client.qna_search(query=query)
96
+ return answer
97
+
98
+ # Function to handle different input types
99
+ def handle_input(user_prompt, image=None, video=None, audio=None, doc=None, websearch=False):
100
+ # Voice input handling
101
+ if audio:
102
+ transcription = client.audio.transcriptions.create(
103
+ file=(audio.name, audio.read()),
104
+ model="whisper-large-v3"
105
+ )
106
+ user_prompt = transcription.text
107
+
108
+ # If user uploaded an image and text, use MiniCPM model
109
+ if image:
110
+ image = Image.open(image).convert('RGB')
111
+ messages = [{"role": "user", "content": [image, user_prompt]}]
112
+ response = text_model.chat(image=None, msgs=messages, tokenizer=tokenizer)
113
+ return response
114
+
115
+ # Determine which tool to use
116
+ if doc:
117
+ file_content = doc.read().decode('utf-8')
118
+ response = use_langchain_rag(doc.name, file_content, user_prompt)
119
+ elif "calculate" in user_prompt.lower():
120
+ response = numpy_calculate(user_prompt)
121
+ elif "generate" in user_prompt.lower() and ("image" in user_prompt.lower() or "picture" in user_prompt.lower()):
122
+ response = image_pipe(prompt=user_prompt, num_inference_steps=20, guidance_scale=7.5)
123
+ elif websearch:
124
+ response = web_search(user_prompt)
125
+ else:
126
+ chat_completion = client.chat.completions.create(
127
+ messages=[
128
+ {"role": "system", "content": "You are a helpful assistant."},
129
+ {"role": "user", "content": user_prompt}
130
+ ],
131
+ model=MODEL,
132
+ )
133
+ response = chat_completion.choices[0].message.content
134
+
135
+ return response
 
 
 
 
 
 
136
 
137
  @spaces.GPU()
138
+ def main_interface(user_prompt, image=None, video=None, audio=None, doc=None, voice_only=False, websearch=False):
139
  text_model.to(device='cuda', dtype=torch.bfloat16)
140
  tts_model.to("cuda")
141
  unet.to("cuda", torch.float16)
142
  image_pipe.to("cuda")
143
 
144
+ response = handle_input(user_prompt, image=image, video=video, audio=audio, doc=doc, websearch=websearch)
145
 
146
  if voice_only:
147
  audio_file = play_voice_output(response)
 
149
  else:
150
  return response, None # Return only the text output, no audio
151
 
152
+ # Gradio UI Setup
153
+ def create_ui():
154
+ with gr.Blocks() as demo:
155
+ gr.Markdown("# AI Assistant")
156
+ with gr.Row():
157
+ with gr.Column(scale=2):
158
+ user_prompt = gr.Textbox(placeholder="Type your message here...", lines=1)
159
+ with gr.Column(scale=1):
160
+ image_input = gr.Image(type="filepath", label="Upload an image", elem_id="image-icon")
161
+ video_input = gr.Video(label="Upload a video", elem_id="video-icon")
162
+ audio_input = gr.Audio(type="filepath", label="Upload audio", elem_id="mic-icon")
163
+ doc_input = gr.File(type="filepath", label="Upload a document", elem_id="document-icon")
164
+ voice_only_mode = gr.Checkbox(label="Enable Voice Only Mode", elem_id="voice-only-mode")
165
+ websearch_mode = gr.Checkbox(label="Enable Web Search", elem_id="websearch-mode")
166
+ with gr.Column(scale=1):
167
+ submit = gr.Button("Submit")
168
+
169
+ output_label = gr.Label(label="Output")
170
+ audio_output = gr.Audio(label="Audio Output", visible=False)
171
+
172
+ submit.click(
173
+ fn=main_interface,
174
+ inputs=[user_prompt, image_input, video_input, audio_input, doc_input, voice_only_mode, websearch_mode],
175
+ outputs=[output_label, audio_output] # Expecting a string and audio file
176
+ )
177
+
178
+ # Voice-only mode UI
179
+ voice_only_mode.change(
180
+ lambda x: gr.update(visible=not x),
181
+ inputs=voice_only_mode,
182
+ outputs=[user_prompt, image_input, video_input, doc_input, websearch_mode, submit]
183
+ )
184
+ voice_only_mode.change(
185
+ lambda x: gr.update(visible=x),
186
+ inputs=voice_only_mode,
187
+ outputs=[audio_input]
188
+ )
189
+
190
+ return demo
191
 
192
+ # Launch the app
193
+ demo = create_ui()
194
  demo.launch(inline=False)