Omar Solano commited on
Commit
9c1e8a7
Β·
1 Parent(s): 471ad41

update gradio chatbot to latest version

Browse files
Files changed (6) hide show
  1. README.md +1 -1
  2. requirements.txt +101 -169
  3. scripts/custom_retriever.py +237 -28
  4. scripts/main.py +82 -43
  5. scripts/prompts.py +116 -14
  6. scripts/setup.py +122 -34
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: πŸ§‘πŸ»β€πŸ«
4
  colorFrom: gray
5
  colorTo: pink
6
  sdk: gradio
7
- sdk_version: 4.42.0
8
  app_file: scripts/main.py
9
  pinned: false
10
  ---
 
4
  colorFrom: gray
5
  colorTo: pink
6
  sdk: gradio
7
+ sdk_version: 4.44.0
8
  app_file: scripts/main.py
9
  pinned: false
10
  ---
requirements.txt CHANGED
@@ -1,280 +1,212 @@
1
  aiofiles==23.2.1
2
  aiohappyeyeballs==2.4.0
3
- aiohttp==3.10.5
4
  aiosignal==1.3.1
 
5
  annotated-types==0.7.0
6
- anyio==4.4.0
7
  appnope==0.1.4
8
  asgiref==3.8.1
9
  asttokens==2.4.1
10
  attrs==24.2.0
11
- automat==24.8.1
12
- azure-core==1.30.2
13
- azure-identity==1.17.1
14
  backoff==2.2.1
15
  bcrypt==4.2.0
16
  beautifulsoup4==4.12.3
17
- bleach==6.1.0
18
- boto3==1.35.5
19
- botocore==1.35.5
20
- build==1.2.1
21
  cachetools==5.5.0
22
- certifi==2024.7.4
23
- cffi==1.17.0
24
  charset-normalizer==3.3.2
25
  chroma-hnswlib==0.7.6
26
- chromadb==0.5.5
27
  click==8.1.7
28
- cohere==5.8.1
29
  coloredlogs==15.0.1
30
  comm==0.2.2
31
- constantly==23.10.4
32
- contourpy==1.2.1
33
- cryptography==43.0.0
34
- cssselect==1.2.0
35
  cycler==0.12.1
36
  dataclasses-json==0.6.7
37
  debugpy==1.8.5
38
  decorator==5.1.1
39
- defusedxml==0.7.1
40
  deprecated==1.2.14
41
  dirtyjson==1.0.8
42
  distro==1.9.0
43
  dnspython==2.6.1
44
- docstring-parser==0.16
45
- executing==2.0.1
46
- fastapi==0.112.2
47
- fastavro==1.9.5
48
- fastjsonschema==2.20.0
49
  ffmpy==0.4.0
50
- filelock==3.15.4
51
  flatbuffers==24.3.25
52
- fonttools==4.53.1
53
  frozenlist==1.4.1
54
- fsspec==2024.6.1
55
  google-ai-generativelanguage==0.6.4
56
- google-api-core==2.19.1
57
- google-api-python-client==2.142.0
58
- google-auth==2.34.0
59
  google-auth-httplib2==0.2.0
60
- google-cloud-aiplatform==1.63.0
61
- google-cloud-bigquery==3.25.0
62
- google-cloud-core==2.4.1
63
- google-cloud-resource-manager==1.12.5
64
- google-cloud-storage==2.18.2
65
- google-crc32c==1.5.0
66
  google-generativeai==0.5.4
67
- google-resumable-media==2.7.2
68
- googleapis-common-protos==1.64.0
69
- gradio==4.42.0
70
  gradio-client==1.3.0
71
- greenlet==3.0.3
72
- grpc-google-iam-v1==0.13.1
73
- grpcio==1.66.0
74
  grpcio-status==1.62.3
 
75
  h11==0.14.0
 
 
76
  httpcore==1.0.5
77
  httplib2==0.22.0
78
  httptools==0.6.1
79
- httpx==0.27.0
80
  httpx-sse==0.4.0
81
- huggingface-hub==0.24.6
82
  humanfriendly==10.0
83
- hyperlink==21.0.0
84
- idna==3.8
85
- importlib-metadata==8.0.0
86
- importlib-resources==6.4.4
87
- incremental==24.7.2
88
- instructor==1.3.4
89
  ipykernel==6.29.5
90
- ipython==8.26.0
91
- itemadapter==0.9.0
92
- itemloaders==1.3.1
93
  jedi==0.19.1
94
  jinja2==3.1.4
95
- jiter==0.4.2
96
  jmespath==1.0.1
97
  joblib==1.4.2
98
- jsonpatch==1.33
99
- jsonpath-python==1.0.6
100
- jsonpointer==3.0.0
101
- jsonschema==4.23.0
102
- jsonschema-specifications==2023.12.1
103
- jupyter-client==8.6.2
104
  jupyter-core==5.7.2
105
- jupyterlab-pygments==0.3.0
106
- kiwisolver==1.4.5
107
- kubernetes==30.1.0
108
- langchain==0.2.14
109
- langchain-chroma==0.1.2
110
- langchain-core==0.2.35
111
- langchain-openai==0.1.22
112
- langchain-text-splitters==0.2.2
113
- langsmith==0.1.104
114
- llama-cloud==0.0.15
115
- llama-index==0.11.1
116
- llama-index-agent-openai==0.3.0
117
- llama-index-cli==0.3.0
118
- llama-index-core==0.11.1
119
- llama-index-embeddings-adapter==0.2.1
120
- llama-index-embeddings-cohere==0.2.0
121
- llama-index-embeddings-huggingface==0.3.1
122
- llama-index-embeddings-openai==0.2.3
123
- llama-index-finetuning==0.2.0
124
- llama-index-indices-managed-llama-cloud==0.3.0
125
  llama-index-legacy==0.9.48.post3
126
- llama-index-llms-azure-openai==0.2.0
127
- llama-index-llms-gemini==0.3.4
128
- llama-index-llms-mistralai==0.2.1
129
- llama-index-llms-openai==0.2.0
130
- llama-index-llms-replicate==0.2.0
131
- llama-index-multi-modal-llms-openai==0.2.0
132
- llama-index-postprocessor-cohere-rerank==0.2.0
133
  llama-index-program-openai==0.2.0
134
  llama-index-question-gen-openai==0.2.0
135
- llama-index-readers-file==0.2.0
136
- llama-index-readers-llama-parse==0.2.0
137
  llama-index-vector-stores-chroma==0.2.0
138
- llama-parse==0.5.0
139
- logfire==0.51.0
140
- lxml==5.3.0
141
  markdown-it-py==3.0.0
142
  markupsafe==2.1.5
143
  marshmallow==3.22.0
144
  matplotlib==3.9.2
145
  matplotlib-inline==0.1.7
146
  mdurl==0.1.2
147
- minijinja==2.0.1
148
- mistralai==1.0.2
149
- mistune==3.0.2
150
- mmh3==4.1.0
151
  monotonic==1.6
152
  mpmath==1.3.0
153
- msal==1.30.0
154
- msal-extensions==1.2.0
155
- multidict==6.0.5
156
  mypy-extensions==1.0.0
157
- nbclient==0.10.0
158
- nbconvert==7.16.4
159
- nbformat==5.10.4
160
  nest-asyncio==1.6.0
161
  networkx==3.3
162
  nltk==3.9.1
163
  numpy==1.26.4
164
  oauthlib==3.2.2
165
- onnxruntime==1.19.0
166
- openai==1.42.0
167
- opentelemetry-api==1.26.0
168
- opentelemetry-exporter-otlp-proto-common==1.26.0
169
- opentelemetry-exporter-otlp-proto-grpc==1.26.0
170
- opentelemetry-exporter-otlp-proto-http==1.26.0
171
- opentelemetry-instrumentation==0.47b0
172
- opentelemetry-instrumentation-asgi==0.47b0
173
- opentelemetry-instrumentation-fastapi==0.47b0
174
- opentelemetry-proto==1.26.0
175
- opentelemetry-sdk==1.26.0
176
- opentelemetry-semantic-conventions==0.47b0
177
- opentelemetry-util-http==0.47b0
178
  orjson==3.10.7
179
  overrides==7.7.0
180
  packaging==24.1
181
- pandas==2.2.2
182
- pandocfilters==1.5.1
183
  parameterized==0.9.0
184
- parsel==1.9.1
185
  parso==0.8.4
186
  pexpect==4.9.0
187
  pillow==10.4.0
188
- platformdirs==4.2.2
189
- portalocker==2.10.1
190
- posthog==3.5.2
191
  prompt-toolkit==3.0.47
192
- protego==0.3.1
193
  proto-plus==1.24.0
194
- protobuf==4.25.4
195
  psutil==6.0.0
196
  ptyprocess==0.7.0
197
  pure-eval==0.2.3
