"""Langchain memory wrapper (for LlamaIndex).""" from typing import Any, Dict, List, Optional from langchain.memory.chat_memory import BaseChatMemory from langchain.schema import AIMessage from langchain.schema import BaseMemory as Memory from langchain.schema import BaseMessage, HumanMessage from pydantic import Field from gpt_index.indices.base import BaseGPTIndex from gpt_index.readers.schema.base import Document from gpt_index.utils import get_new_id def get_prompt_input_key(inputs: Dict[str, Any], memory_variables: List[str]) -> str: """Get prompt input key. Copied over from langchain. """ # "stop" is a special key that can be passed as input but is not used to # format the prompt. prompt_input_keys = list(set(inputs).difference(memory_variables + ["stop"])) if len(prompt_input_keys) != 1: raise ValueError(f"One input key expected got {prompt_input_keys}") return prompt_input_keys[0] class GPTIndexMemory(Memory): """Langchain memory wrapper (for LlamaIndex). Args: human_prefix (str): Prefix for human input. Defaults to "Human". ai_prefix (str): Prefix for AI output. Defaults to "AI". memory_key (str): Key for memory. Defaults to "history". index (BaseGPTIndex): LlamaIndex instance. query_kwargs (Dict[str, Any]): Keyword arguments for LlamaIndex query. input_key (Optional[str]): Input key. Defaults to None. output_key (Optional[str]): Output key. Defaults to None. """ human_prefix: str = "Human" ai_prefix: str = "AI" memory_key: str = "history" index: BaseGPTIndex query_kwargs: Dict = Field(default_factory=dict) output_key: Optional[str] = None input_key: Optional[str] = None @property def memory_variables(self) -> List[str]: """Return memory variables.""" return [self.memory_key] def _get_prompt_input_key(self, inputs: Dict[str, Any]) -> str: if self.input_key is None: prompt_input_key = get_prompt_input_key(inputs, self.memory_variables) else: prompt_input_key = self.input_key return prompt_input_key def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]: """Return key-value pairs given the text input to the chain.""" prompt_input_key = self._get_prompt_input_key(inputs) query_str = inputs[prompt_input_key] # TODO: wrap in prompt # TODO: add option to return the raw text # NOTE: currently it's a hack response = self.index.query(query_str, **self.query_kwargs) return {self.memory_key: str(response)} def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None: """Save the context of this model run to memory.""" prompt_input_key = self._get_prompt_input_key(inputs) if self.output_key is None: if len(outputs) != 1: raise ValueError(f"One output key expected, got {outputs.keys()}") output_key = list(outputs.keys())[0] else: output_key = self.output_key human = f"{self.human_prefix}: " + inputs[prompt_input_key] ai = f"{self.ai_prefix}: " + outputs[output_key] doc_text = "\n".join([human, ai]) doc = Document(text=doc_text) self.index.insert(doc) def clear(self) -> None: """Clear memory contents.""" pass def __repr__(self) -> str: """Return representation.""" return "GPTIndexMemory()" class GPTIndexChatMemory(BaseChatMemory): """Langchain chat memory wrapper (for LlamaIndex). Args: human_prefix (str): Prefix for human input. Defaults to "Human". ai_prefix (str): Prefix for AI output. Defaults to "AI". memory_key (str): Key for memory. Defaults to "history". index (BaseGPTIndex): LlamaIndex instance. query_kwargs (Dict[str, Any]): Keyword arguments for LlamaIndex query. input_key (Optional[str]): Input key. Defaults to None. output_key (Optional[str]): Output key. Defaults to None. """ human_prefix: str = "Human" ai_prefix: str = "AI" memory_key: str = "history" index: BaseGPTIndex query_kwargs: Dict = Field(default_factory=dict) output_key: Optional[str] = None input_key: Optional[str] = None return_source: bool = False id_to_message: Dict[str, BaseMessage] = Field(default_factory=dict) @property def memory_variables(self) -> List[str]: """Return memory variables.""" return [self.memory_key] def _get_prompt_input_key(self, inputs: Dict[str, Any]) -> str: if self.input_key is None: prompt_input_key = get_prompt_input_key(inputs, self.memory_variables) else: prompt_input_key = self.input_key return prompt_input_key def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]: """Return key-value pairs given the text input to the chain.""" prompt_input_key = self._get_prompt_input_key(inputs) query_str = inputs[prompt_input_key] response_obj = self.index.query(query_str, **self.query_kwargs) if self.return_source: source_nodes = response_obj.source_nodes if self.return_messages: # get source messages from ids source_ids = [sn.doc_id for sn in source_nodes] source_messages = [ m for id, m in self.id_to_message.items() if id in source_ids ] # NOTE: type List[BaseMessage] response: Any = source_messages else: source_texts = [sn.source_text for sn in source_nodes] response = "\n\n".join(source_texts) else: response = str(response_obj) return {self.memory_key: response} def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None: """Save the context of this model run to memory.""" prompt_input_key = self._get_prompt_input_key(inputs) if self.output_key is None: if len(outputs) != 1: raise ValueError(f"One output key expected, got {outputs.keys()}") output_key = list(outputs.keys())[0] else: output_key = self.output_key # a bit different than existing langchain implementation # because we want to track id's for messages human_message = HumanMessage(content=inputs[prompt_input_key]) human_message_id = get_new_id(set(self.id_to_message.keys())) ai_message = AIMessage(content=outputs[output_key]) ai_message_id = get_new_id( set(self.id_to_message.keys()).union({human_message_id}) ) self.chat_memory.messages.append(human_message) self.chat_memory.messages.append(ai_message) self.id_to_message[human_message_id] = human_message self.id_to_message[ai_message_id] = ai_message human_txt = f"{self.human_prefix}: " + inputs[prompt_input_key] ai_txt = f"{self.ai_prefix}: " + outputs[output_key] human_doc = Document(text=human_txt, doc_id=human_message_id) ai_doc = Document(text=ai_txt, doc_id=ai_message_id) self.index.insert(human_doc) self.index.insert(ai_doc) def clear(self) -> None: """Clear memory contents.""" pass def __repr__(self) -> str: """Return representation.""" return "GPTIndexMemory()"