yrobel-lima commited on
Commit
e35585c
1 Parent(s): 9181031

Upload 4 files

Browse files
rag/helpers.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import random
4
+ from datetime import datetime
5
+ from functools import lru_cache
6
+ from typing import Sequence
7
+ from zoneinfo import ZoneInfo
8
+
9
+ import langsmith
10
+ from langchain_core.documents import Document
11
+ from langchain_community.document_transformers import LongContextReorder
12
+ from langchain.retrievers.document_compressors import FlashrankRerank
13
+
14
+ logging.basicConfig(level=logging.ERROR)
15
+
16
+
17
+ class DocumentFormatter:
18
+ def __init__(self, prefix: str):
19
+ self.prefix = prefix
20
+
21
+ def __call__(self, docs: list[Document]) -> str:
22
+ return "\n---\n".join(
23
+ [
24
+ f"- {self.prefix} {i+1}:\n\n\t" + d.page_content
25
+ for i, d in enumerate(docs)
26
+ ]
27
+ )
28
+
29
+
30
+ def get_datetime() -> str:
31
+ return datetime.now(ZoneInfo("America/Vancouver")).strftime("%A, %Y-%b-%d %H:%M:%S")
32
+
33
+
34
+ def reorder_documents(docs: list[Document]) -> Sequence[Document]:
35
+ return LongContextReorder().transform_documents(docs)
36
+
37
+
38
+ def randomize_documents(documents: list[Document]) -> list[Document]:
39
+ random.shuffle(documents)
40
+ return documents
41
+
42
+
43
+ def create_langsmith_client():
44
+ os.environ["LANGCHAIN_TRACING_V2"] = "true"
45
+ os.environ["LANGCHAIN_PROJECT"] = "admin-ai-assistant"
46
+ os.environ["LANGCHAIN_ENDPOINT"] = "https://api.smith.langchain.com"
47
+ langsmith_api_key = os.getenv("LANGCHAIN_API_KEY")
48
+ if not langsmith_api_key:
49
+ raise EnvironmentError("Missing environment variable: LANGCHAIN_API_KEY")
50
+ return langsmith.Client()
51
+
52
+
53
+ @lru_cache(maxsize=1)
54
+ def get_reranker(
55
+ top_n: int = 3, model: str = "ms-marco-MiniLM-L-12-v2"
56
+ ) -> FlashrankRerank:
57
+ return FlashrankRerank(top_n=top_n, model=model)
rag/prompt_template.py CHANGED
@@ -11,7 +11,7 @@ def generate_prompt_template():
11
 
12
  ---
13
 
14
- Your name is Ella (Empathetic, Logical, Liaison, Accessible). You are a helpful Virtual Assistant at Tall Tree Health in British Columbia, Canada. Based on the patient's symptoms/needs, connect them with the right practitioner or service offered by Tall Tree. Respond to `Patient Queries` using the `Practitioners Database` and `Tall Tree Health Centre Information` provided in the `Context`. Follow the `Response Guidelines` listed below:
15
 
16
  ---
17
 
@@ -58,7 +58,7 @@ Your name is Ella (Empathetic, Logical, Liaison, Accessible). You are a helpful
58
  # Patient Query
59
 
60
  ```
61
- {message}
62
  ```
63
  ---
64
 
@@ -81,14 +81,14 @@ Your name is Ella (Empathetic, Logical, Liaison, Accessible). You are a helpful
81
 
82
  """
83
 
84
- # Template for system message with markdown formatting
85
  system_message = SystemMessagePromptTemplate.from_template(system_template)
86
 
87
  prompt = ChatPromptTemplate.from_messages(
88
  [
89
  system_message,
90
  MessagesPlaceholder(variable_name="history"),
91
- ("human", "{message}"),
92
  ]
93
  )
94
 
 
11
 
12
  ---
13
 
14
+ Your name is ELLA (Empathetic, Logical, Liaison, Accessible). You are a helpful AI Assistant at Tall Tree Health in British Columbia, Canada. Based on the patient's symptoms/needs, connect them with the right practitioner or service offered by Tall Tree. Respond to `Patient Queries` using the `Practitioners Database` and `Tall Tree Health Centre Information` provided in the `Context`. Follow the `Response Guidelines` listed below:
15
 
16
  ---
17
 
 
58
  # Patient Query
59
 
60
  ```
61
+ {user_query}
62
  ```
63
  ---
64
 
 
81
 
82
  """
83
 
84
+ # Template for the system message with markdown formatting
85
  system_message = SystemMessagePromptTemplate.from_template(system_template)
86
 
87
  prompt = ChatPromptTemplate.from_messages(
88
  [
89
  system_message,
90
  MessagesPlaceholder(variable_name="history"),
91
+ ("human", "{user_query}"),
92
  ]
93
  )
94
 
