File size: 3,435 Bytes
e35585c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import logging
from operator import itemgetter

from langchain.memory import ConversationBufferWindowMemory
from langchain.retrievers import ContextualCompressionRetriever
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import Runnable, RunnableLambda
from langchain_openai import ChatOpenAI

from rag.retrievers import RetrieversConfig

from .helpers import (
    DocumentFormatter,
    create_langsmith_client,
    get_datetime,
    get_reranker,
)
from .prompt_template import generate_prompt_template

logging.basicConfig(level=logging.ERROR)


def retrievers_setup(retrievers_config, reranker: bool = False) -> tuple:
    """Set up retrievers with re-ranking

    Args:

        retrievers_config (_type_):

        reranker (bool, optional): Defaults to False.



    Returns:

        tuple: Retrievers

    """
    # Practitioners
    practitioners_retriever = retrievers_config.get_practitioners_retriever(k=10)
    # Tall Tree documents
    documents_retriever = retrievers_config.get_documents_retriever(k=10)

    # Re-ranking (optional): Improves quality and serves as a filter
    if reranker:
        practitioners_retriever_reranker = ContextualCompressionRetriever(
            base_compressor=get_reranker(top_n=10),
            base_retriever=practitioners_retriever,
        )
        documents_retriever_reranker = ContextualCompressionRetriever(
            base_compressor=get_reranker(top_n=8),
            base_retriever=documents_retriever,
        )

        return practitioners_retriever_reranker, documents_retriever_reranker

    else:
        return practitioners_retriever, documents_retriever


# Set retrievers as global variables (I see better loading time from Streamlit this way)
practitioners_retriever, documents_retriever = retrievers_setup(
    retrievers_config=RetrieversConfig(), reranker=True
)


# Set up runnable and chat memory
def get_runnable_and_memory(

    model: str = "gpt-4o-mini", temperature: float = 0.1

) -> tuple[Runnable, ConversationBufferWindowMemory]:
    """Set up runnable and chat memory



    Args:

        model_name (str, optional): LLM model. Defaults to "gpt-4o-mini".

        temperature (float, optional): Model temperature. Defaults to 0.1.



    Returns:

        Runnable, Memory: Runnable and Memory

    """

    # Set up Langsmith to trace the runnable
    create_langsmith_client()

    # LLM and prompt template
    llm = ChatOpenAI(
        model=model,
        temperature=temperature,
    )

    prompt = generate_prompt_template()

    # Set conversation history window memory. It only uses the last k interactions
    memory = ConversationBufferWindowMemory(
        memory_key="history",
        return_messages=True,
        k=6,
    )

    # Set up runnable using LCEL
    setup = {
        "practitioners_db": itemgetter("user_query")
        | practitioners_retriever
        | DocumentFormatter("Practitioner #"),
        "tall_tree_db": itemgetter("user_query")
        | documents_retriever
        | DocumentFormatter("No."),
        "timestamp": lambda _: get_datetime(),
        "history": RunnableLambda(memory.load_memory_variables) | itemgetter("history"),
        "user_query": itemgetter("user_query"),
    }

    runnable = setup | prompt | llm | StrOutputParser()

    return runnable, memory