Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
Upload 4 files
Browse files- rag_chain/__init__.py +1 -0
- rag_chain/chain.py +173 -0
- rag_chain/prompt_template.py +102 -0
- rag_chain/retrievers_setup.py +172 -0
rag_chain/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
rag_chain/chain.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
from functools import cache
|
4 |
+
from operator import itemgetter
|
5 |
+
|
6 |
+
import langsmith
|
7 |
+
from langchain.memory import ConversationBufferWindowMemory
|
8 |
+
from langchain.retrievers import EnsembleRetriever
|
9 |
+
from langchain_community.document_transformers import LongContextReorder
|
10 |
+
from langchain_core.documents import Document
|
11 |
+
from langchain_core.output_parsers import StrOutputParser
|
12 |
+
from langchain_core.runnables import RunnableLambda
|
13 |
+
from langchain_openai.chat_models import ChatOpenAI
|
14 |
+
|
15 |
+
from .prompt_template import generate_prompt_template
|
16 |
+
from .retrievers_setup import (DenseRetrieverClient, SparseRetrieverClient,
|
17 |
+
compression_retriever_setup)
|
18 |
+
|
19 |
+
# Helpers
|
20 |
+
|
21 |
+
|
22 |
+
def reorder_documents(docs: list[Document]) -> list[Document]:
|
23 |
+
"""Long-Context Reorder: No matter the architecture of the model, there is
|
24 |
+
a performance degradation when we include 10+ retrieved documents.
|
25 |
+
|
26 |
+
Args:
|
27 |
+
docs (list): List of Langchain documents
|
28 |
+
|
29 |
+
Returns:
|
30 |
+
list: Reordered list of Langchain documents
|
31 |
+
"""
|
32 |
+
reorder = LongContextReorder()
|
33 |
+
return reorder.transform_documents(docs)
|
34 |
+
|
35 |
+
|
36 |
+
def randomize_documents(documents: list[Document]) -> list[Document]:
|
37 |
+
"""Randomize the documents to vary the recommendations."""
|
38 |
+
random.shuffle(documents)
|
39 |
+
return documents
|
40 |
+
|
41 |
+
|
42 |
+
def format_practitioners_docs(docs: list[Document]) -> str:
|
43 |
+
"""Format the practitioners_db Documents to markdown.
|
44 |
+
Args:
|
45 |
+
docs (list[Documents]): List of Langchain documents
|
46 |
+
Returns:
|
47 |
+
docs (str):
|
48 |
+
"""
|
49 |
+
return f"\n{'-' * 3}\n".join(
|
50 |
+
[f"- Practitioner #{i+1}:\n\n\t" +
|
51 |
+
d.page_content for i, d in enumerate(docs)]
|
52 |
+
)
|
53 |
+
|
54 |
+
|
55 |
+
def format_tall_tree_docs(docs: list[Document]) -> str:
|
56 |
+
"""Format the tall_tree_db Documents to markdown.
|
57 |
+
Args:
|
58 |
+
docs (list[Documents]): List of Langchain documents
|
59 |
+
Returns:
|
60 |
+
docs (str):
|
61 |
+
|
62 |
+
"""
|
63 |
+
return f"\n{'-' * 3}\n".join(
|
64 |
+
[f"- No. {i+1}:\n\n\t" +
|
65 |
+
d.page_content for i, d in enumerate(docs)]
|
66 |
+
)
|
67 |
+
|
68 |
+
|
69 |
+
def create_langsmith_client():
|
70 |
+
"""Create a Langsmith client."""
|
71 |
+
os.environ["LANGCHAIN_TRACING_V2"] = "true"
|
72 |
+
os.environ["LANGCHAIN_PROJECT"] = "talltree-ai-assistant"
|
73 |
+
os.environ["LANGCHAIN_ENDPOINT"] = "https://api.smith.langchain.com"
|
74 |
+
langsmith_api_key = os.getenv("LANGCHAIN_API_KEY")
|
75 |
+
if not langsmith_api_key:
|
76 |
+
raise EnvironmentError(
|
77 |
+
"Missing environment variable: LANGCHAIN_API_KEY")
|
78 |
+
return langsmith.Client()
|
79 |
+
|
80 |
+
|
81 |
+
# Set up Runnable and Memory
|
82 |
+
|
83 |
+
|
84 |
+
@cache
|
85 |
+
def get_rag_chain(model_name: str = "gpt-4", temperature: float = 0.2) -> tuple[ChatOpenAI, ConversationBufferWindowMemory]:
|
86 |
+
"""Set up runnable and chat memory
|
87 |
+
|
88 |
+
Args:
|
89 |
+
model_name (str, optional): LLM model. Defaults to "gpt-4" 30012024.
|
90 |
+
temperature (float, optional): Model temperature. Defaults to 0.2.
|
91 |
+
|
92 |
+
Returns:
|
93 |
+
Runnable, Memory: Chain and Memory
|
94 |
+
"""
|
95 |
+
|
96 |
+
# Set up Langsmith to trace the chain
|
97 |
+
langsmith_tracing = create_langsmith_client()
|
98 |
+
|
99 |
+
# LLM and prompt template
|
100 |
+
llm = ChatOpenAI(model_name=model_name,
|
101 |
+
temperature=temperature)
|
102 |
+
|
103 |
+
prompt = generate_prompt_template()
|
104 |
+
|
105 |
+
# Set retrievers pointing to the practitioners's dataset
|
106 |
+
embeddings_model = "text-embedding-ada-002"
|
107 |
+
dense_retriever_client = DenseRetrieverClient(embeddings_model=embeddings_model,
|
108 |
+
collection_name="practitioners_db")
|
109 |
+
|
110 |
+
# Qdrant db as a retriever
|
111 |
+
practitioners_db_dense_retriever = dense_retriever_client.get_dense_retriever(search_type="similarity",
|
112 |
+
k=10)
|
113 |
+
|
114 |
+
# Testing the sparse vector retriever using Qdrant
|
115 |
+
collection_name = "practitioners_db_sparse_collection"
|
116 |
+
vector_name = "sparse_vector"
|
117 |
+
sparse_retriever_client = SparseRetrieverClient(
|
118 |
+
collection_name=collection_name,
|
119 |
+
vector_name=vector_name,
|
120 |
+
splade_model_id="naver/splade-cocondenser-ensembledistil",
|
121 |
+
k=15)
|
122 |
+
practitioners_db_sparse_retriever = sparse_retriever_client.get_sparse_retriever()
|
123 |
+
|
124 |
+
# TODO Test the ensemble retriever for hyprid search (dense retriever seems to work better)
|
125 |
+
# Using only the filtered sparse retriever
|
126 |
+
practitioners_ensemble_retriever = EnsembleRetriever(
|
127 |
+
retrievers=[practitioners_db_dense_retriever,
|
128 |
+
practitioners_db_sparse_retriever], weights=[0.1, 0.9]
|
129 |
+
)
|
130 |
+
|
131 |
+
# Compression retriever for practitioners db
|
132 |
+
# TODO Test the filtered ensemble retriever *** Using only the sparse retriever ***
|
133 |
+
practitioners_db_compression_retriever = compression_retriever_setup(
|
134 |
+
practitioners_db_sparse_retriever,
|
135 |
+
embeddings_model="text-embedding-ada-002",
|
136 |
+
similarity_threshold=0.74
|
137 |
+
)
|
138 |
+
|
139 |
+
# Set retrievers pointing to the tall_tree_db
|
140 |
+
dense_retriever_client = DenseRetrieverClient(embeddings_model=embeddings_model,
|
141 |
+
collection_name="tall_tree_db")
|
142 |
+
tall_tree_db_dense_retriever = dense_retriever_client.get_dense_retriever(search_type="similarity",
|
143 |
+
k=5)
|
144 |
+
# Compression retriever for tall_tree_db
|
145 |
+
tall_tree_db_compression_retriever = compression_retriever_setup(
|
146 |
+
tall_tree_db_dense_retriever,
|
147 |
+
embeddings_model="text-embedding-ada-002",
|
148 |
+
similarity_threshold=0.5
|
149 |
+
)
|
150 |
+
|
151 |
+
# Set conversation history window memory. It only uses the last k=4 interactions.
|
152 |
+
memory = ConversationBufferWindowMemory(memory_key="history",
|
153 |
+
return_messages=True,
|
154 |
+
k=5)
|
155 |
+
|
156 |
+
# Set up runnable using LCEL
|
157 |
+
setup_and_retrieval = {"practitioners_db": itemgetter("message")
|
158 |
+
| practitioners_db_compression_retriever
|
159 |
+
| randomize_documents
|
160 |
+
| format_practitioners_docs,
|
161 |
+
"tall_tree_db": itemgetter("message") | tall_tree_db_compression_retriever | format_tall_tree_docs,
|
162 |
+
"history": RunnableLambda(memory.load_memory_variables) | itemgetter("history"),
|
163 |
+
"message": itemgetter("message")
|
164 |
+
}
|
165 |
+
|
166 |
+
chain = (
|
167 |
+
setup_and_retrieval
|
168 |
+
| prompt
|
169 |
+
| llm
|
170 |
+
| StrOutputParser()
|
171 |
+
)
|
172 |
+
|
173 |
+
return chain, memory
|
rag_chain/prompt_template.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain.prompts import (ChatPromptTemplate,
|
2 |
+
SystemMessagePromptTemplate,
|
3 |
+
MessagesPlaceholder,
|
4 |
+
)
|
5 |
+
|
6 |
+
|
7 |
+
def generate_prompt_template():
|
8 |
+
|
9 |
+
# Prompt templates
|
10 |
+
system_template = """# Role
|
11 |
+
|
12 |
+
---
|
13 |
+
|
14 |
+
You are a helpful Virtual Assistant at Tall Tree Health in British Columbia, Canada. Your role is to analyze the patient's symptoms or needs and connect them with the appropriate practitioner or service offered by Tall Tree. Respond to Patient Queries using the `Practitioners Database` and `Tall Tree Health Centre Information` provided in the `Context`. Follow the `Response Guidelines` listed below:
|
15 |
+
|
16 |
+
---
|
17 |
+
|
18 |
+
# Response Guidelines
|
19 |
+
|
20 |
+
1. **Interaction**: Engage in a warm, empathetic, and professional manner. Keep responses brief and focused on the patient's query. Use markdown formatting.
|
21 |
+
|
22 |
+
2. **Symptoms/needs and Location Preference**: Ask for symptoms/needs and location preference (Cordova Bay, James Bay, and Vancouver) before recommending a practitioner or service.
|
23 |
+
|
24 |
+
3. **Avoid Making Assumptions**: Stick to the given `Context`. If you're unable to assist, offer the user the contact details for the closest `Tall Tree Health` clinic.
|
25 |
+
|
26 |
+
4. **No Medical Advice**: Refrain from giving any medical advice or acting as a healthcare professional. Avoid discussing healthcare costs.
|
27 |
+
|
28 |
+
5. **Symptoms/needs and Service Verification**: Match the patient's symptoms/needs with the `Focus Area` field in the `Practitioners Database`. If no match is found, advise the patient accordingly without recommending a practitioner, as Tall Tree is not a primary healthcare provider.
|
29 |
+
|
30 |
+
6. **Recommending Practitioners**: Use the patient's symptoms/needs and location to recommend 3 practitioners from the `Practitioners Database`. Focus on `Discipline`, `Focus Areas`, `Location`, `Treatment Method`, and `Status` (only active). Also, provide the contact info for the corresponding `Tall Tree Health` location for additional assistance.
|
31 |
+
|
32 |
+
7. **Practitioner's Contact Information**: Provide `Name`, `Discipline`, and `Booking Link`. Do not print their `Focus Areas`. Provide contact information in the following structured format:
|
33 |
+
|
34 |
+
- `First Name` `Last Name`:
|
35 |
+
- `Discipline`:
|
36 |
+
- `Booking Link`: (print only if available)
|
37 |
+
|
38 |
+
8. **Online Booking Info**: Provide the appropriate clinic contact information from the `Tall Tree Integrated Health Centre Information` for online booking.
|
39 |
+
|
40 |
+
## Tall Tree Integrated Health Service Routing Guidelines
|
41 |
+
|
42 |
+
9. **Mental Health Queries**: Recommend psychologist or clinical counsellour for mental health queries, including depression, stress, anxiety, trauma, suicidal thoughts, etc.
|
43 |
+
|
44 |
+
10. **Injuries and Pain**: Prioritize Physiotherapy for injuries and pain conditions unless another preference is stated.
|
45 |
+
|
46 |
+
11. **Randomness in Recommendations**: Introduce randomness in practitioner recommendations for general issues to avoid bias.
|
47 |
+
|
48 |
+
12. **Concussion Protocol**: Direct to the `Concussion Treatment Program` for the appropriate location for a comprehensive assessment with a physiotherapist. Do not recommend a practitioner.
|
49 |
+
|
50 |
+
13. **Psychologist in Vancouver**: If a Psychologist is requested in the Vancouver location, provide only the contact and booking link for our mental health team in Cordova Bay - Upstairs location. Do not recommend an alternative practitioner.
|
51 |
+
|
52 |
+
14. **Sleep issues**: Recommend only the Sleep Program intake and provide the phone number to book an appointment. Do not recommend a practitioner.
|
53 |
+
|
54 |
+
15. **Longevity Program**: For longevity queries, provide the Longevity Program phone number. Do not recommend a practitioner.
|
55 |
+
|
56 |
+
16. **DEXA Testing or body composition**: Inform that this service is exclusive to the Cordova Bay clinic and provide the clinic phone number and booking link. Do not recommend a practitioner.
|
57 |
+
|
58 |
+
17. **For VO2 Max Testing**: Determine the patient's location preference for Vancouver or Victoria and provide the booking link for the appropriate location. If Victoria, we only do it at our Cordova Bay location.
|
59 |
+
|
60 |
+
18. **Assistance and Closure**: Offer further assistance and conclude positively with a reassuring statement without being repetitive. Example: "Take care! 😊", etc.
|
61 |
+
|
62 |
+
---
|
63 |
+
|
64 |
+
# Patient Query
|
65 |
+
|
66 |
+
```
|
67 |
+
{message}
|
68 |
+
```
|
69 |
+
---
|
70 |
+
|
71 |
+
# Context
|
72 |
+
|
73 |
+
---
|
74 |
+
1. **Practitioners Database**:
|
75 |
+
|
76 |
+
```
|
77 |
+
{practitioners_db}
|
78 |
+
```
|
79 |
+
---
|
80 |
+
|
81 |
+
2. **Tall Tree Health Centre Information**:
|
82 |
+
|
83 |
+
```
|
84 |
+
{tall_tree_db}
|
85 |
+
```
|
86 |
+
---
|
87 |
+
|
88 |
+
"""
|
89 |
+
|
90 |
+
# Template for system message with markdown formatting
|
91 |
+
system_message = SystemMessagePromptTemplate.from_template(
|
92 |
+
system_template)
|
93 |
+
|
94 |
+
prompt = ChatPromptTemplate.from_messages(
|
95 |
+
[
|
96 |
+
system_message,
|
97 |
+
MessagesPlaceholder(variable_name="history"),
|
98 |
+
("human", "{message}"),
|
99 |
+
]
|
100 |
+
)
|
101 |
+
|
102 |
+
return prompt
|
rag_chain/retrievers_setup.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from functools import cache
|
3 |
+
|
4 |
+
import qdrant_client
|
5 |
+
import torch
|
6 |
+
from langchain.retrievers import ContextualCompressionRetriever
|
7 |
+
from langchain.retrievers.document_compressors import EmbeddingsFilter
|
8 |
+
from langchain_community.retrievers import QdrantSparseVectorRetriever
|
9 |
+
from langchain_community.vectorstores import Qdrant
|
10 |
+
from langchain_openai.embeddings import OpenAIEmbeddings
|
11 |
+
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
12 |
+
|
13 |
+
|
14 |
+
class DenseRetrieverClient:
|
15 |
+
"""Inititalize the dense retriever using OpenAI text embeddings and Qdrant vector database.
|
16 |
+
|
17 |
+
Attributes:
|
18 |
+
embeddings_model (str): The embeddings model to use. Right now only OpenAI text embeddings.
|
19 |
+
collection_name (str): Qdrant collection name.
|
20 |
+
client (QdrantClient): Qdrant client.
|
21 |
+
qdrant_collection (Qdrant): Qdrant collection.
|
22 |
+
"""
|
23 |
+
|
24 |
+
def __init__(self, embeddings_model: str = "text-embedding-ada-002", collection_name: str = "practitioners_db"):
|
25 |
+
self.validate_environment_variables()
|
26 |
+
self.embeddings_model = embeddings_model
|
27 |
+
self.collection_name = collection_name
|
28 |
+
self.client = qdrant_client.QdrantClient(
|
29 |
+
url=os.getenv("QDRANT_URL"),
|
30 |
+
api_key=os.getenv("QDRANT_API_KEY"),
|
31 |
+
)
|
32 |
+
self.qdrant_collection = self.load_qdrant_collection()
|
33 |
+
|
34 |
+
def validate_environment_variables(self):
|
35 |
+
""" Check if the Qdrant environment variables are set."""
|
36 |
+
required_vars = ["QDRANT_API_KEY", "QDRANT_URL"]
|
37 |
+
for var in required_vars:
|
38 |
+
if not os.getenv(var):
|
39 |
+
raise EnvironmentError(f"Missing environment variable: {var}")
|
40 |
+
|
41 |
+
def set_qdrant_collection(self, embeddings):
|
42 |
+
"""Prepare the Qdrant collection for the embeddings model."""
|
43 |
+
return Qdrant(client=self.client,
|
44 |
+
collection_name=self.collection_name,
|
45 |
+
embeddings=embeddings)
|
46 |
+
|
47 |
+
@cache
|
48 |
+
def load_qdrant_collection(self):
|
49 |
+
"""Load Qdrant collection for a given embeddings model."""
|
50 |
+
if self.embeddings_model == "text-embedding-ada-002":
|
51 |
+
self.qdrant_collection = self.set_qdrant_collection(
|
52 |
+
OpenAIEmbeddings(model=self.embeddings_model))
|
53 |
+
else:
|
54 |
+
raise ValueError(
|
55 |
+
f"Invalid embeddings model: {self.embeddings_model}. Select 'text-embedding-ada-002' from OpenAI.")
|
56 |
+
|
57 |
+
return self.qdrant_collection
|
58 |
+
|
59 |
+
def get_dense_retriever(self, search_type: str = "similarity", k: int = 4):
|
60 |
+
"""Set up retrievers (Qdrant vectorstore as retriever).
|
61 |
+
|
62 |
+
Args:
|
63 |
+
search_type (str, optional): similarity or mmr. Defaults to "similarity".
|
64 |
+
k (int, optional): Number of documents retrieved. Defaults to 4.
|
65 |
+
|
66 |
+
Returns:
|
67 |
+
Retriever: Vectorstore as a retriever
|
68 |
+
"""
|
69 |
+
dense_retriever = self.qdrant_collection.as_retriever(search_type=search_type,
|
70 |
+
search_kwargs={
|
71 |
+
"k": k}
|
72 |
+
)
|
73 |
+
return dense_retriever
|
74 |
+
|
75 |
+
|
76 |
+
class SparseRetrieverClient:
|
77 |
+
"""Inititalize the sparse retriever using the SPLADE neural retrieval model and Qdrant vector database.
|
78 |
+
|
79 |
+
Attributes:
|
80 |
+
collection_name (str): Qdrant collection name.
|
81 |
+
vector_name (str): Qdrant vector name.
|
82 |
+
splade_model_id (str): The SPLADE neural retrieval model id.
|
83 |
+
k (int): Number of documents retrieved.
|
84 |
+
client (QdrantClient): Qdrant client.
|
85 |
+
"""
|
86 |
+
|
87 |
+
def __init__(self, collection_name: str, vector_name: str, splade_model_id: str = "naver/splade-cocondenser-ensembledistil", k: int = 15):
|
88 |
+
self.validate_environment_variables()
|
89 |
+
self.client = qdrant_client.QdrantClient(url=os.getenv(
|
90 |
+
"QDRANT_URL"), api_key=os.getenv("QDRANT_API_KEY"))
|
91 |
+
self.model_id = splade_model_id
|
92 |
+
self.collection_name = collection_name
|
93 |
+
self.vector_name = vector_name
|
94 |
+
self.k = k
|
95 |
+
|
96 |
+
def validate_environment_variables(self):
|
97 |
+
required_vars = ["QDRANT_API_KEY", "QDRANT_URL"]
|
98 |
+
for var in required_vars:
|
99 |
+
if not os.getenv(var):
|
100 |
+
raise EnvironmentError(f"Missing environment variable: {var}")
|
101 |
+
|
102 |
+
@cache
|
103 |
+
def set_tokenizer_config(self):
|
104 |
+
"""Initialize the tokenizer and the SPLADE neural retrieval model.
|
105 |
+
See to https://huggingface.co/naver/splade-cocondenser-ensembledistil for more details.
|
106 |
+
"""
|
107 |
+
tokenizer = AutoTokenizer.from_pretrained(self.model_id)
|
108 |
+
model = AutoModelForMaskedLM.from_pretrained(self.model_id)
|
109 |
+
return tokenizer, model
|
110 |
+
|
111 |
+
@cache
|
112 |
+
def sparse_encoder(self, text: str) -> tuple[list[int], list[float]]:
|
113 |
+
"""This function encodes the input text into a sparse vector. The encoder is required for the QdrantSparseVectorRetriever.
|
114 |
+
Adapted from the Qdrant documentation: Computing the Sparse Vector code.
|
115 |
+
|
116 |
+
Args:
|
117 |
+
text (str): Text to encode
|
118 |
+
|
119 |
+
Returns:
|
120 |
+
tuple[list[int], list[float]]: Indices and values of the sparse vector
|
121 |
+
"""
|
122 |
+
tokenizer, model = self.set_tokenizer_config()
|
123 |
+
tokens = tokenizer(text, return_tensors="pt",
|
124 |
+
max_length=512, padding="max_length", truncation=True)
|
125 |
+
output = model(**tokens)
|
126 |
+
logits, attention_mask = output.logits, tokens.attention_mask
|
127 |
+
relu_log = torch.log(1 + torch.relu(logits))
|
128 |
+
weighted_log = relu_log * attention_mask.unsqueeze(-1)
|
129 |
+
max_val, _ = torch.max(weighted_log, dim=1)
|
130 |
+
vec = max_val.squeeze()
|
131 |
+
indices = vec.nonzero().numpy().flatten()
|
132 |
+
values = vec.detach().numpy()[indices]
|
133 |
+
return indices.tolist(), values.tolist()
|
134 |
+
|
135 |
+
def get_sparse_retriever(self):
|
136 |
+
|
137 |
+
sparse_retriever = QdrantSparseVectorRetriever(
|
138 |
+
client=self.client,
|
139 |
+
collection_name=self.collection_name,
|
140 |
+
sparse_vector_name=self.vector_name,
|
141 |
+
sparse_encoder=self.sparse_encoder,
|
142 |
+
k=self.k,
|
143 |
+
)
|
144 |
+
|
145 |
+
return sparse_retriever
|
146 |
+
|
147 |
+
|
148 |
+
def compression_retriever_setup(base_retriever, embeddings_model: str = "text-embedding-ada-002", similarity_threshold: float = 0.76) -> ContextualCompressionRetriever:
|
149 |
+
"""
|
150 |
+
Creates a ContextualCompressionRetriever with a base retriever and a similarity threshold.
|
151 |
+
|
152 |
+
The ContextualCompressionRetriever uses an EmbeddingsFilter with OpenAIEmbeddings to filter out documents
|
153 |
+
with a similarity score below the given threshold.
|
154 |
+
|
155 |
+
Args:
|
156 |
+
base_retriever: Retriever to be filtered.
|
157 |
+
similarity_threshold (float, optional): The similarity threshold for the EmbeddingsFilter.
|
158 |
+
Documents with a similarity score below this threshold will be filtered out. Defaults to 0.76 (Obtained by experimenting with text-embeddings-ada-002).
|
159 |
+
|
160 |
+
Returns:
|
161 |
+
ContextualCompressionRetriever: The created ContextualCompressionRetriever.
|
162 |
+
"""
|
163 |
+
|
164 |
+
# Set up compression retriever (filter out documents with low similarity)
|
165 |
+
relevant_filter = EmbeddingsFilter(embeddings=OpenAIEmbeddings(model=embeddings_model),
|
166 |
+
similarity_threshold=similarity_threshold)
|
167 |
+
|
168 |
+
compression_retriever = ContextualCompressionRetriever(
|
169 |
+
base_compressor=relevant_filter, base_retriever=base_retriever
|
170 |
+
)
|
171 |
+
|
172 |
+
return compression_retriever
|