Spaces:
Runtime error
Runtime error
"""Query runner.""" | |
from typing import Any, Dict, List, Optional, Union, cast | |
from gpt_index.data_structs.data_structs import IndexStruct | |
from gpt_index.docstore import DocumentStore | |
from gpt_index.embeddings.base import BaseEmbedding | |
from gpt_index.indices.prompt_helper import PromptHelper | |
from gpt_index.indices.query.base import BaseGPTIndexQuery, BaseQueryRunner | |
from gpt_index.indices.query.query_transform import BaseQueryTransform | |
from gpt_index.indices.query.schema import QueryBundle, QueryConfig, QueryMode | |
from gpt_index.indices.registry import IndexRegistry | |
from gpt_index.langchain_helpers.chain_wrapper import LLMPredictor | |
from gpt_index.response.schema import Response | |
# TMP: refactor query config type | |
QUERY_CONFIG_TYPE = Union[Dict, QueryConfig] | |
class QueryRunner(BaseQueryRunner): | |
"""Tool to take in a query request and perform a query with the right classes. | |
Higher-level wrapper over a given query. | |
""" | |
def __init__( | |
self, | |
llm_predictor: LLMPredictor, | |
prompt_helper: PromptHelper, | |
embed_model: BaseEmbedding, | |
docstore: DocumentStore, | |
index_registry: IndexRegistry, | |
query_configs: Optional[List[QUERY_CONFIG_TYPE]] = None, | |
query_transform: Optional[BaseQueryTransform] = None, | |
recursive: bool = False, | |
use_async: bool = False, | |
) -> None: | |
"""Init params.""" | |
type_to_config_dict: Dict[str, QueryConfig] = {} | |
id_to_config_dict: Dict[str, QueryConfig] = {} | |
if query_configs is None or len(query_configs) == 0: | |
query_config_objs: List[QueryConfig] = [] | |
elif isinstance(query_configs[0], Dict): | |
query_config_objs = [ | |
QueryConfig.from_dict(cast(Dict, qc)) for qc in query_configs | |
] | |
else: | |
query_config_objs = [cast(QueryConfig, q) for q in query_configs] | |
for qc in query_config_objs: | |
type_to_config_dict[qc.index_struct_type] = qc | |
if qc.index_struct_id is not None: | |
id_to_config_dict[qc.index_struct_id] = qc | |
self._type_to_config_dict = type_to_config_dict | |
self._id_to_config_dict = id_to_config_dict | |
self._llm_predictor = llm_predictor | |
self._prompt_helper = prompt_helper | |
self._embed_model = embed_model | |
self._docstore = docstore | |
self._index_registry = index_registry | |
self._query_transform = query_transform or BaseQueryTransform() | |
self._recursive = recursive | |
self._use_async = use_async | |
def _get_query_kwargs(self, config: QueryConfig) -> Dict[str, Any]: | |
"""Get query kwargs. | |
Also update with default arguments if not present. | |
""" | |
query_kwargs = {k: v for k, v in config.query_kwargs.items()} | |
if "prompt_helper" not in query_kwargs: | |
query_kwargs["prompt_helper"] = self._prompt_helper | |
if "llm_predictor" not in query_kwargs: | |
query_kwargs["llm_predictor"] = self._llm_predictor | |
if "embed_model" not in query_kwargs: | |
query_kwargs["embed_model"] = self._embed_model | |
return query_kwargs | |
def _get_query_obj( | |
self, | |
index_struct: IndexStruct, | |
) -> BaseGPTIndexQuery: | |
"""Get query object.""" | |
index_struct_id = index_struct.get_doc_id() | |
index_struct_type = index_struct.get_type() | |
if index_struct_id in self._id_to_config_dict: | |
config = self._id_to_config_dict[index_struct_id] | |
elif index_struct_type in self._type_to_config_dict: | |
config = self._type_to_config_dict[index_struct_type] | |
else: | |
config = QueryConfig( | |
index_struct_type=index_struct_type, query_mode=QueryMode.DEFAULT | |
) | |
mode = config.query_mode | |
query_cls = self._index_registry.type_to_query[index_struct_type][mode] | |
# if recursive, pass self as query_runner to each individual query | |
query_runner = self | |
query_kwargs = self._get_query_kwargs(config) | |
query_obj = query_cls( | |
index_struct, | |
**query_kwargs, | |
query_runner=query_runner, | |
docstore=self._docstore, | |
recursive=self._recursive, | |
use_async=self._use_async, | |
) | |
return query_obj | |
def query( | |
self, | |
query_str_or_bundle: Union[str, QueryBundle], | |
index_struct: IndexStruct, | |
) -> Response: | |
"""Run query.""" | |
# NOTE: Currently, query transform is only run once | |
# TODO: Consider refactor to support index-specific query transform | |
if isinstance(query_str_or_bundle, str): | |
query_bundle = self._query_transform(query_str_or_bundle) | |
else: | |
query_bundle = query_str_or_bundle | |
query_obj = self._get_query_obj(index_struct) | |
return query_obj.query(query_bundle) | |
async def aquery( | |
self, | |
query_str_or_bundle: Union[str, QueryBundle], | |
index_struct: IndexStruct, | |
) -> Response: | |
"""Run query.""" | |
# NOTE: Currently, query transform is only run once | |
# TODO: Consider refactor to support index-specific query transform | |
if isinstance(query_str_or_bundle, str): | |
query_bundle = self._query_transform(query_str_or_bundle) | |
else: | |
query_bundle = query_str_or_bundle | |
query_obj = self._get_query_obj(index_struct) | |
return await query_obj.aquery(query_bundle) | |