AbeerTrial's picture
Upload folder using huggingface_hub
8a58cf3
"""Base query classes."""
import logging
import re
from abc import abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, Generator, Generic, List, Optional, Tuple, TypeVar, cast
from gpt_index.data_structs.data_structs import IndexStruct, Node
from gpt_index.docstore import DocumentStore
from gpt_index.embeddings.base import BaseEmbedding
from gpt_index.embeddings.openai import OpenAIEmbedding
from gpt_index.indices.prompt_helper import PromptHelper
from gpt_index.indices.query.embedding_utils import SimilarityTracker
from gpt_index.indices.query.schema import QueryBundle
from gpt_index.indices.response.builder import (
RESPONSE_TEXT_TYPE,
ResponseBuilder,
ResponseMode,
TextChunk,
)
from gpt_index.langchain_helpers.chain_wrapper import LLMPredictor
from gpt_index.prompts.default_prompt_selectors import DEFAULT_REFINE_PROMPT_SEL
from gpt_index.prompts.default_prompts import DEFAULT_TEXT_QA_PROMPT
from gpt_index.prompts.prompts import QuestionAnswerPrompt, RefinePrompt
from gpt_index.response.schema import RESPONSE_TYPE, Response, StreamingResponse
from gpt_index.token_counter.token_counter import llm_token_counter
from gpt_index.utils import truncate_text
IS = TypeVar("IS", bound=IndexStruct)
@dataclass
class BaseQueryRunner:
"""Base query runner."""
@abstractmethod
def query(self, query_bundle: QueryBundle, index_struct: IndexStruct) -> Response:
"""Schedule a query."""
raise NotImplementedError("Not implemented yet.")
class BaseGPTIndexQuery(Generic[IS]):
"""Base LlamaIndex Query.
Helper class that is used to query an index. Can be called within `query`
method of a BaseGPTIndex object, or instantiated independently.
Args:
llm_predictor (LLMPredictor): Optional LLMPredictor object. If not provided,
will use the default LLMPredictor (text-davinci-003)
prompt_helper (PromptHelper): Optional PromptHelper object. If not provided,
will use the default PromptHelper.
required_keywords (List[str]): Optional list of keywords that must be present
in nodes. Can be used to query most indices (tree index is an exception).
exclude_keywords (List[str]): Optional list of keywords that must not be
present in nodes. Can be used to query most indices (tree index is an
exception).
response_mode (ResponseMode): Optional ResponseMode. If not provided, will
use the default ResponseMode.
text_qa_template (QuestionAnswerPrompt): Optional QuestionAnswerPrompt object.
If not provided, will use the default QuestionAnswerPrompt.
refine_template (RefinePrompt): Optional RefinePrompt object. If not provided,
will use the default RefinePrompt.
include_summary (bool): Optional bool. If True, will also use the summary
text of the index when generating a response (the summary text can be set
through `index.set_text("<text>")`).
similarity_cutoff (float): Optional float. If set, will filter out nodes with
similarity below this cutoff threshold when computing the response
streaming (bool): Optional bool. If True, will return a StreamingResponse
object. If False, will return a Response object.
"""
def __init__(
self,
index_struct: IS,
# TODO: pass from superclass
llm_predictor: Optional[LLMPredictor] = None,
prompt_helper: Optional[PromptHelper] = None,
embed_model: Optional[BaseEmbedding] = None,
docstore: Optional[DocumentStore] = None,
query_runner: Optional[BaseQueryRunner] = None,
required_keywords: Optional[List[str]] = None,
exclude_keywords: Optional[List[str]] = None,
response_mode: ResponseMode = ResponseMode.DEFAULT,
text_qa_template: Optional[QuestionAnswerPrompt] = None,
refine_template: Optional[RefinePrompt] = None,
include_summary: bool = False,
response_kwargs: Optional[Dict] = None,
similarity_cutoff: Optional[float] = None,
use_async: bool = True,
recursive: bool = False,
streaming: bool = False,
doc_ids: Optional[List[str]] = None,
) -> None:
"""Initialize with parameters."""
if index_struct is None:
raise ValueError("index_struct must be provided.")
self._validate_index_struct(index_struct)
self._index_struct = index_struct
self._llm_predictor = llm_predictor or LLMPredictor()
# NOTE: the embed_model isn't used in all indices
self._embed_model = embed_model or OpenAIEmbedding()
self._docstore = docstore
self._query_runner = query_runner
# TODO: make this a required param
if prompt_helper is None:
raise ValueError("prompt_helper must be provided.")
self._prompt_helper = cast(PromptHelper, prompt_helper)
self._required_keywords = required_keywords
self._exclude_keywords = exclude_keywords
self._response_mode = ResponseMode(response_mode)
self.text_qa_template = text_qa_template or DEFAULT_TEXT_QA_PROMPT
self.refine_template = refine_template or DEFAULT_REFINE_PROMPT_SEL
self._include_summary = include_summary
self._response_kwargs = response_kwargs or {}
self._use_async = use_async
self.response_builder = ResponseBuilder(
self._prompt_helper,
self._llm_predictor,
self.text_qa_template,
self.refine_template,
use_async=use_async,
streaming=streaming,
)
self.similarity_cutoff = similarity_cutoff
self._recursive = recursive
self._streaming = streaming
self._doc_ids = doc_ids
def _should_use_node(
self, node: Node, similarity_tracker: Optional[SimilarityTracker] = None
) -> bool:
"""Run node through filters to determine if it should be used."""
words = re.findall(r"\w+", node.get_text())
if self._required_keywords is not None:
for w in self._required_keywords:
if w not in words:
return False
if self._exclude_keywords is not None:
for w in self._exclude_keywords:
if w in words:
return False
sim_cutoff_exists = (
similarity_tracker is not None and self.similarity_cutoff is not None
)
if sim_cutoff_exists:
similarity = cast(SimilarityTracker, similarity_tracker).find(node)
if similarity is None:
return False
if cast(float, similarity) < cast(float, self.similarity_cutoff):
return False
return True
def _get_text_from_node(
self,
query_bundle: QueryBundle,
node: Node,
level: Optional[int] = None,
) -> Tuple[TextChunk, Optional[Response]]:
"""Query a given node.
If node references a given document, then return the document.
If node references a given index, then query the index.
"""
level_str = "" if level is None else f"[Level {level}]"
fmt_text_chunk = truncate_text(node.get_text(), 50)
logging.debug(f">{level_str} Searching in chunk: {fmt_text_chunk}")
is_index_struct = False
# if recursive and self._query_runner is not None,
# assume we want to do a recursive
# query. In order to not perform a recursive query, make sure
# _query_runner is None.
if (
self._recursive
and self._query_runner is not None
and node.ref_doc_id is not None
and self._docstore is not None
):
doc = self._docstore.get_document(node.ref_doc_id, raise_error=False)
# NOTE: old version of the docstore contain both documents and index_struct,
# whereas new versions of the docstore only contain the index struct
if doc is not None and isinstance(doc, IndexStruct):
is_index_struct = True
if is_index_struct:
query_runner = cast(BaseQueryRunner, self._query_runner)
response = query_runner.query(query_bundle, cast(IndexStruct, doc))
return TextChunk(str(response), is_answer=True), response
else:
text = node.get_text()
return TextChunk(text), None
@property
def index_struct(self) -> IS:
"""Get the index struct."""
return self._index_struct
def _validate_index_struct(self, index_struct: IS) -> None:
"""Validate the index struct."""
pass
def _give_response_for_nodes(self, query_str: str) -> RESPONSE_TEXT_TYPE:
"""Give response for nodes."""
response = self.response_builder.get_response(
query_str,
mode=self._response_mode,
**self._response_kwargs,
)
return response
async def _agive_response_for_nodes(self, query_str: str) -> RESPONSE_TEXT_TYPE:
"""Give response for nodes."""
response = await self.response_builder.aget_response(
query_str,
mode=self._response_mode,
**self._response_kwargs,
)
return response
def get_nodes_and_similarities_for_response(
self, query_bundle: QueryBundle
) -> List[Tuple[Node, Optional[float]]]:
"""Get list of tuples of node and similarity for response.
First part of the tuple is the node.
Second part of tuple is the distance from query to the node.
If not applicable, it's None.
"""
similarity_tracker = SimilarityTracker()
nodes = self._get_nodes_for_response(
query_bundle, similarity_tracker=similarity_tracker
)
nodes = [
node for node in nodes if self._should_use_node(node, similarity_tracker)
]
# TODO: create a `display` method to allow subclasses to print the Node
return similarity_tracker.get_zipped_nodes(nodes)
@abstractmethod
def _get_nodes_for_response(
self,
query_bundle: QueryBundle,
similarity_tracker: Optional[SimilarityTracker] = None,
) -> List[Node]:
"""Get nodes for response."""
def _get_extra_info_for_response(
self,
nodes: List[Node],
) -> Optional[Dict[str, Any]]:
"""Get extra info for response."""
return None
def _prepare_response_builder(
self,
response_builder: ResponseBuilder,
query_bundle: QueryBundle,
tuples: List[Tuple[Node, Optional[float]]],
) -> None:
"""Prepare response builder and return values for query time."""
response_builder.reset()
for node, similarity in tuples:
text, response = self._get_text_from_node(query_bundle, node)
response_builder.add_node_as_source(node, similarity=similarity)
if response is not None:
# these are source nodes from within this node (when it's an index)
for source_node in response.source_nodes:
response_builder.add_source_node(source_node)
response_builder.add_text_chunks([text])
def _prepare_response_output(
self,
response_str: Optional[RESPONSE_TEXT_TYPE],
tuples: List[Tuple[Node, Optional[float]]],
) -> RESPONSE_TYPE:
"""Prepare response object from response string."""
response_extra_info = self._get_extra_info_for_response(
[node for node, _ in tuples]
)
if response_str is None or isinstance(response_str, str):
return Response(
response_str,
source_nodes=self.response_builder.get_sources(),
extra_info=response_extra_info,
)
elif response_str is None or isinstance(response_str, Generator):
return StreamingResponse(
response_str,
source_nodes=self.response_builder.get_sources(),
extra_info=response_extra_info,
)
else:
raise ValueError("Response must be a string or a generator.")
def _query(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
"""Answer a query."""
# TODO: remove _query and just use query
tuples = self.get_nodes_and_similarities_for_response(query_bundle)
# prepare response builder
self._prepare_response_builder(self.response_builder, query_bundle, tuples)
if self._response_mode != ResponseMode.NO_TEXT:
response_str = self._give_response_for_nodes(query_bundle.query_str)
else:
response_str = None
return self._prepare_response_output(response_str, tuples)
async def _aquery(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
"""Answer a query asynchronously."""
# TODO: remove _query and just use query
tuples = self.get_nodes_and_similarities_for_response(query_bundle)
# prepare response builder
self._prepare_response_builder(self.response_builder, query_bundle, tuples)
if self._response_mode != ResponseMode.NO_TEXT:
response_str = await self._agive_response_for_nodes(query_bundle.query_str)
else:
response_str = None
return self._prepare_response_output(response_str, tuples)
@llm_token_counter("query")
def query(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
"""Answer a query."""
response = self._query(query_bundle)
# if include_summary is True, then include summary text in answer
# summary text is set through `set_text` on the underlying index.
# TODO: refactor response builder to be in the __init__
if self._response_mode != ResponseMode.NO_TEXT and self._include_summary:
response_builder = ResponseBuilder(
self._prompt_helper,
self._llm_predictor,
self.text_qa_template,
self.refine_template,
texts=[TextChunk(self._index_struct.get_text())],
streaming=self._streaming,
)
if isinstance(response, Response):
# NOTE: use create and refine for now (default response mode)
response_str = response_builder.get_response(
query_bundle.query_str,
mode=self._response_mode,
prev_response=response.response,
)
response.response = cast(str, response_str)
elif isinstance(response, StreamingResponse):
response_gen = response_builder.get_response(
query_bundle.query_str,
mode=self._response_mode,
prev_response=str(response.response_gen),
)
response.response_gen = cast(Generator, response_gen)
else:
raise ValueError("Response must be a string or a generator.")
return response
@llm_token_counter("query")
async def aquery(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
"""Answer a query."""
response = await self._aquery(query_bundle)
# if include_summary is True, then include summary text in answer
# summary text is set through `set_text` on the underlying index.
# TODO: refactor response builder to be in the __init__
if self._response_mode != ResponseMode.NO_TEXT and self._include_summary:
response_builder = ResponseBuilder(
self._prompt_helper,
self._llm_predictor,
self.text_qa_template,
self.refine_template,
texts=[TextChunk(self._index_struct.get_text())],
streaming=self._streaming,
)
if isinstance(response, Response):
# NOTE: use create and refine for now (default response mode)
response_str = await response_builder.aget_response(
query_bundle.query_str,
mode=self._response_mode,
prev_response=response.response,
)
response.response = cast(str, response_str)
elif isinstance(response, StreamingResponse):
response_gen = await response_builder.aget_response(
query_bundle.query_str,
mode=self._response_mode,
prev_response=str(response.response_gen),
)
response.response_gen = cast(Generator, response_gen)
else:
raise ValueError("Response must be a string or a generator.")
return response