Spaces:
Runtime error
Runtime error
File size: 7,565 Bytes
8a58cf3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 |
"""Langchain memory wrapper (for LlamaIndex)."""
from typing import Any, Dict, List, Optional
from langchain.memory.chat_memory import BaseChatMemory
from langchain.schema import AIMessage
from langchain.schema import BaseMemory as Memory
from langchain.schema import BaseMessage, HumanMessage
from pydantic import Field
from gpt_index.indices.base import BaseGPTIndex
from gpt_index.readers.schema.base import Document
from gpt_index.utils import get_new_id
def get_prompt_input_key(inputs: Dict[str, Any], memory_variables: List[str]) -> str:
"""Get prompt input key.
Copied over from langchain.
"""
# "stop" is a special key that can be passed as input but is not used to
# format the prompt.
prompt_input_keys = list(set(inputs).difference(memory_variables + ["stop"]))
if len(prompt_input_keys) != 1:
raise ValueError(f"One input key expected got {prompt_input_keys}")
return prompt_input_keys[0]
class GPTIndexMemory(Memory):
"""Langchain memory wrapper (for LlamaIndex).
Args:
human_prefix (str): Prefix for human input. Defaults to "Human".
ai_prefix (str): Prefix for AI output. Defaults to "AI".
memory_key (str): Key for memory. Defaults to "history".
index (BaseGPTIndex): LlamaIndex instance.
query_kwargs (Dict[str, Any]): Keyword arguments for LlamaIndex query.
input_key (Optional[str]): Input key. Defaults to None.
output_key (Optional[str]): Output key. Defaults to None.
"""
human_prefix: str = "Human"
ai_prefix: str = "AI"
memory_key: str = "history"
index: BaseGPTIndex
query_kwargs: Dict = Field(default_factory=dict)
output_key: Optional[str] = None
input_key: Optional[str] = None
@property
def memory_variables(self) -> List[str]:
"""Return memory variables."""
return [self.memory_key]
def _get_prompt_input_key(self, inputs: Dict[str, Any]) -> str:
if self.input_key is None:
prompt_input_key = get_prompt_input_key(inputs, self.memory_variables)
else:
prompt_input_key = self.input_key
return prompt_input_key
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
"""Return key-value pairs given the text input to the chain."""
prompt_input_key = self._get_prompt_input_key(inputs)
query_str = inputs[prompt_input_key]
# TODO: wrap in prompt
# TODO: add option to return the raw text
# NOTE: currently it's a hack
response = self.index.query(query_str, **self.query_kwargs)
return {self.memory_key: str(response)}
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""Save the context of this model run to memory."""
prompt_input_key = self._get_prompt_input_key(inputs)
if self.output_key is None:
if len(outputs) != 1:
raise ValueError(f"One output key expected, got {outputs.keys()}")
output_key = list(outputs.keys())[0]
else:
output_key = self.output_key
human = f"{self.human_prefix}: " + inputs[prompt_input_key]
ai = f"{self.ai_prefix}: " + outputs[output_key]
doc_text = "\n".join([human, ai])
doc = Document(text=doc_text)
self.index.insert(doc)
def clear(self) -> None:
"""Clear memory contents."""
pass
def __repr__(self) -> str:
"""Return representation."""
return "GPTIndexMemory()"
class GPTIndexChatMemory(BaseChatMemory):
"""Langchain chat memory wrapper (for LlamaIndex).
Args:
human_prefix (str): Prefix for human input. Defaults to "Human".
ai_prefix (str): Prefix for AI output. Defaults to "AI".
memory_key (str): Key for memory. Defaults to "history".
index (BaseGPTIndex): LlamaIndex instance.
query_kwargs (Dict[str, Any]): Keyword arguments for LlamaIndex query.
input_key (Optional[str]): Input key. Defaults to None.
output_key (Optional[str]): Output key. Defaults to None.
"""
human_prefix: str = "Human"
ai_prefix: str = "AI"
memory_key: str = "history"
index: BaseGPTIndex
query_kwargs: Dict = Field(default_factory=dict)
output_key: Optional[str] = None
input_key: Optional[str] = None
return_source: bool = False
id_to_message: Dict[str, BaseMessage] = Field(default_factory=dict)
@property
def memory_variables(self) -> List[str]:
"""Return memory variables."""
return [self.memory_key]
def _get_prompt_input_key(self, inputs: Dict[str, Any]) -> str:
if self.input_key is None:
prompt_input_key = get_prompt_input_key(inputs, self.memory_variables)
else:
prompt_input_key = self.input_key
return prompt_input_key
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
"""Return key-value pairs given the text input to the chain."""
prompt_input_key = self._get_prompt_input_key(inputs)
query_str = inputs[prompt_input_key]
response_obj = self.index.query(query_str, **self.query_kwargs)
if self.return_source:
source_nodes = response_obj.source_nodes
if self.return_messages:
# get source messages from ids
source_ids = [sn.doc_id for sn in source_nodes]
source_messages = [
m for id, m in self.id_to_message.items() if id in source_ids
]
# NOTE: type List[BaseMessage]
response: Any = source_messages
else:
source_texts = [sn.source_text for sn in source_nodes]
response = "\n\n".join(source_texts)
else:
response = str(response_obj)
return {self.memory_key: response}
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""Save the context of this model run to memory."""
prompt_input_key = self._get_prompt_input_key(inputs)
if self.output_key is None:
if len(outputs) != 1:
raise ValueError(f"One output key expected, got {outputs.keys()}")
output_key = list(outputs.keys())[0]
else:
output_key = self.output_key
# a bit different than existing langchain implementation
# because we want to track id's for messages
human_message = HumanMessage(content=inputs[prompt_input_key])
human_message_id = get_new_id(set(self.id_to_message.keys()))
ai_message = AIMessage(content=outputs[output_key])
ai_message_id = get_new_id(
set(self.id_to_message.keys()).union({human_message_id})
)
self.chat_memory.messages.append(human_message)
self.chat_memory.messages.append(ai_message)
self.id_to_message[human_message_id] = human_message
self.id_to_message[ai_message_id] = ai_message
human_txt = f"{self.human_prefix}: " + inputs[prompt_input_key]
ai_txt = f"{self.ai_prefix}: " + outputs[output_key]
human_doc = Document(text=human_txt, doc_id=human_message_id)
ai_doc = Document(text=ai_txt, doc_id=ai_message_id)
self.index.insert(human_doc)
self.index.insert(ai_doc)
def clear(self) -> None:
"""Clear memory contents."""
pass
def __repr__(self) -> str:
"""Return representation."""
return "GPTIndexMemory()"
|