ai-virtual-assistant / rag /runnable_and_memory.py
yrobel-lima's picture
Upload 4 files
e35585c verified
raw
history blame
3.44 kB
import logging
from operator import itemgetter
from langchain.memory import ConversationBufferWindowMemory
from langchain.retrievers import ContextualCompressionRetriever
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import Runnable, RunnableLambda
from langchain_openai import ChatOpenAI
from rag.retrievers import RetrieversConfig
from .helpers import (
DocumentFormatter,
create_langsmith_client,
get_datetime,
get_reranker,
)
from .prompt_template import generate_prompt_template
logging.basicConfig(level=logging.ERROR)
def retrievers_setup(retrievers_config, reranker: bool = False) -> tuple:
"""Set up retrievers with re-ranking
Args:
retrievers_config (_type_):
reranker (bool, optional): Defaults to False.
Returns:
tuple: Retrievers
"""
# Practitioners
practitioners_retriever = retrievers_config.get_practitioners_retriever(k=10)
# Tall Tree documents
documents_retriever = retrievers_config.get_documents_retriever(k=10)
# Re-ranking (optional): Improves quality and serves as a filter
if reranker:
practitioners_retriever_reranker = ContextualCompressionRetriever(
base_compressor=get_reranker(top_n=10),
base_retriever=practitioners_retriever,
)
documents_retriever_reranker = ContextualCompressionRetriever(
base_compressor=get_reranker(top_n=8),
base_retriever=documents_retriever,
)
return practitioners_retriever_reranker, documents_retriever_reranker
else:
return practitioners_retriever, documents_retriever
# Set retrievers as global variables (I see better loading time from Streamlit this way)
practitioners_retriever, documents_retriever = retrievers_setup(
retrievers_config=RetrieversConfig(), reranker=True
)
# Set up runnable and chat memory
def get_runnable_and_memory(
model: str = "gpt-4o-mini", temperature: float = 0.1
) -> tuple[Runnable, ConversationBufferWindowMemory]:
"""Set up runnable and chat memory
Args:
model_name (str, optional): LLM model. Defaults to "gpt-4o-mini".
temperature (float, optional): Model temperature. Defaults to 0.1.
Returns:
Runnable, Memory: Runnable and Memory
"""
# Set up Langsmith to trace the runnable
create_langsmith_client()
# LLM and prompt template
llm = ChatOpenAI(
model=model,
temperature=temperature,
)
prompt = generate_prompt_template()
# Set conversation history window memory. It only uses the last k interactions
memory = ConversationBufferWindowMemory(
memory_key="history",
return_messages=True,
k=6,
)
# Set up runnable using LCEL
setup = {
"practitioners_db": itemgetter("user_query")
| practitioners_retriever
| DocumentFormatter("Practitioner #"),
"tall_tree_db": itemgetter("user_query")
| documents_retriever
| DocumentFormatter("No."),
"timestamp": lambda _: get_datetime(),
"history": RunnableLambda(memory.load_memory_variables) | itemgetter("history"),
"user_query": itemgetter("user_query"),
}
runnable = setup | prompt | llm | StrOutputParser()
return runnable, memory