File size: 7,565 Bytes
8a58cf3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
"""Langchain memory wrapper (for LlamaIndex)."""

from typing import Any, Dict, List, Optional

from langchain.memory.chat_memory import BaseChatMemory
from langchain.schema import AIMessage
from langchain.schema import BaseMemory as Memory
from langchain.schema import BaseMessage, HumanMessage
from pydantic import Field

from gpt_index.indices.base import BaseGPTIndex
from gpt_index.readers.schema.base import Document
from gpt_index.utils import get_new_id


def get_prompt_input_key(inputs: Dict[str, Any], memory_variables: List[str]) -> str:
    """Get prompt input key.

    Copied over from langchain.

    """
    # "stop" is a special key that can be passed as input but is not used to
    # format the prompt.
    prompt_input_keys = list(set(inputs).difference(memory_variables + ["stop"]))
    if len(prompt_input_keys) != 1:
        raise ValueError(f"One input key expected got {prompt_input_keys}")
    return prompt_input_keys[0]


class GPTIndexMemory(Memory):
    """Langchain memory wrapper (for LlamaIndex).

    Args:
        human_prefix (str): Prefix for human input. Defaults to "Human".
        ai_prefix (str): Prefix for AI output. Defaults to "AI".
        memory_key (str): Key for memory. Defaults to "history".
        index (BaseGPTIndex): LlamaIndex instance.
        query_kwargs (Dict[str, Any]): Keyword arguments for LlamaIndex query.
        input_key (Optional[str]): Input key. Defaults to None.
        output_key (Optional[str]): Output key. Defaults to None.

    """

    human_prefix: str = "Human"
    ai_prefix: str = "AI"
    memory_key: str = "history"
    index: BaseGPTIndex
    query_kwargs: Dict = Field(default_factory=dict)
    output_key: Optional[str] = None
    input_key: Optional[str] = None

    @property
    def memory_variables(self) -> List[str]:
        """Return memory variables."""
        return [self.memory_key]

    def _get_prompt_input_key(self, inputs: Dict[str, Any]) -> str:
        if self.input_key is None:
            prompt_input_key = get_prompt_input_key(inputs, self.memory_variables)
        else:
            prompt_input_key = self.input_key
        return prompt_input_key

    def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
        """Return key-value pairs given the text input to the chain."""
        prompt_input_key = self._get_prompt_input_key(inputs)
        query_str = inputs[prompt_input_key]

        # TODO: wrap in prompt
        # TODO: add option to return the raw text
        # NOTE: currently it's a hack
        response = self.index.query(query_str, **self.query_kwargs)
        return {self.memory_key: str(response)}

    def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
        """Save the context of this model run to memory."""
        prompt_input_key = self._get_prompt_input_key(inputs)
        if self.output_key is None:
            if len(outputs) != 1:
                raise ValueError(f"One output key expected, got {outputs.keys()}")
            output_key = list(outputs.keys())[0]
        else:
            output_key = self.output_key
        human = f"{self.human_prefix}: " + inputs[prompt_input_key]
        ai = f"{self.ai_prefix}: " + outputs[output_key]
        doc_text = "\n".join([human, ai])
        doc = Document(text=doc_text)
        self.index.insert(doc)

    def clear(self) -> None:
        """Clear memory contents."""
        pass

    def __repr__(self) -> str:
        """Return representation."""
        return "GPTIndexMemory()"


class GPTIndexChatMemory(BaseChatMemory):
    """Langchain chat memory wrapper (for LlamaIndex).

    Args:
        human_prefix (str): Prefix for human input. Defaults to "Human".
        ai_prefix (str): Prefix for AI output. Defaults to "AI".
        memory_key (str): Key for memory. Defaults to "history".
        index (BaseGPTIndex): LlamaIndex instance.
        query_kwargs (Dict[str, Any]): Keyword arguments for LlamaIndex query.
        input_key (Optional[str]): Input key. Defaults to None.
        output_key (Optional[str]): Output key. Defaults to None.

    """

    human_prefix: str = "Human"
    ai_prefix: str = "AI"
    memory_key: str = "history"
    index: BaseGPTIndex
    query_kwargs: Dict = Field(default_factory=dict)
    output_key: Optional[str] = None
    input_key: Optional[str] = None

    return_source: bool = False
    id_to_message: Dict[str, BaseMessage] = Field(default_factory=dict)

    @property
    def memory_variables(self) -> List[str]:
        """Return memory variables."""
        return [self.memory_key]

    def _get_prompt_input_key(self, inputs: Dict[str, Any]) -> str:
        if self.input_key is None:
            prompt_input_key = get_prompt_input_key(inputs, self.memory_variables)
        else:
            prompt_input_key = self.input_key
        return prompt_input_key

    def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
        """Return key-value pairs given the text input to the chain."""
        prompt_input_key = self._get_prompt_input_key(inputs)
        query_str = inputs[prompt_input_key]

        response_obj = self.index.query(query_str, **self.query_kwargs)
        if self.return_source:
            source_nodes = response_obj.source_nodes
            if self.return_messages:
                # get source messages from ids
                source_ids = [sn.doc_id for sn in source_nodes]
                source_messages = [
                    m for id, m in self.id_to_message.items() if id in source_ids
                ]
                # NOTE: type List[BaseMessage]
                response: Any = source_messages
            else:
                source_texts = [sn.source_text for sn in source_nodes]
                response = "\n\n".join(source_texts)
        else:
            response = str(response_obj)
        return {self.memory_key: response}

    def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
        """Save the context of this model run to memory."""
        prompt_input_key = self._get_prompt_input_key(inputs)
        if self.output_key is None:
            if len(outputs) != 1:
                raise ValueError(f"One output key expected, got {outputs.keys()}")
            output_key = list(outputs.keys())[0]
        else:
            output_key = self.output_key

        # a bit different than existing langchain implementation
        # because we want to track id's for messages
        human_message = HumanMessage(content=inputs[prompt_input_key])
        human_message_id = get_new_id(set(self.id_to_message.keys()))
        ai_message = AIMessage(content=outputs[output_key])
        ai_message_id = get_new_id(
            set(self.id_to_message.keys()).union({human_message_id})
        )

        self.chat_memory.messages.append(human_message)
        self.chat_memory.messages.append(ai_message)

        self.id_to_message[human_message_id] = human_message
        self.id_to_message[ai_message_id] = ai_message

        human_txt = f"{self.human_prefix}: " + inputs[prompt_input_key]
        ai_txt = f"{self.ai_prefix}: " + outputs[output_key]
        human_doc = Document(text=human_txt, doc_id=human_message_id)
        ai_doc = Document(text=ai_txt, doc_id=ai_message_id)
        self.index.insert(human_doc)
        self.index.insert(ai_doc)

    def clear(self) -> None:
        """Clear memory contents."""
        pass

    def __repr__(self) -> str:
        """Return representation."""
        return "GPTIndexMemory()"