AbeerTrial's picture
Upload folder using huggingface_hub
8a58cf3
"""Langchain memory wrapper (for LlamaIndex)."""
from typing import Any, Dict, List, Optional
from langchain.memory.chat_memory import BaseChatMemory
from langchain.schema import AIMessage
from langchain.schema import BaseMemory as Memory
from langchain.schema import BaseMessage, HumanMessage
from pydantic import Field
from gpt_index.indices.base import BaseGPTIndex
from gpt_index.readers.schema.base import Document
from gpt_index.utils import get_new_id
def get_prompt_input_key(inputs: Dict[str, Any], memory_variables: List[str]) -> str:
"""Get prompt input key.
Copied over from langchain.
"""
# "stop" is a special key that can be passed as input but is not used to
# format the prompt.
prompt_input_keys = list(set(inputs).difference(memory_variables + ["stop"]))
if len(prompt_input_keys) != 1:
raise ValueError(f"One input key expected got {prompt_input_keys}")
return prompt_input_keys[0]
class GPTIndexMemory(Memory):
"""Langchain memory wrapper (for LlamaIndex).
Args:
human_prefix (str): Prefix for human input. Defaults to "Human".
ai_prefix (str): Prefix for AI output. Defaults to "AI".
memory_key (str): Key for memory. Defaults to "history".
index (BaseGPTIndex): LlamaIndex instance.
query_kwargs (Dict[str, Any]): Keyword arguments for LlamaIndex query.
input_key (Optional[str]): Input key. Defaults to None.
output_key (Optional[str]): Output key. Defaults to None.
"""
human_prefix: str = "Human"
ai_prefix: str = "AI"
memory_key: str = "history"
index: BaseGPTIndex
query_kwargs: Dict = Field(default_factory=dict)
output_key: Optional[str] = None
input_key: Optional[str] = None
@property
def memory_variables(self) -> List[str]:
"""Return memory variables."""
return [self.memory_key]
def _get_prompt_input_key(self, inputs: Dict[str, Any]) -> str:
if self.input_key is None:
prompt_input_key = get_prompt_input_key(inputs, self.memory_variables)
else:
prompt_input_key = self.input_key
return prompt_input_key
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
"""Return key-value pairs given the text input to the chain."""
prompt_input_key = self._get_prompt_input_key(inputs)
query_str = inputs[prompt_input_key]
# TODO: wrap in prompt
# TODO: add option to return the raw text
# NOTE: currently it's a hack
response = self.index.query(query_str, **self.query_kwargs)
return {self.memory_key: str(response)}
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""Save the context of this model run to memory."""
prompt_input_key = self._get_prompt_input_key(inputs)
if self.output_key is None:
if len(outputs) != 1:
raise ValueError(f"One output key expected, got {outputs.keys()}")
output_key = list(outputs.keys())[0]
else:
output_key = self.output_key
human = f"{self.human_prefix}: " + inputs[prompt_input_key]
ai = f"{self.ai_prefix}: " + outputs[output_key]
doc_text = "\n".join([human, ai])
doc = Document(text=doc_text)
self.index.insert(doc)
def clear(self) -> None:
"""Clear memory contents."""
pass
def __repr__(self) -> str:
"""Return representation."""
return "GPTIndexMemory()"
class GPTIndexChatMemory(BaseChatMemory):
"""Langchain chat memory wrapper (for LlamaIndex).
Args:
human_prefix (str): Prefix for human input. Defaults to "Human".
ai_prefix (str): Prefix for AI output. Defaults to "AI".
memory_key (str): Key for memory. Defaults to "history".
index (BaseGPTIndex): LlamaIndex instance.
query_kwargs (Dict[str, Any]): Keyword arguments for LlamaIndex query.
input_key (Optional[str]): Input key. Defaults to None.
output_key (Optional[str]): Output key. Defaults to None.
"""
human_prefix: str = "Human"
ai_prefix: str = "AI"
memory_key: str = "history"
index: BaseGPTIndex
query_kwargs: Dict = Field(default_factory=dict)
output_key: Optional[str] = None
input_key: Optional[str] = None
return_source: bool = False
id_to_message: Dict[str, BaseMessage] = Field(default_factory=dict)
@property
def memory_variables(self) -> List[str]:
"""Return memory variables."""
return [self.memory_key]
def _get_prompt_input_key(self, inputs: Dict[str, Any]) -> str:
if self.input_key is None:
prompt_input_key = get_prompt_input_key(inputs, self.memory_variables)
else:
prompt_input_key = self.input_key
return prompt_input_key
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
"""Return key-value pairs given the text input to the chain."""
prompt_input_key = self._get_prompt_input_key(inputs)
query_str = inputs[prompt_input_key]
response_obj = self.index.query(query_str, **self.query_kwargs)
if self.return_source:
source_nodes = response_obj.source_nodes
if self.return_messages:
# get source messages from ids
source_ids = [sn.doc_id for sn in source_nodes]
source_messages = [
m for id, m in self.id_to_message.items() if id in source_ids
]
# NOTE: type List[BaseMessage]
response: Any = source_messages
else:
source_texts = [sn.source_text for sn in source_nodes]
response = "\n\n".join(source_texts)
else:
response = str(response_obj)
return {self.memory_key: response}
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""Save the context of this model run to memory."""
prompt_input_key = self._get_prompt_input_key(inputs)
if self.output_key is None:
if len(outputs) != 1:
raise ValueError(f"One output key expected, got {outputs.keys()}")
output_key = list(outputs.keys())[0]
else:
output_key = self.output_key
# a bit different than existing langchain implementation
# because we want to track id's for messages
human_message = HumanMessage(content=inputs[prompt_input_key])
human_message_id = get_new_id(set(self.id_to_message.keys()))
ai_message = AIMessage(content=outputs[output_key])
ai_message_id = get_new_id(
set(self.id_to_message.keys()).union({human_message_id})
)
self.chat_memory.messages.append(human_message)
self.chat_memory.messages.append(ai_message)
self.id_to_message[human_message_id] = human_message
self.id_to_message[ai_message_id] = ai_message
human_txt = f"{self.human_prefix}: " + inputs[prompt_input_key]
ai_txt = f"{self.ai_prefix}: " + outputs[output_key]
human_doc = Document(text=human_txt, doc_id=human_message_id)
ai_doc = Document(text=ai_txt, doc_id=ai_message_id)
self.index.insert(human_doc)
self.index.insert(ai_doc)
def clear(self) -> None:
"""Clear memory contents."""
pass
def __repr__(self) -> str:
"""Return representation."""
return "GPTIndexMemory()"