198
- pyasn1==0.6.0
199
- pyasn1-modules==0.4.0
200
- pycparser==2.22
201
- pydantic==2.8.2
202
- pydantic-core==2.20.1
203
- pydispatcher==2.0.7
204
  pydub==0.25.1
205
  pygments==2.18.0
206
- pyjwt==2.9.0
207
- pymongo==4.8.0
208
- pyopenssl==24.2.1
209
  pyparsing==3.1.4
210
  pypdf==4.3.1
211
  pypika==0.48.9
212
  pyproject-hooks==1.1.0
213
  python-dateutil==2.9.0.post0
214
  python-dotenv==1.0.1
215
- python-multipart==0.0.9
216
- pytz==2024.1
217
  pyyaml==6.0.2
218
  pyzmq==26.2.0
219
- queuelib==1.7.0
220
- referencing==0.35.1
221
- regex==2024.7.24
222
  requests==2.32.3
223
- requests-file==2.1.0
224
  requests-oauthlib==2.0.0
225
- rich==13.8.0
226
- rpds-py==0.20.0
227
  rsa==4.9
228
- ruff==0.6.2
229
  s3transfer==0.10.2
230
- safetensors==0.4.4
231
- scikit-learn==1.5.1
232
- scipy==1.14.1
233
- scrapy==2.11.2
234
  semantic-version==2.10.0
235
- sentence-transformers==2.7.0
236
- service-identity==24.1.0
237
- setuptools==73.0.1
238
- shapely==2.0.6
239
  shellingham==1.5.4
 
240
  six==1.16.0
241
  sniffio==1.3.1
242
  soupsieve==2.6
243
- sqlalchemy==2.0.32
244
  stack-data==0.6.3
245
- starlette==0.38.2
246
  striprtf==0.0.26
247
- sympy==1.13.2
248
- tabulate==0.9.0
249
- tenacity==8.3.0
250
- threadpoolctl==3.5.0
251
  tiktoken==0.7.0
252
- tinycss2==1.3.0
253
- tldextract==5.1.2
254
- tokenizers==0.19.1
255
  tomlkit==0.12.0
256
- torch==2.4.0
257
  tornado==6.4.1
258
  tqdm==4.66.5
259
  traitlets==5.14.3
260
- transformers==4.44.2
261
- twisted==24.7.0
262
  typer==0.12.5
263
- types-requests==2.32.0.20240712
 
 
264
  typing-extensions==4.12.2
265
  typing-inspect==0.9.0
266
- tzdata==2024.1
267
  uritemplate==4.1.1
268
- urllib3==2.2.2
269
  uvicorn==0.30.6
270
  uvloop==0.20.0
271
- w3lib==2.2.1
272
- watchfiles==0.23.0
273
  wcwidth==0.2.13
274
- webencodings==0.5.1
275
  websocket-client==1.8.0
276
  websockets==12.0
277
  wrapt==1.16.0
278
- yarl==1.9.4
279
- zipp==3.20.0
280
- zope-interface==7.0.1
 
1
  aiofiles==23.2.1
2
  aiohappyeyeballs==2.4.0
3
+ aiohttp==3.10.6
4
  aiosignal==1.3.1
5
+ aiostream==0.5.2
6
  annotated-types==0.7.0
7
+ anyio==4.6.0
8
  appnope==0.1.4
9
  asgiref==3.8.1
10
  asttokens==2.4.1
11
  attrs==24.2.0
 
 
 
12
  backoff==2.2.1
13
  bcrypt==4.2.0
14
  beautifulsoup4==4.12.3
15
+ boto3==1.35.26
16
+ botocore==1.35.26
17
+ build==1.2.2
 
18
  cachetools==5.5.0
19
+ certifi==2024.8.30
 
20
  charset-normalizer==3.3.2
21
  chroma-hnswlib==0.7.6
22
+ chromadb==0.5.7
23
  click==8.1.7
24
+ cohere==5.9.4
25
  coloredlogs==15.0.1
26
  comm==0.2.2
27
+ contourpy==1.3.0
 
 
 
28
  cycler==0.12.1
29
  dataclasses-json==0.6.7
30
  debugpy==1.8.5
31
  decorator==5.1.1
 
32
  deprecated==1.2.14
33
  dirtyjson==1.0.8
34
  distro==1.9.0
35
  dnspython==2.6.1
36
+ durationpy==0.7
37
+ executing==2.1.0
38
+ fastapi==0.115.0
39
+ fastavro==1.9.7
 
40
  ffmpy==0.4.0
41
+ filelock==3.16.1
42
  flatbuffers==24.3.25
43
+ fonttools==4.54.1
44
  frozenlist==1.4.1
45
+ fsspec==2024.9.0
46
  google-ai-generativelanguage==0.6.4
47
+ google-api-core==2.20.0
48
+ google-api-python-client==2.146.0
49
+ google-auth==2.35.0
50
  google-auth-httplib2==0.2.0
 
 
 
 
 
 
51
  google-generativeai==0.5.4
52
+ googleapis-common-protos==1.65.0
53
+ gradio==4.44.0
 
54
  gradio-client==1.3.0
55
+ greenlet==3.1.1
56
+ grpcio==1.66.1
 
57
  grpcio-status==1.62.3
58
+ grpclib==0.4.7
59
  h11==0.14.0
60
+ h2==4.1.0
61
+ hpack==4.0.0
62
  httpcore==1.0.5
63
  httplib2==0.22.0
64
  httptools==0.6.1
65
+ httpx==0.27.2
66
  httpx-sse==0.4.0
67
+ huggingface-hub==0.25.1
68
  humanfriendly==10.0
69
+ hyperframe==6.0.1
70
+ idna==3.10
71
+ importlib-metadata==8.4.0
72
+ importlib-resources==6.4.5
 
 
73
  ipykernel==6.29.5
74
+ ipython==8.27.0
 
 
75
  jedi==0.19.1
76
  jinja2==3.1.4
77
+ jiter==0.5.0
78
  jmespath==1.0.1
79
  joblib==1.4.2
80
+ jupyter-client==8.6.3
 
 
 
 
 
81
  jupyter-core==5.7.2
82
+ kiwisolver==1.4.7
83
+ kubernetes==31.0.0
84
+ llama-cloud==0.1.0
85
+ llama-index==0.11.13
86
+ llama-index-agent-openai==0.3.4
87
+ llama-index-cli==0.3.1
88
+ llama-index-core==0.11.13.post1
89
+ llama-index-embeddings-cohere==0.2.1
90
+ llama-index-embeddings-openai==0.2.5
91
+ llama-index-indices-managed-llama-cloud==0.4.0
 
 
 
 
 
 
 
 
 
 
92
  llama-index-legacy==0.9.48.post3
93
+ llama-index-llms-gemini==0.3.5
94
+ llama-index-llms-openai==0.2.9
95
+ llama-index-multi-modal-llms-openai==0.2.1
96
+ llama-index-postprocessor-cohere-rerank==0.2.1
 
 
 
97
  llama-index-program-openai==0.2.0
98
  llama-index-question-gen-openai==0.2.0
99
+ llama-index-readers-file==0.2.2
100
+ llama-index-readers-llama-parse==0.3.0
101
  llama-index-vector-stores-chroma==0.2.0
102
+ llama-parse==0.5.6
103
+ logfire==0.53.0
 
104
  markdown-it-py==3.0.0
105
  markupsafe==2.1.5
106
  marshmallow==3.22.0
107
  matplotlib==3.9.2
108
  matplotlib-inline==0.1.7
109
  mdurl==0.1.2
110
+ mmh3==5.0.1
111
+ modal==0.64.136
 
 
112
  monotonic==1.6
113
  mpmath==1.3.0
114
+ multidict==6.1.0
 
 
115
  mypy-extensions==1.0.0
 
 
 
116
  nest-asyncio==1.6.0
117
  networkx==3.3
118
  nltk==3.9.1
119
  numpy==1.26.4
120
  oauthlib==3.2.2
121
+ onnxruntime==1.19.2
122
+ openai==1.47.1
123
+ opentelemetry-api==1.27.0
124
+ opentelemetry-exporter-otlp-proto-common==1.27.0
125
+ opentelemetry-exporter-otlp-proto-grpc==1.27.0
126
+ opentelemetry-exporter-otlp-proto-http==1.27.0
127
+ opentelemetry-instrumentation==0.48b0
128
+ opentelemetry-instrumentation-asgi==0.48b0
129
+ opentelemetry-instrumentation-fastapi==0.48b0
130
+ opentelemetry-proto==1.27.0
131
+ opentelemetry-sdk==1.27.0
132
+ opentelemetry-semantic-conventions==0.48b0
133
+ opentelemetry-util-http==0.48b0
134
  orjson==3.10.7
135
  overrides==7.7.0
136
  packaging==24.1
137
+ pandas==2.2.3
 
138
  parameterized==0.9.0
 
139
  parso==0.8.4
