Omar Solano commited on
Commit
1cb00df
Β·
1 Parent(s): 4639509

add multiple retrievers and ChatSummaryMemory

Browse files
Files changed (1) hide show
  1. scripts/gradio-ui.py +124 -22
scripts/gradio-ui.py CHANGED
@@ -10,7 +10,7 @@ from dotenv import load_dotenv
10
  from llama_index.agent.openai import OpenAIAgent
11
  from llama_index.core import VectorStoreIndex
12
  from llama_index.core.llms import MessageRole
13
- from llama_index.core.memory import ChatMemoryBuffer
14
  from llama_index.core.node_parser import SentenceSplitter
15
  from llama_index.core.retrievers import VectorIndexRetriever
16
  from llama_index.core.tools import RetrieverTool, ToolMetadata
@@ -32,23 +32,21 @@ logfire.configure()
32
  CONCURRENCY_COUNT = int(os.getenv("CONCURRENCY_COUNT", 64))
33
  MONGODB_URI = os.getenv("MONGODB_URI")
34
 
35
- DB_PATH = os.getenv("DB_PATH", f"scripts/ai-tutor-vector-db")
36
- DB_COLLECTION = os.getenv("DB_NAME", "ai-tutor-vector-db")
37
 
38
- if not os.path.exists(DB_PATH):
39
  # Download the vector database from the Hugging Face Hub if it doesn't exist locally
40
  # https://huggingface.co/datasets/towardsai-buster/ai-tutor-db/tree/main
41
  logfire.warn(
42
- f"Vector database does not exist at {DB_PATH}, downloading from Hugging Face Hub"
43
  )
44
  from huggingface_hub import snapshot_download
45
 
46
  snapshot_download(
47
  repo_id="towardsai-buster/ai-tutor-vector-db",
48
- local_dir=DB_PATH,
49
  repo_type="dataset",
50
  )
51
- logfire.info(f"Downloaded vector database to {DB_PATH}")
52
 
