File size: 2,528 Bytes
8a58cf3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
"""Base vector store index query."""


from typing import Any, List, Optional

from gpt_index.data_structs.data_structs import IndexDict, Node
from gpt_index.embeddings.base import BaseEmbedding
from gpt_index.indices.query.base import BaseGPTIndexQuery
from gpt_index.indices.query.embedding_utils import SimilarityTracker
from gpt_index.indices.query.schema import QueryBundle
from gpt_index.indices.utils import log_vector_store_query_result
from gpt_index.vector_stores.types import VectorStore


class GPTVectorStoreIndexQuery(BaseGPTIndexQuery[IndexDict]):
    """Base vector store query.

    Args:
        embed_model (Optional[BaseEmbedding]): embedding model
        similarity_top_k (int): number of top k results to return
        vector_store (Optional[VectorStore]): vector store

    """

    def __init__(
        self,
        index_struct: IndexDict,
        vector_store: Optional[VectorStore] = None,
        embed_model: Optional[BaseEmbedding] = None,
        similarity_top_k: int = 1,
        **kwargs: Any,
    ) -> None:
        """Initialize params."""
        super().__init__(index_struct=index_struct, embed_model=embed_model, **kwargs)
        self._similarity_top_k = similarity_top_k
        if vector_store is None:
            raise ValueError("Vector store is required for vector store query.")
        self._vector_store = vector_store

    def _get_nodes_for_response(
        self,
        query_bundle: QueryBundle,
        similarity_tracker: Optional[SimilarityTracker] = None,
    ) -> List[Node]:
        query_embedding = self._embed_model.get_agg_embedding_from_queries(
            query_bundle.embedding_strs
        )

        query_result = self._vector_store.query(
            query_embedding, self._similarity_top_k, self._doc_ids
        )

        if query_result.nodes is None:
            if query_result.ids is None:
                raise ValueError(
                    "Vector store query result should return at "
                    "least one of nodes or ids."
                )
            assert isinstance(self._index_struct, IndexDict)
            nodes = self._index_struct.get_nodes(query_result.ids)
            query_result.nodes = nodes

        log_vector_store_query_result(query_result)

        if similarity_tracker is not None and query_result.similarities is not None:
            for node, similarity in zip(query_result.nodes, query_result.similarities):
                similarity_tracker.add(node, similarity)

        return query_result.nodes