140
  pexpect==4.9.0
141
  pillow==10.4.0
142
+ platformdirs==4.3.6
143
+ posthog==3.6.6
 
144
  prompt-toolkit==3.0.47
 
145
  proto-plus==1.24.0
146
+ protobuf==4.25.5
147
  psutil==6.0.0
148
  ptyprocess==0.7.0
149
  pure-eval==0.2.3
150
+ pyasn1==0.6.1
151
+ pyasn1-modules==0.4.1
152
+ pydantic==2.9.2
153
+ pydantic-core==2.23.4
 
 
154
  pydub==0.25.1
155
  pygments==2.18.0
156
+ pymongo==4.9.1
 
 
157
  pyparsing==3.1.4
158
  pypdf==4.3.1
159
  pypika==0.48.9
160
  pyproject-hooks==1.1.0
161
  python-dateutil==2.9.0.post0
162
  python-dotenv==1.0.1
163
+ python-multipart==0.0.10
164
+ pytz==2024.2
165
  pyyaml==6.0.2
166
  pyzmq==26.2.0
167
+ regex==2024.9.11
 
 
168
  requests==2.32.3
 
169
  requests-oauthlib==2.0.0
170
+ rich==13.8.1
 
171
  rsa==4.9
172
+ ruff==0.6.7
173
  s3transfer==0.10.2
 
 
 
 
174
  semantic-version==2.10.0
175
+ setuptools==75.1.0
 
 
 
176
  shellingham==1.5.4
177
+ sigtools==4.0.1
178
  six==1.16.0
179
  sniffio==1.3.1
180
  soupsieve==2.6
181
+ sqlalchemy==2.0.35
182
  stack-data==0.6.3
183
+ starlette==0.38.6
184
  striprtf==0.0.26
185
+ sympy==1.13.3
186
+ synchronicity==0.7.6
187
+ tenacity==8.5.0
 
188
  tiktoken==0.7.0
189
+ tokenizers==0.20.0
190
+ toml==0.10.2
 
191
  tomlkit==0.12.0
 
192
  tornado==6.4.1
193
  tqdm==4.66.5
194
  traitlets==5.14.3
 
 
195
  typer==0.12.5
196
+ types-certifi==2021.10.8.3
197
+ types-requests==2.32.0.20240914
198
+ types-toml==0.10.8.20240310
199
  typing-extensions==4.12.2
200
  typing-inspect==0.9.0
201
+ tzdata==2024.2
202
  uritemplate==4.1.1
203
+ urllib3==2.2.3
204
  uvicorn==0.30.6
205
  uvloop==0.20.0
206
+ watchfiles==0.24.0
 
207
  wcwidth==0.2.13
 
208
  websocket-client==1.8.0
209
  websockets==12.0
210
  wrapt==1.16.0
211
+ yarl==1.12.1
212
+ zipp==3.20.2
 
scripts/custom_retriever.py CHANGED
@@ -1,11 +1,75 @@
 
1
  import time
2
- from typing import List
 
3
 
4
  import logfire
 
 
5
  from llama_index.core import QueryBundle
 
 
6
  from llama_index.core.retrievers import BaseRetriever, VectorIndexRetriever
7
- from llama_index.core.schema import NodeWithScore, TextNode
8
  from llama_index.postprocessor.cohere_rerank import CohereRerank
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
 
11
  class CustomRetriever(BaseRetriever):
@@ -15,41 +79,64 @@ class CustomRetriever(BaseRetriever):
15
  self,
16
  vector_retriever: VectorIndexRetriever,
17
  document_dict: dict,
 
 
18
  ) -> None:
19
  """Init params."""
20
 
21
  self._vector_retriever = vector_retriever
22
  self._document_dict = document_dict
 
 
 
 
 
 
23
  super().__init__()
24
 
25
- def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
26
  """Retrieve nodes given query."""
27
 
28
  # LlamaIndex adds "\ninput is " to the query string
29
  query_bundle.query_str = query_bundle.query_str.replace("\ninput is ", "")
30
  query_bundle.query_str = query_bundle.query_str.rstrip()
31
 
32
- logfire.info(f"Retrieving 10 nodes with string: '{query_bundle}'")
33
  start = time.time()
34
- nodes = self._vector_retriever.retrieve(query_bundle)
 
35
 
36
- duration = time.time() - start
37
- logfire.info(f"Retrieving nodes took {duration:.2f}s")
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
  # Filter out nodes with the same ref_doc_id
40
  def filter_nodes_by_unique_doc_id(nodes):
41
  unique_nodes = {}
42
  for node in nodes:
43
- doc_id = node.node.ref_doc_id
 
44
  if doc_id is not None and doc_id not in unique_nodes:
45
  unique_nodes[doc_id] = node
46
  return list(unique_nodes.values())
47
 
48
  nodes = filter_nodes_by_unique_doc_id(nodes)
49
- logfire.info(
50
- f"Number of nodes after filtering the ones with same ref_doc_id: {len(nodes)}"
51
- )
52
- logfire.info(f"Nodes retrieved: {nodes}")
53
 
54
  nodes_context = []
55
  for node in nodes:
@@ -59,32 +146,154 @@ class CustomRetriever(BaseRetriever):
59
  # print("Score\t", node.score)
60
  # print("Metadata\t", node.metadata)
61
  # print("-_" * 20)
62
- if node.score < 0.2:
63
- continue
64
  if node.metadata["retrieve_doc"] == True:
65
  # print("This node will be replaced by the document")
66
- doc = self._document_dict[node.node.ref_doc_id]
 
 
67
  # print(doc.text)
68
  new_node = NodeWithScore(
69
- node=TextNode(text=doc.text, metadata=node.metadata), # type: ignore
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  score=node.score,
71
  )
72
  nodes_context.append(new_node)
73
  else:
 
74
  nodes_context.append(node)
75
 
76
  try:
77
- reranker = CohereRerank(top_n=5, model="rerank-english-v3.0")
78
  nodes_context = reranker.postprocess_nodes(nodes_context, query_bundle)
79
- nodes_filtered = []
80
- for node in nodes_context:
81
- if node.score < 0.10: # type: ignore
82
- continue
83
- else:
84
- nodes_filtered.append(node)
85
- logfire.info(f"Cohere raranking to {len(nodes_filtered)} nodes")
86
-
87
- return nodes_filtered
88
  except Exception as e:
89
- logfire.error(f"Error reranking nodes with Cohere: {e}")
90
- return nodes_context
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
  import time
3
+ import traceback
4
+ from typing import List, Optional
5
 
6
  import logfire
7
+ import tiktoken
8
+ from cohere import AsyncClient
9
  from llama_index.core import QueryBundle
10
+ from llama_index.core.async_utils import run_async_tasks
11
+ from llama_index.core.callbacks import CBEventType, EventPayload
12
  from llama_index.core.retrievers import BaseRetriever, VectorIndexRetriever
13
+ from llama_index.core.schema import MetadataMode, NodeWithScore, QueryBundle, TextNode
14
  from llama_index.postprocessor.cohere_rerank import CohereRerank
15
+ from llama_index.postprocessor.cohere_rerank.base import CohereRerank
16
+
17
+
18
+ class AsyncCohereRerank(CohereRerank):
19
+ def __init__(
20
+ self,
21
+ top_n: int = 5,
22
+ model: str = "rerank-english-v3.0",
23
+ api_key: Optional[str] = None,
24
+ ) -> None:
25
+ super().__init__(top_n=top_n, model=model, api_key=api_key)
26
+ self._api_key = api_key
27
+ self._model = model
28
+ self._top_n = top_n
29
+
30
+ async def apostprocess_nodes(
31
+ self,
32
+ nodes: List[NodeWithScore],
33
+ query_bundle: Optional[QueryBundle] = None,
34
+ ) -> List[NodeWithScore]:
35
+ if query_bundle is None:
36
+ raise ValueError("Query bundle must be provided.")
37
+
38
+ if len(nodes) == 0:
39
+ return []
40
+
41
+ async_client = AsyncClient(api_key=self._api_key)
42
+
43
+ with self.callback_manager.event(
44
+ CBEventType.RERANKING,
45
+ payload={
46
+ EventPayload.NODES: nodes,
47
+ EventPayload.MODEL_NAME: self._model,
48
+ EventPayload.QUERY_STR: query_bundle.query_str,
49
+ EventPayload.TOP_K: self._top_n,
50
+ },
51
+ ) as event:
52
+ texts = [
53
+ node.node.get_content(metadata_mode=MetadataMode.EMBED)
54
+ for node in nodes
55
+ ]
56
+
57
+ results = await async_client.rerank(
58
+ model=self._model,
59
+ top_n=self._top_n,
60
+ query=query_bundle.query_str,
61
+ documents=texts,
62
+ )
63
+
64
+ new_nodes = []
65
+ for result in results.results:
66
+ new_node_with_score = NodeWithScore(
67
+ node=nodes[result.index].node, score=result.relevance_score
68
+ )
69
+ new_nodes.append(new_node_with_score)
70
+ event.on_end(payload={EventPayload.NODES: new_nodes})
71
+
72
+ return new_nodes
73
 
