Spaces:
Runtime error
Runtime error
File size: 2,528 Bytes
8a58cf3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
"""Base vector store index query."""
from typing import Any, List, Optional
from gpt_index.data_structs.data_structs import IndexDict, Node
from gpt_index.embeddings.base import BaseEmbedding
from gpt_index.indices.query.base import BaseGPTIndexQuery
from gpt_index.indices.query.embedding_utils import SimilarityTracker
from gpt_index.indices.query.schema import QueryBundle
from gpt_index.indices.utils import log_vector_store_query_result
from gpt_index.vector_stores.types import VectorStore
class GPTVectorStoreIndexQuery(BaseGPTIndexQuery[IndexDict]):
"""Base vector store query.
Args:
embed_model (Optional[BaseEmbedding]): embedding model
similarity_top_k (int): number of top k results to return
vector_store (Optional[VectorStore]): vector store
"""
def __init__(
self,
index_struct: IndexDict,
vector_store: Optional[VectorStore] = None,
embed_model: Optional[BaseEmbedding] = None,
similarity_top_k: int = 1,
**kwargs: Any,
) -> None:
"""Initialize params."""
super().__init__(index_struct=index_struct, embed_model=embed_model, **kwargs)
self._similarity_top_k = similarity_top_k
if vector_store is None:
raise ValueError("Vector store is required for vector store query.")
self._vector_store = vector_store
def _get_nodes_for_response(
self,
query_bundle: QueryBundle,
similarity_tracker: Optional[SimilarityTracker] = None,
) -> List[Node]:
query_embedding = self._embed_model.get_agg_embedding_from_queries(
query_bundle.embedding_strs
)
query_result = self._vector_store.query(
query_embedding, self._similarity_top_k, self._doc_ids
)
if query_result.nodes is None:
if query_result.ids is None:
raise ValueError(
"Vector store query result should return at "
"least one of nodes or ids."
)
assert isinstance(self._index_struct, IndexDict)
nodes = self._index_struct.get_nodes(query_result.ids)
query_result.nodes = nodes
log_vector_store_query_result(query_result)
if similarity_tracker is not None and query_result.similarities is not None:
for node, similarity in zip(query_result.nodes, query_result.similarities):
similarity_tracker.add(node, similarity)
return query_result.nodes
|