Spaces:
Runtime error
Runtime error
"""Query transform.""" | |
from typing import Optional | |
from gpt_index.indices.query.schema import QueryBundle | |
from gpt_index.langchain_helpers.chain_wrapper import LLMPredictor | |
from gpt_index.prompts.base import Prompt | |
from gpt_index.prompts.default_prompts import DEFAULT_HYDE_PROMPT | |
class BaseQueryTransform: | |
"""Base class for query transform. | |
A query transform augments a raw query string with associated transformations | |
to improve index querying. | |
""" | |
def __call__(self, query_str: str) -> QueryBundle: | |
"""Run query processor.""" | |
return QueryBundle(query_str=query_str, custom_embedding_strs=[query_str]) | |
class HyDEQueryTransform(BaseQueryTransform): | |
"""Hypothetical Document Embeddings (HyDE) query transform. | |
It uses an LLM to generate hypothetical answer(s) to a given query, | |
and use the resulting documents as embedding strings. | |
As described in `[Precise Zero-Shot Dense Retrieval without Relevance Labels] | |
(https://arxiv.org/abs/2212.10496)` | |
""" | |
def __init__( | |
self, | |
llm_predictor: Optional[LLMPredictor] = None, | |
hyde_prompt: Optional[Prompt] = None, | |
include_original: bool = True, | |
) -> None: | |
"""Initialize HyDEQueryTransform. | |
Args: | |
llm_predictor (Optional[LLMPredictor]): LLM for generating | |
hypothetical documents | |
hyde_prompt (Optional[Prompt]): Custom prompt for HyDE | |
include_original (bool): Whether to include original query | |
string as one of the embedding strings | |
""" | |
super().__init__() | |
self._llm_predictor = llm_predictor or LLMPredictor() | |
self._hyde_prompt = hyde_prompt or DEFAULT_HYDE_PROMPT | |
self._include_original = include_original | |
def __call__(self, query_str: str) -> QueryBundle: | |
"""Run query transform.""" | |
# TODO: support generating multiple hypothetical docs | |
hypothetical_doc, _ = self._llm_predictor.predict( | |
self._hyde_prompt, context_str=query_str | |
) | |
embedding_strs = [hypothetical_doc] | |
if self._include_original: | |
embedding_strs.append(query_str) | |
return QueryBundle( | |
query_str=query_str, | |
custom_embedding_strs=embedding_strs, | |
) | |