yrobel-lima commited on
Commit
7eb7015
1 Parent(s): 2ff50a6

Update rag/runnable.py

Browse files
Files changed (1) hide show
  1. rag/runnable.py +129 -129
rag/runnable.py CHANGED
@@ -1,129 +1,129 @@
1
- import os
2
- import random
3
- from datetime import datetime
4
- from operator import itemgetter
5
- from typing import Sequence
6
-
7
- import langsmith
8
- from langchain.memory import ConversationBufferWindowMemory
9
- from langchain_community.document_transformers import LongContextReorder
10
- from langchain_core.documents import Document
11
- from langchain_core.output_parsers import StrOutputParser
12
- from langchain_core.runnables import Runnable, RunnableLambda
13
- from langchain_openai import ChatOpenAI
14
- from zoneinfo import ZoneInfo
15
-
16
- from rag.retrievers import RetrieversConfig
17
-
18
- from .prompt_template import generate_prompt_template
19
-
20
- # Helpers
21
-
22
-
23
- def get_datetime() -> str:
24
- """Get the current date and time."""
25
- return datetime.now(ZoneInfo("America/Vancouver")).strftime("%A, %Y-%b-%d %H:%M:%S")
26
-
27
-
28
- def reorder_documents(docs: list[Document]) -> Sequence[Document]:
29
- """Reorder documents to mitigate performance degradation with long contexts."""
30
-
31
- return LongContextReorder().transform_documents(docs)
32
-
33
-
34
- def randomize_documents(documents: list[Document]) -> list[Document]:
35
- """Randomize documents to vary model recommendations."""
36
- random.shuffle(documents)
37
- return documents
38
-
39
-
40
- class DocumentFormatter:
41
- def __init__(self, prefix: str):
42
- self.prefix = prefix
43
-
44
- def __call__(self, docs: list[Document]) -> str:
45
- """Format the Documents to markdown.
46
- Args:
47
- docs (list[Documents]): List of Langchain documents
48
- Returns:
49
- docs (str):
50
- """
51
- return "\n---\n".join(
52
- [
53
- f"- {self.prefix} {i+1}:\n\n\t" + d.page_content
54
- for i, d in enumerate(docs)
55
- ]
56
- )
57
-
58
-
59
- def create_langsmith_client():
60
- """Create a Langsmith client."""
61
- os.environ["LANGCHAIN_TRACING_V2"] = "true"
62
- os.environ["LANGCHAIN_PROJECT"] = "admin-ai-assistant"
63
- os.environ["LANGCHAIN_ENDPOINT"] = "https://api.smith.langchain.com"
64
- langsmith_api_key = os.getenv("LANGCHAIN_API_KEY")
65
- if not langsmith_api_key:
66
- raise EnvironmentError("Missing environment variable: LANGCHAIN_API_KEY")
67
- return langsmith.Client()
68
-
69
-
70
- # Set up Runnable and Memory
71
-
72
-
73
- def get_runnable(
74
- model: str = "gpt-4o-mini", temperature: float = 0.1
75
- ) -> tuple[Runnable, ConversationBufferWindowMemory]:
76
- """Set up runnable and chat memory
77
-
78
- Args:
79
- model_name (str, optional): LLM model. Defaults to "gpt-4o".
80
- temperature (float, optional): Model temperature. Defaults to 0.1.
81
-
82
- Returns:
83
- Runnable, Memory: Chain and Memory
84
- """
85
-
86
- # Set up Langsmith to trace the chain
87
- create_langsmith_client()
88
-
89
- # LLM and prompt template
90
- llm = ChatOpenAI(
91
- model=model,
92
- temperature=temperature,
93
- )
94
-
95
- prompt = generate_prompt_template()
96
-
97
- # Set retrievers with Hybrid search
98
-
99
- retrievers_config = RetrieversConfig()
100
-
101
- # Practitioners data
102
- practitioners_data_retriever = retrievers_config.get_practitioners_retriever(k=10)
103
-
104
- # Tall Tree documents with contact information for locations and services
105
- documents_retriever = retrievers_config.get_documents_retriever(k=10)
106
-
107
- # Set conversation history window memory. It only uses the last k interactions
108
- memory = ConversationBufferWindowMemory(
109
- memory_key="history",
110
- return_messages=True,
111
- k=6,
112
- )
113
-
114
- # Set up runnable using LCEL
115
- setup = {
116
- "practitioners_db": itemgetter("message")
117
- | practitioners_data_retriever
118
- | DocumentFormatter("Practitioner #"),
119
- "tall_tree_db": itemgetter("message")
120
- | documents_retriever
121
- | DocumentFormatter("No."),
122
- "timestamp": lambda _: get_datetime(),
123
- "history": RunnableLambda(memory.load_memory_variables) | itemgetter("history"),
124
- "message": itemgetter("message"),
125
- }
126
-
127
- chain = setup | prompt | llm | StrOutputParser()
128
-
129
- return chain, memory
 
