Spaces:
Runtime error
Runtime error
"""Query for GPTKGTableIndex.""" | |
import logging | |
from collections import defaultdict | |
from enum import Enum | |
from typing import Any, Dict, List, Optional | |
from gpt_index.data_structs.data_structs import KG, Node | |
from gpt_index.indices.keyword_table.utils import extract_keywords_given_response | |
from gpt_index.indices.query.base import BaseGPTIndexQuery | |
from gpt_index.indices.query.embedding_utils import ( | |
SimilarityTracker, | |
get_top_k_embeddings, | |
) | |
from gpt_index.indices.query.schema import QueryBundle | |
from gpt_index.prompts.default_prompts import DEFAULT_QUERY_KEYWORD_EXTRACT_TEMPLATE | |
from gpt_index.prompts.prompts import QueryKeywordExtractPrompt | |
from gpt_index.utils import truncate_text | |
DQKET = DEFAULT_QUERY_KEYWORD_EXTRACT_TEMPLATE | |
class KGQueryMode(str, Enum): | |
"""Query mode enum for Knowledge Graphs. | |
Can be passed as the enum struct, or as the underlying string. | |
Attributes: | |
KEYWORD ("keyword"): Default query mode, using keywords to find triplets. | |
EMBEDDING ("embedding"): Embedding mode, using embeddings to find | |
similar triplets. | |
HYBRID ("hybrid"): Hyrbid mode, combining both keywords and embeddings | |
to find relevant triplets. | |
""" | |
KEYWORD = "keyword" | |
EMBEDDING = "embedding" | |
HYBRID = "hybrid" | |
class GPTKGTableQuery(BaseGPTIndexQuery[KG]): | |
"""Base GPT KG Table Index Query. | |
Arguments are shared among subclasses. | |
Args: | |
query_keyword_extract_template (Optional[QueryKGExtractPrompt]): A Query | |
KG Extraction | |
Prompt (see :ref:`Prompt-Templates`). | |
refine_template (Optional[RefinePrompt]): A Refinement Prompt | |
(see :ref:`Prompt-Templates`). | |
text_qa_template (Optional[QuestionAnswerPrompt]): A Question Answering Prompt | |
(see :ref:`Prompt-Templates`). | |
max_keywords_per_query (int): Maximum number of keywords to extract from query. | |
num_chunks_per_query (int): Maximum number of text chunks to query. | |
include_text (bool): Use the document text source from each relevent triplet | |
during queries. | |
embedding_mode (KGQueryMode): Specifies whether to use keyowrds, | |
embeddings, or both to find relevent triplets. Should be one of "keyword", | |
"embedding", or "hybrid". | |
similarity_top_k (int): The number of top embeddings to use | |
(if embeddings are used). | |
""" | |
def __init__( | |
self, | |
index_struct: KG, | |
query_keyword_extract_template: Optional[QueryKeywordExtractPrompt] = None, | |
max_keywords_per_query: int = 10, | |
num_chunks_per_query: int = 10, | |
include_text: bool = True, | |
embedding_mode: Optional[KGQueryMode] = KGQueryMode.KEYWORD, | |
similarity_top_k: int = 2, | |
**kwargs: Any, | |
) -> None: | |
"""Initialize params.""" | |
super().__init__(index_struct=index_struct, **kwargs) | |
self.max_keywords_per_query = max_keywords_per_query | |
self.num_chunks_per_query = num_chunks_per_query | |
self.query_keyword_extract_template = query_keyword_extract_template or DQKET | |
self.similarity_top_k = similarity_top_k | |
self._include_text = include_text | |
self._embedding_mode = KGQueryMode(embedding_mode) | |
def _get_keywords(self, query_str: str) -> List[str]: | |
"""Extract keywords.""" | |
response, _ = self._llm_predictor.predict( | |
self.query_keyword_extract_template, | |
max_keywords=self.max_keywords_per_query, | |
question=query_str, | |
) | |
keywords = extract_keywords_given_response( | |
response, start_token="KEYWORDS:", lowercase=False | |
) | |
return list(keywords) | |
def _extract_rel_text_keywords(self, rel_texts: List[str]) -> List[str]: | |
"""Find the keywords for given rel text triplets.""" | |
keywords = [] | |
for rel_text in rel_texts: | |
keyword = rel_text.split(",")[0] | |
if keyword: | |
keywords.append(keyword.strip("(\"'")) | |
return keywords | |
def _get_nodes_for_response( | |
self, | |
query_bundle: QueryBundle, | |
similarity_tracker: Optional[SimilarityTracker] = None, | |
) -> List[Node]: | |
"""Get nodes for response.""" | |
logging.info(f"> Starting query: {query_bundle.query_str}") | |
keywords = self._get_keywords(query_bundle.query_str) | |
logging.info(f"> Query keywords: {keywords}") | |
rel_texts = [] | |
cur_rel_map = {} | |
chunk_indices_count: Dict[str, int] = defaultdict(int) | |
if self._embedding_mode != KGQueryMode.EMBEDDING: | |
for keyword in keywords: | |
cur_rel_texts = self.index_struct.get_rel_map_texts(keyword) | |
rel_texts.extend(cur_rel_texts) | |
cur_rel_map[keyword] = self.index_struct.get_rel_map_tuples(keyword) | |
if self._include_text: | |
for node_id in self.index_struct.get_node_ids(keyword): | |
chunk_indices_count[node_id] += 1 | |
if ( | |
self._embedding_mode != KGQueryMode.KEYWORD | |
and len(self.index_struct.embedding_dict) > 0 | |
): | |
query_embedding = self._embed_model.get_text_embedding( | |
query_bundle.query_str | |
) | |
all_rel_texts = list(self.index_struct.embedding_dict.keys()) | |
rel_text_embeddings = [ | |
self.index_struct.embedding_dict[_id] for _id in all_rel_texts | |
] | |
similarities, top_rel_texts = get_top_k_embeddings( | |
query_embedding, | |
rel_text_embeddings, | |
similarity_top_k=self.similarity_top_k, | |
embedding_ids=all_rel_texts, | |
similarity_cutoff=self.similarity_cutoff, | |
) | |
logging.debug( | |
f"Found the following rel_texts+query similarites: {str(similarities)}" | |
) | |
logging.debug(f"Found the following top_k rel_texts: {str(rel_texts)}") | |
rel_texts.extend(top_rel_texts) | |
if self._include_text: | |
keywords = self._extract_rel_text_keywords(top_rel_texts) | |
nested_node_ids = [ | |
self.index_struct.get_node_ids(keyword) for keyword in keywords | |
] | |
# flatten list | |
node_ids = [_id for ids in nested_node_ids for _id in ids] | |
for node_id in node_ids: | |
chunk_indices_count[node_id] += 1 | |
elif len(self.index_struct.embedding_dict) == 0: | |
logging.error( | |
"Index was not constructed with embeddings, skipping embedding usage..." | |
) | |
# remove any duplicates from keyword + embedding queries | |
if self._embedding_mode == KGQueryMode.HYBRID: | |
rel_texts = list(set(rel_texts)) | |
sorted_chunk_indices = sorted( | |
list(chunk_indices_count.keys()), | |
key=lambda x: chunk_indices_count[x], | |
reverse=True, | |
) | |
sorted_chunk_indices = sorted_chunk_indices[: self.num_chunks_per_query] | |
sorted_nodes = [ | |
self.index_struct.text_chunks[idx] for idx in sorted_chunk_indices | |
] | |
# filter sorted nodes | |
sorted_nodes = [node for node in sorted_nodes if self._should_use_node(node)] | |
for chunk_idx, node in zip(sorted_chunk_indices, sorted_nodes): | |
# nodes are found with keyword mapping, give high conf to avoid cutoff | |
if similarity_tracker is not None: | |
similarity_tracker.add(node, 1000.0) | |
logging.info( | |
f"> Querying with idx: {chunk_idx}: " | |
f"{truncate_text(node.get_text(), 80)}" | |
) | |
# add relationships as Node | |
# TODO: make initial text customizable | |
rel_initial_text = ( | |
"The following are knowledge triplets " | |
"in the form of (subset, predicate, object):" | |
) | |
rel_info = [rel_initial_text] + rel_texts | |
rel_node_info = { | |
"kg_rel_texts": rel_texts, | |
"kg_rel_map": cur_rel_map, | |
} | |
rel_text_node = Node(text="\n".join(rel_info), node_info=rel_node_info) | |
# this node is constructed from rel_texts, give high confidence to avoid cutoff | |
if similarity_tracker is not None: | |
similarity_tracker.add(rel_text_node, 1000.0) | |
rel_info_text = "\n".join(rel_info) | |
logging.info(f"> Extracted relationships: {rel_info_text}") | |
sorted_nodes.append(rel_text_node) | |
return sorted_nodes | |
def _get_extra_info_for_response( | |
self, nodes: List[Node] | |
) -> Optional[Dict[str, Any]]: | |
"""Get extra info for response.""" | |
for node in nodes: | |
if node.node_info is None or "kg_rel_map" not in node.node_info: | |
continue | |
return node.node_info | |
raise ValueError("kg_rel_map must be found in at least one Node.") | |