Omar Solano
commited on
Commit
Β·
9c1e8a7
1
Parent(s):
471ad41
update gradio chatbot to latest version
Browse files- README.md +1 -1
- requirements.txt +101 -169
- scripts/custom_retriever.py +237 -28
- scripts/main.py +82 -43
- scripts/prompts.py +116 -14
- scripts/setup.py +122 -34
README.md
CHANGED
@@ -4,7 +4,7 @@ emoji: π§π»βπ«
|
|
4 |
colorFrom: gray
|
5 |
colorTo: pink
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 4.
|
8 |
app_file: scripts/main.py
|
9 |
pinned: false
|
10 |
---
|
|
|
4 |
colorFrom: gray
|
5 |
colorTo: pink
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 4.44.0
|
8 |
app_file: scripts/main.py
|
9 |
pinned: false
|
10 |
---
|
requirements.txt
CHANGED
@@ -1,280 +1,212 @@
|
|
1 |
aiofiles==23.2.1
|
2 |
aiohappyeyeballs==2.4.0
|
3 |
-
aiohttp==3.10.
|
4 |
aiosignal==1.3.1
|
|
|
5 |
annotated-types==0.7.0
|
6 |
-
anyio==4.
|
7 |
appnope==0.1.4
|
8 |
asgiref==3.8.1
|
9 |
asttokens==2.4.1
|
10 |
attrs==24.2.0
|
11 |
-
automat==24.8.1
|
12 |
-
azure-core==1.30.2
|
13 |
-
azure-identity==1.17.1
|
14 |
backoff==2.2.1
|
15 |
bcrypt==4.2.0
|
16 |
beautifulsoup4==4.12.3
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
build==1.2.1
|
21 |
cachetools==5.5.0
|
22 |
-
certifi==2024.
|
23 |
-
cffi==1.17.0
|
24 |
charset-normalizer==3.3.2
|
25 |
chroma-hnswlib==0.7.6
|
26 |
-
chromadb==0.5.
|
27 |
click==8.1.7
|
28 |
-
cohere==5.
|
29 |
coloredlogs==15.0.1
|
30 |
comm==0.2.2
|
31 |
-
|
32 |
-
contourpy==1.2.1
|
33 |
-
cryptography==43.0.0
|
34 |
-
cssselect==1.2.0
|
35 |
cycler==0.12.1
|
36 |
dataclasses-json==0.6.7
|
37 |
debugpy==1.8.5
|
38 |
decorator==5.1.1
|
39 |
-
defusedxml==0.7.1
|
40 |
deprecated==1.2.14
|
41 |
dirtyjson==1.0.8
|
42 |
distro==1.9.0
|
43 |
dnspython==2.6.1
|
44 |
-
|
45 |
-
executing==2.0
|
46 |
-
fastapi==0.
|
47 |
-
fastavro==1.9.
|
48 |
-
fastjsonschema==2.20.0
|
49 |
ffmpy==0.4.0
|
50 |
-
filelock==3.
|
51 |
flatbuffers==24.3.25
|
52 |
-
fonttools==4.
|
53 |
frozenlist==1.4.1
|
54 |
-
fsspec==2024.
|
55 |
google-ai-generativelanguage==0.6.4
|
56 |
-
google-api-core==2.
|
57 |
-
google-api-python-client==2.
|
58 |
-
google-auth==2.
|
59 |
google-auth-httplib2==0.2.0
|
60 |
-
google-cloud-aiplatform==1.63.0
|
61 |
-
google-cloud-bigquery==3.25.0
|
62 |
-
google-cloud-core==2.4.1
|
63 |
-
google-cloud-resource-manager==1.12.5
|
64 |
-
google-cloud-storage==2.18.2
|
65 |
-
google-crc32c==1.5.0
|
66 |
google-generativeai==0.5.4
|
67 |
-
|
68 |
-
|
69 |
-
gradio==4.42.0
|
70 |
gradio-client==1.3.0
|
71 |
-
greenlet==3.
|
72 |
-
|
73 |
-
grpcio==1.66.0
|
74 |
grpcio-status==1.62.3
|
|
|
75 |
h11==0.14.0
|
|
|
|
|
76 |
httpcore==1.0.5
|
77 |
httplib2==0.22.0
|
78 |
httptools==0.6.1
|
79 |
-
httpx==0.27.
|
80 |
httpx-sse==0.4.0
|
81 |
-
huggingface-hub==0.
|
82 |
humanfriendly==10.0
|
83 |
-
|
84 |
-
idna==3.
|
85 |
-
importlib-metadata==8.
|
86 |
-
importlib-resources==6.4.
|
87 |
-
incremental==24.7.2
|
88 |
-
instructor==1.3.4
|
89 |
ipykernel==6.29.5
|
90 |
-
ipython==8.
|
91 |
-
itemadapter==0.9.0
|
92 |
-
itemloaders==1.3.1
|
93 |
jedi==0.19.1
|
94 |
jinja2==3.1.4
|
95 |
-
jiter==0.
|
96 |
jmespath==1.0.1
|
97 |
joblib==1.4.2
|
98 |
-
|
99 |
-
jsonpath-python==1.0.6
|
100 |
-
jsonpointer==3.0.0
|
101 |
-
jsonschema==4.23.0
|
102 |
-
jsonschema-specifications==2023.12.1
|
103 |
-
jupyter-client==8.6.2
|
104 |
jupyter-core==5.7.2
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
llama-cloud==0.0
|
115 |
-
llama-index==0.11.1
|
116 |
-
llama-index-agent-openai==0.3.0
|
117 |
-
llama-index-cli==0.3.0
|
118 |
-
llama-index-core==0.11.1
|
119 |
-
llama-index-embeddings-adapter==0.2.1
|
120 |
-
llama-index-embeddings-cohere==0.2.0
|
121 |
-
llama-index-embeddings-huggingface==0.3.1
|
122 |
-
llama-index-embeddings-openai==0.2.3
|
123 |
-
llama-index-finetuning==0.2.0
|
124 |
-
llama-index-indices-managed-llama-cloud==0.3.0
|
125 |
llama-index-legacy==0.9.48.post3
|
126 |
-
llama-index-llms-
|
127 |
-
llama-index-llms-
|
128 |
-
llama-index-llms-
|
129 |
-
llama-index-
|
130 |
-
llama-index-llms-replicate==0.2.0
|
131 |
-
llama-index-multi-modal-llms-openai==0.2.0
|
132 |
-
llama-index-postprocessor-cohere-rerank==0.2.0
|
133 |
llama-index-program-openai==0.2.0
|
134 |
llama-index-question-gen-openai==0.2.0
|
135 |
-
llama-index-readers-file==0.2.
|
136 |
-
llama-index-readers-llama-parse==0.
|
137 |
llama-index-vector-stores-chroma==0.2.0
|
138 |
-
llama-parse==0.5.
|
139 |
-
logfire==0.
|
140 |
-
lxml==5.3.0
|
141 |
markdown-it-py==3.0.0
|
142 |
markupsafe==2.1.5
|
143 |
marshmallow==3.22.0
|
144 |
matplotlib==3.9.2
|
145 |
matplotlib-inline==0.1.7
|
146 |
mdurl==0.1.2
|
147 |
-
|
148 |
-
|
149 |
-
mistune==3.0.2
|
150 |
-
mmh3==4.1.0
|
151 |
monotonic==1.6
|
152 |
mpmath==1.3.0
|
153 |
-
|
154 |
-
msal-extensions==1.2.0
|
155 |
-
multidict==6.0.5
|
156 |
mypy-extensions==1.0.0
|
157 |
-
nbclient==0.10.0
|
158 |
-
nbconvert==7.16.4
|
159 |
-
nbformat==5.10.4
|
160 |
nest-asyncio==1.6.0
|
161 |
networkx==3.3
|
162 |
nltk==3.9.1
|
163 |
numpy==1.26.4
|
164 |
oauthlib==3.2.2
|
165 |
-
onnxruntime==1.19.
|
166 |
-
openai==1.
|
167 |
-
opentelemetry-api==1.
|
168 |
-
opentelemetry-exporter-otlp-proto-common==1.
|
169 |
-
opentelemetry-exporter-otlp-proto-grpc==1.
|
170 |
-
opentelemetry-exporter-otlp-proto-http==1.
|
171 |
-
opentelemetry-instrumentation==0.
|
172 |
-
opentelemetry-instrumentation-asgi==0.
|
173 |
-
opentelemetry-instrumentation-fastapi==0.
|
174 |
-
opentelemetry-proto==1.
|
175 |
-
opentelemetry-sdk==1.
|
176 |
-
opentelemetry-semantic-conventions==0.
|
177 |
-
opentelemetry-util-http==0.
|
178 |
orjson==3.10.7
|
179 |
overrides==7.7.0
|
180 |
packaging==24.1
|
181 |
-
pandas==2.2.
|
182 |
-
pandocfilters==1.5.1
|
183 |
parameterized==0.9.0
|
184 |
-
parsel==1.9.1
|
185 |
parso==0.8.4
|
186 |
pexpect==4.9.0
|
187 |
pillow==10.4.0
|
188 |
-
platformdirs==4.
|
189 |
-
|
190 |
-
posthog==3.5.2
|
191 |
prompt-toolkit==3.0.47
|
192 |
-
protego==0.3.1
|
193 |
proto-plus==1.24.0
|
194 |
-
protobuf==4.25.
|
195 |
psutil==6.0.0
|
196 |
ptyprocess==0.7.0
|
197 |
pure-eval==0.2.3
|
198 |
-
pyasn1==0.6.
|
199 |
-
pyasn1-modules==0.4.
|
200 |
-
|
201 |
-
pydantic==2.
|
202 |
-
pydantic-core==2.20.1
|
203 |
-
pydispatcher==2.0.7
|
204 |
pydub==0.25.1
|
205 |
pygments==2.18.0
|
206 |
-
|
207 |
-
pymongo==4.8.0
|
208 |
-
pyopenssl==24.2.1
|
209 |
pyparsing==3.1.4
|
210 |
pypdf==4.3.1
|
211 |
pypika==0.48.9
|
212 |
pyproject-hooks==1.1.0
|
213 |
python-dateutil==2.9.0.post0
|
214 |
python-dotenv==1.0.1
|
215 |
-
python-multipart==0.0.
|
216 |
-
pytz==2024.
|
217 |
pyyaml==6.0.2
|
218 |
pyzmq==26.2.0
|
219 |
-
|
220 |
-
referencing==0.35.1
|
221 |
-
regex==2024.7.24
|
222 |
requests==2.32.3
|
223 |
-
requests-file==2.1.0
|
224 |
requests-oauthlib==2.0.0
|
225 |
-
rich==13.8.
|
226 |
-
rpds-py==0.20.0
|
227 |
rsa==4.9
|
228 |
-
ruff==0.6.
|
229 |
s3transfer==0.10.2
|
230 |
-
safetensors==0.4.4
|
231 |
-
scikit-learn==1.5.1
|
232 |
-
scipy==1.14.1
|
233 |
-
scrapy==2.11.2
|
234 |
semantic-version==2.10.0
|
235 |
-
|
236 |
-
service-identity==24.1.0
|
237 |
-
setuptools==73.0.1
|
238 |
-
shapely==2.0.6
|
239 |
shellingham==1.5.4
|
|
|
240 |
six==1.16.0
|
241 |
sniffio==1.3.1
|
242 |
soupsieve==2.6
|
243 |
-
sqlalchemy==2.0.
|
244 |
stack-data==0.6.3
|
245 |
-
starlette==0.38.
|
246 |
striprtf==0.0.26
|
247 |
-
sympy==1.13.
|
248 |
-
|
249 |
-
tenacity==8.
|
250 |
-
threadpoolctl==3.5.0
|
251 |
tiktoken==0.7.0
|
252 |
-
|
253 |
-
|
254 |
-
tokenizers==0.19.1
|
255 |
tomlkit==0.12.0
|
256 |
-
torch==2.4.0
|
257 |
tornado==6.4.1
|
258 |
tqdm==4.66.5
|
259 |
traitlets==5.14.3
|
260 |
-
transformers==4.44.2
|
261 |
-
twisted==24.7.0
|
262 |
typer==0.12.5
|
263 |
-
types-
|
|
|
|
|
264 |
typing-extensions==4.12.2
|
265 |
typing-inspect==0.9.0
|
266 |
-
tzdata==2024.
|
267 |
uritemplate==4.1.1
|
268 |
-
urllib3==2.2.
|
269 |
uvicorn==0.30.6
|
270 |
uvloop==0.20.0
|
271 |
-
|
272 |
-
watchfiles==0.23.0
|
273 |
wcwidth==0.2.13
|
274 |
-
webencodings==0.5.1
|
275 |
websocket-client==1.8.0
|
276 |
websockets==12.0
|
277 |
wrapt==1.16.0
|
278 |
-
yarl==1.
|
279 |
-
zipp==3.20.
|
280 |
-
zope-interface==7.0.1
|
|
|
1 |
aiofiles==23.2.1
|
2 |
aiohappyeyeballs==2.4.0
|
3 |
+
aiohttp==3.10.6
|
4 |
aiosignal==1.3.1
|
5 |
+
aiostream==0.5.2
|
6 |
annotated-types==0.7.0
|
7 |
+
anyio==4.6.0
|
8 |
appnope==0.1.4
|
9 |
asgiref==3.8.1
|
10 |
asttokens==2.4.1
|
11 |
attrs==24.2.0
|
|
|
|
|
|
|
12 |
backoff==2.2.1
|
13 |
bcrypt==4.2.0
|
14 |
beautifulsoup4==4.12.3
|
15 |
+
boto3==1.35.26
|
16 |
+
botocore==1.35.26
|
17 |
+
build==1.2.2
|
|
|
18 |
cachetools==5.5.0
|
19 |
+
certifi==2024.8.30
|
|
|
20 |
charset-normalizer==3.3.2
|
21 |
chroma-hnswlib==0.7.6
|
22 |
+
chromadb==0.5.7
|
23 |
click==8.1.7
|
24 |
+
cohere==5.9.4
|
25 |
coloredlogs==15.0.1
|
26 |
comm==0.2.2
|
27 |
+
contourpy==1.3.0
|
|
|
|
|
|
|
28 |
cycler==0.12.1
|
29 |
dataclasses-json==0.6.7
|
30 |
debugpy==1.8.5
|
31 |
decorator==5.1.1
|
|
|
32 |
deprecated==1.2.14
|
33 |
dirtyjson==1.0.8
|
34 |
distro==1.9.0
|
35 |
dnspython==2.6.1
|
36 |
+
durationpy==0.7
|
37 |
+
executing==2.1.0
|
38 |
+
fastapi==0.115.0
|
39 |
+
fastavro==1.9.7
|
|
|
40 |
ffmpy==0.4.0
|
41 |
+
filelock==3.16.1
|
42 |
flatbuffers==24.3.25
|
43 |
+
fonttools==4.54.1
|
44 |
frozenlist==1.4.1
|
45 |
+
fsspec==2024.9.0
|
46 |
google-ai-generativelanguage==0.6.4
|
47 |
+
google-api-core==2.20.0
|
48 |
+
google-api-python-client==2.146.0
|
49 |
+
google-auth==2.35.0
|
50 |
google-auth-httplib2==0.2.0
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
google-generativeai==0.5.4
|
52 |
+
googleapis-common-protos==1.65.0
|
53 |
+
gradio==4.44.0
|
|
|
54 |
gradio-client==1.3.0
|
55 |
+
greenlet==3.1.1
|
56 |
+
grpcio==1.66.1
|
|
|
57 |
grpcio-status==1.62.3
|
58 |
+
grpclib==0.4.7
|
59 |
h11==0.14.0
|
60 |
+
h2==4.1.0
|
61 |
+
hpack==4.0.0
|
62 |
httpcore==1.0.5
|
63 |
httplib2==0.22.0
|
64 |
httptools==0.6.1
|
65 |
+
httpx==0.27.2
|
66 |
httpx-sse==0.4.0
|
67 |
+
huggingface-hub==0.25.1
|
68 |
humanfriendly==10.0
|
69 |
+
hyperframe==6.0.1
|
70 |
+
idna==3.10
|
71 |
+
importlib-metadata==8.4.0
|
72 |
+
importlib-resources==6.4.5
|
|
|
|
|
73 |
ipykernel==6.29.5
|
74 |
+
ipython==8.27.0
|
|
|
|
|
75 |
jedi==0.19.1
|
76 |
jinja2==3.1.4
|
77 |
+
jiter==0.5.0
|
78 |
jmespath==1.0.1
|
79 |
joblib==1.4.2
|
80 |
+
jupyter-client==8.6.3
|
|
|
|
|
|
|
|
|
|
|
81 |
jupyter-core==5.7.2
|
82 |
+
kiwisolver==1.4.7
|
83 |
+
kubernetes==31.0.0
|
84 |
+
llama-cloud==0.1.0
|
85 |
+
llama-index==0.11.13
|
86 |
+
llama-index-agent-openai==0.3.4
|
87 |
+
llama-index-cli==0.3.1
|
88 |
+
llama-index-core==0.11.13.post1
|
89 |
+
llama-index-embeddings-cohere==0.2.1
|
90 |
+
llama-index-embeddings-openai==0.2.5
|
91 |
+
llama-index-indices-managed-llama-cloud==0.4.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
llama-index-legacy==0.9.48.post3
|
93 |
+
llama-index-llms-gemini==0.3.5
|
94 |
+
llama-index-llms-openai==0.2.9
|
95 |
+
llama-index-multi-modal-llms-openai==0.2.1
|
96 |
+
llama-index-postprocessor-cohere-rerank==0.2.1
|
|
|
|
|
|
|
97 |
llama-index-program-openai==0.2.0
|
98 |
llama-index-question-gen-openai==0.2.0
|
99 |
+
llama-index-readers-file==0.2.2
|
100 |
+
llama-index-readers-llama-parse==0.3.0
|
101 |
llama-index-vector-stores-chroma==0.2.0
|
102 |
+
llama-parse==0.5.6
|
103 |
+
logfire==0.53.0
|
|
|
104 |
markdown-it-py==3.0.0
|
105 |
markupsafe==2.1.5
|
106 |
marshmallow==3.22.0
|
107 |
matplotlib==3.9.2
|
108 |
matplotlib-inline==0.1.7
|
109 |
mdurl==0.1.2
|
110 |
+
mmh3==5.0.1
|
111 |
+
modal==0.64.136
|
|
|
|
|
112 |
monotonic==1.6
|
113 |
mpmath==1.3.0
|
114 |
+
multidict==6.1.0
|
|
|
|
|
115 |
mypy-extensions==1.0.0
|
|
|
|
|
|
|
116 |
nest-asyncio==1.6.0
|
117 |
networkx==3.3
|
118 |
nltk==3.9.1
|
119 |
numpy==1.26.4
|
120 |
oauthlib==3.2.2
|
121 |
+
onnxruntime==1.19.2
|
122 |
+
openai==1.47.1
|
123 |
+
opentelemetry-api==1.27.0
|
124 |
+
opentelemetry-exporter-otlp-proto-common==1.27.0
|
125 |
+
opentelemetry-exporter-otlp-proto-grpc==1.27.0
|
126 |
+
opentelemetry-exporter-otlp-proto-http==1.27.0
|
127 |
+
opentelemetry-instrumentation==0.48b0
|
128 |
+
opentelemetry-instrumentation-asgi==0.48b0
|
129 |
+
opentelemetry-instrumentation-fastapi==0.48b0
|
130 |
+
opentelemetry-proto==1.27.0
|
131 |
+
opentelemetry-sdk==1.27.0
|
132 |
+
opentelemetry-semantic-conventions==0.48b0
|
133 |
+
opentelemetry-util-http==0.48b0
|
134 |
orjson==3.10.7
|
135 |
overrides==7.7.0
|
136 |
packaging==24.1
|
137 |
+
pandas==2.2.3
|
|
|
138 |
parameterized==0.9.0
|
|
|
139 |
parso==0.8.4
|
140 |
pexpect==4.9.0
|
141 |
pillow==10.4.0
|
142 |
+
platformdirs==4.3.6
|
143 |
+
posthog==3.6.6
|
|
|
144 |
prompt-toolkit==3.0.47
|
|
|
145 |
proto-plus==1.24.0
|
146 |
+
protobuf==4.25.5
|
147 |
psutil==6.0.0
|
148 |
ptyprocess==0.7.0
|
149 |
pure-eval==0.2.3
|
150 |
+
pyasn1==0.6.1
|
151 |
+
pyasn1-modules==0.4.1
|
152 |
+
pydantic==2.9.2
|
153 |
+
pydantic-core==2.23.4
|
|
|
|
|
154 |
pydub==0.25.1
|
155 |
pygments==2.18.0
|
156 |
+
pymongo==4.9.1
|
|
|
|
|
157 |
pyparsing==3.1.4
|
158 |
pypdf==4.3.1
|
159 |
pypika==0.48.9
|
160 |
pyproject-hooks==1.1.0
|
161 |
python-dateutil==2.9.0.post0
|
162 |
python-dotenv==1.0.1
|
163 |
+
python-multipart==0.0.10
|
164 |
+
pytz==2024.2
|
165 |
pyyaml==6.0.2
|
166 |
pyzmq==26.2.0
|
167 |
+
regex==2024.9.11
|
|
|
|
|
168 |
requests==2.32.3
|
|
|
169 |
requests-oauthlib==2.0.0
|
170 |
+
rich==13.8.1
|
|
|
171 |
rsa==4.9
|
172 |
+
ruff==0.6.7
|
173 |
s3transfer==0.10.2
|
|
|
|
|
|
|
|
|
174 |
semantic-version==2.10.0
|
175 |
+
setuptools==75.1.0
|
|
|
|
|
|
|
176 |
shellingham==1.5.4
|
177 |
+
sigtools==4.0.1
|
178 |
six==1.16.0
|
179 |
sniffio==1.3.1
|
180 |
soupsieve==2.6
|
181 |
+
sqlalchemy==2.0.35
|
182 |
stack-data==0.6.3
|
183 |
+
starlette==0.38.6
|
184 |
striprtf==0.0.26
|
185 |
+
sympy==1.13.3
|
186 |
+
synchronicity==0.7.6
|
187 |
+
tenacity==8.5.0
|
|
|
188 |
tiktoken==0.7.0
|
189 |
+
tokenizers==0.20.0
|
190 |
+
toml==0.10.2
|
|
|
191 |
tomlkit==0.12.0
|
|
|
192 |
tornado==6.4.1
|
193 |
tqdm==4.66.5
|
194 |
traitlets==5.14.3
|
|
|
|
|
195 |
typer==0.12.5
|
196 |
+
types-certifi==2021.10.8.3
|
197 |
+
types-requests==2.32.0.20240914
|
198 |
+
types-toml==0.10.8.20240310
|
199 |
typing-extensions==4.12.2
|
200 |
typing-inspect==0.9.0
|
201 |
+
tzdata==2024.2
|
202 |
uritemplate==4.1.1
|
203 |
+
urllib3==2.2.3
|
204 |
uvicorn==0.30.6
|
205 |
uvloop==0.20.0
|
206 |
+
watchfiles==0.24.0
|
|
|
207 |
wcwidth==0.2.13
|
|
|
208 |
websocket-client==1.8.0
|
209 |
websockets==12.0
|
210 |
wrapt==1.16.0
|
211 |
+
yarl==1.12.1
|
212 |
+
zipp==3.20.2
|
|
scripts/custom_retriever.py
CHANGED
@@ -1,11 +1,75 @@
|
|
|
|
1 |
import time
|
2 |
-
|
|
|
3 |
|
4 |
import logfire
|
|
|
|
|
5 |
from llama_index.core import QueryBundle
|
|
|
|
|
6 |
from llama_index.core.retrievers import BaseRetriever, VectorIndexRetriever
|
7 |
-
from llama_index.core.schema import NodeWithScore, TextNode
|
8 |
from llama_index.postprocessor.cohere_rerank import CohereRerank
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
|
11 |
class CustomRetriever(BaseRetriever):
|
@@ -15,41 +79,64 @@ class CustomRetriever(BaseRetriever):
|
|
15 |
self,
|
16 |
vector_retriever: VectorIndexRetriever,
|
17 |
document_dict: dict,
|
|
|
|
|
18 |
) -> None:
|
19 |
"""Init params."""
|
20 |
|
21 |
self._vector_retriever = vector_retriever
|
22 |
self._document_dict = document_dict
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
super().__init__()
|
24 |
|
25 |
-
def
|
26 |
"""Retrieve nodes given query."""
|
27 |
|
28 |
# LlamaIndex adds "\ninput is " to the query string
|
29 |
query_bundle.query_str = query_bundle.query_str.replace("\ninput is ", "")
|
30 |
query_bundle.query_str = query_bundle.query_str.rstrip()
|
31 |
|
32 |
-
logfire.info(f"Retrieving
|
33 |
start = time.time()
|
34 |
-
nodes = self._vector_retriever.
|
|
|
35 |
|
36 |
-
|
37 |
-
logfire.info(f"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
# Filter out nodes with the same ref_doc_id
|
40 |
def filter_nodes_by_unique_doc_id(nodes):
|
41 |
unique_nodes = {}
|
42 |
for node in nodes:
|
43 |
-
doc_id = node.node.ref_doc_id
|
|
|
44 |
if doc_id is not None and doc_id not in unique_nodes:
|
45 |
unique_nodes[doc_id] = node
|
46 |
return list(unique_nodes.values())
|
47 |
|
48 |
nodes = filter_nodes_by_unique_doc_id(nodes)
|
49 |
-
logfire.info(
|
50 |
-
|
51 |
-
)
|
52 |
-
logfire.info(f"Nodes retrieved: {nodes}")
|
53 |
|
54 |
nodes_context = []
|
55 |
for node in nodes:
|
@@ -59,32 +146,154 @@ class CustomRetriever(BaseRetriever):
|
|
59 |
# print("Score\t", node.score)
|
60 |
# print("Metadata\t", node.metadata)
|
61 |
# print("-_" * 20)
|
62 |
-
|
63 |
-
continue
|
64 |
if node.metadata["retrieve_doc"] == True:
|
65 |
# print("This node will be replaced by the document")
|
66 |
-
doc = self._document_dict[node.node.ref_doc_id]
|
|
|
|
|
67 |
# print(doc.text)
|
68 |
new_node = NodeWithScore(
|
69 |
-
node=TextNode(text=doc.text, metadata=node.metadata), # type: ignore
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
score=node.score,
|
71 |
)
|
72 |
nodes_context.append(new_node)
|
73 |
else:
|
|
|
74 |
nodes_context.append(node)
|
75 |
|
76 |
try:
|
77 |
-
reranker = CohereRerank(top_n=
|
78 |
nodes_context = reranker.postprocess_nodes(nodes_context, query_bundle)
|
79 |
-
|
80 |
-
for node in nodes_context:
|
81 |
-
if node.score < 0.10: # type: ignore
|
82 |
-
continue
|
83 |
-
else:
|
84 |
-
nodes_filtered.append(node)
|
85 |
-
logfire.info(f"Cohere raranking to {len(nodes_filtered)} nodes")
|
86 |
-
|
87 |
-
return nodes_filtered
|
88 |
except Exception as e:
|
89 |
-
|
90 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
import time
|
3 |
+
import traceback
|
4 |
+
from typing import List, Optional
|
5 |
|
6 |
import logfire
|
7 |
+
import tiktoken
|
8 |
+
from cohere import AsyncClient
|
9 |
from llama_index.core import QueryBundle
|
10 |
+
from llama_index.core.async_utils import run_async_tasks
|
11 |
+
from llama_index.core.callbacks import CBEventType, EventPayload
|
12 |
from llama_index.core.retrievers import BaseRetriever, VectorIndexRetriever
|
13 |
+
from llama_index.core.schema import MetadataMode, NodeWithScore, QueryBundle, TextNode
|
14 |
from llama_index.postprocessor.cohere_rerank import CohereRerank
|
15 |
+
from llama_index.postprocessor.cohere_rerank.base import CohereRerank
|
16 |
+
|
17 |
+
|
18 |
+
class AsyncCohereRerank(CohereRerank):
|
19 |
+
def __init__(
|
20 |
+
self,
|
21 |
+
top_n: int = 5,
|
22 |
+
model: str = "rerank-english-v3.0",
|
23 |
+
api_key: Optional[str] = None,
|
24 |
+
) -> None:
|
25 |
+
super().__init__(top_n=top_n, model=model, api_key=api_key)
|
26 |
+
self._api_key = api_key
|
27 |
+
self._model = model
|
28 |
+
self._top_n = top_n
|
29 |
+
|
30 |
+
async def apostprocess_nodes(
|
31 |
+
self,
|
32 |
+
nodes: List[NodeWithScore],
|
33 |
+
query_bundle: Optional[QueryBundle] = None,
|
34 |
+
) -> List[NodeWithScore]:
|
35 |
+
if query_bundle is None:
|
36 |
+
raise ValueError("Query bundle must be provided.")
|
37 |
+
|
38 |
+
if len(nodes) == 0:
|
39 |
+
return []
|
40 |
+
|
41 |
+
async_client = AsyncClient(api_key=self._api_key)
|
42 |
+
|
43 |
+
with self.callback_manager.event(
|
44 |
+
CBEventType.RERANKING,
|
45 |
+
payload={
|
46 |
+
EventPayload.NODES: nodes,
|
47 |
+
EventPayload.MODEL_NAME: self._model,
|
48 |
+
EventPayload.QUERY_STR: query_bundle.query_str,
|
49 |
+
EventPayload.TOP_K: self._top_n,
|
50 |
+
},
|
51 |
+
) as event:
|
52 |
+
texts = [
|
53 |
+
node.node.get_content(metadata_mode=MetadataMode.EMBED)
|
54 |
+
for node in nodes
|
55 |
+
]
|
56 |
+
|
57 |
+
results = await async_client.rerank(
|
58 |
+
model=self._model,
|
59 |
+
top_n=self._top_n,
|
60 |
+
query=query_bundle.query_str,
|
61 |
+
documents=texts,
|
62 |
+
)
|
63 |
+
|
64 |
+
new_nodes = []
|
65 |
+
for result in results.results:
|
66 |
+
new_node_with_score = NodeWithScore(
|
67 |
+
node=nodes[result.index].node, score=result.relevance_score
|
68 |
+
)
|
69 |
+
new_nodes.append(new_node_with_score)
|
70 |
+
event.on_end(payload={EventPayload.NODES: new_nodes})
|
71 |
+
|
72 |
+
return new_nodes
|
73 |
|
74 |
|
75 |
class CustomRetriever(BaseRetriever):
|
|
|
79 |
self,
|
80 |
vector_retriever: VectorIndexRetriever,
|
81 |
document_dict: dict,
|
82 |
+
keyword_retriever,
|
83 |
+
mode: str = "AND",
|
84 |
) -> None:
|
85 |
"""Init params."""
|
86 |
|
87 |
self._vector_retriever = vector_retriever
|
88 |
self._document_dict = document_dict
|
89 |
+
|
90 |
+
self._keyword_retriever = keyword_retriever
|
91 |
+
if mode not in ("AND", "OR"):
|
92 |
+
raise ValueError("Invalid mode.")
|
93 |
+
self._mode = mode
|
94 |
+
|
95 |
super().__init__()
|
96 |
|
97 |
+
async def _aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
|
98 |
"""Retrieve nodes given query."""
|
99 |
|
100 |
# LlamaIndex adds "\ninput is " to the query string
|
101 |
query_bundle.query_str = query_bundle.query_str.replace("\ninput is ", "")
|
102 |
query_bundle.query_str = query_bundle.query_str.rstrip()
|
103 |
|
104 |
+
# logfire.info(f"Retrieving nodes with string: '{query_bundle}'")
|
105 |
start = time.time()
|
106 |
+
nodes = await self._vector_retriever.aretrieve(query_bundle)
|
107 |
+
keyword_nodes = await self._keyword_retriever.aretrieve(query_bundle)
|
108 |
|
109 |
+
# logfire.info(f"Number of vector nodes: {len(nodes)}")
|
110 |
+
# logfire.info(f"Number of keyword nodes: {len(keyword_nodes)}")
|
111 |
+
|
112 |
+
vector_ids = {n.node.node_id for n in nodes}
|
113 |
+
keyword_ids = {n.node.node_id for n in keyword_nodes}
|
114 |
+
|
115 |
+
combined_dict = {n.node.node_id: n for n in nodes}
|
116 |
+
combined_dict.update({n.node.node_id: n for n in keyword_nodes})
|
117 |
+
|
118 |
+
if self._mode == "AND":
|
119 |
+
retrieve_ids = vector_ids.intersection(keyword_ids)
|
120 |
+
else:
|
121 |
+
retrieve_ids = vector_ids.union(keyword_ids)
|
122 |
+
|
123 |
+
nodes = [combined_dict[rid] for rid in retrieve_ids]
|
124 |
|
125 |
# Filter out nodes with the same ref_doc_id
|
126 |
def filter_nodes_by_unique_doc_id(nodes):
|
127 |
unique_nodes = {}
|
128 |
for node in nodes:
|
129 |
+
# doc_id = node.node.ref_doc_id
|
130 |
+
doc_id = node.node.source_node.node_id
|
131 |
if doc_id is not None and doc_id not in unique_nodes:
|
132 |
unique_nodes[doc_id] = node
|
133 |
return list(unique_nodes.values())
|
134 |
|
135 |
nodes = filter_nodes_by_unique_doc_id(nodes)
|
136 |
+
# logfire.info(
|
137 |
+
# f"Number of nodes after filtering the ones with same ref_doc_id: {len(nodes)}"
|
138 |
+
# )
|
139 |
+
# logfire.info(f"Nodes retrieved: {nodes}")
|
140 |
|
141 |
nodes_context = []
|
142 |
for node in nodes:
|
|
|
146 |
# print("Score\t", node.score)
|
147 |
# print("Metadata\t", node.metadata)
|
148 |
# print("-_" * 20)
|
149 |
+
doc_id = node.node.source_node.node_id # type: ignore
|
|
|
150 |
if node.metadata["retrieve_doc"] == True:
|
151 |
# print("This node will be replaced by the document")
|
152 |
+
# doc = self._document_dict[node.node.ref_doc_id]
|
153 |
+
# print("retrieved doc == True")
|
154 |
+
doc = self._document_dict[doc_id]
|
155 |
# print(doc.text)
|
156 |
new_node = NodeWithScore(
|
157 |
+
node=TextNode(text=doc.text, metadata=node.metadata, id_=doc_id), # type: ignore
|
158 |
+
score=node.score,
|
159 |
+
)
|
160 |
+
nodes_context.append(new_node)
|
161 |
+
else:
|
162 |
+
node.node.node_id = doc_id
|
163 |
+
nodes_context.append(node)
|
164 |
+
|
165 |
+
try:
|
166 |
+
reranker = AsyncCohereRerank(top_n=3, model="rerank-english-v3.0")
|
167 |
+
nodes_context = await reranker.apostprocess_nodes(
|
168 |
+
nodes_context, query_bundle
|
169 |
+
)
|
170 |
+
|
171 |
+
except Exception as e:
|
172 |
+
error_msg = f"Error during reranking: {type(e).__name__}: {str(e)}\n"
|
173 |
+
error_msg += "Traceback:\n"
|
174 |
+
error_msg += traceback.format_exc()
|
175 |
+
logfire.error(error_msg)
|
176 |
+
|
177 |
+
nodes_filtered = []
|
178 |
+
total_tokens = 0
|
179 |
+
enc = tiktoken.encoding_for_model("gpt-4o-mini")
|
180 |
+
for node in nodes_context:
|
181 |
+
if node.score < 0.10: # type: ignore
|
182 |
+
continue
|
183 |
+
|
184 |
+
# Count tokens
|
185 |
+
if "tokens" in node.node.metadata:
|
186 |
+
node_tokens = node.node.metadata["tokens"]
|
187 |
+
else:
|
188 |
+
node_tokens = len(enc.encode(node.node.text)) # type: ignore
|
189 |
+
|
190 |
+
if total_tokens + node_tokens > 100_000:
|
191 |
+
logfire.info("Skipping node due to token count exceeding 100k")
|
192 |
+
break
|
193 |
+
|
194 |
+
total_tokens += node_tokens
|
195 |
+
nodes_filtered.append(node)
|
196 |
+
|
197 |
+
# logfire.info(f"Final nodes to context {len(nodes_filtered)} nodes")
|
198 |
+
# logfire.info(f"Total tokens: {total_tokens}")
|
199 |
+
|
200 |
+
# duration = time.time() - start
|
201 |
+
# logfire.info(f"Retrieving nodes took {duration:.2f}s")
|
202 |
+
|
203 |
+
return nodes_filtered[:3]
|
204 |
+
|
205 |
+
# def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
|
206 |
+
# return asyncio.run(self._aretrieve(query_bundle))
|
207 |
+
|
208 |
+
def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
|
209 |
+
"""Retrieve nodes given query."""
|
210 |
+
|
211 |
+
# LlamaIndex adds "\ninput is " to the query string
|
212 |
+
query_bundle.query_str = query_bundle.query_str.replace("\ninput is ", "")
|
213 |
+
query_bundle.query_str = query_bundle.query_str.rstrip()
|
214 |
+
logfire.info(f"Retrieving nodes with string: '{query_bundle}'")
|
215 |
+
|
216 |
+
start = time.time()
|
217 |
+
nodes = self._vector_retriever.retrieve(query_bundle)
|
218 |
+
keyword_nodes = self._keyword_retriever.retrieve(query_bundle)
|
219 |
+
|
220 |
+
logfire.info(f"Number of vector nodes: {len(nodes)}")
|
221 |
+
logfire.info(f"Number of keyword nodes: {len(keyword_nodes)}")
|
222 |
+
|
223 |
+
vector_ids = {n.node.node_id for n in nodes}
|
224 |
+
keyword_ids = {n.node.node_id for n in keyword_nodes}
|
225 |
+
|
226 |
+
combined_dict = {n.node.node_id: n for n in nodes}
|
227 |
+
combined_dict.update({n.node.node_id: n for n in keyword_nodes})
|
228 |
+
|
229 |
+
if self._mode == "AND":
|
230 |
+
retrieve_ids = vector_ids.intersection(keyword_ids)
|
231 |
+
else:
|
232 |
+
retrieve_ids = vector_ids.union(keyword_ids)
|
233 |
+
|
234 |
+
nodes = [combined_dict[rid] for rid in retrieve_ids]
|
235 |
+
|
236 |
+
def filter_nodes_by_unique_doc_id(nodes):
|
237 |
+
unique_nodes = {}
|
238 |
+
for node in nodes:
|
239 |
+
# doc_id = node.node.ref_doc_id
|
240 |
+
doc_id = node.node.source_node.node_id
|
241 |
+
if doc_id is not None and doc_id not in unique_nodes:
|
242 |
+
unique_nodes[doc_id] = node
|
243 |
+
return list(unique_nodes.values())
|
244 |
+
|
245 |
+
nodes = filter_nodes_by_unique_doc_id(nodes)
|
246 |
+
logfire.info(
|
247 |
+
f"Number of nodes after filtering the ones with same ref_doc_id: {len(nodes)}"
|
248 |
+
)
|
249 |
+
logfire.info(f"Nodes retrieved: {nodes}")
|
250 |
+
|
251 |
+
nodes_context = []
|
252 |
+
for node in nodes:
|
253 |
+
doc_id = node.node.source_node.node_id # type: ignore
|
254 |
+
if node.metadata["retrieve_doc"] == True:
|
255 |
+
doc = self._document_dict[doc_id]
|
256 |
+
new_node = NodeWithScore(
|
257 |
+
node=TextNode(text=doc.text, metadata=node.metadata, id_=doc_id), # type: ignore
|
258 |
score=node.score,
|
259 |
)
|
260 |
nodes_context.append(new_node)
|
261 |
else:
|
262 |
+
node.node.node_id = doc_id
|
263 |
nodes_context.append(node)
|
264 |
|
265 |
try:
|
266 |
+
reranker = CohereRerank(top_n=3, model="rerank-english-v3.0")
|
267 |
nodes_context = reranker.postprocess_nodes(nodes_context, query_bundle)
|
268 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
269 |
except Exception as e:
|
270 |
+
error_msg = f"Error during reranking: {type(e).__name__}: {str(e)}\n"
|
271 |
+
error_msg += "Traceback:\n"
|
272 |
+
error_msg += traceback.format_exc()
|
273 |
+
logfire.error(error_msg)
|
274 |
+
|
275 |
+
nodes_filtered = []
|
276 |
+
total_tokens = 0
|
277 |
+
enc = tiktoken.encoding_for_model("gpt-4o-mini")
|
278 |
+
for node in nodes_context:
|
279 |
+
if node.score < 0.10: # type: ignore
|
280 |
+
continue
|
281 |
+
if "tokens" in node.node.metadata:
|
282 |
+
node_tokens = node.node.metadata["tokens"]
|
283 |
+
else:
|
284 |
+
node_tokens = len(enc.encode(node.node.text)) # type: ignore
|
285 |
+
|
286 |
+
if total_tokens + node_tokens > 100_000:
|
287 |
+
logfire.info("Skipping node due to token count exceeding 100k")
|
288 |
+
break
|
289 |
+
|
290 |
+
total_tokens += node_tokens
|
291 |
+
nodes_filtered.append(node)
|
292 |
+
|
293 |
+
logfire.info(f"Final nodes to context {len(nodes_filtered)} nodes")
|
294 |
+
logfire.info(f"Total tokens: {total_tokens}")
|
295 |
+
|
296 |
+
duration = time.time() - start
|
297 |
+
logfire.info(f"Retrieving nodes took {duration:.2f}s")
|
298 |
+
|
299 |
+
return nodes_filtered[:3]
|
scripts/main.py
CHANGED
@@ -6,53 +6,59 @@ from llama_index.agent.openai import OpenAIAgent
|
|
6 |
from llama_index.core.llms import MessageRole
|
7 |
from llama_index.core.memory import ChatSummaryMemoryBuffer
|
8 |
from llama_index.core.tools import RetrieverTool, ToolMetadata
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
from llama_index.llms.openai import OpenAI
|
10 |
from prompts import system_message_openai_agent
|
11 |
-
from setup import (
|
12 |
AVAILABLE_SOURCES,
|
13 |
AVAILABLE_SOURCES_UI,
|
14 |
CONCURRENCY_COUNT,
|
15 |
-
|
16 |
-
custom_retriever_llama_index,
|
17 |
-
custom_retriever_openai_cookbooks,
|
18 |
-
custom_retriever_peft,
|
19 |
-
custom_retriever_transformers,
|
20 |
-
custom_retriever_trl,
|
21 |
)
|
22 |
|
23 |
|
24 |
def update_query_engine_tools(selected_sources):
|
25 |
tools = []
|
26 |
source_mapping = {
|
27 |
-
"Transformers Docs": (
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
),
|
32 |
-
"PEFT Docs": (
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
),
|
37 |
-
"TRL Docs": (
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
),
|
42 |
-
"LlamaIndex Docs": (
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
),
|
47 |
-
"OpenAI Cookbooks": (
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
),
|
52 |
-
"LangChain Docs": (
|
53 |
-
|
54 |
-
|
55 |
-
|
|
|
|
|
|
|
|
|
|
|
56 |
),
|
57 |
}
|
58 |
|
@@ -80,9 +86,7 @@ def generate_completion(
|
|
80 |
memory,
|
81 |
):
|
82 |
with logfire.span("Running query"):
|
83 |
-
logfire.info(f"query: {query}")
|
84 |
-
logfire.info(f"model: {model}")
|
85 |
-
logfire.info(f"sources: {sources}")
|
86 |
|
87 |
chat_list = memory.get()
|
88 |
|
@@ -102,7 +106,34 @@ def generate_completion(
|
|
102 |
client = llm._get_client()
|
103 |
logfire.instrument_openai(client)
|
104 |
|
105 |
-
query_engine_tools = update_query_engine_tools(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
|
107 |
agent = OpenAIAgent.from_tools(
|
108 |
llm=llm,
|
@@ -151,8 +182,16 @@ def format_sources(completion) -> str:
|
|
151 |
)
|
152 |
document_template: str = "[π {source}: {title}]({url}), relevance: {score:2.2f}"
|
153 |
all_documents = []
|
154 |
-
for source in completion.sources:
|
155 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
156 |
document = document_template.format(
|
157 |
title=src.metadata["title"],
|
158 |
score=src.score,
|
@@ -189,13 +228,13 @@ sources = gr.CheckboxGroup(
|
|
189 |
"LlamaIndex Docs",
|
190 |
"LangChain Docs",
|
191 |
"OpenAI Cookbooks",
|
|
|
192 |
],
|
193 |
interactive=True,
|
194 |
)
|
195 |
model = gr.Dropdown(
|
196 |
[
|
197 |
"gpt-4o-mini",
|
198 |
-
"gpt-4o",
|
199 |
],
|
200 |
label="Model",
|
201 |
value="gpt-4o-mini",
|
|
|
6 |
from llama_index.core.llms import MessageRole
|
7 |
from llama_index.core.memory import ChatSummaryMemoryBuffer
|
8 |
from llama_index.core.tools import RetrieverTool, ToolMetadata
|
9 |
+
from llama_index.core.vector_stores import (
|
10 |
+
FilterCondition,
|
11 |
+
FilterOperator,
|
12 |
+
MetadataFilter,
|
13 |
+
MetadataFilters,
|
14 |
+
)
|
15 |
from llama_index.llms.openai import OpenAI
|
16 |
from prompts import system_message_openai_agent
|
17 |
+
from setup import ( # custom_retriever_langchain,; custom_retriever_llama_index,; custom_retriever_openai_cookbooks,; custom_retriever_peft,; custom_retriever_transformers,; custom_retriever_trl,
|
18 |
AVAILABLE_SOURCES,
|
19 |
AVAILABLE_SOURCES_UI,
|
20 |
CONCURRENCY_COUNT,
|
21 |
+
custom_retriever_all_sources,
|
|
|
|
|
|
|
|
|
|
|
22 |
)
|
23 |
|
24 |
|
25 |
def update_query_engine_tools(selected_sources):
|
26 |
tools = []
|
27 |
source_mapping = {
|
28 |
+
# "Transformers Docs": (
|
29 |
+
# custom_retriever_transformers,
|
30 |
+
# "Transformers_information",
|
31 |
+
# """Useful for general questions asking about the artificial intelligence (AI) field. Employ this tool to fetch information on topics such as language models (LLMs) models such as Llama3 and theory (transformer architectures), tips on prompting, quantization, etc.""",
|
32 |
+
# ),
|
33 |
+
# "PEFT Docs": (
|
34 |
+
# custom_retriever_peft,
|
35 |
+
# "PEFT_information",
|
36 |
+
# """Useful for questions asking about efficient LLM fine-tuning. Employ this tool to fetch information on topics such as LoRA, QLoRA, etc.""",
|
37 |
+
# ),
|
38 |
+
# "TRL Docs": (
|
39 |
+
# custom_retriever_trl,
|
40 |
+
# "TRL_information",
|
41 |
+
# """Useful for questions asking about fine-tuning LLMs with reinforcement learning (RLHF). Includes information about the Supervised Fine-tuning step (SFT), Reward Modeling step (RM), and the Proximal Policy Optimization (PPO) step.""",
|
42 |
+
# ),
|
43 |
+
# "LlamaIndex Docs": (
|
44 |
+
# custom_retriever_llama_index,
|
45 |
+
# "LlamaIndex_information",
|
46 |
+
# """Useful for questions asking about retrieval augmented generation (RAG) with LLMs and embedding models. It is the documentation of a framework, includes info about fine-tuning embedding models, building chatbots, and agents with llms, using vector databases, embeddings, information retrieval with cosine similarity or bm25, etc.""",
|
47 |
+
# ),
|
48 |
+
# "OpenAI Cookbooks": (
|
49 |
+
# custom_retriever_openai_cookbooks,
|
50 |
+
# "openai_cookbooks_info",
|
51 |
+
# """Useful for questions asking about accomplishing common tasks with theΒ OpenAI API. Returns example code and guides stored in Jupyter notebooks, including info about ChatGPT GPT actions, OpenAI Assistants API, and How to fine-tune OpenAI's GPT-4o and GPT-4o-mini models with the OpenAI API.""",
|
52 |
+
# ),
|
53 |
+
# "LangChain Docs": (
|
54 |
+
# custom_retriever_langchain,
|
55 |
+
# "langchain_info",
|
56 |
+
# """Useful for questions asking about the LangChain framework. It is the documentation of the LangChain framework, includes info about building chains, agents, and tools, using memory, prompts, callbacks, etc.""",
|
57 |
+
# ),
|
58 |
+
"All Sources": (
|
59 |
+
custom_retriever_all_sources,
|
60 |
+
"all_sources_info",
|
61 |
+
"""Useful for questions asking about information in the field of AI.""",
|
62 |
),
|
63 |
}
|
64 |
|
|
|
86 |
memory,
|
87 |
):
|
88 |
with logfire.span("Running query"):
|
89 |
+
logfire.info(f"User query: {query}")
|
|
|
|
|
90 |
|
91 |
chat_list = memory.get()
|
92 |
|
|
|
106 |
client = llm._get_client()
|
107 |
logfire.instrument_openai(client)
|
108 |
|
109 |
+
query_engine_tools = update_query_engine_tools(["All Sources"])
|
110 |
+
|
111 |
+
filter_list = []
|
112 |
+
source_mapping = {
|
113 |
+
"Transformers Docs": "transformers",
|
114 |
+
"PEFT Docs": "peft",
|
115 |
+
"TRL Docs": "trl",
|
116 |
+
"LlamaIndex Docs": "llama_index",
|
117 |
+
"LangChain Docs": "langchain",
|
118 |
+
"OpenAI Cookbooks": "openai_cookbooks",
|
119 |
+
"Towards AI Blog": "tai_blog",
|
120 |
+
}
|
121 |
+
|
122 |
+
for source in sources:
|
123 |
+
if source in source_mapping:
|
124 |
+
filter_list.append(
|
125 |
+
MetadataFilter(
|
126 |
+
key="source",
|
127 |
+
operator=FilterOperator.EQ,
|
128 |
+
value=source_mapping[source],
|
129 |
+
)
|
130 |
+
)
|
131 |
+
|
132 |
+
filters = MetadataFilters(
|
133 |
+
filters=filter_list,
|
134 |
+
condition=FilterCondition.OR,
|
135 |
+
)
|
136 |
+
query_engine_tools[0].retriever._vector_retriever._filters = filters
|
137 |
|
138 |
agent = OpenAIAgent.from_tools(
|
139 |
llm=llm,
|
|
|
182 |
)
|
183 |
document_template: str = "[π {source}: {title}]({url}), relevance: {score:2.2f}"
|
184 |
all_documents = []
|
185 |
+
for source in completion.sources: # looping over list[ToolOutput]
|
186 |
+
if isinstance(source.raw_output, Exception):
|
187 |
+
logfire.error(f"Error in source output: {source.raw_output}")
|
188 |
+
# pdb.set_trace()
|
189 |
+
continue
|
190 |
+
|
191 |
+
if not isinstance(source.raw_output, list):
|
192 |
+
logfire.warn(f"Unexpected source output type: {type(source.raw_output)}")
|
193 |
+
continue
|
194 |
+
for src in source.raw_output: # looping over list[NodeWithScore]
|
195 |
document = document_template.format(
|
196 |
title=src.metadata["title"],
|
197 |
score=src.score,
|
|
|
228 |
"LlamaIndex Docs",
|
229 |
"LangChain Docs",
|
230 |
"OpenAI Cookbooks",
|
231 |
+
# "All Sources",
|
232 |
],
|
233 |
interactive=True,
|
234 |
)
|
235 |
model = gr.Dropdown(
|
236 |
[
|
237 |
"gpt-4o-mini",
|
|
|
238 |
],
|
239 |
label="Model",
|
240 |
value="gpt-4o-mini",
|
scripts/prompts.py
CHANGED
@@ -1,19 +1,123 @@
|
|
1 |
-
|
|
|
2 |
|
3 |
-
|
4 |
|
5 |
-
|
6 |
-
e.g:
|
7 |
-
User question: 'How can I fine-tune an LLM?'
|
8 |
-
Input to the tool: 'Fine-tuning an LLM'
|
9 |
|
10 |
-
|
11 |
-
Input to the tool: 'Quantization for LLMs'
|
12 |
|
13 |
-
|
14 |
-
Input to the tool: 'Building an AI Agent'
|
15 |
|
16 |
-
Only some information returned by the tools might be relevant to the question, so ignore the irrelevant part and answer the question with what you have.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
Your responses are exclusively based on the output provided by the tools. Refrain from incorporating information not directly obtained from the tool's responses.
|
19 |
|
@@ -25,11 +129,9 @@ Should the tools repository lack information on the queried topic, politely info
|
|
25 |
|
26 |
At the end of your answers, always invite the students to ask deeper questions about the topic if they have any. Make sure reformulate the question to the tool to capture this new angle or more profound layer of inquiry.
|
27 |
|
28 |
-
Do not refer to the documentation directly, but use the information provided within it to answer questions.
|
29 |
|
30 |
If code is provided in the information, share it with the students. It's important to provide complete code blocks so they can execute the code when they copy and paste them.
|
31 |
|
32 |
Make sure to format your answers in Markdown format, including code blocks and snippets.
|
33 |
-
|
34 |
-
Politely reject questions not related to AI, while being cautious not to reject unfamiliar terms or acronyms too quickly. If a question seems unrelated but you suspect it might contain AI-related terminology.
|
35 |
"""
|
|
|
1 |
+
# # Prompt 1
|
2 |
+
# system_message_openai_agent = """You are an AI teacher, answering questions from students of an applied AI course on Large Language Models (LLMs or llm) and Retrieval Augmented Generation (RAG) for LLMs.
|
3 |
|
4 |
+
# Topics covered include training models, fine-tuning models, giving memory to LLMs, prompting tips, hallucinations and bias, vector databases, transformer architectures, embeddings, RAG frameworks such as Langchain and LlamaIndex, making LLMs interact with tools, AI agents, reinforcement learning with human feedback (RLHF). Questions should be understood in this context.
|
5 |
|
6 |
+
# Your answers are aimed to teach students, so they should be complete, clear, and easy to understand.
|
|
|
|
|
|
|
7 |
|
8 |
+
# Use the available tools to gather insights pertinent to the field of AI.
|
|
|
9 |
|
10 |
+
# To answer student questions, always use the all_sources_info tool plus another one simultaneously. Meaning that should be using two tools in total.
|
|
|
11 |
|
12 |
+
# Only some information returned by the tools might be relevant to the question, so ignore the irrelevant part and answer the question with what you have.
|
13 |
+
|
14 |
+
# Your responses are exclusively based on the output provided by the tools. Refrain from incorporating information not directly obtained from the tool's responses.
|
15 |
+
|
16 |
+
# When the conversation deepens or shifts focus within a topic, adapt your input to the tools to reflect these nuances. This means if a user requests further elaboration on a specific aspect of a previously discussed topic, you should reformulate your input to the tool to capture this new angle or more profound layer of inquiry.
|
17 |
+
|
18 |
+
# Provide comprehensive answers, ideally structured in multiple paragraphs, drawing from the tool's variety of relevant details. The depth and breadth of your responses should align with the scope and specificity of the information retrieved.
|
19 |
+
|
20 |
+
# Should the tools repository lack information on the queried topic, politely inform the user that the question transcends the bounds of your current knowledge base, citing the absence of relevant content in the tool's documentation.
|
21 |
+
|
22 |
+
# At the end of your answers, always invite the students to ask deeper questions about the topic if they have any. Make sure reformulate the question to the tool to capture this new angle or more profound layer of inquiry.
|
23 |
+
|
24 |
+
# Do not refer to the documentation directly, but use the information provided within it to answer questions.
|
25 |
+
|
26 |
+
# If code is provided in the information, share it with the students. It's important to provide complete code blocks so they can execute the code when they copy and paste them.
|
27 |
+
|
28 |
+
# Make sure to format your answers in Markdown format, including code blocks and snippets.
|
29 |
+
# """
|
30 |
+
|
31 |
+
# Prompt 2
|
32 |
+
# system_message_openai_agent = """You are an AI teacher, answering questions from students of an applied AI course on Large Language Models (LLMs or llm) and Retrieval Augmented Generation (RAG) for LLMs.
|
33 |
+
|
34 |
+
# Topics covered include training models, fine-tuning models, giving memory to LLMs, prompting tips, hallucinations and bias, vector databases, transformer architectures, embeddings, RAG frameworks such as Langchain and LlamaIndex, making LLMs interact with tools, AI agents, reinforcement learning with human feedback (RLHF). Questions should be understood in this context.
|
35 |
+
|
36 |
+
# Your answers are aimed to teach students, so they should be complete, clear, and easy to understand.
|
37 |
+
|
38 |
+
# Use the available tools to gather insights pertinent to the field of AI.
|
39 |
+
|
40 |
+
# To answer student questions, always use the all_sources_info tool. For complex questions, you can decompose the user question into TWO sub questions (you are limited to two sub-questions) that can be answered by the tools.
|
41 |
+
|
42 |
+
# These are the guidelines to consider if you decide to create sub questions:
|
43 |
+
# * Be as specific as possible
|
44 |
+
# * The two sub questions should be relevant to the user question
|
45 |
+
# * The two sub questions should be answerable by the tools provided
|
46 |
+
|
47 |
+
# Only some information returned by the tools might be relevant to the question, so ignore the irrelevant part and answer the question with what you have.
|
48 |
+
|
49 |
+
# Your responses are exclusively based on the output provided by the tools. Refrain from incorporating information not directly obtained from the tool's responses.
|
50 |
+
|
51 |
+
# When the conversation deepens or shifts focus within a topic, adapt your input to the tools to reflect these nuances. This means if a user requests further elaboration on a specific aspect of a previously discussed topic, you should reformulate your input to the tool to capture this new angle or more profound layer of inquiry.
|
52 |
+
|
53 |
+
# Provide comprehensive answers, ideally structured in multiple paragraphs, drawing from the tool's variety of relevant details. The depth and breadth of your responses should align with the scope and specificity of the information retrieved.
|
54 |
+
|
55 |
+
# Should the tools repository lack information on the queried topic, politely inform the user that the question transcends the bounds of your current knowledge base, citing the absence of relevant content in the tool's documentation.
|
56 |
+
|
57 |
+
# At the end of your answers, always invite the students to ask deeper questions about the topic if they have any. Make sure reformulate the question to the tool to capture this new angle or more profound layer of inquiry.
|
58 |
+
|
59 |
+
# Do not refer to the documentation directly, but use the information provided within it to answer questions.
|
60 |
+
|
61 |
+
# If code is provided in the information, share it with the students. It's important to provide complete code blocks so they can execute the code when they copy and paste them.
|
62 |
+
|
63 |
+
# Make sure to format your answers in Markdown format, including code blocks and snippets.
|
64 |
+
# """
|
65 |
+
|
66 |
+
# # Prompt 3
|
67 |
+
# system_message_openai_agent = """You are an AI teacher, answering questions from students of an applied AI course on Large Language Models (LLMs or llm) and Retrieval Augmented Generation (RAG) for LLMs.
|
68 |
+
|
69 |
+
# Topics covered include training models, fine-tuning models, giving memory to LLMs, prompting tips, hallucinations and bias, vector databases, transformer architectures, embeddings, RAG frameworks such as Langchain and LlamaIndex, making LLMs interact with tools, AI agents, reinforcement learning with human feedback (RLHF). Questions should be understood in this context.
|
70 |
+
|
71 |
+
# Your answers are aimed to teach students, so they should be complete, clear, and easy to understand.
|
72 |
+
|
73 |
+
# Use the available tools to gather insights pertinent to the field of AI.
|
74 |
+
|
75 |
+
# To answer student questions, always use the all_sources_info tool. For each question, you should decompose the user question into TWO sub questions (you are limited to two sub-questions) that can be answered by the tools.
|
76 |
+
|
77 |
+
# These are the guidelines to consider when creating sub questions:
|
78 |
+
# * Be as specific as possible
|
79 |
+
# * The two sub questions should be relevant to the user question
|
80 |
+
# * The two sub questions should be answerable by the tools provided
|
81 |
+
|
82 |
+
# Only some information returned by the tools might be relevant to the user question, so ignore the irrelevant part and answer the user question with what you have.
|
83 |
+
|
84 |
+
# Your responses are exclusively based on the output provided by the tools. Refrain from incorporating information not directly obtained from the tool's responses.
|
85 |
+
|
86 |
+
# When the conversation deepens or shifts focus within a topic, adapt your input to the tools to reflect these nuances. This means if a user requests further elaboration on a specific aspect of a previously discussed topic, you should reformulate your input to the tool to capture this new angle or more profound layer of inquiry.
|
87 |
+
|
88 |
+
# Provide comprehensive answers, ideally structured in multiple paragraphs, drawing from the tool's variety of relevant details. The depth and breadth of your responses should align with the scope and specificity of the information retrieved.
|
89 |
+
|
90 |
+
# Should the tools repository lack information on the queried topic, politely inform the user that the question transcends the bounds of your current knowledge base, citing the absence of relevant content in the tool's documentation.
|
91 |
+
|
92 |
+
# At the end of your answers, always invite the students to ask deeper questions about the topic if they have any. Make sure reformulate the question to the tool to capture this new angle or more profound layer of inquiry.
|
93 |
+
|
94 |
+
# Do not refer to the documentation directly, but use the information provided within it to answer questions.
|
95 |
+
|
96 |
+
# If code is provided in the information, share it with the students. It's important to provide complete code blocks so they can execute the code when they copy and paste them.
|
97 |
+
|
98 |
+
# Make sure to format your answers in Markdown format, including code blocks and snippets.
|
99 |
+
# """
|
100 |
+
|
101 |
+
|
102 |
+
# Prompt 4 Trying to make it like #1
|
103 |
+
system_message_openai_agent = """You are an AI teacher, answering questions from students of an applied AI course on Large Language Models (LLMs or llm) and Retrieval Augmented Generation (RAG) for LLMs.
|
104 |
+
|
105 |
+
Topics covered include training models, fine-tuning models, giving memory to LLMs, prompting tips, hallucinations and bias, vector databases, transformer architectures, embeddings, RAG frameworks such as Langchain and LlamaIndex, making LLMs interact with tools, AI agents, reinforcement learning with human feedback (RLHF). Questions should be understood in this context.
|
106 |
+
|
107 |
+
Your answers are aimed to teach students, so they should be complete, clear, and easy to understand.
|
108 |
+
|
109 |
+
Use the available tools to gather insights pertinent to the field of AI.
|
110 |
+
|
111 |
+
To answer student questions, always use the all_sources_info tool plus another one simultaneously.
|
112 |
+
Decompose the user question into TWO sub questions (you are limited to two sub-questions) one for each tool.
|
113 |
+
Meaning that should be using two tools in total for each user question.
|
114 |
+
|
115 |
+
These are the guidelines to consider if you decide to create sub questions:
|
116 |
+
* Be as specific as possible
|
117 |
+
* The two sub questions should be relevant to the user question
|
118 |
+
* The two sub questions should be answerable by the tools provided
|
119 |
+
|
120 |
+
Only some information returned by the tools might be relevant to the question, so ignore the irrelevant part and answer the question with what you have.
|
121 |
|
122 |
Your responses are exclusively based on the output provided by the tools. Refrain from incorporating information not directly obtained from the tool's responses.
|
123 |
|
|
|
129 |
|
130 |
At the end of your answers, always invite the students to ask deeper questions about the topic if they have any. Make sure reformulate the question to the tool to capture this new angle or more profound layer of inquiry.
|
131 |
|
132 |
+
Do not refer to the documentation directly, but use the information provided within it to answer questions.
|
133 |
|
134 |
If code is provided in the information, share it with the students. It's important to provide complete code blocks so they can execute the code when they copy and paste them.
|
135 |
|
136 |
Make sure to format your answers in Markdown format, including code blocks and snippets.
|
|
|
|
|
137 |
"""
|
scripts/setup.py
CHANGED
@@ -1,3 +1,5 @@
|
|
|
|
|
|
1 |
import logging
|
2 |
import os
|
3 |
import pickle
|
@@ -6,9 +8,16 @@ import chromadb
|
|
6 |
import logfire
|
7 |
from custom_retriever import CustomRetriever
|
8 |
from dotenv import load_dotenv
|
9 |
-
from
|
|
|
|
|
10 |
from llama_index.core.node_parser import SentenceSplitter
|
11 |
-
from llama_index.core.retrievers import
|
|
|
|
|
|
|
|
|
|
|
12 |
from llama_index.embeddings.openai import OpenAIEmbedding
|
13 |
from llama_index.vector_stores.chroma import ChromaVectorStore
|
14 |
from utils import init_mongo_db
|
@@ -21,11 +30,11 @@ logging.getLogger("httpx").setLevel(logging.WARNING)
|
|
21 |
logfire.configure()
|
22 |
|
23 |
|
24 |
-
if not os.path.exists("data/chroma-db-
|
25 |
# Download the vector database from the Hugging Face Hub if it doesn't exist locally
|
26 |
# https://huggingface.co/datasets/towardsai-buster/ai-tutor-vector-db/tree/main
|
27 |
logfire.warn(
|
28 |
-
f"Vector database does not exist at 'data/chroma-db-
|
29 |
)
|
30 |
from huggingface_hub import snapshot_download
|
31 |
|
@@ -34,51 +43,127 @@ if not os.path.exists("data/chroma-db-transformers"):
|
|
34 |
local_dir="data",
|
35 |
repo_type="dataset",
|
36 |
)
|
37 |
-
logfire.info(f"Downloaded vector database to 'data/chroma-db-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
|
40 |
def setup_database(db_collection, dict_file_name):
|
41 |
db = chromadb.PersistentClient(path=f"data/{db_collection}")
|
42 |
chroma_collection = db.get_or_create_collection(db_collection)
|
43 |
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
index = VectorStoreIndex.from_vector_store(
|
46 |
vector_store=vector_store,
|
47 |
-
|
48 |
-
transformations=[SentenceSplitter(chunk_size=800, chunk_overlap=400)],
|
49 |
show_progress=True,
|
50 |
use_async=True,
|
51 |
)
|
52 |
vector_retriever = VectorIndexRetriever(
|
53 |
index=index,
|
54 |
similarity_top_k=15,
|
|
|
55 |
use_async=True,
|
56 |
-
embed_model=OpenAIEmbedding(model="text-embedding-3-large", mode="similarity"),
|
57 |
)
|
58 |
with open(f"data/{db_collection}/{dict_file_name}", "rb") as f:
|
59 |
document_dict = pickle.load(f)
|
60 |
|
61 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
|
63 |
|
64 |
# Setup retrievers
|
65 |
-
custom_retriever_transformers = setup_database(
|
66 |
-
|
67 |
-
|
68 |
-
)
|
69 |
-
custom_retriever_peft = setup_database(
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
)
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
)
|
79 |
-
|
80 |
-
|
81 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
)
|
83 |
|
84 |
# Constants
|
@@ -92,7 +177,8 @@ AVAILABLE_SOURCES_UI = [
|
|
92 |
"LlamaIndex Docs",
|
93 |
"LangChain Docs",
|
94 |
"OpenAI Cookbooks",
|
95 |
-
|
|
|
96 |
# "RAG Course",
|
97 |
]
|
98 |
|
@@ -103,7 +189,8 @@ AVAILABLE_SOURCES = [
|
|
103 |
"llama_index",
|
104 |
"langchain",
|
105 |
"openai_cookbooks",
|
106 |
-
|
|
|
107 |
# "rag_course",
|
108 |
]
|
109 |
|
@@ -114,12 +201,13 @@ mongo_db = (
|
|
114 |
)
|
115 |
|
116 |
__all__ = [
|
117 |
-
"custom_retriever_transformers",
|
118 |
-
"custom_retriever_peft",
|
119 |
-
"custom_retriever_trl",
|
120 |
-
"custom_retriever_llama_index",
|
121 |
-
"custom_retriever_openai_cookbooks",
|
122 |
-
"custom_retriever_langchain",
|
|
|
123 |
"mongo_db",
|
124 |
"CONCURRENCY_COUNT",
|
125 |
"AVAILABLE_SOURCES_UI",
|
|
|
1 |
+
import asyncio
|
2 |
+
import json
|
3 |
import logging
|
4 |
import os
|
5 |
import pickle
|
|
|
8 |
import logfire
|
9 |
from custom_retriever import CustomRetriever
|
10 |
from dotenv import load_dotenv
|
11 |
+
from evaluate_rag_system import AsyncKeywordTableSimpleRetriever
|
12 |
+
from llama_index.core import Document, SimpleKeywordTableIndex, VectorStoreIndex
|
13 |
+
from llama_index.core.ingestion import IngestionPipeline
|
14 |
from llama_index.core.node_parser import SentenceSplitter
|
15 |
+
from llama_index.core.retrievers import (
|
16 |
+
KeywordTableSimpleRetriever,
|
17 |
+
VectorIndexRetriever,
|
18 |
+
)
|
19 |
+
from llama_index.core.schema import NodeWithScore, QueryBundle
|
20 |
+
from llama_index.embeddings.cohere import CohereEmbedding
|
21 |
from llama_index.embeddings.openai import OpenAIEmbedding
|
22 |
from llama_index.vector_stores.chroma import ChromaVectorStore
|
23 |
from utils import init_mongo_db
|
|
|
30 |
logfire.configure()
|
31 |
|
32 |
|
33 |
+
if not os.path.exists("data/chroma-db-all_sources"):
|
34 |
# Download the vector database from the Hugging Face Hub if it doesn't exist locally
|
35 |
# https://huggingface.co/datasets/towardsai-buster/ai-tutor-vector-db/tree/main
|
36 |
logfire.warn(
|
37 |
+
f"Vector database does not exist at 'data/chroma-db-all_sources', downloading from Hugging Face Hub"
|
38 |
)
|
39 |
from huggingface_hub import snapshot_download
|
40 |
|
|
|
43 |
local_dir="data",
|
44 |
repo_type="dataset",
|
45 |
)
|
46 |
+
logfire.info(f"Downloaded vector database to 'data/chroma-db-all_sources'")
|
47 |
+
|
48 |
+
|
49 |
+
def create_docs(input_file: str) -> list[Document]:
|
50 |
+
with open(input_file, "r") as f:
|
51 |
+
documents = []
|
52 |
+
for line in f:
|
53 |
+
data = json.loads(line)
|
54 |
+
documents.append(
|
55 |
+
Document(
|
56 |
+
doc_id=data["doc_id"],
|
57 |
+
text=data["content"],
|
58 |
+
metadata={ # type: ignore
|
59 |
+
"url": data["url"],
|
60 |
+
"title": data["name"],
|
61 |
+
"tokens": data["tokens"],
|
62 |
+
"retrieve_doc": data["retrieve_doc"],
|
63 |
+
"source": data["source"],
|
64 |
+
},
|
65 |
+
excluded_llm_metadata_keys=[
|
66 |
+
"title",
|
67 |
+
"tokens",
|
68 |
+
"retrieve_doc",
|
69 |
+
"source",
|
70 |
+
],
|
71 |
+
excluded_embed_metadata_keys=[
|
72 |
+
"url",
|
73 |
+
"tokens",
|
74 |
+
"retrieve_doc",
|
75 |
+
"source",
|
76 |
+
],
|
77 |
+
)
|
78 |
+
)
|
79 |
+
return documents
|
80 |
|
81 |
|
82 |
def setup_database(db_collection, dict_file_name):
|
83 |
db = chromadb.PersistentClient(path=f"data/{db_collection}")
|
84 |
chroma_collection = db.get_or_create_collection(db_collection)
|
85 |
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
|
86 |
+
embed_model = CohereEmbedding(
|
87 |
+
api_key=os.environ["COHERE_API_KEY"],
|
88 |
+
model_name="embed-english-v3.0",
|
89 |
+
input_type="search_query",
|
90 |
+
)
|
91 |
|
92 |
index = VectorStoreIndex.from_vector_store(
|
93 |
vector_store=vector_store,
|
94 |
+
transformations=[SentenceSplitter(chunk_size=800, chunk_overlap=0)],
|
|
|
95 |
show_progress=True,
|
96 |
use_async=True,
|
97 |
)
|
98 |
vector_retriever = VectorIndexRetriever(
|
99 |
index=index,
|
100 |
similarity_top_k=15,
|
101 |
+
embed_model=embed_model,
|
102 |
use_async=True,
|
|
|
103 |
)
|
104 |
with open(f"data/{db_collection}/{dict_file_name}", "rb") as f:
|
105 |
document_dict = pickle.load(f)
|
106 |
|
107 |
+
with open("data/keyword_retriever_sync.pkl", "rb") as f:
|
108 |
+
keyword_retriever: KeywordTableSimpleRetriever = pickle.load(f)
|
109 |
+
|
110 |
+
# # Creating the keyword index and retriever
|
111 |
+
# logfire.info("Creating nodes from documents")
|
112 |
+
# documents = create_docs("data/all_sources_data.jsonl")
|
113 |
+
# pipeline = IngestionPipeline(
|
114 |
+
# transformations=[SentenceSplitter(chunk_size=800, chunk_overlap=0)]
|
115 |
+
# )
|
116 |
+
# all_nodes = pipeline.run(documents=documents, show_progress=True)
|
117 |
+
# # with open("data/all_nodes.pkl", "wb") as f:
|
118 |
+
# # pickle.dump(all_nodes, f)
|
119 |
+
|
120 |
+
# # all_nodes = pickle.load(open("data/all_nodes.pkl", "rb"))
|
121 |
+
# logfire.info(f"Number of nodes: {len(all_nodes)}")
|
122 |
+
|
123 |
+
# keyword_index = SimpleKeywordTableIndex(
|
124 |
+
# nodes=all_nodes, max_keywords_per_chunk=10, show_progress=True, use_async=False
|
125 |
+
# )
|
126 |
+
# # with open("data/keyword_index.pkl", "wb") as f:
|
127 |
+
# # pickle.dump(keyword_index, f)
|
128 |
+
|
129 |
+
# # keyword_index = pickle.load(open("data/keyword_index.pkl", "rb"))
|
130 |
+
|
131 |
+
# logfire.info("Creating keyword retriever")
|
132 |
+
# keyword_retriever = KeywordTableSimpleRetriever(index=keyword_index)
|
133 |
+
|
134 |
+
# with open("data/keyword_retriever_sync.pkl", "wb") as f:
|
135 |
+
# pickle.dump(keyword_retriever, f)
|
136 |
+
|
137 |
+
return CustomRetriever(vector_retriever, document_dict, keyword_retriever, "OR")
|
138 |
|
139 |
|
140 |
# Setup retrievers
|
141 |
+
# custom_retriever_transformers: CustomRetriever = setup_database(
|
142 |
+
# "chroma-db-transformers",
|
143 |
+
# "document_dict_transformers.pkl",
|
144 |
+
# )
|
145 |
+
# custom_retriever_peft: CustomRetriever = setup_database(
|
146 |
+
# "chroma-db-peft", "document_dict_peft.pkl"
|
147 |
+
# )
|
148 |
+
# custom_retriever_trl: CustomRetriever = setup_database(
|
149 |
+
# "chroma-db-trl", "document_dict_trl.pkl"
|
150 |
+
# )
|
151 |
+
# custom_retriever_llama_index: CustomRetriever = setup_database(
|
152 |
+
# "chroma-db-llama_index",
|
153 |
+
# "document_dict_llama_index.pkl",
|
154 |
+
# )
|
155 |
+
# custom_retriever_openai_cookbooks: CustomRetriever = setup_database(
|
156 |
+
# "chroma-db-openai_cookbooks",
|
157 |
+
# "document_dict_openai_cookbooks.pkl",
|
158 |
+
# )
|
159 |
+
# custom_retriever_langchain: CustomRetriever = setup_database(
|
160 |
+
# "chroma-db-langchain",
|
161 |
+
# "document_dict_langchain.pkl",
|
162 |
+
# )
|
163 |
+
|
164 |
+
custom_retriever_all_sources: CustomRetriever = setup_database(
|
165 |
+
"chroma-db-all_sources",
|
166 |
+
"document_dict_all_sources.pkl",
|
167 |
)
|
168 |
|
169 |
# Constants
|
|
|
177 |
"LlamaIndex Docs",
|
178 |
"LangChain Docs",
|
179 |
"OpenAI Cookbooks",
|
180 |
+
"Towards AI Blog",
|
181 |
+
# "All Sources",
|
182 |
# "RAG Course",
|
183 |
]
|
184 |
|
|
|
189 |
"llama_index",
|
190 |
"langchain",
|
191 |
"openai_cookbooks",
|
192 |
+
"tai_blog",
|
193 |
+
# "all_sources",
|
194 |
# "rag_course",
|
195 |
]
|
196 |
|
|
|
201 |
)
|
202 |
|
203 |
__all__ = [
|
204 |
+
# "custom_retriever_transformers",
|
205 |
+
# "custom_retriever_peft",
|
206 |
+
# "custom_retriever_trl",
|
207 |
+
# "custom_retriever_llama_index",
|
208 |
+
# "custom_retriever_openai_cookbooks",
|
209 |
+
# "custom_retriever_langchain",
|
210 |
+
"custom_retriever_all_sources",
|
211 |
"mongo_db",
|
212 |
"CONCURRENCY_COUNT",
|
213 |
"AVAILABLE_SOURCES_UI",
|