"""ChatGPT Wrapper.""" from typing import Any, Dict, List, Optional import openai from gpt_index.langchain_helpers.chain_wrapper import LLMPredictor from gpt_index.prompts.base import Prompt DEFAULT_MESSAGE_PREPEND = [ {"role": "system", "content": "You are a helpful assistant."}, ] class ChatGPTLLMPredictor(LLMPredictor): """LLM predictor class. Args: retry_on_throttling (bool): Whether to retry on rate limit errors. Defaults to true. prepend_messages (Optional[List[Dict[str, str]]]): Messages to prepend to the ChatGPT API. openai_kwargs (Any): Additional kwargs to pass to the OpenAI API. https://platform.openai.com/docs/api-reference/chat/create. """ def __init__( self, prepend_messages: Optional[List[Dict[str, str]]] = None, include_role_in_response: bool = False, **openai_kwargs: Any ) -> None: """Initialize params.""" self._prepend_messages = prepend_messages or DEFAULT_MESSAGE_PREPEND self._include_role_in_response = include_role_in_response # set openAI kwargs if "temperature" not in openai_kwargs: openai_kwargs["temperature"] = 0 self._openai_kwargs = openai_kwargs self._total_tokens_used = 0 self.flag = True self._last_token_usage: Optional[int] = None def _predict(self, prompt: Prompt, **prompt_args: Any) -> str: """Inner predict function. If retry_on_throttling is true, we will retry on rate limit errors. """ prompt_str = prompt.format(**prompt_args) messages = self._prepend_messages + [{"role": "user", "content": prompt_str}] response = openai.ChatCompletion.create( model="gpt-3.5-turbo", messages=messages, **self._openai_kwargs ) # hacky message_obj = response["choices"][0]["message"] response = "" if self._include_role_in_response: response += "role: " + message_obj["role"] + "\n" response += message_obj["content"] return response async def _apredict(self, prompt: Prompt, **prompt_args: Any) -> str: """Async inner predict function. If retry_on_throttling is true, we will retry on rate limit errors. """ prompt_str = prompt.format(**prompt_args) messages = self._prepend_messages + [{"role": "user", "content": prompt_str}] response = await openai.ChatCompletion.acreate( model="gpt-3.5-turbo", messages=messages, **self._openai_kwargs ) # hacky message_obj = response["choices"][0]["message"] response = "" if self._include_role_in_response: response += "role: " + message_obj["role"] + "\n" response += message_obj["content"] return response