File size: 2,873 Bytes
8a58cf3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
"""ChatGPT Wrapper."""

from typing import Any, Dict, List, Optional

import openai

from gpt_index.langchain_helpers.chain_wrapper import LLMPredictor
from gpt_index.prompts.base import Prompt

DEFAULT_MESSAGE_PREPEND = [
    {"role": "system", "content": "You are a helpful assistant."},
]


class ChatGPTLLMPredictor(LLMPredictor):
    """LLM predictor class.

    Args:
        retry_on_throttling (bool): Whether to retry on rate limit errors.
            Defaults to true.
        prepend_messages (Optional[List[Dict[str, str]]]): Messages to prepend to the
            ChatGPT API.
        openai_kwargs (Any): Additional kwargs to pass to the OpenAI API.
            https://platform.openai.com/docs/api-reference/chat/create.

    """

    def __init__(
        self,
        prepend_messages: Optional[List[Dict[str, str]]] = None,
        include_role_in_response: bool = False,
        **openai_kwargs: Any
    ) -> None:
        """Initialize params."""
        self._prepend_messages = prepend_messages or DEFAULT_MESSAGE_PREPEND
        self._include_role_in_response = include_role_in_response
        # set openAI kwargs
        if "temperature" not in openai_kwargs:
            openai_kwargs["temperature"] = 0
        self._openai_kwargs = openai_kwargs
        self._total_tokens_used = 0
        self.flag = True
        self._last_token_usage: Optional[int] = None

    def _predict(self, prompt: Prompt, **prompt_args: Any) -> str:
        """Inner predict function.

        If retry_on_throttling is true, we will retry on rate limit errors.

        """
        prompt_str = prompt.format(**prompt_args)
        messages = self._prepend_messages + [{"role": "user", "content": prompt_str}]

        response = openai.ChatCompletion.create(
            model="gpt-3.5-turbo", messages=messages, **self._openai_kwargs
        )

        # hacky
        message_obj = response["choices"][0]["message"]
        response = ""
        if self._include_role_in_response:
            response += "role: " + message_obj["role"] + "\n"
        response += message_obj["content"]
        return response

    async def _apredict(self, prompt: Prompt, **prompt_args: Any) -> str:
        """Async inner predict function.

        If retry_on_throttling is true, we will retry on rate limit errors.

        """
        prompt_str = prompt.format(**prompt_args)
        messages = self._prepend_messages + [{"role": "user", "content": prompt_str}]

        response = await openai.ChatCompletion.acreate(
            model="gpt-3.5-turbo", messages=messages, **self._openai_kwargs
        )

        # hacky
        message_obj = response["choices"][0]["message"]
        response = ""
        if self._include_role_in_response:
            response += "role: " + message_obj["role"] + "\n"
        response += message_obj["content"]
        return response