yrobel-lima commited on
Commit
e921012
1 Parent(s): 1bbf691

Upload 4 files

Browse files
Files changed (4) hide show
  1. rag/__init__.py +0 -0
  2. rag/prompt_template.py +95 -0
  3. rag/retrievers.py +72 -0
  4. rag/runnable.py +129 -0
rag/__init__.py ADDED
File without changes
rag/prompt_template.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_core.prompts import (
2
+ ChatPromptTemplate,
3
+ MessagesPlaceholder,
4
+ SystemMessagePromptTemplate,
5
+ )
6
+
7
+
8
+ def generate_prompt_template():
9
+ system_template = """`Current date and time: {timestamp}`
10
+ # Role
11
+
12
+ ---
13
+
14
+ Your name is Ella (Empathetic, Logical, Liaison, Accessible). You are a helpful Virtual Assistant at Tall Tree Health in British Columbia, Canada. Based on the patient's symptoms/needs, connect them with the right practitioner or service offered by Tall Tree. Respond to `Patient Queries` using the `Practitioners Database` and `Tall Tree Health Centre Information` provided in the `Context`. Follow the `Response Guidelines` listed below:
15
+
16
+ ---
17
+
18
+ # Response Guidelines
19
+
20
+ 1. **Interaction**: Engage in a warm, empathetic, and professional manner. Keep responses brief and focused on the patient's query. Always conclude positively with a reassuring statement. Use markdown formatting and do not use headings.
21
+
22
+ 2. **Symptoms/needs and Location Preference**: Only if not specified, ask for symptoms/needs and location preference (Cordova Bay, James Bay, and Vancouver) before recommending a practitioner or service.
23
+
24
+ 3. **Avoid Making Assumptions**: Stick to the given `Context`. If you're unable to assist, offer the user the contact details for the closest `Tall Tree Health` clinic.
25
+
26
+ 4. Do not give medical advice or act as a health professional. Avoid discussing healthcare costs.
27
+
28
+ 5. **Symptoms/needs and Service Verification**: Match the patient's symptoms/needs with the `Focus Area` field in the `Practitioners Database`. If no match is found, advise the patient accordingly without recommending a practitioner, as Tall Tree is not a primary healthcare provider.
29
+
30
+ 6. **Recommending Practitioners**: Based on the patient's symptoms/needs, location and preferred discipline, recommend only up to 3 practitioners who strictly match the given criteria. Provide the contact info for the corresponding `Tall Tree Health` location for additional assistance.
31
+
32
+ 7. **Practitioner's Contact Information**: Provide contact information in the following structured format. Do not print their `Focus Areas`:
33
+
34
+ - `FirstName` and `LastName`:
35
+ - `Discipline`
36
+ - [Book an appointment](`BookingLink`) (print only if available)
37
+
38
+ ## Tall Tree Health Service Routing Guidelines
39
+
40
+ 8. **Mental Health Urgent Queries**: For urgent situations such as self-harm, suicidal thoughts, violence, hallucinations, or dissociation direct the patient to call the [9-8-8](tel:9-8-8) suicide crisis helpline, reach out to the Vancouver Island Crisis Line at [1-888-494-3888](tel:1-888-494-3888), or head to the nearest emergency room. Tall Tree isn't equipped for mental health emergencies.
41
+
42
+ 9. **Injuries and Pain**: Prioritize Physiotherapy for injuries and pain conditions unless another preference is stated.
43
+
44
+ 10. **Concussion Protocol**: Direct to the `Concussion Treatment Program` for the appropriate location for a comprehensive assessment with a physiotherapist. Do not recommend a practitioner.
45
+
46
+ 11. **Psychologist in Vancouver**: If a Psychologist is requested in the Vancouver location, provide only the contact and booking link for our mental health team in Cordova Bay - Upstairs location. Do not recommend an alternative practitioner.
47
+
48
+ 12. **Sleep issues**: Recommend only the Sleep Program intake and provide the phone number to book an appointment. Do not recommend a practitioner.
49
+
50
+ 13. **Longevity Program**: For longevity queries, provide the Longevity Program phone number. Do not recommend a practitioner.
51
+
52
+ 14. **DEXA Testing or body composition**: Inform that this service is exclusive to the Cordova Bay clinic and provide the clinic phone number and booking link. Do not recommend a practitioner.
53
+
54
+ 15. **For VO2 Max Testing**: Determine the patient's location preference for Vancouver or Victoria and provide the booking link for the appropriate location. If Victoria, we only do it at our Cordova Bay location.
55
+
56
+ ---
57
+
58
+ # Patient Query
59
+
60
+ ```
61
+ {message}
62
+ ```
63
+ ---
64
+
65
+ # Context
66
+
67
+ ---
68
+ 1. **Practitioners Database**:
69
+
70
+ ```
71
+ {practitioners_db}
72
+ ```
73
+ ---
74
+
75
+ 2. **Tall Tree Health Centre Information**:
76
+
77
+ ```
78
+ {tall_tree_db}
79
+ ```
80
+ ---
81
+
82
+ """
83
+
84
+ # Template for system message with markdown formatting
85
+ system_message = SystemMessagePromptTemplate.from_template(system_template)
86
+
87
+ prompt = ChatPromptTemplate.from_messages(
88
+ [
89
+ system_message,
90
+ MessagesPlaceholder(variable_name="history"),
91
+ ("human", "{message}"),
92
+ ]
93
+ )
94
+
95
+ return prompt
rag/retrievers.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List, Literal
3
+
4
+ from langchain_core.vectorstores import VectorStoreRetriever
5
+ from langchain_openai import OpenAIEmbeddings
6
+ from langchain_qdrant import FastEmbedSparse, QdrantVectorStore, RetrievalMode
7
+
8
+ os.environ["GRPC_VERBOSITY"] = "NONE"
9
+
10
+
11
+ class RetrieversConfig:
12
+ def __init__(
13
+ self,
14
+ dense_model_name: Literal["text-embedding-3-small"] = "text-embedding-3-small",
15
+ sparse_model_name: Literal[
16
+ "prithivida/Splade_PP_en_v1"
17
+ ] = "prithivida/Splade_PP_en_v1",
18
+ ):
19
+ self.required_env_vars = ["QDRANT_API_KEY", "QDRANT_URL", "OPENAI_API_KEY"]
20
+ self._validate_environment(self.required_env_vars)
21
+ self.qdrant_url = os.getenv("QDRANT_URL")
22
+ self.qdrant_api_key = os.getenv("QDRANT_API_KEY")
23
+ self.dense_embeddings = OpenAIEmbeddings(model=dense_model_name)
24
+ self.sparse_embeddings = FastEmbedSparse(
25
+ model_name=sparse_model_name,
26
+ )
27
+
28
+ def _validate_environment(self, required_env_vars: List[str]):
29
+ missing_vars = [
30
+ var for var in required_env_vars if not os.getenv(var, "").strip()
31
+ ]
32
+ if missing_vars:
33
+ raise EnvironmentError(
34
+ f"Missing or empty environment variable(s): {', '.join(missing_vars)}"
35
+ )
36
+
37
+ def get_qdrant_retriever(
38
+ self,
39
+ collection_name: str,
40
+ dense_vector_name: str,
41
+ sparse_vector_name: str,
42
+ k: int = 5,
43
+ ) -> VectorStoreRetriever:
44
+ qdrantdb = QdrantVectorStore.from_existing_collection(
45
+ embedding=self.dense_embeddings,
46
+ sparse_embedding=self.sparse_embeddings,
47
+ url=self.qdrant_url,
48
+ api_key=self.qdrant_api_key,
49
+ prefer_grpc=True,
50
+ collection_name=collection_name,
51
+ retrieval_mode=RetrievalMode.HYBRID,
52
+ vector_name=dense_vector_name,
53
+ sparse_vector_name=sparse_vector_name,
54
+ )
55
+
56
+ return qdrantdb.as_retriever(search_kwargs={"k": k})
57
+
58
+ def get_documents_retriever(self, k: int = 5) -> VectorStoreRetriever:
59
+ return self.get_qdrant_retriever(
60
+ collection_name="docs_hybrid_db",
61
+ dense_vector_name="docs_dense_vectors",
62
+ sparse_vector_name="docs_sparse_vectors",
63
+ k=k,
64
+ )
65
+
66
+ def get_practitioners_retriever(self, k: int = 5) -> VectorStoreRetriever:
67
+ return self.get_qdrant_retriever(
68
+ collection_name="practitioners_hybrid_db",
69
+ dense_vector_name="practitioners_dense_vectors",
70
+ sparse_vector_name="practitioners_sparse_vectors",
71
+ k=k,
72
+ )
rag/runnable.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ from datetime import datetime
4
+ from operator import itemgetter
5
+ from typing import Sequence
6
+
7
+ import langsmith
8
+ from langchain.memory import ConversationBufferWindowMemory
9
+ from langchain_community.document_transformers import LongContextReorder
10
+ from langchain_core.documents import Document
11
+ from langchain_core.output_parsers import StrOutputParser
12
+ from langchain_core.runnables import Runnable, RunnableLambda
13
+ from langchain_openai import ChatOpenAI
14
+ from zoneinfo import ZoneInfo
15
+
16
+ from rag.retrievers import RetrieversConfig
17
+
18
+ from .prompt_template import generate_prompt_template
19
+
20
+ # Helpers
21
+
22
+
23
+ def get_datetime() -> str:
24
+ """Get the current date and time."""
25
+ return datetime.now(ZoneInfo("America/Vancouver")).strftime("%A, %Y-%b-%d %H:%M:%S")
26
+
27
+
28
+ def reorder_documents(docs: list[Document]) -> Sequence[Document]:
29
+ """Reorder documents to mitigate performance degradation with long contexts."""
30
+
31
+ return LongContextReorder().transform_documents(docs)
32
+
33
+
34
+ def randomize_documents(documents: list[Document]) -> list[Document]:
35
+ """Randomize documents to vary model recommendations."""
36
+ random.shuffle(documents)
37
+ return documents
38
+
39
+
40
+ class DocumentFormatter:
41
+ def __init__(self, prefix: str):
42
+ self.prefix = prefix
43
+
44
+ def __call__(self, docs: list[Document]) -> str:
45
+ """Format the Documents to markdown.
46
+ Args:
47
+ docs (list[Documents]): List of Langchain documents
48
+ Returns:
49
+ docs (str):
50
+ """
51
+ return "\n---\n".join(
52
+ [
53
+ f"- {self.prefix} {i+1}:\n\n\t" + d.page_content
54
+ for i, d in enumerate(docs)
55
+ ]
56
+ )
57
+
58
+
59
+ def create_langsmith_client():
60
+ """Create a Langsmith client."""
61
+ os.environ["LANGCHAIN_TRACING_V2"] = "true"
62
+ os.environ["LANGCHAIN_PROJECT"] = "admin-ai-assistant"
63
+ os.environ["LANGCHAIN_ENDPOINT"] = "https://api.smith.langchain.com"
64
+ langsmith_api_key = os.getenv("LANGCHAIN_API_KEY")
65
+ if not langsmith_api_key:
66
+ raise EnvironmentError("Missing environment variable: LANGCHAIN_API_KEY")
67
+ return langsmith.Client()
68
+
69
+
70
+ # Set up Runnable and Memory
71
+
72
+
73
+ def get_runnable(
74
+ model: str = "gpt-4o-mini", temperature: float = 0.1
75
+ ) -> tuple[Runnable, ConversationBufferWindowMemory]:
76
+ """Set up runnable and chat memory
77
+
78
+ Args:
79
+ model_name (str, optional): LLM model. Defaults to "gpt-4o".
80
+ temperature (float, optional): Model temperature. Defaults to 0.1.
81
+
82
+ Returns:
83
+ Runnable, Memory: Chain and Memory
84
+ """
85
+
86
+ # Set up Langsmith to trace the chain
87
+ create_langsmith_client()
88
+
89
+ # LLM and prompt template
90
+ llm = ChatOpenAI(
91
+ model=model,
92
+ temperature=temperature,
93
+ )
94
+
95
+ prompt = generate_prompt_template()
96
+
97
+ # Set retrievers with Hybrid search
98
+
99
+ retrievers_config = RetrieversConfig()
100
+
101
+ # Practitioners data
102
+ practitioners_data_retriever = retrievers_config.get_practitioners_retriever(k=10)
103
+
104
+ # Tall Tree documents with contact information for locations and services
105
+ documents_retriever = retrievers_config.get_documents_retriever(k=10)
106
+
107
+ # Set conversation history window memory. It only uses the last k interactions
108
+ memory = ConversationBufferWindowMemory(
109
+ memory_key="history",
110
+ return_messages=True,
111
+ k=6,
112
+ )
113
+
114
+ # Set up runnable using LCEL
115
+ setup = {
116
+ "practitioners_db": itemgetter("message")
117
+ | practitioners_data_retriever
118
+ | DocumentFormatter("Practitioner #"),
119
+ "tall_tree_db": itemgetter("message")
120
+ | documents_retriever
121
+ | DocumentFormatter("No."),
122
+ "timestamp": lambda _: get_datetime(),
123
+ "history": RunnableLambda(memory.load_memory_variables) | itemgetter("history"),
124
+ "message": itemgetter("message"),
125
+ }
126
+
127
+ chain = setup | prompt | llm | StrOutputParser()
128
+
129
+ return chain, memory