1
+ import os
2
+ import random
3
+ from datetime import datetime
4
+ from operator import itemgetter
5
+ from typing import Sequence
6
+
7
+ import langsmith
8
+ from langchain.memory import ConversationBufferWindowMemory
9
+ from langchain_community.document_transformers import LongContextReorder
10
+ from langchain_core.documents import Document
11
+ from langchain_core.output_parsers import StrOutputParser
12
+ from langchain_core.runnables import Runnable, RunnableLambda
13
+ from langchain_openai import ChatOpenAI
14
+ from zoneinfo import ZoneInfo
15
+
16
+ from rag.retrievers import RetrieversConfig
17
+
18
+ from .prompt_template import generate_prompt_template
19
+
20
+ # Helpers
21
+
22
+
23
+ def get_datetime() -> str:
24
+ """Get the current date and time."""
25
+ return datetime.now(ZoneInfo("America/Vancouver")).strftime("%A, %Y-%b-%d %H:%M:%S")
26
+
27
+
28
+ def reorder_documents(docs: list[Document]) -> Sequence[Document]:
29
+ """Reorder documents to mitigate performance degradation with long contexts."""
30
+
31
+ return LongContextReorder().transform_documents(docs)
32
+
33
+
34
+ def randomize_documents(documents: list[Document]) -> list[Document]:
35
+ """Randomize documents to vary model recommendations."""
36
+ random.shuffle(documents)
37
+ return documents
38
+
39
+
40
+ class DocumentFormatter:
41
+ def __init__(self, prefix: str):
42
+ self.prefix = prefix
43
+
44
+ def __call__(self, docs: list[Document]) -> str:
45
+ """Format the Documents to markdown.
46
+ Args:
47
+ docs (list[Documents]): List of Langchain documents
48
+ Returns:
49
+ docs (str):
50
+ """
51
+ return "\n---\n".join(
52
+ [
53
+ f"- {self.prefix} {i+1}:\n\n\t" + d.page_content
54
+ for i, d in enumerate(docs)
55
+ ]
56
+ )
57
+
58
+
59
+ def create_langsmith_client():
60
+ """Create a Langsmith client."""
61
+ os.environ["LANGCHAIN_TRACING_V2"] = "true"
62
+ os.environ["LANGCHAIN_PROJECT"] = "talltree-ai-assistant"
63
+ os.environ["LANGCHAIN_ENDPOINT"] = "https://api.smith.langchain.com"
64
+ langsmith_api_key = os.getenv("LANGCHAIN_API_KEY")
65
+ if not langsmith_api_key:
66
+ raise EnvironmentError("Missing environment variable: LANGCHAIN_API_KEY")
67
+ return langsmith.Client()
68
+
69
+
70
+ # Set up Runnable and Memory
71
+
72
+
73
+ def get_runnable(
74
+ model: str = "gpt-4o-mini", temperature: float = 0.1
75
+ ) -> tuple[Runnable, ConversationBufferWindowMemory]:
76
+ """Set up runnable and chat memory
77
+
78
+ Args:
79
+ model_name (str, optional): LLM model. Defaults to "gpt-4o".
80
+ temperature (float, optional): Model temperature. Defaults to 0.1.
81
+
82
+ Returns:
83
+ Runnable, Memory: Chain and Memory
84
+ """
85
+
86
+ # Set up Langsmith to trace the chain
87
+ create_langsmith_client()
88
+
89
+ # LLM and prompt template
90
+ llm = ChatOpenAI(
91
+ model=model,
92
+ temperature=temperature,
93
+ )
94
+
95
+ prompt = generate_prompt_template()
96
+
97
+ # Set retrievers with Hybrid search
98
+
99
+ retrievers_config = RetrieversConfig()
100
+
101
+ # Practitioners data
102
+ practitioners_data_retriever = retrievers_config.get_practitioners_retriever(k=10)
103
+
104
+ # Tall Tree documents with contact information for locations and services
105
+ documents_retriever = retrievers_config.get_documents_retriever(k=10)
106
+
107
+ # Set conversation history window memory. It only uses the last k interactions
108
+ memory = ConversationBufferWindowMemory(
109
+ memory_key="history",
110
+ return_messages=True,
111
+ k=6,
112
+ )
113
+
114
+ # Set up runnable using LCEL
115
+ setup = {
116
+ "practitioners_db": itemgetter("message")
117
+ | practitioners_data_retriever
118
+ | DocumentFormatter("Practitioner #"),
119
+ "tall_tree_db": itemgetter("message")
120
+ | documents_retriever
121
+ | DocumentFormatter("No."),
122
+ "timestamp": lambda _: get_datetime(),
123
+ "history": RunnableLambda(memory.load_memory_variables) | itemgetter("history"),
124
+ "message": itemgetter("message"),
125
+ }
126
+
127
+ chain = setup | prompt | llm | StrOutputParser()
128
+
129
+ return chain, memory