74
 
75
  class CustomRetriever(BaseRetriever):
 
79
  self,
80
  vector_retriever: VectorIndexRetriever,
81
  document_dict: dict,
82
+ keyword_retriever,
83
+ mode: str = "AND",
84
  ) -> None:
85
  """Init params."""
86
 
87
  self._vector_retriever = vector_retriever
88
  self._document_dict = document_dict
89
+
90
+ self._keyword_retriever = keyword_retriever
91
+ if mode not in ("AND", "OR"):
92
+ raise ValueError("Invalid mode.")
93
+ self._mode = mode
94
+
95
  super().__init__()
96
 
97
+ async def _aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
98
  """Retrieve nodes given query."""
99
 
100
  # LlamaIndex adds "\ninput is " to the query string
101
  query_bundle.query_str = query_bundle.query_str.replace("\ninput is ", "")
102
  query_bundle.query_str = query_bundle.query_str.rstrip()
103
 
104
+ # logfire.info(f"Retrieving nodes with string: '{query_bundle}'")
105
  start = time.time()
106
+ nodes = await self._vector_retriever.aretrieve(query_bundle)
107
+ keyword_nodes = await self._keyword_retriever.aretrieve(query_bundle)
108
 
109
+ # logfire.info(f"Number of vector nodes: {len(nodes)}")
110
+ # logfire.info(f"Number of keyword nodes: {len(keyword_nodes)}")
111
+
112
+ vector_ids = {n.node.node_id for n in nodes}
113
+ keyword_ids = {n.node.node_id for n in keyword_nodes}
114
+
115
+ combined_dict = {n.node.node_id: n for n in nodes}
116
+ combined_dict.update({n.node.node_id: n for n in keyword_nodes})
117
+
118
+ if self._mode == "AND":
119
+ retrieve_ids = vector_ids.intersection(keyword_ids)
120
+ else:
121
+ retrieve_ids = vector_ids.union(keyword_ids)
122
+
123
+ nodes = [combined_dict[rid] for rid in retrieve_ids]
124
 
125
  # Filter out nodes with the same ref_doc_id
126
  def filter_nodes_by_unique_doc_id(nodes):
127
  unique_nodes = {}
128
  for node in nodes:
129
+ # doc_id = node.node.ref_doc_id
130
+ doc_id = node.node.source_node.node_id
131
  if doc_id is not None and doc_id not in unique_nodes:
132
  unique_nodes[doc_id] = node
133
  return list(unique_nodes.values())
134
 
135
  nodes = filter_nodes_by_unique_doc_id(nodes)
136
+ # logfire.info(
137
+ # f"Number of nodes after filtering the ones with same ref_doc_id: {len(nodes)}"
138
+ # )
139
+ # logfire.info(f"Nodes retrieved: {nodes}")
140
 
141
  nodes_context = []
142
  for node in nodes:
 
146
  # print("Score\t", node.score)
147
  # print("Metadata\t", node.metadata)
148
  # print("-_" * 20)
149
+ doc_id = node.node.source_node.node_id # type: ignore
 
150
  if node.metadata["retrieve_doc"] == True:
151
  # print("This node will be replaced by the document")
152
+ # doc = self._document_dict[node.node.ref_doc_id]
153
+ # print("retrieved doc == True")
154
+ doc = self._document_dict[doc_id]
155
  # print(doc.text)
156
  new_node = NodeWithScore(
157
+ node=TextNode(text=doc.text, metadata=node.metadata, id_=doc_id), # type: ignore
158
+ score=node.score,
159
+ )
160
+ nodes_context.append(new_node)
161
+ else:
162
+ node.node.node_id = doc_id
163
+ nodes_context.append(node)
164
+
165
+ try:
166
+ reranker = AsyncCohereRerank(top_n=3, model="rerank-english-v3.0")
167
+ nodes_context = await reranker.apostprocess_nodes(
168
+ nodes_context, query_bundle
169
+ )
170
+
171
+ except Exception as e:
172
+ error_msg = f"Error during reranking: {type(e).__name__}: {str(e)}\n"
173
+ error_msg += "Traceback:\n"
174
+ error_msg += traceback.format_exc()
175
+ logfire.error(error_msg)
176
+
177
+ nodes_filtered = []
178
+ total_tokens = 0
179
+ enc = tiktoken.encoding_for_model("gpt-4o-mini")
180
+ for node in nodes_context:
181
+ if node.score < 0.10: # type: ignore
182
+ continue
183
+
184
+ # Count tokens
185
+ if "tokens" in node.node.metadata:
186
+ node_tokens = node.node.metadata["tokens"]
187
+ else:
188
+ node_tokens = len(enc.encode(node.node.text)) # type: ignore
189
+
190
+ if total_tokens + node_tokens > 100_000:
191
+ logfire.info("Skipping node due to token count exceeding 100k")
192
+ break
193
+
194
+ total_tokens += node_tokens
195
+ nodes_filtered.append(node)
196
+
197
+ # logfire.info(f"Final nodes to context {len(nodes_filtered)} nodes")
198
+ # logfire.info(f"Total tokens: {total_tokens}")
199
+
200
+ # duration = time.time() - start
201
+ # logfire.info(f"Retrieving nodes took {duration:.2f}s")
202
+
203
+ return nodes_filtered[:3]
204
+
205
+ # def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
206
+ # return asyncio.run(self._aretrieve(query_bundle))
207
+
208
+ def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
209
+ """Retrieve nodes given query."""
210
+
211
+ # LlamaIndex adds "\ninput is " to the query string
212
+ query_bundle.query_str = query_bundle.query_str.replace("\ninput is ", "")
213
+ query_bundle.query_str = query_bundle.query_str.rstrip()
214
+ logfire.info(f"Retrieving nodes with string: '{query_bundle}'")
215
+
216
+ start = time.time()
217
+ nodes = self._vector_retriever.retrieve(query_bundle)
218
+ keyword_nodes = self._keyword_retriever.retrieve(query_bundle)
219
+
220
+ logfire.info(f"Number of vector nodes: {len(nodes)}")
221
+ logfire.info(f"Number of keyword nodes: {len(keyword_nodes)}")
222
+
223
+ vector_ids = {n.node.node_id for n in nodes}
224
+ keyword_ids = {n.node.node_id for n in keyword_nodes}
225
+
226
+ combined_dict = {n.node.node_id: n for n in nodes}
227
+ combined_dict.update({n.node.node_id: n for n in keyword_nodes})
228
+
229
+ if self._mode == "AND":
230
+ retrieve_ids = vector_ids.intersection(keyword_ids)
231
+ else:
232
+ retrieve_ids = vector_ids.union(keyword_ids)
233
+
234
+ nodes = [combined_dict[rid] for rid in retrieve_ids]
235
+
236
+ def filter_nodes_by_unique_doc_id(nodes):
237
+ unique_nodes = {}
238
+ for node in nodes:
239
+ # doc_id = node.node.ref_doc_id
240
+ doc_id = node.node.source_node.node_id
241
+ if doc_id is not None and doc_id not in unique_nodes:
242
+ unique_nodes[doc_id] = node
243
+ return list(unique_nodes.values())
244
+
245
+ nodes = filter_nodes_by_unique_doc_id(nodes)
246
+ logfire.info(
247
+ f"Number of nodes after filtering the ones with same ref_doc_id: {len(nodes)}"
248
+ )
249
+ logfire.info(f"Nodes retrieved: {nodes}")
250
+
251
+ nodes_context = []
252
+ for node in nodes:
253
+ doc_id = node.node.source_node.node_id # type: ignore
254
+ if node.metadata["retrieve_doc"] == True:
255
+ doc = self._document_dict[doc_id]
256
+ new_node = NodeWithScore(
257
+ node=TextNode(text=doc.text, metadata=node.metadata, id_=doc_id), # type: ignore
258
  score=node.score,
259
  )
260
  nodes_context.append(new_node)
261
  else:
262
+ node.node.node_id = doc_id
263
  nodes_context.append(node)
264
 
265
  try:
266
+ reranker = CohereRerank(top_n=3, model="rerank-english-v3.0")
267
  nodes_context = reranker.postprocess_nodes(nodes_context, query_bundle)
268
+
 
 
 
 
 
 
 
 
269
  except Exception as e:
