Omar Solano
commited on
Commit
Β·
1cb00df
1
Parent(s):
4639509
add multiple retrievers and ChatSummaryMemory
Browse files- scripts/gradio-ui.py +124 -22
scripts/gradio-ui.py
CHANGED
@@ -10,7 +10,7 @@ from dotenv import load_dotenv
|
|
10 |
from llama_index.agent.openai import OpenAIAgent
|
11 |
from llama_index.core import VectorStoreIndex
|
12 |
from llama_index.core.llms import MessageRole
|
13 |
-
from llama_index.core.memory import ChatMemoryBuffer
|
14 |
from llama_index.core.node_parser import SentenceSplitter
|
15 |
from llama_index.core.retrievers import VectorIndexRetriever
|
16 |
from llama_index.core.tools import RetrieverTool, ToolMetadata
|
@@ -32,23 +32,21 @@ logfire.configure()
|
|
32 |
CONCURRENCY_COUNT = int(os.getenv("CONCURRENCY_COUNT", 64))
|
33 |
MONGODB_URI = os.getenv("MONGODB_URI")
|
34 |
|
35 |
-
DB_PATH = os.getenv("DB_PATH", f"scripts/ai-tutor-vector-db")
|
36 |
-
DB_COLLECTION = os.getenv("DB_NAME", "ai-tutor-vector-db")
|
37 |
|
38 |
-
if not os.path.exists(
|
39 |
# Download the vector database from the Hugging Face Hub if it doesn't exist locally
|
40 |
# https://huggingface.co/datasets/towardsai-buster/ai-tutor-db/tree/main
|
41 |
logfire.warn(
|
42 |
-
f"Vector database does not exist at
|
43 |
)
|
44 |
from huggingface_hub import snapshot_download
|
45 |
|
46 |
snapshot_download(
|
47 |
repo_id="towardsai-buster/ai-tutor-vector-db",
|
48 |
-
local_dir=
|
49 |
repo_type="dataset",
|
50 |
)
|
51 |
-
logfire.info(f"Downloaded vector database to
|
52 |
|
53 |
AVAILABLE_SOURCES_UI = [
|
54 |
"HF Transformers",
|
@@ -77,6 +75,8 @@ AVAILABLE_SOURCES = [
|
|
77 |
# else logfire.warn("No mongodb uri found, you will not be able to save data.")
|
78 |
# )
|
79 |
|
|
|
|
|
80 |
|
81 |
db2 = chromadb.PersistentClient(path=DB_PATH)
|
82 |
chroma_collection = db2.get_or_create_collection(DB_COLLECTION)
|
@@ -95,29 +95,104 @@ vector_retriever = VectorIndexRetriever(
|
|
95 |
use_async=True,
|
96 |
embed_model=OpenAIEmbedding(model="text-embedding-3-large", mode="similarity"),
|
97 |
)
|
98 |
-
with open("
|
99 |
document_dict = pickle.load(f)
|
100 |
|
101 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
|
103 |
|
104 |
def format_sources(completion) -> str:
|
105 |
if len(completion.sources) == 0:
|
106 |
return ""
|
107 |
|
108 |
-
# Mapping of source system names to user-friendly names
|
109 |
display_source_to_ui = {
|
110 |
src: ui for src, ui in zip(AVAILABLE_SOURCES, AVAILABLE_SOURCES_UI)
|
111 |
}
|
112 |
|
113 |
documents_answer_template: str = (
|
114 |
-
"π Here are the sources I used to answer your question:\n
|
115 |
)
|
116 |
document_template: str = "[π {source}: {title}]({url}), relevance: {score:2.2f}"
|
117 |
|
118 |
-
|
119 |
-
|
120 |
-
|
|
|
121 |
title=src.metadata["title"],
|
122 |
score=src.score,
|
123 |
source=display_source_to_ui.get(
|
@@ -125,9 +200,9 @@ def format_sources(completion) -> str:
|
|
125 |
),
|
126 |
url=src.metadata["url"],
|
127 |
)
|
128 |
-
|
129 |
-
|
130 |
-
)
|
131 |
|
132 |
return documents_answer_template.format(documents=documents)
|
133 |
|
@@ -199,12 +274,34 @@ def generate_completion(
|
|
199 |
|
200 |
query_engine_tools = [
|
201 |
RetrieverTool(
|
202 |
-
retriever=
|
203 |
metadata=ToolMetadata(
|
204 |
-
name="
|
205 |
-
description="""
|
206 |
),
|
207 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
208 |
]
|
209 |
|
210 |
agent = OpenAIAgent.from_tools(
|
@@ -282,7 +379,12 @@ with gr.Blocks(
|
|
282 |
title="Towards AI π€",
|
283 |
analytics_enabled=True,
|
284 |
) as demo:
|
285 |
-
|
|
|
|
|
|
|
|
|
|
|
286 |
chatbot = gr.Chatbot(
|
287 |
scale=1,
|
288 |
placeholder="<strong>Towards AI π€: A Question-Answering Bot for anything AI-related</strong><br>",
|
|
|
10 |
from llama_index.agent.openai import OpenAIAgent
|
11 |
from llama_index.core import VectorStoreIndex
|
12 |
from llama_index.core.llms import MessageRole
|
13 |
+
from llama_index.core.memory import ChatMemoryBuffer, ChatSummaryMemoryBuffer
|
14 |
from llama_index.core.node_parser import SentenceSplitter
|
15 |
from llama_index.core.retrievers import VectorIndexRetriever
|
16 |
from llama_index.core.tools import RetrieverTool, ToolMetadata
|
|
|
32 |
CONCURRENCY_COUNT = int(os.getenv("CONCURRENCY_COUNT", 64))
|
33 |
MONGODB_URI = os.getenv("MONGODB_URI")
|
34 |
|
|
|
|
|
35 |
|
36 |
+
if not os.path.exists("data/chroma-db-transformers"):
|
37 |
# Download the vector database from the Hugging Face Hub if it doesn't exist locally
|
38 |
# https://huggingface.co/datasets/towardsai-buster/ai-tutor-db/tree/main
|
39 |
logfire.warn(
|
40 |
+
f"Vector database does not exist at 'data/chroma-db-transformers', downloading from Hugging Face Hub"
|
41 |
)
|
42 |
from huggingface_hub import snapshot_download
|
43 |
|
44 |
snapshot_download(
|
45 |
repo_id="towardsai-buster/ai-tutor-vector-db",
|
46 |
+
local_dir="data",
|
47 |
repo_type="dataset",
|
48 |
)
|
49 |
+
logfire.info(f"Downloaded vector database to 'data/chroma-db-transformers'")
|
50 |
|
51 |
AVAILABLE_SOURCES_UI = [
|
52 |
"HF Transformers",
|
|
|
75 |
# else logfire.warn("No mongodb uri found, you will not be able to save data.")
|
76 |
# )
|
77 |
|
78 |
+
DB_PATH = os.getenv("DB_PATH", "data/chroma-db-transformers")
|
79 |
+
DB_COLLECTION = os.getenv("DB_NAME", "chroma-db-transformers")
|
80 |
|
81 |
db2 = chromadb.PersistentClient(path=DB_PATH)
|
82 |
chroma_collection = db2.get_or_create_collection(DB_COLLECTION)
|
|
|
95 |
use_async=True,
|
96 |
embed_model=OpenAIEmbedding(model="text-embedding-3-large", mode="similarity"),
|
97 |
)
|
98 |
+
with open(f"{DB_PATH}/document_dict_tf.pkl", "rb") as f:
|
99 |
document_dict = pickle.load(f)
|
100 |
|
101 |
+
custom_retriever_tf = CustomRetriever(vector_retriever, document_dict)
|
102 |
+
|
103 |
+
DB_PATH = os.getenv("DB_PATH", "data/chroma-db-peft")
|
104 |
+
DB_COLLECTION = os.getenv("DB_NAME", "chroma-db-peft")
|
105 |
+
|
106 |
+
db2 = chromadb.PersistentClient(path=DB_PATH)
|
107 |
+
chroma_collection = db2.get_or_create_collection(DB_COLLECTION)
|
108 |
+
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
|
109 |
+
|
110 |
+
index = VectorStoreIndex.from_vector_store(
|
111 |
+
vector_store=vector_store,
|
112 |
+
embed_model=OpenAIEmbedding(model="text-embedding-3-large", mode="similarity"),
|
113 |
+
transformations=[SentenceSplitter(chunk_size=800, chunk_overlap=400)],
|
114 |
+
show_progress=True,
|
115 |
+
use_async=True,
|
116 |
+
)
|
117 |
+
vector_retriever = VectorIndexRetriever(
|
118 |
+
index=index,
|
119 |
+
similarity_top_k=10,
|
120 |
+
use_async=True,
|
121 |
+
embed_model=OpenAIEmbedding(model="text-embedding-3-large", mode="similarity"),
|
122 |
+
)
|
123 |
+
with open(f"{DB_PATH}/document_dict_peft.pkl", "rb") as f:
|
124 |
+
document_dict = pickle.load(f)
|
125 |
+
|
126 |
+
custom_retriever_peft = CustomRetriever(vector_retriever, document_dict)
|
127 |
+
|
128 |
+
DB_PATH = os.getenv("DB_PATH", f"data/chroma-db-trl")
|
129 |
+
DB_COLLECTION = os.getenv("DB_NAME", "chroma-db-trl")
|
130 |
+
|
131 |
+
db2 = chromadb.PersistentClient(path=DB_PATH)
|
132 |
+
chroma_collection = db2.get_or_create_collection(DB_COLLECTION)
|
133 |
+
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
|
134 |
+
|
135 |
+
index = VectorStoreIndex.from_vector_store(
|
136 |
+
vector_store=vector_store,
|
137 |
+
embed_model=OpenAIEmbedding(model="text-embedding-3-large", mode="similarity"),
|
138 |
+
transformations=[SentenceSplitter(chunk_size=800, chunk_overlap=400)],
|
139 |
+
show_progress=True,
|
140 |
+
use_async=True,
|
141 |
+
)
|
142 |
+
vector_retriever = VectorIndexRetriever(
|
143 |
+
index=index,
|
144 |
+
similarity_top_k=10,
|
145 |
+
use_async=True,
|
146 |
+
embed_model=OpenAIEmbedding(model="text-embedding-3-large", mode="similarity"),
|
147 |
+
)
|
148 |
+
with open(f"{DB_PATH}/document_dict_trl.pkl", "rb") as f:
|
149 |
+
document_dict = pickle.load(f)
|
150 |
+
|
151 |
+
custom_retriever_trl = CustomRetriever(vector_retriever, document_dict)
|
152 |
+
|
153 |
+
DB_PATH = os.getenv("DB_PATH", "data/chroma-db-llama-index")
|
154 |
+
DB_COLLECTION = os.getenv("DB_NAME", "chroma-db-llama-index")
|
155 |
+
|
156 |
+
db2 = chromadb.PersistentClient(path=DB_PATH)
|
157 |
+
chroma_collection = db2.get_or_create_collection(DB_COLLECTION)
|
158 |
+
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
|
159 |
+
|
160 |
+
index = VectorStoreIndex.from_vector_store(
|
161 |
+
vector_store=vector_store,
|
162 |
+
embed_model=OpenAIEmbedding(model="text-embedding-3-large", mode="similarity"),
|
163 |
+
transformations=[SentenceSplitter(chunk_size=800, chunk_overlap=400)],
|
164 |
+
show_progress=True,
|
165 |
+
use_async=True,
|
166 |
+
)
|
167 |
+
vector_retriever = VectorIndexRetriever(
|
168 |
+
index=index,
|
169 |
+
similarity_top_k=10,
|
170 |
+
use_async=True,
|
171 |
+
embed_model=OpenAIEmbedding(model="text-embedding-3-large", mode="similarity"),
|
172 |
+
)
|
173 |
+
with open(f"{DB_PATH}/document_dict_llamaindex.pkl", "rb") as f:
|
174 |
+
document_dict = pickle.load(f)
|
175 |
+
|
176 |
+
custom_retriever_llamaindex = CustomRetriever(vector_retriever, document_dict)
|
177 |
|
178 |
|
179 |
def format_sources(completion) -> str:
|
180 |
if len(completion.sources) == 0:
|
181 |
return ""
|
182 |
|
|
|
183 |
display_source_to_ui = {
|
184 |
src: ui for src, ui in zip(AVAILABLE_SOURCES, AVAILABLE_SOURCES_UI)
|
185 |
}
|
186 |
|
187 |
documents_answer_template: str = (
|
188 |
+
"π Here are the sources I used to answer your question:\n{documents}"
|
189 |
)
|
190 |
document_template: str = "[π {source}: {title}]({url}), relevance: {score:2.2f}"
|
191 |
|
192 |
+
all_documents = []
|
193 |
+
for source in completion.sources:
|
194 |
+
for src in source.raw_output:
|
195 |
+
document = document_template.format(
|
196 |
title=src.metadata["title"],
|
197 |
score=src.score,
|
198 |
source=display_source_to_ui.get(
|
|
|
200 |
),
|
201 |
url=src.metadata["url"],
|
202 |
)
|
203 |
+
all_documents.append(document)
|
204 |
+
|
205 |
+
documents = "\n".join(all_documents)
|
206 |
|
207 |
return documents_answer_template.format(documents=documents)
|
208 |
|
|
|
274 |
|
275 |
query_engine_tools = [
|
276 |
RetrieverTool(
|
277 |
+
retriever=custom_retriever_tf,
|
278 |
metadata=ToolMetadata(
|
279 |
+
name="Transformers_information",
|
280 |
+
description="""Useful for general questions asking about the artificial intelligence (AI) field. Employ this tool to fetch general information on topics such as language models theory (transformer architectures), tips on prompting, models, quantization, etc.""",
|
281 |
),
|
282 |
+
),
|
283 |
+
RetrieverTool(
|
284 |
+
retriever=custom_retriever_peft,
|
285 |
+
metadata=ToolMetadata(
|
286 |
+
name="PEFT_information",
|
287 |
+
description=" Useful for questions asking about efficient LLM fine-tuning. Employ this tool to fetch information on topics such as LoRA, QLoRA, etc."
|
288 |
+
"",
|
289 |
+
),
|
290 |
+
),
|
291 |
+
RetrieverTool(
|
292 |
+
retriever=custom_retriever_trl,
|
293 |
+
metadata=ToolMetadata(
|
294 |
+
name="TRL_information",
|
295 |
+
description="""Useful for questions asking about fine-tuning LLMs with reinforcement learning (RLHF). Includes information about the Supervised Fine-tuning step (SFT), Reward Modeling step (RM), and the Proximal Policy Optimization (PPO) step.""",
|
296 |
+
),
|
297 |
+
),
|
298 |
+
RetrieverTool(
|
299 |
+
retriever=custom_retriever_llamaindex,
|
300 |
+
metadata=ToolMetadata(
|
301 |
+
name="LlamaIndex_information",
|
302 |
+
description="""Useful for questions asking about retrieval augmented generation (RAG) with LLMs and embedding models. It is the documentation of the LlamaIndex framework, includes info about fine-tuning embedding models, building chatbots, and agents, using vector databases, embeddings, information retrieval with cosine similarity or bm25, etc.""",
|
303 |
+
),
|
304 |
+
),
|
305 |
]
|
306 |
|
307 |
agent = OpenAIAgent.from_tools(
|
|
|
379 |
title="Towards AI π€",
|
380 |
analytics_enabled=True,
|
381 |
) as demo:
|
382 |
+
|
383 |
+
memory = gr.State(
|
384 |
+
ChatSummaryMemoryBuffer.from_defaults(
|
385 |
+
token_limit=120000,
|
386 |
+
)
|
387 |
+
)
|
388 |
chatbot = gr.Chatbot(
|
389 |
scale=1,
|
390 |
placeholder="<strong>Towards AI π€: A Question-Answering Bot for anything AI-related</strong><br>",
|