File size: 5,937 Bytes
8a58cf3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
"""Base module for prompts."""
from copy import deepcopy
from string import Formatter
from typing import Any, Dict, List, Optional, Type, TypeVar

from langchain import BasePromptTemplate as BaseLangchainPrompt
from langchain import PromptTemplate as LangchainPrompt
from langchain.chains.prompt_selector import ConditionalPromptSelector
from langchain.schema import BaseLanguageModel

from gpt_index.prompts.prompt_type import PromptType

PMT = TypeVar("PMT", bound="Prompt")


class Prompt:
    """Prompt class for LlamaIndex.

    Wrapper around langchain's prompt class. Adds ability to:
        - enforce certain prompt types
        - partially fill values

    """

    input_variables: List[str]
    prompt_type: PromptType = PromptType.CUSTOM

    def __init__(
        self,
        template: Optional[str] = None,
        langchain_prompt: Optional[BaseLangchainPrompt] = None,
        langchain_prompt_selector: Optional[ConditionalPromptSelector] = None,
        **prompt_kwargs: Any,
    ) -> None:
        """Init params."""
        # first check if langchain_prompt_selector is provided
        # TODO: self.prompt is deprecated, switch to prompt_selector under the hood
        if langchain_prompt_selector is not None:
            self.prompt_selector = langchain_prompt_selector
            self.prompt: BaseLangchainPrompt = self.prompt_selector.default_prompt
        # then check if template is provided
        elif langchain_prompt is None:
            if template is None:
                raise ValueError(
                    "`template` must be specified if `langchain_prompt` is None"
                )
            # validate
            tmpl_vars = {
                v for _, v, _, _ in Formatter().parse(template) if v is not None
            }
            if tmpl_vars != set(self.input_variables):
                raise ValueError(
                    f"Invalid template: {template}, variables do not match the "
                    f"required input_variables: {self.input_variables}"
                )

            self.prompt = LangchainPrompt(
                input_variables=self.input_variables, template=template, **prompt_kwargs
            )
            self.prompt_selector = ConditionalPromptSelector(default_prompt=self.prompt)
        # finally, check if langchain_prompt is provided
        else:
            if template:
                raise ValueError(
                    f"Both template ({template}) and langchain_prompt "
                    f"({langchain_prompt}) are provided, only one should be."
                )
            self.prompt = langchain_prompt
            self.prompt_selector = ConditionalPromptSelector(default_prompt=self.prompt)

        # validate all prompts in prompt selector
        all_lc_prompts = [self.prompt_selector.default_prompt]
        for _, prompt in self.prompt_selector.conditionals:
            all_lc_prompts.append(prompt)
        for lc_prompt in all_lc_prompts:
            if set(lc_prompt.input_variables) != set(self.input_variables):
                raise ValueError(
                    f"Invalid prompt: {langchain_prompt}, variables do not match the "
                    f"required input_variables: {self.input_variables}"
                )
        self.partial_dict: Dict[str, Any] = {}
        self.prompt_kwargs = prompt_kwargs

    @classmethod
    def from_langchain_prompt(
        cls: Type[PMT], prompt: BaseLangchainPrompt, **kwargs: Any
    ) -> PMT:
        """Load prompt from LangChain prompt."""
        return cls(langchain_prompt=prompt, **kwargs)

    @classmethod
    def from_langchain_prompt_selector(
        cls: Type[PMT], prompt_selector: ConditionalPromptSelector, **kwargs: Any
    ) -> PMT:
        """Load prompt from LangChain prompt."""
        return cls(langchain_prompt_selector=prompt_selector, **kwargs)

    def partial_format(self: PMT, **kwargs: Any) -> PMT:
        """Format the prompt partially.

        Return an instance of itself.

        """
        for k in kwargs.keys():
            if k not in self.input_variables:
                raise ValueError(
                    f"Invalid input variable: {k}, not found in input_variables"
                )

        copy_obj = deepcopy(self)
        copy_obj.partial_dict.update(kwargs)
        return copy_obj

    @classmethod
    def from_prompt(
        cls: Type[PMT], prompt: "Prompt", llm: Optional[BaseLanguageModel] = None
    ) -> PMT:
        """Create a prompt from an existing prompt.

        Use case: If the existing prompt is already partially filled,
        and the remaining fields satisfy the requirements of the
        prompt class, then we can create a new prompt from the existing
        partially filled prompt.

        """
        lc_prompt = prompt.get_langchain_prompt(llm=llm)
        tmpl_vars = lc_prompt.input_variables
        format_dict = {}
        for var in tmpl_vars:
            if var not in prompt.partial_dict:
                format_dict[var] = f"{{{var}}}"

        template_str = prompt.format(llm=llm, **format_dict)
        cls_obj: PMT = cls(template_str, **prompt.prompt_kwargs)
        return cls_obj

    def get_langchain_prompt(
        self, llm: Optional[BaseLanguageModel] = None
    ) -> BaseLangchainPrompt:
        """Get langchain prompt."""
        if llm is None:
            return self.prompt_selector.default_prompt
        return self.prompt_selector.get_prompt(llm=llm)

    def format(self, llm: Optional[BaseLanguageModel] = None, **kwargs: Any) -> str:
        """Format the prompt."""
        kwargs.update(self.partial_dict)
        lc_prompt = self.get_langchain_prompt(llm=llm)
        return lc_prompt.format(**kwargs)

    def get_full_format_args(self, kwargs: Dict) -> Dict[str, Any]:
        """Get dict of all format args.

        Hack to pass into Langchain to pass validation.

        """
        kwargs.update(self.partial_dict)
        return kwargs