rag/retrievers.py CHANGED
@@ -1,5 +1,6 @@
1
  import os
2
- from typing import List, Literal
 
3
 
4
  from langchain_core.vectorstores import VectorStoreRetriever
5
  from langchain_openai import OpenAIEmbeddings
@@ -9,6 +10,8 @@ os.environ["GRPC_VERBOSITY"] = "NONE"
9
 
10
 
11
  class RetrieversConfig:
 
 
12
  def __init__(
13
  self,
14
  dense_model_name: Literal["text-embedding-3-small"] = "text-embedding-3-small",
@@ -16,24 +19,35 @@ class RetrieversConfig:
16
  "prithivida/Splade_PP_en_v1"
17
  ] = "prithivida/Splade_PP_en_v1",
18
  ):
19
- self.required_env_vars = ["QDRANT_API_KEY", "QDRANT_URL", "OPENAI_API_KEY"]
20
- self._validate_environment(self.required_env_vars)
21
  self.qdrant_url = os.getenv("QDRANT_URL")
22
  self.qdrant_api_key = os.getenv("QDRANT_API_KEY")
23
- self.dense_embeddings = OpenAIEmbeddings(model=dense_model_name)
24
- self.sparse_embeddings = FastEmbedSparse(
25
- model_name=sparse_model_name,
26
- )
27
 
28
- def _validate_environment(self, required_env_vars: List[str]):
 
29
  missing_vars = [
30
- var for var in required_env_vars if not os.getenv(var, "").strip()
 
 
31
  ]
32
  if missing_vars:
33
  raise EnvironmentError(
34
  f"Missing or empty environment variable(s): {', '.join(missing_vars)}"
35
  )
36
 
 
 
 
 
 
 
 
 
 
 
 
37
  def get_qdrant_retriever(
38
  self,
39
  collection_name: str,
@@ -55,14 +69,6 @@ class RetrieversConfig:
55
 
56
  return qdrantdb.as_retriever(search_kwargs={"k": k})
57
 
58
- def get_documents_retriever(self, k: int = 5) -> VectorStoreRetriever:
59
- return self.get_qdrant_retriever(
60
- collection_name="docs_hybrid_db",
61
- dense_vector_name="docs_dense_vectors",
62
- sparse_vector_name="docs_sparse_vectors",
63
- k=k,
64
- )
65
-
66
  def get_practitioners_retriever(self, k: int = 5) -> VectorStoreRetriever:
67
  return self.get_qdrant_retriever(
68
  collection_name="practitioners_hybrid_db",
@@ -70,3 +76,11 @@ class RetrieversConfig:
70
  sparse_vector_name="practitioners_sparse_vectors",
71
  k=k,
72
  )
 
 
 
 
 
 
 
 
 
1
  import os
2
+ from functools import lru_cache
3
+ from typing import Literal
4
 
5
  from langchain_core.vectorstores import VectorStoreRetriever
6
  from langchain_openai import OpenAIEmbeddings
 
10
 
11
 
12
  class RetrieversConfig:
13
+ REQUIRED_ENV_VARS = ["QDRANT_API_KEY", "QDRANT_URL", "OPENAI_API_KEY"]
14
+
15
  def __init__(
16
  self,
17
  dense_model_name: Literal["text-embedding-3-small"] = "text-embedding-3-small",
 
19
  "prithivida/Splade_PP_en_v1"
20
  ] = "prithivida/Splade_PP_en_v1",
21
  ):
22
+ self._validate_environment()
 
23
  self.qdrant_url = os.getenv("QDRANT_URL")
24
  self.qdrant_api_key = os.getenv("QDRANT_API_KEY")
25
+ self.dense_model_name = dense_model_name
26
+ self.sparse_model_name = sparse_model_name
 
 
27
 
28
+ @staticmethod
29
+ def _validate_environment():
30
  missing_vars = [
31
+ var
32
+ for var in RetrieversConfig.REQUIRED_ENV_VARS
33
+ if not os.getenv(var, "").strip()
34
  ]
35
  if missing_vars:
36
  raise EnvironmentError(
37
  f"Missing or empty environment variable(s): {', '.join(missing_vars)}"
38
  )
39
 
