Omar Solano
commited on
Commit
Β·
0cfc98f
1
Parent(s):
e0aadb4
add metadata filtering
Browse files- scripts/gradio-ui.py +45 -31
scripts/gradio-ui.py
CHANGED
@@ -8,6 +8,11 @@ from llama_index.vector_stores.chroma import ChromaVectorStore
|
|
8 |
from llama_index.core import VectorStoreIndex
|
9 |
from llama_index.embeddings.openai import OpenAIEmbedding
|
10 |
from llama_index.llms.openai import OpenAI
|
|
|
|
|
|
|
|
|
|
|
11 |
import gradio as gr
|
12 |
from gradio.themes.utils import (
|
13 |
fonts,
|
@@ -62,34 +67,9 @@ index = VectorStoreIndex.from_vector_store(vector_store=vector_store)
|
|
62 |
# Initialize query engine
|
63 |
llm = OpenAI(temperature=0, model="gpt-3.5-turbo-0125", max_tokens=None)
|
64 |
embeds = OpenAIEmbedding(model="text-embedding-3-large", mode="text_search")
|
65 |
-
query_engine = index.as_query_engine(
|
66 |
-
|
67 |
-
)
|
68 |
-
|
69 |
-
|
70 |
-
AVAILABLE_SOURCES_UI = [
|
71 |
-
"Gen AI 360: LLMs",
|
72 |
-
"Gen AI 360: LangChain",
|
73 |
-
"Gen AI 360: Advanced RAG",
|
74 |
-
"Towards AI Blog",
|
75 |
-
"Activeloop Docs",
|
76 |
-
"HF Transformers Docs",
|
77 |
-
"Wikipedia",
|
78 |
-
"OpenAI Docs",
|
79 |
-
"LangChain Docs",
|
80 |
-
]
|
81 |
-
|
82 |
-
AVAILABLE_SOURCES = [
|
83 |
-
"llm_course",
|
84 |
-
"langchain_course",
|
85 |
-
"advanced_rag_course",
|
86 |
-
"towards_ai",
|
87 |
-
"activeloop",
|
88 |
-
"hf_transformers",
|
89 |
-
"wikipedia",
|
90 |
-
"openai",
|
91 |
-
"langchain_docs",
|
92 |
-
]
|
93 |
|
94 |
|
95 |
def save_completion(completion, history):
|
@@ -178,6 +158,8 @@ def format_sources(completion) -> str:
|
|
178 |
|
179 |
|
180 |
def add_sources(history, completion):
|
|
|
|
|
181 |
|
182 |
formatted_sources = format_sources(completion)
|
183 |
history.append([None, formatted_sources])
|
@@ -192,10 +174,35 @@ def user(user_input, history):
|
|
192 |
|
193 |
def get_answer(history, sources: Optional[list[str]] = None):
|
194 |
user_input = history[-1][0]
|
|
|
|
|
|
|
|
|
|
|
|
|
195 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
196 |
completion = query_engine.query(user_input)
|
197 |
|
198 |
-
history[-1][1] = ""
|
199 |
for token in completion.response_gen:
|
200 |
history[-1][1] += token
|
201 |
yield history, completion
|
@@ -224,6 +231,13 @@ with gr.Blocks(
|
|
224 |
|
225 |
latest_completion = gr.State()
|
226 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
227 |
chatbot = gr.Chatbot(
|
228 |
elem_id="chatbot", show_copy_button=True, scale=2, likeable=True
|
229 |
)
|
@@ -257,14 +271,14 @@ with gr.Blocks(
|
|
257 |
completion = gr.State()
|
258 |
|
259 |
submit.click(user, [question, chatbot], [question, chatbot], queue=False).then(
|
260 |
-
get_answer, inputs=[chatbot], outputs=[chatbot, completion]
|
261 |
).then(add_sources, inputs=[chatbot, completion], outputs=[chatbot])
|
262 |
# .then(
|
263 |
# save_completion, inputs=[completion, chatbot]
|
264 |
# )
|
265 |
|
266 |
question.submit(user, [question, chatbot], [question, chatbot], queue=False).then(
|
267 |
-
get_answer, inputs=[chatbot], outputs=[chatbot, completion]
|
268 |
).then(add_sources, inputs=[chatbot, completion], outputs=[chatbot])
|
269 |
# .then(
|
270 |
# save_completion, inputs=[completion, chatbot]
|
|
|
8 |
from llama_index.core import VectorStoreIndex
|
9 |
from llama_index.embeddings.openai import OpenAIEmbedding
|
10 |
from llama_index.llms.openai import OpenAI
|
11 |
+
from llama_index.core.vector_stores import (
|
12 |
+
MetadataFilters,
|
13 |
+
MetadataFilter,
|
14 |
+
FilterCondition,
|
15 |
+
)
|
16 |
import gradio as gr
|
17 |
from gradio.themes.utils import (
|
18 |
fonts,
|
|
|
67 |
# Initialize query engine
|
68 |
llm = OpenAI(temperature=0, model="gpt-3.5-turbo-0125", max_tokens=None)
|
69 |
embeds = OpenAIEmbedding(model="text-embedding-3-large", mode="text_search")
|
70 |
+
# query_engine = index.as_query_engine(
|
71 |
+
# llm=llm, similarity_top_k=5, embed_model=embeds, streaming=True
|
72 |
+
# )
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
|
74 |
|
75 |
def save_completion(completion, history):
|
|
|
158 |
|
159 |
|
160 |
def add_sources(history, completion):
|
161 |
+
if history[-1][1] == "No sources selected. Please select sources to search.":
|
162 |
+
return history
|
163 |
|
164 |
formatted_sources = format_sources(completion)
|
165 |
history.append([None, formatted_sources])
|
|
|
174 |
|
175 |
def get_answer(history, sources: Optional[list[str]] = None):
|
176 |
user_input = history[-1][0]
|
177 |
+
history[-1][1] = ""
|
178 |
+
|
179 |
+
if len(sources) == 0:
|
180 |
+
history[-1][1] = "No sources selected. Please select sources to search."
|
181 |
+
yield history, "No sources selected. Please select sources to search."
|
182 |
+
return
|
183 |
|
184 |
+
# Dynamically create filters list
|
185 |
+
display_ui_to_source = {
|
186 |
+
ui: src for ui, src in zip(AVAILABLE_SOURCES_UI, AVAILABLE_SOURCES)
|
187 |
+
}
|
188 |
+
sources_renamed = [display_ui_to_source[disp] for disp in sources]
|
189 |
+
dynamic_filters = [
|
190 |
+
MetadataFilter(key="source", value=source) for source in sources_renamed
|
191 |
+
]
|
192 |
+
|
193 |
+
filters = MetadataFilters(
|
194 |
+
filters=dynamic_filters,
|
195 |
+
condition=FilterCondition.OR,
|
196 |
+
)
|
197 |
+
query_engine = index.as_query_engine(
|
198 |
+
llm=llm,
|
199 |
+
similarity_top_k=5,
|
200 |
+
embed_model=embeds,
|
201 |
+
streaming=True,
|
202 |
+
filters=filters,
|
203 |
+
)
|
204 |
completion = query_engine.query(user_input)
|
205 |
|
|
|
206 |
for token in completion.response_gen:
|
207 |
history[-1][1] += token
|
208 |
yield history, completion
|
|
|
231 |
|
232 |
latest_completion = gr.State()
|
233 |
|
234 |
+
source_selection = gr.Dropdown(
|
235 |
+
choices=AVAILABLE_SOURCES_UI,
|
236 |
+
label="Select Sources",
|
237 |
+
value=AVAILABLE_SOURCES_UI,
|
238 |
+
multiselect=True,
|
239 |
+
)
|
240 |
+
|
241 |
chatbot = gr.Chatbot(
|
242 |
elem_id="chatbot", show_copy_button=True, scale=2, likeable=True
|
243 |
)
|
|
|
271 |
completion = gr.State()
|
272 |
|
273 |
submit.click(user, [question, chatbot], [question, chatbot], queue=False).then(
|
274 |
+
get_answer, inputs=[chatbot, source_selection], outputs=[chatbot, completion]
|
275 |
).then(add_sources, inputs=[chatbot, completion], outputs=[chatbot])
|
276 |
# .then(
|
277 |
# save_completion, inputs=[completion, chatbot]
|
278 |
# )
|
279 |
|
280 |
question.submit(user, [question, chatbot], [question, chatbot], queue=False).then(
|
281 |
+
get_answer, inputs=[chatbot, source_selection], outputs=[chatbot, completion]
|
282 |
).then(add_sources, inputs=[chatbot, completion], outputs=[chatbot])
|
283 |
# .then(
|
284 |
# save_completion, inputs=[completion, chatbot]
|