270
+ error_msg = f"Error during reranking: {type(e).__name__}: {str(e)}\n"
271
+ error_msg += "Traceback:\n"
272
+ error_msg += traceback.format_exc()
273
+ logfire.error(error_msg)
274
+
275
+ nodes_filtered = []
276
+ total_tokens = 0
277
+ enc = tiktoken.encoding_for_model("gpt-4o-mini")
278
+ for node in nodes_context:
279
+ if node.score < 0.10: # type: ignore
280
+ continue
281
+ if "tokens" in node.node.metadata:
282
+ node_tokens = node.node.metadata["tokens"]
283
+ else:
284
+ node_tokens = len(enc.encode(node.node.text)) # type: ignore
285
+
286
+ if total_tokens + node_tokens > 100_000:
287
+ logfire.info("Skipping node due to token count exceeding 100k")
288
+ break
289
+
290
+ total_tokens += node_tokens
291
+ nodes_filtered.append(node)
292
+
293
+ logfire.info(f"Final nodes to context {len(nodes_filtered)} nodes")
294
+ logfire.info(f"Total tokens: {total_tokens}")
295
+
296
+ duration = time.time() - start
297
+ logfire.info(f"Retrieving nodes took {duration:.2f}s")
298
+
299
+ return nodes_filtered[:3]
scripts/main.py CHANGED
@@ -6,53 +6,59 @@ from llama_index.agent.openai import OpenAIAgent
6
  from llama_index.core.llms import MessageRole
7
  from llama_index.core.memory import ChatSummaryMemoryBuffer
8
  from llama_index.core.tools import RetrieverTool, ToolMetadata
 
 
 
 
 
 
9
  from llama_index.llms.openai import OpenAI
10
  from prompts import system_message_openai_agent
11
- from setup import (
12
  AVAILABLE_SOURCES,
13
  AVAILABLE_SOURCES_UI,
14
  CONCURRENCY_COUNT,
15
- custom_retriever_langchain,
16
- custom_retriever_llama_index,
17
- custom_retriever_openai_cookbooks,
18
- custom_retriever_peft,
19
- custom_retriever_transformers,
20
- custom_retriever_trl,
21
  )
22
 
23
 
24
  def update_query_engine_tools(selected_sources):
25
  tools = []
26
  source_mapping = {
27
- "Transformers Docs": (
28
- custom_retriever_transformers,
29
- "Transformers_information",
30
- """Useful for general questions asking about the artificial intelligence (AI) field. Employ this tool to fetch information on topics such as language models (LLMs) models such as Llama3 and theory (transformer architectures), tips on prompting, quantization, etc.""",
31
- ),
32
- "PEFT Docs": (
33
- custom_retriever_peft,
34
- "PEFT_information",
35
- """Useful for questions asking about efficient LLM fine-tuning. Employ this tool to fetch information on topics such as LoRA, QLoRA, etc.""",
36
- ),
37
- "TRL Docs": (
38
- custom_retriever_trl,
39
- "TRL_information",
40
- """Useful for questions asking about fine-tuning LLMs with reinforcement learning (RLHF). Includes information about the Supervised Fine-tuning step (SFT), Reward Modeling step (RM), and the Proximal Policy Optimization (PPO) step.""",
41
- ),
42
- "LlamaIndex Docs": (
43
- custom_retriever_llama_index,
44
- "LlamaIndex_information",
45
- """Useful for questions asking about retrieval augmented generation (RAG) with LLMs and embedding models. It is the documentation of a framework, includes info about fine-tuning embedding models, building chatbots, and agents with llms, using vector databases, embeddings, information retrieval with cosine similarity or bm25, etc.""",
46
- ),
47
- "OpenAI Cookbooks": (
48
- custom_retriever_openai_cookbooks,
49
- "openai_cookbooks_info",
50
- """Useful for questions asking about accomplishing common tasks with theΒ OpenAI API. Returns example code and guides stored in Jupyter notebooks, including info about ChatGPT GPT actions, OpenAI Assistants API, and How to fine-tune OpenAI's GPT-4o and GPT-4o-mini models with the OpenAI API.""",
51
- ),
52
- "LangChain Docs": (
53
- custom_retriever_langchain,
54
- "langchain_info",
55
- """Useful for questions asking about the LangChain framework. It is the documentation of the LangChain framework, includes info about building chains, agents, and tools, using memory, prompts, callbacks, etc.""",
 
 
 
 
 
56
  ),
57
  }
58
 
@@ -80,9 +86,7 @@ def generate_completion(
80
  memory,
81
  ):
82
  with logfire.span("Running query"):
83
- logfire.info(f"query: {query}")
84
- logfire.info(f"model: {model}")
85
- logfire.info(f"sources: {sources}")
86
 
87
  chat_list = memory.get()
88
 
@@ -102,7 +106,34 @@ def generate_completion(
102
  client = llm._get_client()
103
  logfire.instrument_openai(client)
104
 
105
- query_engine_tools = update_query_engine_tools(sources)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
  agent = OpenAIAgent.from_tools(
108
  llm=llm,
@@ -151,8 +182,16 @@ def format_sources(completion) -> str:
151
  )
152
  document_template: str = "[πŸ”— {source}: {title}]({url}), relevance: {score:2.2f}"
153
  all_documents = []
154
- for source in completion.sources:
155
- for src in source.raw_output:
 
 
 
 
 
 
 
 
156
  document = document_template.format(
157
  title=src.metadata["title"],
158
  score=src.score,
@@ -189,13 +228,13 @@ sources = gr.CheckboxGroup(
189
  "LlamaIndex Docs",
190
  "LangChain Docs",
191
  "OpenAI Cookbooks",
 
192
  ],
193
  interactive=True,
194
  )
195
  model = gr.Dropdown(
196
  [
197
  "gpt-4o-mini",
198
- "gpt-4o",
199
  ],
200
  label="Model",
201
  value="gpt-4o-mini",
 
6
  from llama_index.core.llms import MessageRole
7
  from llama_index.core.memory import ChatSummaryMemoryBuffer
8
  from llama_index.core.tools import RetrieverTool, ToolMetadata
9
+ from llama_index.core.vector_stores import (
10
+ FilterCondition,
11
+ FilterOperator,
12
+ MetadataFilter,
13
+ MetadataFilters,
14
+ )
15
  from llama_index.llms.openai import OpenAI
16
  from prompts import system_message_openai_agent
17
+ from setup import ( # custom_retriever_langchain,; custom_retriever_llama_index,; custom_retriever_openai_cookbooks,; custom_retriever_peft,; custom_retriever_transformers,; custom_retriever_trl,
18
  AVAILABLE_SOURCES,
19
  AVAILABLE_SOURCES_UI,
20
  CONCURRENCY_COUNT,
21
+ custom_retriever_all_sources,
 
 
 
 
 
22
  )
23
 
24
 
25
  def update_query_engine_tools(selected_sources):
26
  tools = []
27
  source_mapping = {
28
+ # "Transformers Docs": (
29
+ # custom_retriever_transformers,
30
+ # "Transformers_information",
31
+ # """Useful for general questions asking about the artificial intelligence (AI) field. Employ this tool to fetch information on topics such as language models (LLMs) models such as Llama3 and theory (transformer architectures), tips on prompting, quantization, etc.""",
32
+ # ),
33
+ # "PEFT Docs": (
34
+ # custom_retriever_peft,
35
+ # "PEFT_information",
36
+ # """Useful for questions asking about efficient LLM fine-tuning. Employ this tool to fetch information on topics such as LoRA, QLoRA, etc.""",
37
+ # ),
38
+ # "TRL Docs": (
39
+ # custom_retriever_trl,
40
+ # "TRL_information",
41
+ # """Useful for questions asking about fine-tuning LLMs with reinforcement learning (RLHF). Includes information about the Supervised Fine-tuning step (SFT), Reward Modeling step (RM), and the Proximal Policy Optimization (PPO) step.""",
42
+ # ),
43
+ # "LlamaIndex Docs": (
44
+ # custom_retriever_llama_index,
45
+ # "LlamaIndex_information",
46
+ # """Useful for questions asking about retrieval augmented generation (RAG) with LLMs and embedding models. It is the documentation of a framework, includes info about fine-tuning embedding models, building chatbots, and agents with llms, using vector databases, embeddings, information retrieval with cosine similarity or bm25, etc.""",
47
+ # ),
48
+ # "OpenAI Cookbooks": (
49
+ # custom_retriever_openai_cookbooks,
50
+ # "openai_cookbooks_info",
51
+ # """Useful for questions asking about accomplishing common tasks with theΒ OpenAI API. Returns example code and guides stored in Jupyter notebooks, including info about ChatGPT GPT actions, OpenAI Assistants API, and How to fine-tune OpenAI's GPT-4o and GPT-4o-mini models with the OpenAI API.""",
52
+ # ),
53
+ # "LangChain Docs": (
54
+ # custom_retriever_langchain,
55
+ # "langchain_info",
56
+ # """Useful for questions asking about the LangChain framework. It is the documentation of the LangChain framework, includes info about building chains, agents, and tools, using memory, prompts, callbacks, etc.""",
57
+ # ),
58
+ "All Sources": (
59
+ custom_retriever_all_sources,
60
+ "all_sources_info",
61
+ """Useful for questions asking about information in the field of AI.""",
62
  ),
63
  }
64
 
 
86
  memory,
87
  ):
88
  with logfire.span("Running query"):
89
+ logfire.info(f"User query: {query}")
 
 
90
 
91
  chat_list = memory.get()
92
 
 
106
  client = llm._get_client()
107
  logfire.instrument_openai(client)
108
 
109
+ query_engine_tools = update_query_engine_tools(["All Sources"])
110
+
111
+ filter_list = []
112
+ source_mapping = {
113
+ "Transformers Docs": "transformers",
114
+ "PEFT Docs": "peft",
115
+ "TRL Docs": "trl",
116
+ "LlamaIndex Docs": "llama_index",
117
+ "LangChain Docs": "langchain",
118
+ "OpenAI Cookbooks": "openai_cookbooks",
119
+ "Towards AI Blog": "tai_blog",
120
+ }
121
+
122
+ for source in sources:
123
+ if source in source_mapping:
124
+ filter_list.append(
125
+ MetadataFilter(
126
+ key="source",
127
+ operator=FilterOperator.EQ,
128
+ value=source_mapping[source],
129
+ )
130
+ )
131
+
132
+ filters = MetadataFilters(
133
+ filters=filter_list,
134
+ condition=FilterCondition.OR,
135
+ )
136
+ query_engine_tools[0].retriever._vector_retriever._filters = filters
137
 
138
  agent = OpenAIAgent.from_tools(
139
  llm=llm,
 
182
  )
183
  document_template: str = "[πŸ”— {source}: {title}]({url}), relevance: {score:2.2f}"
184
  all_documents = []
185
+ for source in completion.sources: # looping over list[ToolOutput]
186
+ if isinstance(source.raw_output, Exception):
187
+ logfire.error(f"Error in source output: {source.raw_output}")
188
+ # pdb.set_trace()
189
+ continue
190
+
191
+ if not isinstance(source.raw_output, list):
192
+ logfire.warn(f"Unexpected source output type: {type(source.raw_output)}")
193
+ continue
194
+ for src in source.raw_output: # looping over list[NodeWithScore]
195
  document = document_template.format(
196
  title=src.metadata["title"],
197
  score=src.score,
 
228
  "LlamaIndex Docs",
229
  "LangChain Docs",
230
  "OpenAI Cookbooks",
231
+ # "All Sources",
232
  ],
233
  interactive=True,
234
  )
