import os import random from functools import cache from operator import itemgetter import langsmith from langchain.memory import ConversationBufferWindowMemory from langchain.retrievers import EnsembleRetriever from langchain_community.document_transformers import LongContextReorder from langchain_core.documents import Document from langchain_core.output_parsers import StrOutputParser from langchain_core.runnables import RunnableLambda from langchain_openai.chat_models import ChatOpenAI from .prompt_template import generate_prompt_template from .retrievers_setup import ( DenseRetrieverClient, SparseRetrieverClient, compression_retriever_setup, multi_query_retriever_setup, ) # Helpers def reorder_documents(docs: list[Document]) -> list[Document]: """Reorder documents to mitigate performance degradation with long contexts.""" return LongContextReorder().transform_documents(docs) def randomize_documents(documents: list[Document]) -> list[Document]: """Randomize documents to vary model recommendations.""" random.shuffle(documents) return documents class DocumentFormatter: def __init__(self, prefix: str): self.prefix = prefix def __call__(self, docs: list[Document]) -> str: """Format the Documents to markdown. Args: docs (list[Documents]): List of Langchain documents Returns: docs (str): """ return f"\n---\n".join( [ f"- {self.prefix} {i+1}:\n\n\t" + d.page_content for i, d in enumerate(docs) ] ) @cache def create_langsmith_client(): """Create a Langsmith client.""" os.environ["LANGCHAIN_TRACING_V2"] = "true" os.environ["LANGCHAIN_PROJECT"] = "talltree-ai-assistant" os.environ["LANGCHAIN_ENDPOINT"] = "https://api.smith.langchain.com" langsmith_api_key = os.getenv("LANGCHAIN_API_KEY") if not langsmith_api_key: raise EnvironmentError("Missing environment variable: LANGCHAIN_API_KEY") return langsmith.Client() # Set up Runnable and Memory @cache def get_rag_chain( model_name: str = "gpt-4", temperature: float = 0.2 ) -> tuple[ChatOpenAI, ConversationBufferWindowMemory]: """Set up runnable and chat memory Args: model_name (str, optional): LLM model. Defaults to "gpt-4" 30012024. temperature (float, optional): Model temperature. Defaults to 0.2. Returns: Runnable, Memory: Chain and Memory """ RETRIEVER_PARAMETERS = { "embeddings_model": "text-embedding-3-small", "k_dense_practitioners_db": 8, "k_sparse_practitioners_db": 15, "weights_ensemble_practitioners_db": [0.2, 0.8], "k_compression_practitioners_db": 18, "k_dense_talltree": 6, "k_compression_talltree": 6, } # Set up Langsmith to trace the chain langsmith_tracing = create_langsmith_client() # LLM and prompt template llm = ChatOpenAI( model_name=model_name, temperature=temperature, ) prompt = generate_prompt_template() # Set retrievers pointing to the practitioners's dataset dense_retriever_client = DenseRetrieverClient( embeddings_model=RETRIEVER_PARAMETERS["embeddings_model"], collection_name="practitioners_db", search_type="similarity", k=RETRIEVER_PARAMETERS["k_dense_practitioners_db"], ) # k x 4 using multiquery retriever # Qdrant db as a retriever practitioners_db_dense_retriever = dense_retriever_client.get_dense_retriever() # Multiquery retriever using the dense retriever practitioners_db_dense_multiquery_retriever = multi_query_retriever_setup( practitioners_db_dense_retriever ) # Sparse vector retriever sparse_retriever_client = SparseRetrieverClient( collection_name="practitioners_db_sparse_collection", vector_name="sparse_vector", splade_model_id="naver/splade-cocondenser-ensembledistil", k=RETRIEVER_PARAMETERS["k_sparse_practitioners_db"], ) practitioners_db_sparse_retriever = sparse_retriever_client.get_sparse_retriever() # Ensemble retriever for hyprid search (dense retriever seems to work better but the dense retriever is good for acronyms like RMT) practitioners_ensemble_retriever = EnsembleRetriever( retrievers=[ practitioners_db_dense_multiquery_retriever, practitioners_db_sparse_retriever, ], weights=RETRIEVER_PARAMETERS["weights_ensemble_practitioners_db"], ) # Compression retriever for practitioners db practitioners_db_compression_retriever = compression_retriever_setup( practitioners_ensemble_retriever, embeddings_model=RETRIEVER_PARAMETERS["embeddings_model"], k=RETRIEVER_PARAMETERS["k_compression_practitioners_db"], ) # Set retrievers pointing to the tall_tree_db dense_retriever_client = DenseRetrieverClient( embeddings_model=RETRIEVER_PARAMETERS["embeddings_model"], collection_name="tall_tree_db", search_type="similarity", k=RETRIEVER_PARAMETERS["k_dense_talltree"], ) tall_tree_db_dense_retriever = dense_retriever_client.get_dense_retriever() # Compression retriever for tall_tree_db tall_tree_db_compression_retriever = compression_retriever_setup( tall_tree_db_dense_retriever, embeddings_model=RETRIEVER_PARAMETERS["embeddings_model"], k=RETRIEVER_PARAMETERS["k_compression_talltree"], ) # Set conversation history window memory. It only uses the last k interactions. memory = ConversationBufferWindowMemory( memory_key="history", return_messages=True, k=6, ) # Set up runnable using LCEL setup_and_retrieval = { "practitioners_db": itemgetter("message") | practitioners_db_compression_retriever | DocumentFormatter("Practitioner #"), "tall_tree_db": itemgetter("message") | tall_tree_db_compression_retriever | DocumentFormatter("No."), "history": RunnableLambda(memory.load_memory_variables) | itemgetter("history"), "message": itemgetter("message"), } chain = setup_and_retrieval | prompt | llm | StrOutputParser() return chain, memory