40
+ @property
41
+ @lru_cache(maxsize=2)
42
+ def dense_embeddings(self):
43
+ return OpenAIEmbeddings(model=self.dense_model_name)
44
+
45
+ @property
46
+ @lru_cache(maxsize=2)
47
+ def sparse_embeddings(self):
48
+ return FastEmbedSparse(model_name=self.sparse_model_name)
49
+
50
+ @lru_cache(maxsize=8)
51
  def get_qdrant_retriever(
52
  self,
53
  collection_name: str,
 
69
 
70
  return qdrantdb.as_retriever(search_kwargs={"k": k})
71
 
 
 
 
 
 
 
 
 
72
  def get_practitioners_retriever(self, k: int = 5) -> VectorStoreRetriever:
73
  return self.get_qdrant_retriever(
74
  collection_name="practitioners_hybrid_db",
 
76
  sparse_vector_name="practitioners_sparse_vectors",
77
  k=k,
78
  )
79
+
80
+ def get_documents_retriever(self, k: int = 5) -> VectorStoreRetriever:
81
+ return self.get_qdrant_retriever(
82
+ collection_name="docs_hybrid_db",
83
+ dense_vector_name="docs_dense_vectors",
84
+ sparse_vector_name="docs_sparse_vectors",
85
+ k=k,
86
+ )
rag/runnable_and_memory.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from operator import itemgetter
3
+
4
+ from langchain.memory import ConversationBufferWindowMemory
5
+ from langchain.retrievers import ContextualCompressionRetriever
6
+ from langchain_core.output_parsers import StrOutputParser
7
+ from langchain_core.runnables import Runnable, RunnableLambda
8
+ from langchain_openai import ChatOpenAI
9
+
10
+ from rag.retrievers import RetrieversConfig
11
+
12
+ from .helpers import (
13
+ DocumentFormatter,
14
+ create_langsmith_client,
15
+ get_datetime,
16
+ get_reranker,
17
+ )
18
+ from .prompt_template import generate_prompt_template
19
+
20
+ logging.basicConfig(level=logging.ERROR)
21
+
22
+
23
+ def retrievers_setup(retrievers_config, reranker: bool = False) -> tuple:
24
+ """Set up retrievers with re-ranking
25
+ Args:
26
+ retrievers_config (_type_):
27
+ reranker (bool, optional): Defaults to False.
28
+
29
+ Returns:
30
+ tuple: Retrievers
31
+ """
32
+ # Practitioners
33
+ practitioners_retriever = retrievers_config.get_practitioners_retriever(k=10)
34
+ # Tall Tree documents
35
+ documents_retriever = retrievers_config.get_documents_retriever(k=10)
36
+
37
+ # Re-ranking (optional): Improves quality and serves as a filter
38
+ if reranker:
39
+ practitioners_retriever_reranker = ContextualCompressionRetriever(
40
+ base_compressor=get_reranker(top_n=10),
41
+ base_retriever=practitioners_retriever,
42
+ )
43
+ documents_retriever_reranker = ContextualCompressionRetriever(
44
+ base_compressor=get_reranker(top_n=8),
45
+ base_retriever=documents_retriever,
46
+ )
47
+
48
+ return practitioners_retriever_reranker, documents_retriever_reranker
49
+
50
+ else:
51
+ return practitioners_retriever, documents_retriever
52
+
53
+
54
+ # Set retrievers as global variables (I see better loading time from Streamlit this way)
55
+ practitioners_retriever, documents_retriever = retrievers_setup(
56
+ retrievers_config=RetrieversConfig(), reranker=True
57
+ )
58
+
59
+
60
+ # Set up runnable and chat memory
61
+ def get_runnable_and_memory(
62
+ model: str = "gpt-4o-mini", temperature: float = 0.1
63
+ ) -> tuple[Runnable, ConversationBufferWindowMemory]:
64
+ """Set up runnable and chat memory
65
+
66
+ Args:
67
+ model_name (str, optional): LLM model. Defaults to "gpt-4o-mini".
68
+ temperature (float, optional): Model temperature. Defaults to 0.1.
69
+
70
+ Returns:
71
+ Runnable, Memory: Runnable and Memory
72
+ """
73
+
74
+ # Set up Langsmith to trace the runnable
75
+ create_langsmith_client()
76
+
77
+ # LLM and prompt template
78
+ llm = ChatOpenAI(
79
+ model=model,
80
+ temperature=temperature,
81
+ )
82
+
83
+ prompt = generate_prompt_template()
84
+
85
+ # Set conversation history window memory. It only uses the last k interactions
86
+ memory = ConversationBufferWindowMemory(
87
+ memory_key="history",
88
+ return_messages=True,
89
+ k=6,
90
+ )
91
+
92
+ # Set up runnable using LCEL
93
+ setup = {
94
+ "practitioners_db": itemgetter("user_query")
95
+ | practitioners_retriever
96
+ | DocumentFormatter("Practitioner #"),
97
+ "tall_tree_db": itemgetter("user_query")
98
+ | documents_retriever
99
+ | DocumentFormatter("No."),
100
+ "timestamp": lambda _: get_datetime(),
101
+ "history": RunnableLambda(memory.load_memory_variables) | itemgetter("history"),
102
+ "user_query": itemgetter("user_query"),
103
+ }
104
+
105
+ runnable = setup | prompt | llm | StrOutputParser()
106
+
107
+ return runnable, memory