Spaces:
Running
Running
fix retrieval
Browse files- aimakerspace/vectordatabase.py +9 -31
- app.py +1 -2
- uv.lock +0 -22
aimakerspace/vectordatabase.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
from qdrant_client import QdrantClient
|
2 |
-
from qdrant_client.
|
3 |
from typing import List, Dict, Optional
|
4 |
import os
|
5 |
|
@@ -34,22 +34,6 @@ class VectorDatabase:
|
|
34 |
f"Inserted {len(texts)} documents into collection '{self.collection_name}'."
|
35 |
)
|
36 |
|
37 |
-
def create_collection(self):
|
38 |
-
# Define the vector size and distance metric
|
39 |
-
vector_size = 128 # Size of the vectors you plan to store
|
40 |
-
distance_metric = (
|
41 |
-
models.Distance.COSINE
|
42 |
-
) # Distance metric (e.g., COSINE, EUCLID, DOT)
|
43 |
-
|
44 |
-
# Create the collection
|
45 |
-
self.client.create_collection(
|
46 |
-
collection_name=self.collection_name,
|
47 |
-
vectors_config=models.VectorParams(
|
48 |
-
size=vector_size,
|
49 |
-
distance=distance_metric,
|
50 |
-
),
|
51 |
-
)
|
52 |
-
|
53 |
def _delete_collection(self):
|
54 |
"""
|
55 |
Delete a collection from the Qdrant database.
|
@@ -60,31 +44,25 @@ class VectorDatabase:
|
|
60 |
# Check if the collection exists
|
61 |
collections = self.client.get_collections()
|
62 |
collection_names = [collection.name for collection in collections.collections]
|
|
|
63 |
|
64 |
if self.collection_name in collection_names:
|
65 |
# Delete the collection
|
66 |
-
self.client.delete_collection(collection_name)
|
67 |
-
print(f"Collection '{collection_name}' deleted.")
|
68 |
else:
|
69 |
-
print(f"Collection '{collection_name}' does not exist.")
|
70 |
|
71 |
def search_similar(self, query_text: str):
|
72 |
search_result = self.client.query(
|
73 |
collection_name=self.collection_name,
|
74 |
query_text=query_text,
|
75 |
-
limit=
|
76 |
)
|
77 |
|
78 |
-
|
79 |
-
|
80 |
-
return
|
81 |
-
|
82 |
-
def delete_collection(self):
|
83 |
-
"""
|
84 |
-
Delete the Qdrant collection.
|
85 |
-
"""
|
86 |
-
self.client.delete_collection(self.collection_name)
|
87 |
-
print(f"Deleted collection: {self.collection_name}")
|
88 |
|
89 |
def list_collections(self):
|
90 |
"""
|
|
|
1 |
from qdrant_client import QdrantClient
|
2 |
+
from qdrant_client.models import VectorParams, Distance
|
3 |
from typing import List, Dict, Optional
|
4 |
import os
|
5 |
|
|
|
34 |
f"Inserted {len(texts)} documents into collection '{self.collection_name}'."
|
35 |
)
|
36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
def _delete_collection(self):
|
38 |
"""
|
39 |
Delete a collection from the Qdrant database.
|
|
|
44 |
# Check if the collection exists
|
45 |
collections = self.client.get_collections()
|
46 |
collection_names = [collection.name for collection in collections.collections]
|
47 |
+
print(f"Existing collections: {collection_names}")
|
48 |
|
49 |
if self.collection_name in collection_names:
|
50 |
# Delete the collection
|
51 |
+
self.client.delete_collection(self.collection_name)
|
52 |
+
print(f"Collection '{self.collection_name}' deleted.")
|
53 |
else:
|
54 |
+
print(f"Collection '{self.collection_name}' does not exist.")
|
55 |
|
56 |
def search_similar(self, query_text: str):
|
57 |
search_result = self.client.query(
|
58 |
collection_name=self.collection_name,
|
59 |
query_text=query_text,
|
60 |
+
limit=4,
|
61 |
)
|
62 |
|
63 |
+
documents = [item.document for item in search_result]
|
64 |
+
result = " ".join(document for document in documents if document)
|
65 |
+
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
|
67 |
def list_collections(self):
|
68 |
"""
|
app.py
CHANGED
@@ -31,8 +31,6 @@ class RetrievalAugmentedQAPipeline:
|
|
31 |
def __init__(self, llm: ChatOpenAI(), vector_db_retriever: VectorDatabase) -> None:
|
32 |
self.llm = llm
|
33 |
self.vector_db_retriever = vector_db_retriever
|
34 |
-
self.vector_db_retriever.delete_collection()
|
35 |
-
self.vector_db_retriever.create_collection()
|
36 |
|
37 |
async def arun_pipeline(self, user_query: str):
|
38 |
context_data = self.vector_db_retriever.search_similar(user_query)
|
@@ -115,6 +113,7 @@ async def on_chat_start():
|
|
115 |
|
116 |
# Create a dict vector store
|
117 |
vector_db = VectorDatabase()
|
|
|
118 |
vector_db.upsert_documents(texts)
|
119 |
chat_openai = ChatOpenAI()
|
120 |
|
|
|
31 |
def __init__(self, llm: ChatOpenAI(), vector_db_retriever: VectorDatabase) -> None:
|
32 |
self.llm = llm
|
33 |
self.vector_db_retriever = vector_db_retriever
|
|
|
|
|
34 |
|
35 |
async def arun_pipeline(self, user_query: str):
|
36 |
context_data = self.vector_db_retriever.search_similar(user_query)
|
|
|
113 |
|
114 |
# Create a dict vector store
|
115 |
vector_db = VectorDatabase()
|
116 |
+
vector_db._delete_collection()
|
117 |
vector_db.upsert_documents(texts)
|
118 |
chat_openai = ChatOpenAI()
|
119 |
|
uv.lock
CHANGED
@@ -7,7 +7,6 @@ version = "0.1.0"
|
|
7 |
source = { virtual = "." }
|
8 |
dependencies = [
|
9 |
{ name = "chainlit" },
|
10 |
-
{ name = "maturin" },
|
11 |
{ name = "numpy" },
|
12 |
{ name = "openai" },
|
13 |
{ name = "pydantic" },
|
@@ -19,7 +18,6 @@ dependencies = [
|
|
19 |
[package.metadata]
|
20 |
requires-dist = [
|
21 |
{ name = "chainlit", specifier = ">=2.0.4" },
|
22 |
-
{ name = "maturin", specifier = ">=1.8.1" },
|
23 |
{ name = "numpy", specifier = ">=2.2.2" },
|
24 |
{ name = "openai", specifier = ">=1.59.9" },
|
25 |
{ name = "pydantic", specifier = "==2.10.1" },
|
@@ -535,26 +533,6 @@ wheels = [
|
|
535 |
{ url = "https://files.pythonhosted.org/packages/8e/25/5b300f0400078d9783fbe44d30fedd849a130fc3aff01f18278c12342b6f/marshmallow-3.25.1-py3-none-any.whl", hash = "sha256:ec5d00d873ce473b7f2ffcb7104286a376c354cab0c2fa12f5573dab03e87210", size = 49624 },
|
536 |
]
|
537 |
|
538 |
-
[[package]]
|
539 |
-
name = "maturin"
|
540 |
-
version = "1.8.1"
|
541 |
-
source = { registry = "https://pypi.org/simple" }
|
542 |
-
sdist = { url = "https://files.pythonhosted.org/packages/9a/08/ccb0f917722a35ab0d758be9bb5edaf645c3a3d6170061f10d396ecd273f/maturin-1.8.1.tar.gz", hash = "sha256:49cd964aabf59f8b0a6969f9860d2cdf194ac331529caae14c884f5659568857", size = 197397 }
|
543 |
-
wheels = [
|
544 |
-
{ url = "https://files.pythonhosted.org/packages/4c/00/f34077315f34db8ad2ccf6bfe11b864ca27baab3a1320634da8e3cf89a48/maturin-1.8.1-py3-none-linux_armv6l.whl", hash = "sha256:7e590a23d9076b8a994f2e67bc63dc9a2d1c9a41b1e7b45ac354ba8275254e89", size = 7568415 },
|
545 |
-
{ url = "https://files.pythonhosted.org/packages/5c/07/9219976135ce0cb32d2fa6ea5c6d0ad709013d9a17967312e149b98153a6/maturin-1.8.1-py3-none-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:8d8251a95682c83ea60988c804b620c181911cd824aa107b4a49ac5333c92968", size = 14527816 },
|
546 |
-
{ url = "https://files.pythonhosted.org/packages/e6/04/fa009a00903acdd1785d58322193140bfe358595347c39f315112dabdf9e/maturin-1.8.1-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:b9fc1a4354cac5e32c190410208039812ea88c4a36bd2b6499268ec49ef5de00", size = 7580446 },
|
547 |
-
{ url = "https://files.pythonhosted.org/packages/9b/d4/414b2aab9bbfe88182b734d3aa1b4fef7d7701e50f6be48500378b8c8721/maturin-1.8.1-py3-none-manylinux_2_12_i686.manylinux2010_i686.musllinux_1_1_i686.whl", hash = "sha256:621e171c6b39f95f1d0df69a118416034fbd59c0f89dcaea8c2ea62019deecba", size = 7650535 },
|
548 |
-
{ url = "https://files.pythonhosted.org/packages/f0/64/879418a8a0196013ec1fb19eada0781c04a30e8d6d9227e80f91275a4f5b/maturin-1.8.1-py3-none-manylinux_2_12_x86_64.manylinux2010_x86_64.musllinux_1_1_x86_64.whl", hash = "sha256:98f638739a5132962347871b85c91f525c9246ef4d99796ae98a2031e3df029f", size = 8006702 },
|
549 |
-
{ url = "https://files.pythonhosted.org/packages/39/c2/605829324f8371294f70303aca130682df75318958efed246873d3d604ab/maturin-1.8.1-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.musllinux_1_1_aarch64.whl", hash = "sha256:f9f5c47521924b6e515cbc652a042fe5f17f8747445be9d931048e5d8ddb50a4", size = 7368164 },
|
550 |
-
{ url = "https://files.pythonhosted.org/packages/be/6c/30e136d397bb146b94b628c0ef7f17708281611b97849e2cf37847025ac7/maturin-1.8.1-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.musllinux_1_1_armv7l.whl", hash = "sha256:0f4407c7353c31bfbb8cdeb82bc2170e474cbfb97b5ba27568f440c9d6c1fdd4", size = 7450889 },
|
551 |
-
{ url = "https://files.pythonhosted.org/packages/1b/50/e1f5023512696d4e56096f702e2f68d6d9a30afe0a4eec82b0e27b8eb4e4/maturin-1.8.1-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.musllinux_1_1_ppc64le.whl", hash = "sha256:ec49cd70cad3c389946c6e2bc0bd50772a7fcb463040dd800720345897eec9bf", size = 9585819 },
|
552 |
-
{ url = "https://files.pythonhosted.org/packages/b7/80/b24b5248d89d2e5982553900237a337ea098ca9297b8369ca2aa95549e0f/maturin-1.8.1-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c08767d794de8f8a11c5c8b1b47a4ff9fb6ae2d2d97679e27030f2f509c8c2a0", size = 10920801 },
|
553 |
-
{ url = "https://files.pythonhosted.org/packages/6e/f4/8ede7a662fabf93456b44390a5ad22630e25fb5ddaecf787251071b2e143/maturin-1.8.1-py3-none-win32.whl", hash = "sha256:d678407713f3e10df33c5b3d7a343ec0551eb7f14d8ad9ba6febeb96f4e4c75c", size = 6873556 },
|
554 |
-
{ url = "https://files.pythonhosted.org/packages/9c/22/757f093ed0e319e9648155b8c9d716765442bea5bc98ebc58ad4ad5b0524/maturin-1.8.1-py3-none-win_amd64.whl", hash = "sha256:a526f90fe0e5cb59ffb81f4ff547ddc42e823bbdeae4a31012c0893ca6dcaf46", size = 7823153 },
|
555 |
-
{ url = "https://files.pythonhosted.org/packages/a4/f5/051413e04f6da25069db5e76759ecdb8cd2a8ab4a94045b5a3bf548c66fa/maturin-1.8.1-py3-none-win_arm64.whl", hash = "sha256:e95f077fd2ddd2f048182880eed458c308571a534be3eb2add4d3dac55bf57f4", size = 6552131 },
|
556 |
-
]
|
557 |
-
|
558 |
[[package]]
|
559 |
name = "mmh3"
|
560 |
version = "4.1.0"
|
|
|
7 |
source = { virtual = "." }
|
8 |
dependencies = [
|
9 |
{ name = "chainlit" },
|
|
|
10 |
{ name = "numpy" },
|
11 |
{ name = "openai" },
|
12 |
{ name = "pydantic" },
|
|
|
18 |
[package.metadata]
|
19 |
requires-dist = [
|
20 |
{ name = "chainlit", specifier = ">=2.0.4" },
|
|
|
21 |
{ name = "numpy", specifier = ">=2.2.2" },
|
22 |
{ name = "openai", specifier = ">=1.59.9" },
|
23 |
{ name = "pydantic", specifier = "==2.10.1" },
|
|
|
533 |
{ url = "https://files.pythonhosted.org/packages/8e/25/5b300f0400078d9783fbe44d30fedd849a130fc3aff01f18278c12342b6f/marshmallow-3.25.1-py3-none-any.whl", hash = "sha256:ec5d00d873ce473b7f2ffcb7104286a376c354cab0c2fa12f5573dab03e87210", size = 49624 },
|
534 |
]
|
535 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
536 |
[[package]]
|
537 |
name = "mmh3"
|
538 |
version = "4.1.0"
|