235
  model = gr.Dropdown(
236
  [
237
  "gpt-4o-mini",
 
238
  ],
239
  label="Model",
240
  value="gpt-4o-mini",
scripts/prompts.py CHANGED
@@ -1,19 +1,123 @@
1
- system_message_openai_agent = """You are an AI teacher, answering questions from students of an applied AI course on Large Language Models (LLMs or llm) and Retrieval Augmented Generation (RAG) for LLMs. Topics covered include training models, fine-tuning models, giving memory to LLMs, prompting tips, hallucinations and bias, vector databases, transformer architectures, embeddings, RAG frameworks, Langchain, LlamaIndex, making LLMs interact with tools, AI agents, reinforcement learning with human feedback. Questions should be understood in this context.
 
2
 
3
- Your answers are aimed to teach students, so they should be complete, clear, and easy to understand.
4
 
5
- Use the available tools to gather insights pertinent to the field of AI. Always use two tools at the same time. These tools accept a string (a user query rewritten as a statement) and return informative content regarding the domain of AI.
6
- e.g:
7
- User question: 'How can I fine-tune an LLM?'
8
- Input to the tool: 'Fine-tuning an LLM'
9
 
10
- User question: How can quantize an LLM?
11
- Input to the tool: 'Quantization for LLMs'
12
 
13
- User question: 'Teach me how to build an AI agent"'
14
- Input to the tool: 'Building an AI Agent'
15
 
16
- Only some information returned by the tools might be relevant to the question, so ignore the irrelevant part and answer the question with what you have.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  Your responses are exclusively based on the output provided by the tools. Refrain from incorporating information not directly obtained from the tool's responses.
19
 
@@ -25,11 +129,9 @@ Should the tools repository lack information on the queried topic, politely info
25
 
26
  At the end of your answers, always invite the students to ask deeper questions about the topic if they have any. Make sure reformulate the question to the tool to capture this new angle or more profound layer of inquiry.
27
 
28
- Do not refer to the documentation directly, but use the information provided within it to answer questions.
29
 
30
  If code is provided in the information, share it with the students. It's important to provide complete code blocks so they can execute the code when they copy and paste them.
31
 
32
  Make sure to format your answers in Markdown format, including code blocks and snippets.
33
-
34
- Politely reject questions not related to AI, while being cautious not to reject unfamiliar terms or acronyms too quickly. If a question seems unrelated but you suspect it might contain AI-related terminology.
35
  """
 
1
+ # # Prompt 1
2
+ # system_message_openai_agent = """You are an AI teacher, answering questions from students of an applied AI course on Large Language Models (LLMs or llm) and Retrieval Augmented Generation (RAG) for LLMs.
3
 
4
+ # Topics covered include training models, fine-tuning models, giving memory to LLMs, prompting tips, hallucinations and bias, vector databases, transformer architectures, embeddings, RAG frameworks such as Langchain and LlamaIndex, making LLMs interact with tools, AI agents, reinforcement learning with human feedback (RLHF). Questions should be understood in this context.
5
 
6
+ # Your answers are aimed to teach students, so they should be complete, clear, and easy to understand.
 
 
 
7
 
8
+ # Use the available tools to gather insights pertinent to the field of AI.
 
9
 
10
+ # To answer student questions, always use the all_sources_info tool plus another one simultaneously. Meaning that should be using two tools in total.
 
11
 
12
+ # Only some information returned by the tools might be relevant to the question, so ignore the irrelevant part and answer the question with what you have.
13
+
14
+ # Your responses are exclusively based on the output provided by the tools. Refrain from incorporating information not directly obtained from the tool's responses.
15
+
16
+ # When the conversation deepens or shifts focus within a topic, adapt your input to the tools to reflect these nuances. This means if a user requests further elaboration on a specific aspect of a previously discussed topic, you should reformulate your input to the tool to capture this new angle or more profound layer of inquiry.
17
+
18
+ # Provide comprehensive answers, ideally structured in multiple paragraphs, drawing from the tool's variety of relevant details. The depth and breadth of your responses should align with the scope and specificity of the information retrieved.
19
+
20
+ # Should the tools repository lack information on the queried topic, politely inform the user that the question transcends the bounds of your current knowledge base, citing the absence of relevant content in the tool's documentation.
21
+
22
+ # At the end of your answers, always invite the students to ask deeper questions about the topic if they have any. Make sure reformulate the question to the tool to capture this new angle or more profound layer of inquiry.
23
+
24
+ # Do not refer to the documentation directly, but use the information provided within it to answer questions.
25
+
26
+ # If code is provided in the information, share it with the students. It's important to provide complete code blocks so they can execute the code when they copy and paste them.
27
+
28
+ # Make sure to format your answers in Markdown format, including code blocks and snippets.
29
+ # """
30
+
31
+ # Prompt 2
32
+ # system_message_openai_agent = """You are an AI teacher, answering questions from students of an applied AI course on Large Language Models (LLMs or llm) and Retrieval Augmented Generation (RAG) for LLMs.
33
+
34
+ # Topics covered include training models, fine-tuning models, giving memory to LLMs, prompting tips, hallucinations and bias, vector databases, transformer architectures, embeddings, RAG frameworks such as Langchain and LlamaIndex, making LLMs interact with tools, AI agents, reinforcement learning with human feedback (RLHF). Questions should be understood in this context.
35
+
36
+ # Your answers are aimed to teach students, so they should be complete, clear, and easy to understand.
37
+
38
+ # Use the available tools to gather insights pertinent to the field of AI.
39
+
40
+ # To answer student questions, always use the all_sources_info tool. For complex questions, you can decompose the user question into TWO sub questions (you are limited to two sub-questions) that can be answered by the tools.
41
+
42
+ # These are the guidelines to consider if you decide to create sub questions:
43
+ # * Be as specific as possible
44
+ # * The two sub questions should be relevant to the user question
45
+ # * The two sub questions should be answerable by the tools provided
46
+
47
+ # Only some information returned by the tools might be relevant to the question, so ignore the irrelevant part and answer the question with what you have.
48
+
49
+ # Your responses are exclusively based on the output provided by the tools. Refrain from incorporating information not directly obtained from the tool's responses.
50
+
51
+ # When the conversation deepens or shifts focus within a topic, adapt your input to the tools to reflect these nuances. This means if a user requests further elaboration on a specific aspect of a previously discussed topic, you should reformulate your input to the tool to capture this new angle or more profound layer of inquiry.
52
+
53
+ # Provide comprehensive answers, ideally structured in multiple paragraphs, drawing from the tool's variety of relevant details. The depth and breadth of your responses should align with the scope and specificity of the information retrieved.
54
+
55
+ # Should the tools repository lack information on the queried topic, politely inform the user that the question transcends the bounds of your current knowledge base, citing the absence of relevant content in the tool's documentation.
56
+
57
+ # At the end of your answers, always invite the students to ask deeper questions about the topic if they have any. Make sure reformulate the question to the tool to capture this new angle or more profound layer of inquiry.
58
+
59
+ # Do not refer to the documentation directly, but use the information provided within it to answer questions.
60
+
61
+ # If code is provided in the information, share it with the students. It's important to provide complete code blocks so they can execute the code when they copy and paste them.
62
+
63
+ # Make sure to format your answers in Markdown format, including code blocks and snippets.
64
+ # """
65
+
66
+ # # Prompt 3
67
+ # system_message_openai_agent = """You are an AI teacher, answering questions from students of an applied AI course on Large Language Models (LLMs or llm) and Retrieval Augmented Generation (RAG) for LLMs.
68
+
69
+ # Topics covered include training models, fine-tuning models, giving memory to LLMs, prompting tips, hallucinations and bias, vector databases, transformer architectures, embeddings, RAG frameworks such as Langchain and LlamaIndex, making LLMs interact with tools, AI agents, reinforcement learning with human feedback (RLHF). Questions should be understood in this context.
70
+
71
+ # Your answers are aimed to teach students, so they should be complete, clear, and easy to understand.
72
+
73
+ # Use the available tools to gather insights pertinent to the field of AI.
74
+
75
+ # To answer student questions, always use the all_sources_info tool. For each question, you should decompose the user question into TWO sub questions (you are limited to two sub-questions) that can be answered by the tools.
76
+
77
+ # These are the guidelines to consider when creating sub questions:
78
+ # * Be as specific as possible
79
+ # * The two sub questions should be relevant to the user question
80
+ # * The two sub questions should be answerable by the tools provided
81
+
82
+ # Only some information returned by the tools might be relevant to the user question, so ignore the irrelevant part and answer the user question with what you have.
83
+
84
+ # Your responses are exclusively based on the output provided by the tools. Refrain from incorporating information not directly obtained from the tool's responses.
85
+
86
+ # When the conversation deepens or shifts focus within a topic, adapt your input to the tools to reflect these nuances. This means if a user requests further elaboration on a specific aspect of a previously discussed topic, you should reformulate your input to the tool to capture this new angle or more profound layer of inquiry.
87
+
88
+ # Provide comprehensive answers, ideally structured in multiple paragraphs, drawing from the tool's variety of relevant details. The depth and breadth of your responses should align with the scope and specificity of the information retrieved.
89
+
90
+ # Should the tools repository lack information on the queried topic, politely inform the user that the question transcends the bounds of your current knowledge base, citing the absence of relevant content in the tool's documentation.
91
+
92
+ # At the end of your answers, always invite the students to ask deeper questions about the topic if they have any. Make sure reformulate the question to the tool to capture this new angle or more profound layer of inquiry.
93
+
94
+ # Do not refer to the documentation directly, but use the information provided within it to answer questions.
95
+
96
+ # If code is provided in the information, share it with the students. It's important to provide complete code blocks so they can execute the code when they copy and paste them.
97
+
98
+ # Make sure to format your answers in Markdown format, including code blocks and snippets.
99
+ # """
100
+
101
+
102
+ # Prompt 4 Trying to make it like #1
103
+ system_message_openai_agent = """You are an AI teacher, answering questions from students of an applied AI course on Large Language Models (LLMs or llm) and Retrieval Augmented Generation (RAG) for LLMs.
104
+
105
+ Topics covered include training models, fine-tuning models, giving memory to LLMs, prompting tips, hallucinations and bias, vector databases, transformer architectures, embeddings, RAG frameworks such as Langchain and LlamaIndex, making LLMs interact with tools, AI agents, reinforcement learning with human feedback (RLHF). Questions should be understood in this context.
106
+
107
+ Your answers are aimed to teach students, so they should be complete, clear, and easy to understand.
108
+
109
+ Use the available tools to gather insights pertinent to the field of AI.
110
+
111
+ To answer student questions, always use the all_sources_info tool plus another one simultaneously.
112
+ Decompose the user question into TWO sub questions (you are limited to two sub-questions) one for each tool.
113
+ Meaning that should be using two tools in total for each user question.
114
+
115
+ These are the guidelines to consider if you decide to create sub questions:
116
+ * Be as specific as possible
117
+ * The two sub questions should be relevant to the user question
118
+ * The two sub questions should be answerable by the tools provided
119
+
120
+ Only some information returned by the tools might be relevant to the question, so ignore the irrelevant part and answer the question with what you have.
121
 
122
  Your responses are exclusively based on the output provided by the tools. Refrain from incorporating information not directly obtained from the tool's responses.
123
 
 
129
 
130
  At the end of your answers, always invite the students to ask deeper questions about the topic if they have any. Make sure reformulate the question to the tool to capture this new angle or more profound layer of inquiry.
131
 
132
+ Do not refer to the documentation directly, but use the information provided within it to answer questions.
133
 
134
  If code is provided in the information, share it with the students. It's important to provide complete code blocks so they can execute the code when they copy and paste them.
135
 
136
  Make sure to format your answers in Markdown format, including code blocks and snippets.
 
 
137
  """
scripts/setup.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import logging
2
  import os
3
  import pickle
@@ -6,9 +8,16 @@ import chromadb
6
  import logfire
7
  from custom_retriever import CustomRetriever
8
  from dotenv import load_dotenv
9
- from llama_index.core import VectorStoreIndex
 
 
10
  from llama_index.core.node_parser import SentenceSplitter
11
- from llama_index.core.retrievers import VectorIndexRetriever
 
 
 
 
 
12
  from llama_index.embeddings.openai import OpenAIEmbedding
13
  from llama_index.vector_stores.chroma import ChromaVectorStore
14
  from utils import init_mongo_db
@@ -21,11 +30,11 @@ logging.getLogger("httpx").setLevel(logging.WARNING)
21
  logfire.configure()
22
 
23
 
24
- if not os.path.exists("data/chroma-db-transformers"):
25
  # Download the vector database from the Hugging Face Hub if it doesn't exist locally
26
  # https://huggingface.co/datasets/towardsai-buster/ai-tutor-vector-db/tree/main
27
  logfire.warn(
28
- f"Vector database does not exist at 'data/chroma-db-transformers', downloading from Hugging Face Hub"
29
  )
30
  from huggingface_hub import snapshot_download
31
 
@@ -34,51 +43,127 @@ if not os.path.exists("data/chroma-db-transformers"):
34
  local_dir="data",
35
  repo_type="dataset",
36
  )
