talltree commited on
Commit
9e95b48
1 Parent(s): e2bc689

Upload 3 files

Browse files
utils/__init__.py ADDED
File without changes
utils/data_processing.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+
3
+
4
+ def format_docs(docs):
5
+ """Print the contents of a list of Langchain Documents.
6
+ Args:
7
+ docs (str):
8
+ """
9
+ print(
10
+ f"\n{'-' * 100}\n".join(
11
+ [f"Document {i+1}:\n\n" +
12
+ d.page_content for i, d in enumerate(docs)]
13
+ )
14
+ )
15
+
16
+
17
+ def excel_to_dataframe(data_directory: str) -> pd.DataFrame:
18
+ """Load an Excel file, clean its contents, and generate a pd.Dataframe.
19
+
20
+ Args:
21
+ data_directory (str): File path to the directory where the Excel file is located.
22
+
23
+ Raises:
24
+ FileNotFoundError: If no Excel files are found in the specified directory.
25
+
26
+ Returns:
27
+ pd.Dataframe:
28
+
29
+ """
30
+ # Get the xls file name (one excel worksheet)
31
+ excel_files = [file for file in data_directory.iterdir()
32
+ if file.suffix == '.xlsx']
33
+
34
+ if not excel_files:
35
+ raise FileNotFoundError(
36
+ "No Excel files found in the specified directory.")
37
+ if len(excel_files) > 1:
38
+ raise ValueError(
39
+ "More than one Excel file found in the specified directory.")
40
+
41
+ path = excel_files[0]
42
+
43
+ # Load Excel file
44
+ df = pd.read_excel(path, engine='openpyxl')
45
+
46
+ # Change column names to title case
47
+ df.columns = df.columns.str.title()
48
+
49
+ # Function to replace curly apostrophes with straight ones
50
+ def replace_apostrophes(text):
51
+ if isinstance(text, str):
52
+ return text.replace("\u2019", "'")
53
+ return text
54
+
55
+ # Clean data
56
+ # Trim strings, standardize text (convert to title case), and replace apostrophes
57
+ for col in df.columns:
58
+ # If the column is text-based
59
+ if col.lower() != 'booking link' and df[col].dtype == 'object':
60
+ # Trim, standardize case, and replace apostrophes
61
+ df[col] = df[col].str.strip().str.title().apply(replace_apostrophes)
62
+
63
+ # Handle missing values
64
+ df.fillna('Information Not Available', inplace=True)
65
+
66
+ return df
utils/update_vector_database.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import sys
4
+ from functools import cache
5
+ from pathlib import Path
6
+
7
+ import torch
8
+ from langchain_community.retrievers import QdrantSparseVectorRetriever
9
+ from langchain_community.vectorstores import Qdrant
10
+ from langchain_core.documents import Document
11
+ from langchain_openai.embeddings import OpenAIEmbeddings
12
+ from qdrant_client import QdrantClient, models
13
+ from transformers import AutoModelForMaskedLM, AutoTokenizer
14
+
15
+ from data_processing import excel_to_dataframe
16
+
17
+
18
+ class DataProcessor:
19
+ def __init__(self, data_dir: Path):
20
+ self.data_dir = data_dir
21
+
22
+ @staticmethod
23
+ def categorize_location(location):
24
+ if any(place in location.lower() for place in ['cordova bay', 'james bay']):
25
+ return 'Victoria'
26
+ return location
27
+
28
+ def load_practitioners_data(self):
29
+ try:
30
+ df = excel_to_dataframe(self.data_dir)
31
+ df['City'] = df['Location'].apply(self.categorize_location)
32
+ practitioners_data = []
33
+ for idx, row in df.iterrows():
34
+ # I am using dot as a separator for text embeddings
35
+ content = '. '.join(
36
+ f"{key}: {value}" for key, value in row.items())
37
+ doc = Document(page_content=content, metadata={'row': idx})
38
+ practitioners_data.append(doc)
39
+ return practitioners_data
40
+ except FileNotFoundError:
41
+ sys.exit(
42
+ "Directory or Excel file not found. Please check the path and try again.")
43
+
44
+ def load_tall_tree_data(self):
45
+ # Check if the file has a .json extension
46
+ json_files = [file for file in self.data_dir.iterdir()
47
+ if file.suffix == '.json']
48
+
49
+ if not json_files:
50
+ raise FileNotFoundError(
51
+ "No JSON files found in the specified directory.")
52
+ if len(json_files) > 1:
53
+ raise ValueError(
54
+ "More than one JSON file found in the specified directory.")
55
+
56
+ path = json_files[0]
57
+ data = self.load_json_file(path)
58
+ tall_tree_data = self.process_json_data(data)
59
+
60
+ return tall_tree_data
61
+
62
+ def load_json_file(self, path):
63
+ try:
64
+ with open(path, 'r') as f:
65
+ data = json.load(f)
66
+ return data
67
+ except json.JSONDecodeError:
68
+ raise ValueError(f"The file {path} is not a valid JSON file.")
69
+
70
+ def process_json_data(self, data):
71
+ tall_tree_data = []
72
+ for idx, (key, value) in enumerate(data.items()):
73
+ content = f"{key}: {value}"
74
+ doc = Document(page_content=content, metadata={'row': idx})
75
+ tall_tree_data.append(doc)
76
+ return tall_tree_data
77
+
78
+
79
+ class DenseVectorStore:
80
+ """Store dense data in Qdrant vector database."""
81
+
82
+ def __init__(self, documents: list[Document], embeddings: OpenAIEmbeddings, collection_name: str = 'practitioners_db'):
83
+ self.validate_environment_variables()
84
+ self.qdrant_db = Qdrant.from_documents(
85
+ documents,
86
+ embeddings,
87
+ url=os.getenv("QDRANT_URL"),
88
+ prefer_grpc=True,
89
+ api_key=os.getenv(
90
+ "QDRANT_API_KEY"),
91
+ collection_name=collection_name,
92
+ force_recreate=True)
93
+
94
+ def validate_environment_variables(self):
95
+ required_vars = ["QDRANT_API_KEY", "QDRANT_URL"]
96
+ for var in required_vars:
97
+ if not os.getenv(var):
98
+ raise EnvironmentError(f"Missing environment variable: {var}")
99
+
100
+ def get_db(self):
101
+ return self.qdrant_db
102
+
103
+
104
+ class SparseVectorStore:
105
+ """Store sparse vectors in Qdrant vector database using SPLADE neural retrieval model."""
106
+
107
+ def __init__(self, documents: list[Document], collection_name: str, vector_name: str, k: int = 4, splade_model_id: str = "naver/splade-cocondenser-ensembledistil"):
108
+ self.validate_environment_variables()
109
+ self.client = QdrantClient(url=os.getenv(
110
+ "QDRANT_URL"), api_key=os.getenv("QDRANT_API_KEY"))
111
+ self.model_id = splade_model_id
112
+ self.tokenizer, self.model = self.set_tokenizer_config()
113
+ self.collection_name = collection_name
114
+ self.vector_name = vector_name
115
+ self.k = k
116
+ self.sparse_retriever = self.create_sparse_retriever()
117
+ self.add_documents(documents)
118
+
119
+ def validate_environment_variables(self):
120
+ required_vars = ["QDRANT_API_KEY", "QDRANT_URL"]
121
+ for var in required_vars:
122
+ if not os.getenv(var):
123
+ raise EnvironmentError(f"Missing environment variable: {var}")
124
+
125
+ @cache
126
+ def set_tokenizer_config(self):
127
+ """Initialize the tokenizer and the SPLADE neural retrieval model.
128
+ See to https://huggingface.co/naver/splade-cocondenser-ensembledistil for more details.
129
+ """
130
+ tokenizer = AutoTokenizer.from_pretrained(self.model_id)
131
+ model = AutoModelForMaskedLM.from_pretrained(self.model_id)
132
+ return tokenizer, model
133
+
134
+ @cache
135
+ def sparse_encoder(self, text: str) -> tuple[list[int], list[float]]:
136
+ """This function encodes the input text into a sparse vector. The sparse_encoder is required for the QdrantSparseVectorRetriever.
137
+ Adapted from the Qdrant documentation: Computing the Sparse Vector code.
138
+
139
+ Args:
140
+ text (str): Text to encode
141
+
142
+ Returns:
143
+ tuple[list[int], list[float]]: Indices and values of the sparse vector
144
+ """
145
+ tokens = self.tokenizer(
146
+ text, return_tensors="pt", max_length=512, padding="max_length", truncation=True)
147
+ output = self.model(**tokens)
148
+ logits, attention_mask = output.logits, tokens.attention_mask
149
+ relu_log = torch.log(1 + torch.relu(logits))
150
+ weighted_log = relu_log * attention_mask.unsqueeze(-1)
151
+ max_val, _ = torch.max(weighted_log, dim=1)
152
+ vec = max_val.squeeze()
153
+
154
+ indices = vec.nonzero().numpy().flatten()
155
+ values = vec.detach().numpy()[indices]
156
+
157
+ return indices.tolist(), values.tolist()
158
+
159
+ def create_sparse_retriever(self):
160
+ self.client.recreate_collection(
161
+ self.collection_name,
162
+ vectors_config={},
163
+ sparse_vectors_config={
164
+ self.vector_name: models.SparseVectorParams(
165
+ index=models.SparseIndexParams(
166
+ on_disk=False,
167
+ )
168
+ )
169
+ },
170
+ )
171
+
172
+ return QdrantSparseVectorRetriever(
173
+ client=self.client,
174
+ collection_name=self.collection_name,
175
+ sparse_vector_name=self.vector_name,
176
+ sparse_encoder=self.sparse_encoder,
177
+ k=self.k,
178
+ )
179
+
180
+ def add_documents(self, documents):
181
+ self.sparse_retriever.add_documents(documents)
182
+
183
+
184
+ def main():
185
+ data_dir = Path().resolve().parent / "data"
186
+ if not data_dir.exists():
187
+ sys.exit(f"The directory {data_dir} does not exist.")
188
+
189
+ processor = DataProcessor(data_dir)
190
+
191
+ print("Loading and cleaning Practitioners data...")
192
+ practitioners_dataset = processor.load_practitioners_data()
193
+
194
+ print("Loading Tall Tree data from json file...")
195
+ tall_tree_dataset = processor.load_tall_tree_data()
196
+
197
+ # Set OpenAI embeddings model
198
+ # TODO: Test new embeddings model text-embedding-3-small
199
+ embeddings_model = "text-embedding-ada-002"
200
+ openai_embeddings = OpenAIEmbeddings(model=embeddings_model)
201
+
202
+ # Store both datasets in Qdrant
203
+ print(f"Storing dense vectors in Qdrant using {embeddings_model}...")
204
+ practitioners_db = DenseVectorStore(practitioners_dataset,
205
+ openai_embeddings,
206
+ collection_name="practitioners_db").get_db()
207
+
208
+ tall_tree_db = DenseVectorStore(tall_tree_dataset,
209
+ openai_embeddings,
210
+ collection_name="tall_tree_db").get_db()
211
+
212
+ print(f"Storing sparse vectors in Qdrant using SPLADE neural retrieval model...")
213
+ practitioners_sparse_vector_db = SparseVectorStore(
214
+ documents=practitioners_dataset,
215
+ collection_name="practitioners_db_sparse_collection",
216
+ vector_name="sparse_vector",
217
+ k=15,
218
+ splade_model_id="naver/splade-cocondenser-ensembledistil",
219
+ )
220
+
221
+
222
+ if __name__ == "__main__":
223
+ main()