SOAPAssist / gpt_index /indices /query /query_transform.py
AbeerTrial's picture
Upload folder using huggingface_hub
8a58cf3
"""Query transform."""
from typing import Optional
from gpt_index.indices.query.schema import QueryBundle
from gpt_index.langchain_helpers.chain_wrapper import LLMPredictor
from gpt_index.prompts.base import Prompt
from gpt_index.prompts.default_prompts import DEFAULT_HYDE_PROMPT
class BaseQueryTransform:
"""Base class for query transform.
A query transform augments a raw query string with associated transformations
to improve index querying.
"""
def __call__(self, query_str: str) -> QueryBundle:
"""Run query processor."""
return QueryBundle(query_str=query_str, custom_embedding_strs=[query_str])
class HyDEQueryTransform(BaseQueryTransform):
"""Hypothetical Document Embeddings (HyDE) query transform.
It uses an LLM to generate hypothetical answer(s) to a given query,
and use the resulting documents as embedding strings.
As described in `[Precise Zero-Shot Dense Retrieval without Relevance Labels]
(https://arxiv.org/abs/2212.10496)`
"""
def __init__(
self,
llm_predictor: Optional[LLMPredictor] = None,
hyde_prompt: Optional[Prompt] = None,
include_original: bool = True,
) -> None:
"""Initialize HyDEQueryTransform.
Args:
llm_predictor (Optional[LLMPredictor]): LLM for generating
hypothetical documents
hyde_prompt (Optional[Prompt]): Custom prompt for HyDE
include_original (bool): Whether to include original query
string as one of the embedding strings
"""
super().__init__()
self._llm_predictor = llm_predictor or LLMPredictor()
self._hyde_prompt = hyde_prompt or DEFAULT_HYDE_PROMPT
self._include_original = include_original
def __call__(self, query_str: str) -> QueryBundle:
"""Run query transform."""
# TODO: support generating multiple hypothetical docs
hypothetical_doc, _ = self._llm_predictor.predict(
self._hyde_prompt, context_str=query_str
)
embedding_strs = [hypothetical_doc]
if self._include_original:
embedding_strs.append(query_str)
return QueryBundle(
query_str=query_str,
custom_embedding_strs=embedding_strs,
)