yrobel-lima commited on
Commit
ea992cb
1 Parent(s): e35585c

Delete rag/runnable.py

Browse files
Files changed (1) hide show
  1. rag/runnable.py +0 -129
rag/runnable.py DELETED
@@ -1,129 +0,0 @@
1
- import os
2
- import random
3
- from datetime import datetime
4
- from operator import itemgetter
5
- from typing import Sequence
6
-
7
- import langsmith
8
- from langchain.memory import ConversationBufferWindowMemory
9
- from langchain_community.document_transformers import LongContextReorder
10
- from langchain_core.documents import Document
11
- from langchain_core.output_parsers import StrOutputParser
12
- from langchain_core.runnables import Runnable, RunnableLambda
13
- from langchain_openai import ChatOpenAI
14
- from zoneinfo import ZoneInfo
15
-
16
- from rag.retrievers import RetrieversConfig
17
-
18
- from .prompt_template import generate_prompt_template
19
-
20
- # Helpers
21
-
22
-
23
- def get_datetime() -> str:
24
- """Get the current date and time."""
25
- return datetime.now(ZoneInfo("America/Vancouver")).strftime("%A, %Y-%b-%d %H:%M:%S")
26
-
27
-
28
- def reorder_documents(docs: list[Document]) -> Sequence[Document]:
29
- """Reorder documents to mitigate performance degradation with long contexts."""
30
-
31
- return LongContextReorder().transform_documents(docs)
32
-
33
-
34
- def randomize_documents(documents: list[Document]) -> list[Document]:
35
- """Randomize documents to vary model recommendations."""
36
- random.shuffle(documents)
37
- return documents
38
-
39
-
40
- class DocumentFormatter:
41
- def __init__(self, prefix: str):
42
- self.prefix = prefix
43
-
44
- def __call__(self, docs: list[Document]) -> str:
45
- """Format the Documents to markdown.
46
- Args:
47
- docs (list[Documents]): List of Langchain documents
48
- Returns:
49
- docs (str):
50
- """
51
- return "\n---\n".join(
52
- [
53
- f"- {self.prefix} {i+1}:\n\n\t" + d.page_content
54
- for i, d in enumerate(docs)
55
- ]
56
- )
57
-
58
-
59
- def create_langsmith_client():
60
- """Create a Langsmith client."""
61
- os.environ["LANGCHAIN_TRACING_V2"] = "true"
62
- os.environ["LANGCHAIN_PROJECT"] = "talltree-ai-assistant"
63
- os.environ["LANGCHAIN_ENDPOINT"] = "https://api.smith.langchain.com"
64
- langsmith_api_key = os.getenv("LANGCHAIN_API_KEY")
65
- if not langsmith_api_key:
66
- raise EnvironmentError("Missing environment variable: LANGCHAIN_API_KEY")
67
- return langsmith.Client()
68
-
69
-
70
- # Set up Runnable and Memory
71
-
72
-
73
- def get_runnable(
74
- model: str = "gpt-4o-mini", temperature: float = 0.1
75
- ) -> tuple[Runnable, ConversationBufferWindowMemory]:
76
- """Set up runnable and chat memory
77
-
78
- Args:
79
- model_name (str, optional): LLM model. Defaults to "gpt-4o".
80
- temperature (float, optional): Model temperature. Defaults to 0.1.
81
-
82
- Returns:
83
- Runnable, Memory: Chain and Memory
84
- """
85
-
86
- # Set up Langsmith to trace the chain
87
- create_langsmith_client()
88
-
89
- # LLM and prompt template
90
- llm = ChatOpenAI(
91
- model=model,
92
- temperature=temperature,
93
- )
94
-
95
- prompt = generate_prompt_template()
96
-
97
- # Set retrievers with Hybrid search
98
-
99
- retrievers_config = RetrieversConfig()
100
-
101
- # Practitioners data
102
- practitioners_data_retriever = retrievers_config.get_practitioners_retriever(k=10)
103
-
104
- # Tall Tree documents with contact information for locations and services
105
- documents_retriever = retrievers_config.get_documents_retriever(k=10)
106
-
107
- # Set conversation history window memory. It only uses the last k interactions
108
- memory = ConversationBufferWindowMemory(
109
- memory_key="history",
110
- return_messages=True,
111
- k=6,
112
- )
113
-
114
- # Set up runnable using LCEL
115
- setup = {
116
- "practitioners_db": itemgetter("message")
117
- | practitioners_data_retriever
118
- | DocumentFormatter("Practitioner #"),
119
- "tall_tree_db": itemgetter("message")
120
- | documents_retriever
121
- | DocumentFormatter("No."),
122
- "timestamp": lambda _: get_datetime(),
123
- "history": RunnableLambda(memory.load_memory_variables) | itemgetter("history"),
124
- "message": itemgetter("message"),
125
- }
126
-
127
- chain = setup | prompt | llm | StrOutputParser()
128
-
129
- return chain, memory