VanguardAI commited on
Commit
5db9d8c
·
verified ·
1 Parent(s): 9748b93

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -76
app.py CHANGED
@@ -8,13 +8,11 @@ from transformers import AutoModel, AutoTokenizer
8
  from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler
9
  from parler_tts import ParlerTTSForConditionalGeneration
10
  import soundfile as sf
11
- from langchain_community.embeddings import OpenAIEmbeddings
12
- from langchain_community.vectorstores import Chroma
13
- from langchain.text_splitter import RecursiveCharacterTextSplitter
14
- from langchain.chains import RetrievalQA
15
- from langchain import LLMChain, PromptTemplate
16
- from langchain.agents import AgentExecutor, Tool, ZeroShotAgent
17
- from langchain.llms import OpenAI
18
  from PIL import Image
19
  from decord import VideoReader, cpu
20
  from tavily import TavilyClient
@@ -89,32 +87,29 @@ def image_generation(query):
89
 
90
  # Document Question Answering Tool
91
  def doc_question_answering(query, file_path):
92
- with open(file_path, 'r') as f:
93
- file_content = f.read()
94
-
95
- # Split the document into smaller chunks
96
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
97
- docs = text_splitter.create_documents([file_content])
98
-
99
- # Create embeddings using the groq model
100
- embeddings = OpenAIEmbeddings() # If you're using a custom embeddings model, replace this line with the corresponding embeddings model for groq
101
-
102
- # Set up the Chroma database for document retrieval
103
- db = Chroma.from_documents(docs, embeddings, persist_directory=".chroma_db")
104
-
105
- # Create a custom function to use groq for the question-answering step
106
- def groq_llm(query):
107
- response = client.chat.completions.create(
108
- model=MODEL,
109
- messages=[{"role": "user", "content": query}]
110
- )
111
- return response.choices[0].message.content
112
-
113
- # Set up the RetrievalQA chain using the custom groq LLM function
114
- qa = RetrievalQA.from_chain_type(llm=groq_llm, chain_type="stuff", retriever=db.as_retriever())
115
-
116
- # Run the QA process with the groq model
117
- return qa.run(query)
118
 
119
  # Function to handle different input types and choose the right tool
120
  def handle_input(user_prompt, image=None, video=None, audio=None, doc=None, websearch=False):
@@ -128,53 +123,21 @@ def handle_input(user_prompt, image=None, video=None, audio=None, doc=None, webs
128
  user_prompt = transcription.text
129
 
130
  tools = [
131
- Tool(
132
- name="Numpy Code Calculator",
133
- func=numpy_code_calculator,
134
- description="Useful for when you need to perform mathematical calculations using NumPy. Provide the calculation you want to perform.",
135
- ),
136
- Tool(
137
- name="Web Search",
138
- func=web_search,
139
- description="Useful for when you need to find information from the real world.",
140
- ),
141
- Tool(
142
- name="Image Generation",
143
- func=image_generation,
144
- description="Useful for when you need to generate an image based on a description.",
145
- ),
146
  ]
147
 
148
  if doc:
149
  tools.append(
150
- Tool(
151
- name="Document Question Answering",
152
- func=lambda query: doc_question_answering(query, doc.name),
153
- description="Useful for when you need to answer questions about the uploaded document.",
154
  )
155
  )
156
 
157
- # Add this new code block:
158
- prefix = """You are an AI assistant. You have access to the following tools:"""
159
- suffix = """Begin!"
160
-
161
- {chat_history}
162
- Human: {input}
163
- AI: I will do my best to assist you. Let me think about this step-by-step:"""
164
-
165
- prompt = ZeroShotAgent.create_prompt(
166
- tools,
167
- prefix=prefix,
168
- suffix=suffix,
169
- input_variables=["input", "chat_history"]
170
- )
171
-
172
- llm = Groq(model=MODEL)
173
- llm_chain = LLMChain(llm=llm, prompt=prompt)
174
-
175
- agent = ZeroShotAgent(llm_chain=llm_chain, tools=tools, verbose=True)
176
-
177
- agent_executor = AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, verbose=True)
178
 
179
  if image:
180
  image = Image.open(image).convert('RGB')
@@ -183,9 +146,9 @@ def handle_input(user_prompt, image=None, video=None, audio=None, doc=None, webs
183
  return response
184
 
185
  if websearch:
186
- response = agent_executor.run(f"{user_prompt} Use the Web Search tool if necessary.")
187
  else:
188
- response = agent_executor.run(user_prompt)
189
 
190
  return response
191
 
@@ -245,4 +208,4 @@ def main_interface(user_prompt, image=None, audio=None, doc=None, voice_only=Fal
245
 
246
  # Launch the UI
247
  demo = create_ui()
248
- demo.launch()
 
8
  from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler
9
  from parler_tts import ParlerTTSForConditionalGeneration
10
  import soundfile as sf
11
+ from llama_index import GPTSimpleVectorIndex, SimpleDirectoryReader, LLMPredictor, PromptHelper
12
+ from llama_index.embeddings import GroqEmbedding
13
+ from llama_index.llms import GroqLLM
14
+ from llama_index.agent import ReActAgent
15
+ from llama_index.tools import FunctionTool
 
 
16
  from PIL import Image
17
  from decord import VideoReader, cpu
18
  from tavily import TavilyClient
 
87
 
88
  # Document Question Answering Tool
89
  def doc_question_answering(query, file_path):
90
+ # Load documents
91
+ documents = SimpleDirectoryReader(input_files=[file_path]).load_data()
92
+
93
+ # Initialize Groq embedding model
94
+ embed_model = GroqEmbedding()
95
+
96
+ # Initialize Groq LLM
97
+ llm_predictor = LLMPredictor(llm=GroqLLM(model_name=MODEL))
98
+
99
+ # Initialize prompt helper
100
+ prompt_helper = PromptHelper()
101
+
102
+ # Create index
103
+ index = GPTSimpleVectorIndex.from_documents(
104
+ documents,
105
+ embed_model=embed_model,
106
+ llm_predictor=llm_predictor,
107
+ prompt_helper=prompt_helper
108
+ )
109
+
110
+ # Query the index
111
+ response = index.query(query)
112
+ return response.response
 
 
 
113
 
114
  # Function to handle different input types and choose the right tool
115
  def handle_input(user_prompt, image=None, video=None, audio=None, doc=None, websearch=False):
 
123
  user_prompt = transcription.text
124
 
125
  tools = [
126
+ FunctionTool.from_defaults(fn=numpy_code_calculator, name="Numpy Code Calculator"),
127
+ FunctionTool.from_defaults(fn=web_search, name="Web Search"),
128
+ FunctionTool.from_defaults(fn=image_generation, name="Image Generation"),
 
 
 
 
 
 
 
 
 
 
 
 
129
  ]
130
 
131
  if doc:
132
  tools.append(
133
+ FunctionTool.from_defaults(
134
+ fn=lambda query: doc_question_answering(query, doc.name),
135
+ name="Document Question Answering"
 
136
  )
137
  )
138
 
139
+ llm = GroqLLM(model_name=MODEL)
140
+ agent = ReActAgent.from_tools(tools, llm=llm, verbose=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
 
142
  if image:
143
  image = Image.open(image).convert('RGB')
 
146
  return response
147
 
148
  if websearch:
149
+ response = agent.chat(f"{user_prompt} Use the Web Search tool if necessary.")
150
  else:
151
+ response = agent.chat(user_prompt)
152
 
153
  return response
154
 
 
208
 
209
  # Launch the UI
210
  demo = create_ui()
211
+ demo.launch()