37
- logfire.info(f"Downloaded vector database to 'data/chroma-db-transformers'")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
 
40
  def setup_database(db_collection, dict_file_name):
41
  db = chromadb.PersistentClient(path=f"data/{db_collection}")
42
  chroma_collection = db.get_or_create_collection(db_collection)
43
  vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
 
 
 
 
 
44
 
45
  index = VectorStoreIndex.from_vector_store(
46
  vector_store=vector_store,
47
- embed_model=OpenAIEmbedding(model="text-embedding-3-large", mode="similarity"),
48
- transformations=[SentenceSplitter(chunk_size=800, chunk_overlap=400)],
49
  show_progress=True,
50
  use_async=True,
51
  )
52
  vector_retriever = VectorIndexRetriever(
53
  index=index,
54
  similarity_top_k=15,
 
55
  use_async=True,
56
- embed_model=OpenAIEmbedding(model="text-embedding-3-large", mode="similarity"),
57
  )
58
  with open(f"data/{db_collection}/{dict_file_name}", "rb") as f:
59
  document_dict = pickle.load(f)
60
 
61
- return CustomRetriever(vector_retriever, document_dict)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
 
64
  # Setup retrievers
65
- custom_retriever_transformers = setup_database(
66
- "chroma-db-transformers",
67
- "document_dict_transformers.pkl",
68
- )
69
- custom_retriever_peft = setup_database("chroma-db-peft", "document_dict_peft.pkl")
70
- custom_retriever_trl = setup_database("chroma-db-trl", "document_dict_trl.pkl")
71
- custom_retriever_llama_index = setup_database(
72
- "chroma-db-llama_index",
73
- "document_dict_llama_index.pkl",
74
- )
75
- custom_retriever_openai_cookbooks = setup_database(
76
- "chroma-db-openai_cookbooks",
77
- "document_dict_openai_cookbooks.pkl",
78
- )
79
- custom_retriever_langchain = setup_database(
80
- "chroma-db-langchain",
81
- "document_dict_langchain.pkl",
 
 
 
 
 
 
 
 
 
82
  )
