AbeerTrial's picture
Upload folder using huggingface_hub
8a58cf3
"""Query runner."""
from typing import Any, Dict, List, Optional, Union, cast
from gpt_index.data_structs.data_structs import IndexStruct
from gpt_index.docstore import DocumentStore
from gpt_index.embeddings.base import BaseEmbedding
from gpt_index.indices.prompt_helper import PromptHelper
from gpt_index.indices.query.base import BaseGPTIndexQuery, BaseQueryRunner
from gpt_index.indices.query.query_transform import BaseQueryTransform
from gpt_index.indices.query.schema import QueryBundle, QueryConfig, QueryMode
from gpt_index.indices.registry import IndexRegistry
from gpt_index.langchain_helpers.chain_wrapper import LLMPredictor
from gpt_index.response.schema import Response
# TMP: refactor query config type
QUERY_CONFIG_TYPE = Union[Dict, QueryConfig]
class QueryRunner(BaseQueryRunner):
"""Tool to take in a query request and perform a query with the right classes.
Higher-level wrapper over a given query.
"""
def __init__(
self,
llm_predictor: LLMPredictor,
prompt_helper: PromptHelper,
embed_model: BaseEmbedding,
docstore: DocumentStore,
index_registry: IndexRegistry,
query_configs: Optional[List[QUERY_CONFIG_TYPE]] = None,
query_transform: Optional[BaseQueryTransform] = None,
recursive: bool = False,
use_async: bool = False,
) -> None:
"""Init params."""
type_to_config_dict: Dict[str, QueryConfig] = {}
id_to_config_dict: Dict[str, QueryConfig] = {}
if query_configs is None or len(query_configs) == 0:
query_config_objs: List[QueryConfig] = []
elif isinstance(query_configs[0], Dict):
query_config_objs = [
QueryConfig.from_dict(cast(Dict, qc)) for qc in query_configs
]
else:
query_config_objs = [cast(QueryConfig, q) for q in query_configs]
for qc in query_config_objs:
type_to_config_dict[qc.index_struct_type] = qc
if qc.index_struct_id is not None:
id_to_config_dict[qc.index_struct_id] = qc
self._type_to_config_dict = type_to_config_dict
self._id_to_config_dict = id_to_config_dict
self._llm_predictor = llm_predictor
self._prompt_helper = prompt_helper
self._embed_model = embed_model
self._docstore = docstore
self._index_registry = index_registry
self._query_transform = query_transform or BaseQueryTransform()
self._recursive = recursive
self._use_async = use_async
def _get_query_kwargs(self, config: QueryConfig) -> Dict[str, Any]:
"""Get query kwargs.
Also update with default arguments if not present.
"""
query_kwargs = {k: v for k, v in config.query_kwargs.items()}
if "prompt_helper" not in query_kwargs:
query_kwargs["prompt_helper"] = self._prompt_helper
if "llm_predictor" not in query_kwargs:
query_kwargs["llm_predictor"] = self._llm_predictor
if "embed_model" not in query_kwargs:
query_kwargs["embed_model"] = self._embed_model
return query_kwargs
def _get_query_obj(
self,
index_struct: IndexStruct,
) -> BaseGPTIndexQuery:
"""Get query object."""
index_struct_id = index_struct.get_doc_id()
index_struct_type = index_struct.get_type()
if index_struct_id in self._id_to_config_dict:
config = self._id_to_config_dict[index_struct_id]
elif index_struct_type in self._type_to_config_dict:
config = self._type_to_config_dict[index_struct_type]
else:
config = QueryConfig(
index_struct_type=index_struct_type, query_mode=QueryMode.DEFAULT
)
mode = config.query_mode
query_cls = self._index_registry.type_to_query[index_struct_type][mode]
# if recursive, pass self as query_runner to each individual query
query_runner = self
query_kwargs = self._get_query_kwargs(config)
query_obj = query_cls(
index_struct,
**query_kwargs,
query_runner=query_runner,
docstore=self._docstore,
recursive=self._recursive,
use_async=self._use_async,
)
return query_obj
def query(
self,
query_str_or_bundle: Union[str, QueryBundle],
index_struct: IndexStruct,
) -> Response:
"""Run query."""
# NOTE: Currently, query transform is only run once
# TODO: Consider refactor to support index-specific query transform
if isinstance(query_str_or_bundle, str):
query_bundle = self._query_transform(query_str_or_bundle)
else:
query_bundle = query_str_or_bundle
query_obj = self._get_query_obj(index_struct)
return query_obj.query(query_bundle)
async def aquery(
self,
query_str_or_bundle: Union[str, QueryBundle],
index_struct: IndexStruct,
) -> Response:
"""Run query."""
# NOTE: Currently, query transform is only run once
# TODO: Consider refactor to support index-specific query transform
if isinstance(query_str_or_bundle, str):
query_bundle = self._query_transform(query_str_or_bundle)
else:
query_bundle = query_str_or_bundle
query_obj = self._get_query_obj(index_struct)
return await query_obj.aquery(query_bundle)