File size: 1,968 Bytes
2210481
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from qdrant_client import QdrantClient
from qdrant_client.http import models
from typing import List, Dict, Optional
import os


class VectorDatabase:
    def __init__(
        self,
        url=os.getenv("QDRANT_URL"),
        api_key=os.getenv("QDRANT_API_KEY"),
        collection_name="testing_col",
        # embedding_model_name: str = "BAAI/bge-small-en",  # Default model
    ):
        """
        Initialize the Qdrant client, FastEmbed, and collection.

        Args:
            host (str): Host address of the Qdrant server.
            port (int): Port of the Qdrant server.
            collection_name (str): Name of the collection to use or create.
            embedding_model_name (str): Name of the FastEmbed model to use.
        """
        self.client = QdrantClient(url=url, api_key=api_key)
        self.collection_name = collection_name

    def upsert_documents(self, texts: List[str]):
        # Insert into Qdrant
        self.client.add(
            collection_name=self.collection_name,
            documents=texts,
        )
        print(
            f"Inserted {len(texts)} documents into collection '{self.collection_name}'."
        )

    def search_similar(self, query_text: str):
        search_result = self.client.query(
            collection_name=self.collection_name,
            query_text=query_text,
            limit=1,
        )

        document = search_result[0].document

        return document

    def delete_collection(self):
        """
        Delete the Qdrant collection.
        """
        self.client.delete_collection(self.collection_name)
        print(f"Deleted collection: {self.collection_name}")

    def list_collections(self):
        """
        List all collections in the Qdrant database.

        Returns:
            List[str]: List of collection names.
        """
        collections = self.client.get_collections().collections
        return [collection.name for collection in collections]