83
 
84
  # Constants
@@ -92,7 +177,8 @@ AVAILABLE_SOURCES_UI = [
92
  "LlamaIndex Docs",
93
  "LangChain Docs",
94
  "OpenAI Cookbooks",
95
- # "Towards AI Blog",
 
96
  # "RAG Course",
97
  ]
98
 
@@ -103,7 +189,8 @@ AVAILABLE_SOURCES = [
103
  "llama_index",
104
  "langchain",
105
  "openai_cookbooks",
106
- # "towards_ai_blog",
 
107
  # "rag_course",
108
  ]
109
 
@@ -114,12 +201,13 @@ mongo_db = (
114
  )
115
 
116
  __all__ = [
117
- "custom_retriever_transformers",
118
- "custom_retriever_peft",
119
- "custom_retriever_trl",
120
- "custom_retriever_llama_index",
121
- "custom_retriever_openai_cookbooks",
122
- "custom_retriever_langchain",
 
123
  "mongo_db",
124
  "CONCURRENCY_COUNT",
125
  "AVAILABLE_SOURCES_UI",
 
1
+ import asyncio
2
+ import json
3
  import logging
4
  import os
5
  import pickle
 
8
  import logfire
9
  from custom_retriever import CustomRetriever
10
  from dotenv import load_dotenv
11
+ from evaluate_rag_system import AsyncKeywordTableSimpleRetriever
12
+ from llama_index.core import Document, SimpleKeywordTableIndex, VectorStoreIndex
13
+ from llama_index.core.ingestion import IngestionPipeline
14
  from llama_index.core.node_parser import SentenceSplitter
15
+ from llama_index.core.retrievers import (
16
+ KeywordTableSimpleRetriever,
17
+ VectorIndexRetriever,
18
+ )
19
+ from llama_index.core.schema import NodeWithScore, QueryBundle
20
+ from llama_index.embeddings.cohere import CohereEmbedding
21
  from llama_index.embeddings.openai import OpenAIEmbedding
22
  from llama_index.vector_stores.chroma import ChromaVectorStore
23
  from utils import init_mongo_db
 
30
  logfire.configure()
31
 
32
 
33
+ if not os.path.exists("data/chroma-db-all_sources"):
34
  # Download the vector database from the Hugging Face Hub if it doesn't exist locally
35
  # https://huggingface.co/datasets/towardsai-buster/ai-tutor-vector-db/tree/main
36
  logfire.warn(
37
+ f"Vector database does not exist at 'data/chroma-db-all_sources', downloading from Hugging Face Hub"
38
  )
39
  from huggingface_hub import snapshot_download
40
 
 
43
  local_dir="data",
44
  repo_type="dataset",
45
  )
46
+ logfire.info(f"Downloaded vector database to 'data/chroma-db-all_sources'")
47
+
48
+
49
+ def create_docs(input_file: str) -> list[Document]:
50
+ with open(input_file, "r") as f:
51
+ documents = []
52
+ for line in f:
53
+ data = json.loads(line)
54
+ documents.append(
55
+ Document(
56
+ doc_id=data["doc_id"],
57
+ text=data["content"],
58
+ metadata={ # type: ignore
59
+ "url": data["url"],
60
+ "title": data["name"],
61
+ "tokens": data["tokens"],
62
+ "retrieve_doc": data["retrieve_doc"],
63
+ "source": data["source"],
64
+ },
65
+ excluded_llm_metadata_keys=[
66
+ "title",
67
+ "tokens",
68
+ "retrieve_doc",
69
+ "source",
70
+ ],
71
+ excluded_embed_metadata_keys=[
72
+ "url",
73
+ "tokens",
74
+ "retrieve_doc",
75
+ "source",
76
+ ],
77
+ )
78
+ )
79
+ return documents
80
 
81
 
82
  def setup_database(db_collection, dict_file_name):
83
  db = chromadb.PersistentClient(path=f"data/{db_collection}")
84
  chroma_collection = db.get_or_create_collection(db_collection)
85
  vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
86
+ embed_model = CohereEmbedding(
87
+ api_key=os.environ["COHERE_API_KEY"],
88
+ model_name="embed-english-v3.0",
89
+ input_type="search_query",
90
+ )
91
 
92
  index = VectorStoreIndex.from_vector_store(
93
  vector_store=vector_store,
94
+ transformations=[SentenceSplitter(chunk_size=800, chunk_overlap=0)],
 
95
  show_progress=True,
96
  use_async=True,
97
  )
98
  vector_retriever = VectorIndexRetriever(
99
  index=index,
100
  similarity_top_k=15,
101
+ embed_model=embed_model,
102
  use_async=True,
 
103
  )
104
  with open(f"data/{db_collection}/{dict_file_name}", "rb") as f:
105
  document_dict = pickle.load(f)
106
 
107
+ with open("data/keyword_retriever_sync.pkl", "rb") as f:
108
+ keyword_retriever: KeywordTableSimpleRetriever = pickle.load(f)
109
+
110
+ # # Creating the keyword index and retriever
111
+ # logfire.info("Creating nodes from documents")
112
+ # documents = create_docs("data/all_sources_data.jsonl")
113
+ # pipeline = IngestionPipeline(
114
+ # transformations=[SentenceSplitter(chunk_size=800, chunk_overlap=0)]
115
+ # )
116
+ # all_nodes = pipeline.run(documents=documents, show_progress=True)
117
+ # # with open("data/all_nodes.pkl", "wb") as f:
118
+ # # pickle.dump(all_nodes, f)
119
+
120
+ # # all_nodes = pickle.load(open("data/all_nodes.pkl", "rb"))
121
+ # logfire.info(f"Number of nodes: {len(all_nodes)}")
122
+
123
+ # keyword_index = SimpleKeywordTableIndex(
124
+ # nodes=all_nodes, max_keywords_per_chunk=10, show_progress=True, use_async=False
125
+ # )
126
+ # # with open("data/keyword_index.pkl", "wb") as f:
127
+ # # pickle.dump(keyword_index, f)
128
+
129
+ # # keyword_index = pickle.load(open("data/keyword_index.pkl", "rb"))
130
+
131
+ # logfire.info("Creating keyword retriever")
132
+ # keyword_retriever = KeywordTableSimpleRetriever(index=keyword_index)
133
+
134
+ # with open("data/keyword_retriever_sync.pkl", "wb") as f:
135
+ # pickle.dump(keyword_retriever, f)
136
+
137
+ return CustomRetriever(vector_retriever, document_dict, keyword_retriever, "OR")
138
 
139
 
140
  # Setup retrievers
141
+ # custom_retriever_transformers: CustomRetriever = setup_database(
142
+ # "chroma-db-transformers",
143
+ # "document_dict_transformers.pkl",
144
+ # )
145
+ # custom_retriever_peft: CustomRetriever = setup_database(
146
+ # "chroma-db-peft", "document_dict_peft.pkl"
147
+ # )
148
+ # custom_retriever_trl: CustomRetriever = setup_database(
149
+ # "chroma-db-trl", "document_dict_trl.pkl"
150
+ # )
151
+ # custom_retriever_llama_index: CustomRetriever = setup_database(
152
+ # "chroma-db-llama_index",
153
+ # "document_dict_llama_index.pkl",
154
+ # )
155
+ # custom_retriever_openai_cookbooks: CustomRetriever = setup_database(
156
+ # "chroma-db-openai_cookbooks",
157
+ # "document_dict_openai_cookbooks.pkl",
158
+ # )
159
+ # custom_retriever_langchain: CustomRetriever = setup_database(
160
+ # "chroma-db-langchain",
161
+ # "document_dict_langchain.pkl",
162
+ # )
163
+
164
+ custom_retriever_all_sources: CustomRetriever = setup_database(
165
+ "chroma-db-all_sources",
166
+ "document_dict_all_sources.pkl",
167
  )
168
 
169
  # Constants
 
177
  "LlamaIndex Docs",
178
  "LangChain Docs",
179
  "OpenAI Cookbooks",
180
+ "Towards AI Blog",
181
+ # "All Sources",
182
  # "RAG Course",
183
  ]
184
 
 
189
  "llama_index",
190
  "langchain",
191
  "openai_cookbooks",
192
+ "tai_blog",
193
+ # "all_sources",
194
  # "rag_course",
195
  ]
196
 
 
201
  )
202
 
203
  __all__ = [
204
+ # "custom_retriever_transformers",
205
+ # "custom_retriever_peft",
206
+ # "custom_retriever_trl",
207
+ # "custom_retriever_llama_index",
208
+ # "custom_retriever_openai_cookbooks",
209
+ # "custom_retriever_langchain",
210
+ "custom_retriever_all_sources",
211
  "mongo_db",
212
  "CONCURRENCY_COUNT",
213
  "AVAILABLE_SOURCES_UI",