"""Query for GPTKeywordTableIndex.""" import logging from abc import abstractmethod from collections import defaultdict from typing import Any, Dict, List, Optional from gpt_index.data_structs.data_structs import KeywordTable, Node from gpt_index.indices.keyword_table.utils import ( extract_keywords_given_response, rake_extract_keywords, simple_extract_keywords, ) from gpt_index.indices.query.base import BaseGPTIndexQuery from gpt_index.indices.query.embedding_utils import SimilarityTracker from gpt_index.indices.query.schema import QueryBundle from gpt_index.prompts.default_prompts import ( DEFAULT_KEYWORD_EXTRACT_TEMPLATE, DEFAULT_QUERY_KEYWORD_EXTRACT_TEMPLATE, ) from gpt_index.prompts.prompts import KeywordExtractPrompt, QueryKeywordExtractPrompt from gpt_index.utils import truncate_text DQKET = DEFAULT_QUERY_KEYWORD_EXTRACT_TEMPLATE class BaseGPTKeywordTableQuery(BaseGPTIndexQuery[KeywordTable]): """Base GPT Keyword Table Index Query. Arguments are shared among subclasses. Args: keyword_extract_template (Optional[KeywordExtractPrompt]): A Keyword Extraction Prompt (see :ref:`Prompt-Templates`). query_keyword_extract_template (Optional[QueryKeywordExtractPrompt]): A Query Keyword Extraction Prompt (see :ref:`Prompt-Templates`). refine_template (Optional[RefinePrompt]): A Refinement Prompt (see :ref:`Prompt-Templates`). text_qa_template (Optional[QuestionAnswerPrompt]): A Question Answering Prompt (see :ref:`Prompt-Templates`). max_keywords_per_query (int): Maximum number of keywords to extract from query. num_chunks_per_query (int): Maximum number of text chunks to query. """ def __init__( self, index_struct: KeywordTable, keyword_extract_template: Optional[KeywordExtractPrompt] = None, query_keyword_extract_template: Optional[QueryKeywordExtractPrompt] = None, max_keywords_per_query: int = 10, num_chunks_per_query: int = 10, **kwargs: Any, ) -> None: """Initialize params.""" super().__init__(index_struct=index_struct, **kwargs) self.max_keywords_per_query = max_keywords_per_query self.num_chunks_per_query = num_chunks_per_query self.keyword_extract_template = ( keyword_extract_template or DEFAULT_KEYWORD_EXTRACT_TEMPLATE ) self.query_keyword_extract_template = query_keyword_extract_template or DQKET @abstractmethod def _get_keywords(self, query_str: str) -> List[str]: """Extract keywords.""" def _get_nodes_for_response( self, query_bundle: QueryBundle, similarity_tracker: Optional[SimilarityTracker] = None, ) -> List[Node]: """Get nodes for response.""" logging.info(f"> Starting query: {query_bundle.query_str}") keywords = self._get_keywords(query_bundle.query_str) logging.info(f"query keywords: {keywords}") # go through text chunks in order of most matching keywords chunk_indices_count: Dict[int, int] = defaultdict(int) keywords = [k for k in keywords if k in self.index_struct.keywords] logging.info(f"> Extracted keywords: {keywords}") for k in keywords: for text_chunk_idx in self.index_struct.table[k]: chunk_indices_count[text_chunk_idx] += 1 sorted_chunk_indices = sorted( list(chunk_indices_count.keys()), key=lambda x: chunk_indices_count[x], reverse=True, ) sorted_chunk_indices = sorted_chunk_indices[: self.num_chunks_per_query] sorted_nodes = [ self.index_struct.text_chunks[idx] for idx in sorted_chunk_indices ] # filter sorted nodes sorted_nodes = [node for node in sorted_nodes if self._should_use_node(node)] if logging.getLogger(__name__).getEffectiveLevel() == logging.DEBUG: for chunk_idx, node in zip(sorted_chunk_indices, sorted_nodes): logging.debug( f"> Querying with idx: {chunk_idx}: " f"{truncate_text(node.get_text(), 50)}" ) return sorted_nodes class GPTKeywordTableGPTQuery(BaseGPTKeywordTableQuery): """GPT Keyword Table Index Query. Extracts keywords using GPT. Set when `mode="default"` in `query` method of `GPTKeywordTableIndex`. .. code-block:: python response = index.query("", mode="default") See BaseGPTKeywordTableQuery for arguments. """ def _get_keywords(self, query_str: str) -> List[str]: """Extract keywords.""" response, _ = self._llm_predictor.predict( self.query_keyword_extract_template, max_keywords=self.max_keywords_per_query, question=query_str, ) keywords = extract_keywords_given_response(response, start_token="KEYWORDS:") return list(keywords) class GPTKeywordTableSimpleQuery(BaseGPTKeywordTableQuery): """GPT Keyword Table Index Simple Query. Extracts keywords using simple regex-based keyword extractor. Set when `mode="simple"` in `query` method of `GPTKeywordTableIndex`. .. code-block:: python response = index.query("", mode="simple") See BaseGPTKeywordTableQuery for arguments. """ def _get_keywords(self, query_str: str) -> List[str]: """Extract keywords.""" return list( simple_extract_keywords(query_str, max_keywords=self.max_keywords_per_query) ) class GPTKeywordTableRAKEQuery(BaseGPTKeywordTableQuery): """GPT Keyword Table Index RAKE Query. Extracts keywords using RAKE keyword extractor. Set when `mode="rake"` in `query` method of `GPTKeywordTableIndex`. .. code-block:: python response = index.query("", mode="rake") See BaseGPTKeywordTableQuery for arguments. """ def _get_keywords(self, query_str: str) -> List[str]: """Extract keywords.""" return list( rake_extract_keywords(query_str, max_keywords=self.max_keywords_per_query) )