53
  AVAILABLE_SOURCES_UI = [
54
  "HF Transformers",
@@ -77,6 +75,8 @@ AVAILABLE_SOURCES = [
77
  # else logfire.warn("No mongodb uri found, you will not be able to save data.")
78
  # )
79
 
 
 
80
 
81
  db2 = chromadb.PersistentClient(path=DB_PATH)
82
  chroma_collection = db2.get_or_create_collection(DB_COLLECTION)
@@ -95,29 +95,104 @@ vector_retriever = VectorIndexRetriever(
95
  use_async=True,
96
  embed_model=OpenAIEmbedding(model="text-embedding-3-large", mode="similarity"),
97
  )
98
- with open("scripts/ai-tutor-vector-db/document_dict.pkl", "rb") as f:
99
  document_dict = pickle.load(f)
100
 
101
- custom_retriever = CustomRetriever(vector_retriever, document_dict)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
 
104
  def format_sources(completion) -> str:
105
  if len(completion.sources) == 0:
106
  return ""
107
 
108
- # Mapping of source system names to user-friendly names
109
  display_source_to_ui = {
110
  src: ui for src, ui in zip(AVAILABLE_SOURCES, AVAILABLE_SOURCES_UI)
111
  }
112
 
113
  documents_answer_template: str = (
114
- "πŸ“ Here are the sources I used to answer your question:\n\n{documents}"
115
  )
116
  document_template: str = "[πŸ”— {source}: {title}]({url}), relevance: {score:2.2f}"
117
 
118
- documents = "\n".join(
119
- [
120
- document_template.format(
 
121
  title=src.metadata["title"],
122
  score=src.score,
123
  source=display_source_to_ui.get(
@@ -125,9 +200,9 @@ def format_sources(completion) -> str:
125
  ),
126
  url=src.metadata["url"],
127
  )
128
- for src in completion.sources[0].raw_output
129
- ]
130
- )
131
 
132
  return documents_answer_template.format(documents=documents)
133
 
@@ -199,12 +274,34 @@ def generate_completion(
199
 
200
  query_engine_tools = [
201
  RetrieverTool(
202
- retriever=custom_retriever,
203
  metadata=ToolMetadata(
204
- name="AI_information",
205
- description="""Only use this tool if necessary. The 'AI_information' tool returns information about the artificial intelligence (AI) field. When using this tool, the input should be the user's question rewritten as a statement. e.g. When the user asks 'How can I quantize a model?', the input should be 'Model quantization'. The input can also be adapted to focus on specific aspects or further details of the current topic under discussion. This dynamic input approach allows for a tailored exploration of AI subjects, ensuring that responses are relevant and informative. Employ this tool to fetch nuanced information on topics such as model training, fine-tuning, and LLM augmentation, thereby facilitating a rich, context-aware dialogue. """,
206
  ),
207
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
  ]
209
 
210
  agent = OpenAIAgent.from_tools(
@@ -282,7 +379,12 @@ with gr.Blocks(
282
  title="Towards AI πŸ€–",
283
  analytics_enabled=True,
284
  ) as demo:
285
- memory = gr.State(ChatMemoryBuffer.from_defaults(token_limit=120000))
 
 
 
 
 
286
  chatbot = gr.Chatbot(
287
  scale=1,
288
  placeholder="<strong>Towards AI πŸ€–: A Question-Answering Bot for anything AI-related</strong><br>",
 
10
  from llama_index.agent.openai import OpenAIAgent
11
  from llama_index.core import VectorStoreIndex
12
  from llama_index.core.llms import MessageRole
13
+ from llama_index.core.memory import ChatMemoryBuffer, ChatSummaryMemoryBuffer
14
  from llama_index.core.node_parser import SentenceSplitter
15
  from llama_index.core.retrievers import VectorIndexRetriever
16
  from llama_index.core.tools import RetrieverTool, ToolMetadata
 
32
  CONCURRENCY_COUNT = int(os.getenv("CONCURRENCY_COUNT", 64))
33
  MONGODB_URI = os.getenv("MONGODB_URI")
34
 
 
 
35
 
36
+ if not os.path.exists("data/chroma-db-transformers"):
37
  # Download the vector database from the Hugging Face Hub if it doesn't exist locally
38
  # https://huggingface.co/datasets/towardsai-buster/ai-tutor-db/tree/main
39
  logfire.warn(
40
+ f"Vector database does not exist at 'data/chroma-db-transformers', downloading from Hugging Face Hub"
41
  )
42
  from huggingface_hub import snapshot_download
43
 
44
  snapshot_download(
45
  repo_id="towardsai-buster/ai-tutor-vector-db",
46
+ local_dir="data",
47
  repo_type="dataset",
48
  )
49
+ logfire.info(f"Downloaded vector database to 'data/chroma-db-transformers'")
50
 
51
  AVAILABLE_SOURCES_UI = [
52
  "HF Transformers",
 
75
  # else logfire.warn("No mongodb uri found, you will not be able to save data.")
76
  # )
77
 
78
+ DB_PATH = os.getenv("DB_PATH", "data/chroma-db-transformers")
79
+ DB_COLLECTION = os.getenv("DB_NAME", "chroma-db-transformers")
80
 
81
  db2 = chromadb.PersistentClient(path=DB_PATH)
82
  chroma_collection = db2.get_or_create_collection(DB_COLLECTION)
 
95
  use_async=True,
96
  embed_model=OpenAIEmbedding(model="text-embedding-3-large", mode="similarity"),
97
  )
98
+ with open(f"{DB_PATH}/document_dict_tf.pkl", "rb") as f:
99
  document_dict = pickle.load(f)
100
 
101
+ custom_retriever_tf = CustomRetriever(vector_retriever, document_dict)
102
+
103
+ DB_PATH = os.getenv("DB_PATH", "data/chroma-db-peft")
104
+ DB_COLLECTION = os.getenv("DB_NAME", "chroma-db-peft")
105
+
106
+ db2 = chromadb.PersistentClient(path=DB_PATH)
107
+ chroma_collection = db2.get_or_create_collection(DB_COLLECTION)
108
+ vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
109
+
110
+ index = VectorStoreIndex.from_vector_store(
111
+ vector_store=vector_store,
112
+ embed_model=OpenAIEmbedding(model="text-embedding-3-large", mode="similarity"),
113
+ transformations=[SentenceSplitter(chunk_size=800, chunk_overlap=400)],
114
+ show_progress=True,
115
+ use_async=True,
116
+ )
117
+ vector_retriever = VectorIndexRetriever(
118
+ index=index,
119
+ similarity_top_k=10,
120
+ use_async=True,
121
+ embed_model=OpenAIEmbedding(model="text-embedding-3-large", mode="similarity"),
122
+ )
123
+ with open(f"{DB_PATH}/document_dict_peft.pkl", "rb") as f:
124
+ document_dict = pickle.load(f)
125
+
126
+ custom_retriever_peft = CustomRetriever(vector_retriever, document_dict)
127
+
128
+ DB_PATH = os.getenv("DB_PATH", f"data/chroma-db-trl")
129
+ DB_COLLECTION = os.getenv("DB_NAME", "chroma-db-trl")
130
+
131
+ db2 = chromadb.PersistentClient(path=DB_PATH)
132
+ chroma_collection = db2.get_or_create_collection(DB_COLLECTION)
133
+ vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
134
+
135
+ index = VectorStoreIndex.from_vector_store(
136
+ vector_store=vector_store,
137
+ embed_model=OpenAIEmbedding(model="text-embedding-3-large", mode="similarity"),
138
+ transformations=[SentenceSplitter(chunk_size=800, chunk_overlap=400)],
139
+ show_progress=True,
140
+ use_async=True,
141
+ )
142
+ vector_retriever = VectorIndexRetriever(
143
+ index=index,
144
+ similarity_top_k=10,
145
+ use_async=True,
146
+ embed_model=OpenAIEmbedding(model="text-embedding-3-large", mode="similarity"),
147
+ )
148
+ with open(f"{DB_PATH}/document_dict_trl.pkl", "rb") as f:
149
+ document_dict = pickle.load(f)
150
+
151
+ custom_retriever_trl = CustomRetriever(vector_retriever, document_dict)
152
+
153
+ DB_PATH = os.getenv("DB_PATH", "data/chroma-db-llama-index")
154
+ DB_COLLECTION = os.getenv("DB_NAME", "chroma-db-llama-index")
155
+
156
+ db2 = chromadb.PersistentClient(path=DB_PATH)
157
+ chroma_collection = db2.get_or_create_collection(DB_COLLECTION)
158
+ vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
159
+
160
+ index = VectorStoreIndex.from_vector_store(
161
+ vector_store=vector_store,
162
+ embed_model=OpenAIEmbedding(model="text-embedding-3-large", mode="similarity"),
163
+ transformations=[SentenceSplitter(chunk_size=800, chunk_overlap=400)],
164
+ show_progress=True,
165
+ use_async=True,
166
+ )
167
+ vector_retriever = VectorIndexRetriever(
168
+ index=index,
169
+ similarity_top_k=10,
170
+ use_async=True,
171
+ embed_model=OpenAIEmbedding(model="text-embedding-3-large", mode="similarity"),
172
+ )
173
+ with open(f"{DB_PATH}/document_dict_llamaindex.pkl", "rb") as f:
174
+ document_dict = pickle.load(f)
175
+
176
+ custom_retriever_llamaindex = CustomRetriever(vector_retriever, document_dict)
177
 
178
 
179
  def format_sources(completion) -> str:
180
  if len(completion.sources) == 0:
181
  return ""
182
 
 
183
  display_source_to_ui = {
184
  src: ui for src, ui in zip(AVAILABLE_SOURCES, AVAILABLE_SOURCES_UI)
185
  }
186
 
187
  documents_answer_template: str = (
188
+ "πŸ“ Here are the sources I used to answer your question:\n{documents}"
189
  )
190
  document_template: str = "[πŸ”— {source}: {title}]({url}), relevance: {score:2.2f}"
191
 
192
+ all_documents = []
193
+ for source in completion.sources:
194
+ for src in source.raw_output:
195
+ document = document_template.format(
196
  title=src.metadata["title"],
197
  score=src.score,
198
  source=display_source_to_ui.get(
 
200
  ),
201
  url=src.metadata["url"],
202
  )
203
+ all_documents.append(document)
204
+
205
+ documents = "\n".join(all_documents)
206
 
207
  return documents_answer_template.format(documents=documents)
208
 
 
274
 
275
  query_engine_tools = [
276
  RetrieverTool(
277
+ retriever=custom_retriever_tf,
278
  metadata=ToolMetadata(
279
+ name="Transformers_information",
280
+ description="""Useful for general questions asking about the artificial intelligence (AI) field. Employ this tool to fetch general information on topics such as language models theory (transformer architectures), tips on prompting, models, quantization, etc.""",
281
  ),
282
+ ),
283
+ RetrieverTool(
284
+ retriever=custom_retriever_peft,
285
+ metadata=ToolMetadata(
286
+ name="PEFT_information",
287
+ description=" Useful for questions asking about efficient LLM fine-tuning. Employ this tool to fetch information on topics such as LoRA, QLoRA, etc."
288
+ "",
289
+ ),
290
+ ),
291
+ RetrieverTool(
292
+ retriever=custom_retriever_trl,
293
+ metadata=ToolMetadata(
294
+ name="TRL_information",
295
+ description="""Useful for questions asking about fine-tuning LLMs with reinforcement learning (RLHF). Includes information about the Supervised Fine-tuning step (SFT), Reward Modeling step (RM), and the Proximal Policy Optimization (PPO) step.""",
296
+ ),
297
+ ),
298
+ RetrieverTool(
299
+ retriever=custom_retriever_llamaindex,
300
+ metadata=ToolMetadata(
301
+ name="LlamaIndex_information",
302
+ description="""Useful for questions asking about retrieval augmented generation (RAG) with LLMs and embedding models. It is the documentation of the LlamaIndex framework, includes info about fine-tuning embedding models, building chatbots, and agents, using vector databases, embeddings, information retrieval with cosine similarity or bm25, etc.""",
303
+ ),
304
+ ),
305
  ]
306
 
307
  agent = OpenAIAgent.from_tools(
 
379
  title="Towards AI πŸ€–",
380
  analytics_enabled=True,
381
  ) as demo:
382
+
383
+ memory = gr.State(
384
+ ChatSummaryMemoryBuffer.from_defaults(
385
+ token_limit=120000,
386
+ )
387
+ )
388
  chatbot = gr.Chatbot(
389
  scale=1,
390
  placeholder="<strong>Towards AI πŸ€–: A Question-Answering Bot for anything AI-related</strong><br>",