Omar Solano commited on
Commit
0cfc98f
Β·
1 Parent(s): e0aadb4

add metadata filtering

Browse files
Files changed (1) hide show
  1. scripts/gradio-ui.py +45 -31
scripts/gradio-ui.py CHANGED
@@ -8,6 +8,11 @@ from llama_index.vector_stores.chroma import ChromaVectorStore
8
  from llama_index.core import VectorStoreIndex
9
  from llama_index.embeddings.openai import OpenAIEmbedding
10
  from llama_index.llms.openai import OpenAI
 
 
 
 
 
11
  import gradio as gr
12
  from gradio.themes.utils import (
13
  fonts,
@@ -62,34 +67,9 @@ index = VectorStoreIndex.from_vector_store(vector_store=vector_store)
62
  # Initialize query engine
63
  llm = OpenAI(temperature=0, model="gpt-3.5-turbo-0125", max_tokens=None)
64
  embeds = OpenAIEmbedding(model="text-embedding-3-large", mode="text_search")
65
- query_engine = index.as_query_engine(
66
- llm=llm, similarity_top_k=5, embed_model=embeds, streaming=True
67
- )
68
-
69
-
70
- AVAILABLE_SOURCES_UI = [
71
- "Gen AI 360: LLMs",
72
- "Gen AI 360: LangChain",
73
- "Gen AI 360: Advanced RAG",
74
- "Towards AI Blog",
75
- "Activeloop Docs",
76
- "HF Transformers Docs",
77
- "Wikipedia",
78
- "OpenAI Docs",
79
- "LangChain Docs",
80
- ]
81
-
82
- AVAILABLE_SOURCES = [
83
- "llm_course",
84
- "langchain_course",
85
- "advanced_rag_course",
86
- "towards_ai",
87
- "activeloop",
88
- "hf_transformers",
89
- "wikipedia",
90
- "openai",
91
- "langchain_docs",
92
- ]
93
 
94
 
95
  def save_completion(completion, history):
@@ -178,6 +158,8 @@ def format_sources(completion) -> str:
178
 
179
 
180
  def add_sources(history, completion):
 
 
181
 
182
  formatted_sources = format_sources(completion)
183
  history.append([None, formatted_sources])
@@ -192,10 +174,35 @@ def user(user_input, history):
192
 
193
  def get_answer(history, sources: Optional[list[str]] = None):
194
  user_input = history[-1][0]
 
 
 
 
 
 
195
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
  completion = query_engine.query(user_input)
197
 
198
- history[-1][1] = ""
199
  for token in completion.response_gen:
200
  history[-1][1] += token
201
  yield history, completion
@@ -224,6 +231,13 @@ with gr.Blocks(
224
 
225
  latest_completion = gr.State()
226
 
 
 
 
 
 
 
 
227
  chatbot = gr.Chatbot(
228
  elem_id="chatbot", show_copy_button=True, scale=2, likeable=True
229
  )
@@ -257,14 +271,14 @@ with gr.Blocks(
257
  completion = gr.State()
258
 
259
  submit.click(user, [question, chatbot], [question, chatbot], queue=False).then(
260
- get_answer, inputs=[chatbot], outputs=[chatbot, completion]
261
  ).then(add_sources, inputs=[chatbot, completion], outputs=[chatbot])
262
  # .then(
263
  # save_completion, inputs=[completion, chatbot]
264
  # )
265
 
266
  question.submit(user, [question, chatbot], [question, chatbot], queue=False).then(
267
- get_answer, inputs=[chatbot], outputs=[chatbot, completion]
268
  ).then(add_sources, inputs=[chatbot, completion], outputs=[chatbot])
269
  # .then(
270
  # save_completion, inputs=[completion, chatbot]
 
8
  from llama_index.core import VectorStoreIndex
9
  from llama_index.embeddings.openai import OpenAIEmbedding
10
  from llama_index.llms.openai import OpenAI
11
+ from llama_index.core.vector_stores import (
12
+ MetadataFilters,
13
+ MetadataFilter,
14
+ FilterCondition,
15
+ )
16
  import gradio as gr
17
  from gradio.themes.utils import (
18
  fonts,
 
67
  # Initialize query engine
68
  llm = OpenAI(temperature=0, model="gpt-3.5-turbo-0125", max_tokens=None)
69
  embeds = OpenAIEmbedding(model="text-embedding-3-large", mode="text_search")
70
+ # query_engine = index.as_query_engine(
71
+ # llm=llm, similarity_top_k=5, embed_model=embeds, streaming=True
72
+ # )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
 
75
  def save_completion(completion, history):
 
158
 
159
 
160
  def add_sources(history, completion):
161
+ if history[-1][1] == "No sources selected. Please select sources to search.":
162
+ return history
163
 
164
  formatted_sources = format_sources(completion)
165
  history.append([None, formatted_sources])
 
174
 
175
  def get_answer(history, sources: Optional[list[str]] = None):
176
  user_input = history[-1][0]
177
+ history[-1][1] = ""
178
+
179
+ if len(sources) == 0:
180
+ history[-1][1] = "No sources selected. Please select sources to search."
181
+ yield history, "No sources selected. Please select sources to search."
182
+ return
183
 
184
+ # Dynamically create filters list
185
+ display_ui_to_source = {
186
+ ui: src for ui, src in zip(AVAILABLE_SOURCES_UI, AVAILABLE_SOURCES)
187
+ }
188
+ sources_renamed = [display_ui_to_source[disp] for disp in sources]
189
+ dynamic_filters = [
190
+ MetadataFilter(key="source", value=source) for source in sources_renamed
191
+ ]
192
+
193
+ filters = MetadataFilters(
194
+ filters=dynamic_filters,
195
+ condition=FilterCondition.OR,
196
+ )
197
+ query_engine = index.as_query_engine(
198
+ llm=llm,
199
+ similarity_top_k=5,
200
+ embed_model=embeds,
201
+ streaming=True,
202
+ filters=filters,
203
+ )
204
  completion = query_engine.query(user_input)
205
 
 
206
  for token in completion.response_gen:
207
  history[-1][1] += token
208
  yield history, completion
 
231
 
232
  latest_completion = gr.State()
233
 
234
+ source_selection = gr.Dropdown(
235
+ choices=AVAILABLE_SOURCES_UI,
236
+ label="Select Sources",
237
+ value=AVAILABLE_SOURCES_UI,
238
+ multiselect=True,
239
+ )
240
+
241
  chatbot = gr.Chatbot(
242
  elem_id="chatbot", show_copy_button=True, scale=2, likeable=True
243
  )
 
271
  completion = gr.State()
272
 
273
  submit.click(user, [question, chatbot], [question, chatbot], queue=False).then(
274
+ get_answer, inputs=[chatbot, source_selection], outputs=[chatbot, completion]
275
  ).then(add_sources, inputs=[chatbot, completion], outputs=[chatbot])
276
  # .then(
277
  # save_completion, inputs=[completion, chatbot]
278
  # )
279
 
280
  question.submit(user, [question, chatbot], [question, chatbot], queue=False).then(
281
+ get_answer, inputs=[chatbot, source_selection], outputs=[chatbot, completion]
282
  ).then(add_sources, inputs=[chatbot, completion], outputs=[chatbot])
283
  # .then(
284
  # save_completion, inputs=[completion, chatbot]