AbeerTrial's picture
Upload folder using huggingface_hub
8a58cf3
"""Query for GPTKGTableIndex."""
import logging
from collections import defaultdict
from enum import Enum
from typing import Any, Dict, List, Optional
from gpt_index.data_structs.data_structs import KG, Node
from gpt_index.indices.keyword_table.utils import extract_keywords_given_response
from gpt_index.indices.query.base import BaseGPTIndexQuery
from gpt_index.indices.query.embedding_utils import (
SimilarityTracker,
get_top_k_embeddings,
)
from gpt_index.indices.query.schema import QueryBundle
from gpt_index.prompts.default_prompts import DEFAULT_QUERY_KEYWORD_EXTRACT_TEMPLATE
from gpt_index.prompts.prompts import QueryKeywordExtractPrompt
from gpt_index.utils import truncate_text
DQKET = DEFAULT_QUERY_KEYWORD_EXTRACT_TEMPLATE
class KGQueryMode(str, Enum):
"""Query mode enum for Knowledge Graphs.
Can be passed as the enum struct, or as the underlying string.
Attributes:
KEYWORD ("keyword"): Default query mode, using keywords to find triplets.
EMBEDDING ("embedding"): Embedding mode, using embeddings to find
similar triplets.
HYBRID ("hybrid"): Hyrbid mode, combining both keywords and embeddings
to find relevant triplets.
"""
KEYWORD = "keyword"
EMBEDDING = "embedding"
HYBRID = "hybrid"
class GPTKGTableQuery(BaseGPTIndexQuery[KG]):
"""Base GPT KG Table Index Query.
Arguments are shared among subclasses.
Args:
query_keyword_extract_template (Optional[QueryKGExtractPrompt]): A Query
KG Extraction
Prompt (see :ref:`Prompt-Templates`).
refine_template (Optional[RefinePrompt]): A Refinement Prompt
(see :ref:`Prompt-Templates`).
text_qa_template (Optional[QuestionAnswerPrompt]): A Question Answering Prompt
(see :ref:`Prompt-Templates`).
max_keywords_per_query (int): Maximum number of keywords to extract from query.
num_chunks_per_query (int): Maximum number of text chunks to query.
include_text (bool): Use the document text source from each relevent triplet
during queries.
embedding_mode (KGQueryMode): Specifies whether to use keyowrds,
embeddings, or both to find relevent triplets. Should be one of "keyword",
"embedding", or "hybrid".
similarity_top_k (int): The number of top embeddings to use
(if embeddings are used).
"""
def __init__(
self,
index_struct: KG,
query_keyword_extract_template: Optional[QueryKeywordExtractPrompt] = None,
max_keywords_per_query: int = 10,
num_chunks_per_query: int = 10,
include_text: bool = True,
embedding_mode: Optional[KGQueryMode] = KGQueryMode.KEYWORD,
similarity_top_k: int = 2,
**kwargs: Any,
) -> None:
"""Initialize params."""
super().__init__(index_struct=index_struct, **kwargs)
self.max_keywords_per_query = max_keywords_per_query
self.num_chunks_per_query = num_chunks_per_query
self.query_keyword_extract_template = query_keyword_extract_template or DQKET
self.similarity_top_k = similarity_top_k
self._include_text = include_text
self._embedding_mode = KGQueryMode(embedding_mode)
def _get_keywords(self, query_str: str) -> List[str]:
"""Extract keywords."""
response, _ = self._llm_predictor.predict(
self.query_keyword_extract_template,
max_keywords=self.max_keywords_per_query,
question=query_str,
)
keywords = extract_keywords_given_response(
response, start_token="KEYWORDS:", lowercase=False
)
return list(keywords)
def _extract_rel_text_keywords(self, rel_texts: List[str]) -> List[str]:
"""Find the keywords for given rel text triplets."""
keywords = []
for rel_text in rel_texts:
keyword = rel_text.split(",")[0]
if keyword:
keywords.append(keyword.strip("(\"'"))
return keywords
def _get_nodes_for_response(
self,
query_bundle: QueryBundle,
similarity_tracker: Optional[SimilarityTracker] = None,
) -> List[Node]:
"""Get nodes for response."""
logging.info(f"> Starting query: {query_bundle.query_str}")
keywords = self._get_keywords(query_bundle.query_str)
logging.info(f"> Query keywords: {keywords}")
rel_texts = []
cur_rel_map = {}
chunk_indices_count: Dict[str, int] = defaultdict(int)
if self._embedding_mode != KGQueryMode.EMBEDDING:
for keyword in keywords:
cur_rel_texts = self.index_struct.get_rel_map_texts(keyword)
rel_texts.extend(cur_rel_texts)
cur_rel_map[keyword] = self.index_struct.get_rel_map_tuples(keyword)
if self._include_text:
for node_id in self.index_struct.get_node_ids(keyword):
chunk_indices_count[node_id] += 1
if (
self._embedding_mode != KGQueryMode.KEYWORD
and len(self.index_struct.embedding_dict) > 0
):
query_embedding = self._embed_model.get_text_embedding(
query_bundle.query_str
)
all_rel_texts = list(self.index_struct.embedding_dict.keys())
rel_text_embeddings = [
self.index_struct.embedding_dict[_id] for _id in all_rel_texts
]
similarities, top_rel_texts = get_top_k_embeddings(
query_embedding,
rel_text_embeddings,
similarity_top_k=self.similarity_top_k,
embedding_ids=all_rel_texts,
similarity_cutoff=self.similarity_cutoff,
)
logging.debug(
f"Found the following rel_texts+query similarites: {str(similarities)}"
)
logging.debug(f"Found the following top_k rel_texts: {str(rel_texts)}")
rel_texts.extend(top_rel_texts)
if self._include_text:
keywords = self._extract_rel_text_keywords(top_rel_texts)
nested_node_ids = [
self.index_struct.get_node_ids(keyword) for keyword in keywords
]
# flatten list
node_ids = [_id for ids in nested_node_ids for _id in ids]
for node_id in node_ids:
chunk_indices_count[node_id] += 1
elif len(self.index_struct.embedding_dict) == 0:
logging.error(
"Index was not constructed with embeddings, skipping embedding usage..."
)
# remove any duplicates from keyword + embedding queries
if self._embedding_mode == KGQueryMode.HYBRID:
rel_texts = list(set(rel_texts))
sorted_chunk_indices = sorted(
list(chunk_indices_count.keys()),
key=lambda x: chunk_indices_count[x],
reverse=True,
)
sorted_chunk_indices = sorted_chunk_indices[: self.num_chunks_per_query]
sorted_nodes = [
self.index_struct.text_chunks[idx] for idx in sorted_chunk_indices
]
# filter sorted nodes
sorted_nodes = [node for node in sorted_nodes if self._should_use_node(node)]
for chunk_idx, node in zip(sorted_chunk_indices, sorted_nodes):
# nodes are found with keyword mapping, give high conf to avoid cutoff
if similarity_tracker is not None:
similarity_tracker.add(node, 1000.0)
logging.info(
f"> Querying with idx: {chunk_idx}: "
f"{truncate_text(node.get_text(), 80)}"
)
# add relationships as Node
# TODO: make initial text customizable
rel_initial_text = (
"The following are knowledge triplets "
"in the form of (subset, predicate, object):"
)
rel_info = [rel_initial_text] + rel_texts
rel_node_info = {
"kg_rel_texts": rel_texts,
"kg_rel_map": cur_rel_map,
}
rel_text_node = Node(text="\n".join(rel_info), node_info=rel_node_info)
# this node is constructed from rel_texts, give high confidence to avoid cutoff
if similarity_tracker is not None:
similarity_tracker.add(rel_text_node, 1000.0)
rel_info_text = "\n".join(rel_info)
logging.info(f"> Extracted relationships: {rel_info_text}")
sorted_nodes.append(rel_text_node)
return sorted_nodes
def _get_extra_info_for_response(
self, nodes: List[Node]
) -> Optional[Dict[str, Any]]:
"""Get extra info for response."""
for node in nodes:
if node.node_info is None or "kg_rel_map" not in node.node_info:
continue
return node.node_info
raise ValueError("kg_rel_map must be found in at least one Node.")