Spaces:
Runtime error
Runtime error
"""Wrapper functions around an LLM chain.""" | |
import logging | |
from dataclasses import dataclass | |
from typing import Any, Generator, Optional, Tuple | |
import openai | |
from langchain import Cohere, LLMChain, OpenAI | |
from langchain.llms import AI21 | |
from langchain.llms.base import BaseLLM | |
from gpt_index.constants import MAX_CHUNK_SIZE, NUM_OUTPUTS | |
from gpt_index.prompts.base import Prompt | |
from gpt_index.utils import ( | |
ErrorToRetry, | |
globals_helper, | |
retry_on_exceptions_with_backoff, | |
) | |
class LLMMetadata: | |
"""LLM metadata. | |
We extract this metadata to help with our prompts. | |
""" | |
max_input_size: int = MAX_CHUNK_SIZE | |
num_output: int = NUM_OUTPUTS | |
def _get_llm_metadata(llm: BaseLLM) -> LLMMetadata: | |
"""Get LLM metadata from llm.""" | |
if not isinstance(llm, BaseLLM): | |
raise ValueError("llm must be an instance of langchain.llms.base.LLM") | |
if isinstance(llm, OpenAI): | |
return LLMMetadata( | |
max_input_size=llm.modelname_to_contextsize(llm.model_name), | |
num_output=llm.max_tokens, | |
) | |
elif isinstance(llm, Cohere): | |
# TODO: figure out max input size for cohere | |
return LLMMetadata(num_output=llm.max_tokens) | |
elif isinstance(llm, AI21): | |
# TODO: figure out max input size for AI21 | |
return LLMMetadata(num_output=llm.maxTokens) | |
else: | |
return LLMMetadata() | |
def _get_response_gen(openai_response_stream: Generator) -> Generator: | |
"""Get response generator from openai response stream.""" | |
for response in openai_response_stream: | |
yield response["choices"][0]["text"] | |
class LLMPredictor: | |
"""LLM predictor class. | |
Wrapper around an LLMChain from Langchain. | |
Args: | |
llm (Optional[langchain.llms.base.LLM]): LLM from Langchain to use | |
for predictions. Defaults to OpenAI's text-davinci-003 model. | |
Please see `Langchain's LLM Page | |
<https://langchain.readthedocs.io/en/latest/modules/llms.html>`_ | |
for more details. | |
retry_on_throttling (bool): Whether to retry on rate limit errors. | |
Defaults to true. | |
""" | |
def __init__( | |
self, llm: Optional[BaseLLM] = None, retry_on_throttling: bool = True | |
) -> None: | |
"""Initialize params.""" | |
self._llm = llm or OpenAI(temperature=0, model_name="text-davinci-003") | |
self.retry_on_throttling = retry_on_throttling | |
self._total_tokens_used = 0 | |
self.flag = True | |
self._last_token_usage: Optional[int] = None | |
def get_llm_metadata(self) -> LLMMetadata: | |
"""Get LLM metadata.""" | |
# TODO: refactor mocks in unit tests, this is a stopgap solution | |
if hasattr(self, "_llm") and self._llm is not None: | |
return _get_llm_metadata(self._llm) | |
else: | |
return LLMMetadata() | |
def _predict(self, prompt: Prompt, **prompt_args: Any) -> str: | |
"""Inner predict function. | |
If retry_on_throttling is true, we will retry on rate limit errors. | |
""" | |
llm_chain = LLMChain( | |
prompt=prompt.get_langchain_prompt(llm=self._llm), llm=self._llm | |
) | |
# Note: we don't pass formatted_prompt to llm_chain.predict because | |
# langchain does the same formatting under the hood | |
full_prompt_args = prompt.get_full_format_args(prompt_args) | |
if self.retry_on_throttling: | |
llm_prediction = retry_on_exceptions_with_backoff( | |
lambda: llm_chain.predict(**full_prompt_args), | |
[ | |
ErrorToRetry(openai.error.RateLimitError), | |
ErrorToRetry(openai.error.ServiceUnavailableError), | |
ErrorToRetry(openai.error.TryAgain), | |
ErrorToRetry( | |
openai.error.APIConnectionError, lambda e: e.should_retry | |
), | |
], | |
) | |
else: | |
llm_prediction = llm_chain.predict(**full_prompt_args) | |
return llm_prediction | |
def predict(self, prompt: Prompt, **prompt_args: Any) -> Tuple[str, str]: | |
"""Predict the answer to a query. | |
Args: | |
prompt (Prompt): Prompt to use for prediction. | |
Returns: | |
Tuple[str, str]: Tuple of the predicted answer and the formatted prompt. | |
""" | |
formatted_prompt = prompt.format(llm=self._llm, **prompt_args) | |
llm_prediction = self._predict(prompt, **prompt_args) | |
logging.debug(llm_prediction) | |
# We assume that the value of formatted_prompt is exactly the thing | |
# eventually sent to OpenAI, or whatever LLM downstream | |
prompt_tokens_count = self._count_tokens(formatted_prompt) | |
prediction_tokens_count = self._count_tokens(llm_prediction) | |
self._total_tokens_used += prompt_tokens_count + prediction_tokens_count | |
return llm_prediction, formatted_prompt | |
def stream(self, prompt: Prompt, **prompt_args: Any) -> Tuple[Generator, str]: | |
"""Stream the answer to a query. | |
NOTE: this is a beta feature. Will try to build or use | |
better abstractions about response handling. | |
Args: | |
prompt (Prompt): Prompt to use for prediction. | |
Returns: | |
str: The predicted answer. | |
""" | |
if not isinstance(self._llm, OpenAI): | |
raise ValueError("stream is only supported for OpenAI LLMs") | |
formatted_prompt = prompt.format(llm=self._llm, **prompt_args) | |
raw_response_gen = self._llm.stream(formatted_prompt) | |
response_gen = _get_response_gen(raw_response_gen) | |
# NOTE/TODO: token counting doesn't work with streaming | |
return response_gen, formatted_prompt | |
def total_tokens_used(self) -> int: | |
"""Get the total tokens used so far.""" | |
return self._total_tokens_used | |
def _count_tokens(self, text: str) -> int: | |
tokens = globals_helper.tokenizer(text) | |
return len(tokens) | |
def last_token_usage(self) -> int: | |
"""Get the last token usage.""" | |
if self._last_token_usage is None: | |
return 0 | |
return self._last_token_usage | |
def last_token_usage(self, value: int) -> None: | |
"""Set the last token usage.""" | |
self._last_token_usage = value | |
async def _apredict(self, prompt: Prompt, **prompt_args: Any) -> str: | |
"""Async inner predict function. | |
If retry_on_throttling is true, we will retry on rate limit errors. | |
""" | |
llm_chain = LLMChain( | |
prompt=prompt.get_langchain_prompt(llm=self._llm), llm=self._llm | |
) | |
# Note: we don't pass formatted_prompt to llm_chain.predict because | |
# langchain does the same formatting under the hood | |
full_prompt_args = prompt.get_full_format_args(prompt_args) | |
# TODO: support retry on throttling | |
llm_prediction = await llm_chain.apredict(**full_prompt_args) | |
return llm_prediction | |
async def apredict(self, prompt: Prompt, **prompt_args: Any) -> Tuple[str, str]: | |
"""Async predict the answer to a query. | |
Args: | |
prompt (Prompt): Prompt to use for prediction. | |
Returns: | |
Tuple[str, str]: Tuple of the predicted answer and the formatted prompt. | |
""" | |
formatted_prompt = prompt.format(llm=self._llm, **prompt_args) | |
llm_prediction = await self._apredict(prompt, **prompt_args) | |
logging.debug(llm_prediction) | |
# We assume that the value of formatted_prompt is exactly the thing | |
# eventually sent to OpenAI, or whatever LLM downstream | |
prompt_tokens_count = self._count_tokens(formatted_prompt) | |
prediction_tokens_count = self._count_tokens(llm_prediction) | |
self._total_tokens_used += prompt_tokens_count + prediction_tokens_count | |
return llm_prediction, formatted_prompt | |