Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
File size: 8,658 Bytes
9e95b48 93d3140 9e95b48 93d3140 9e95b48 93d3140 9e95b48 93d3140 9e95b48 93d3140 9e95b48 93d3140 9e95b48 93d3140 9e95b48 93d3140 9e95b48 93d3140 9e95b48 93d3140 9e95b48 93d3140 9e95b48 93d3140 9e95b48 93d3140 9e95b48 93d3140 9e95b48 93d3140 9e95b48 93d3140 9e95b48 93d3140 9e95b48 93d3140 9e95b48 93d3140 f2a1c22 93d3140 f2a1c22 93d3140 9e95b48 93d3140 9e95b48 fcae4fc 93d3140 9e95b48 93d3140 9e95b48 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 |
import json
import os
import sys
from functools import cache
from pathlib import Path
import torch
from langchain_community.retrievers import QdrantSparseVectorRetriever
from langchain_community.vectorstores import Qdrant
from langchain_core.documents import Document
from langchain_openai.embeddings import OpenAIEmbeddings
from qdrant_client import QdrantClient, models
from transformers import AutoModelForMaskedLM, AutoTokenizer
from data_processing import excel_to_dataframe
class DataProcessor:
def __init__(self, data_dir: Path):
self.data_dir = data_dir
def load_practitioners_data(self):
try:
df = excel_to_dataframe(self.data_dir)
practitioners_data = []
for idx, row in df.iterrows():
# I am using dot as a separator for text embeddings
content = ". ".join(f"{key}: {value}" for key, value in row.items())
doc = Document(page_content=content, metadata={"row": idx})
practitioners_data.append(doc)
return practitioners_data
except FileNotFoundError:
sys.exit(
"Directory or Excel file not found. Please check the path and try again."
)
def load_tall_tree_data(self):
# Check if the file has a .json extension
json_files = [
file for file in self.data_dir.iterdir() if file.suffix == ".json"
]
if not json_files:
raise FileNotFoundError("No JSON files found in the specified directory.")
if len(json_files) > 1:
raise ValueError(
"More than one JSON file found in the specified directory."
)
path = json_files[0]
data = self.load_json_file(path)
tall_tree_data = self.process_json_data(data)
return tall_tree_data
def load_json_file(self, path):
try:
with open(path, "r") as f:
data = json.load(f)
return data
except json.JSONDecodeError:
raise ValueError(f"The file {path} is not a valid JSON file.")
def process_json_data(self, data):
tall_tree_data = []
for idx, (key, value) in enumerate(data.items()):
content = f"{key}: {value}"
doc = Document(page_content=content, metadata={"row": idx})
tall_tree_data.append(doc)
return tall_tree_data
class ValidateQdrantClient:
"""Base class for retriever clients to ensure environment variables are set."""
def __init__(self):
self.validate_environment_variables()
def validate_environment_variables(self):
"""Check if the Qdrant environment variables are set."""
required_vars = ["QDRANT_API_KEY", "QDRANT_URL"]
missing_vars = [var for var in required_vars if not os.getenv(var)]
if missing_vars:
raise EnvironmentError(
f"Missing environment variable(s): {', '.join(missing_vars)}"
)
class DenseVectorStore(ValidateQdrantClient):
"""Store dense data in Qdrant vector database."""
TEXT_EMBEDDING_MODELS = [
"text-embedding-ada-002",
"text-embedding-3-small",
"text-embedding-3-large",
]
def __init__(
self,
documents: list[Document],
embeddings_model: str = "text-embedding-3-small",
collection_name: str = "practitioners_db",
):
super().__init__()
if embeddings_model not in self.TEXT_EMBEDDING_MODELS:
raise ValueError(
f"Invalid embeddings model: {embeddings_model}. Valid options are {', '.join(self.TEXT_EMBEDDING_MODELS)}."
)
self.documents = documents
self.embeddings_model = embeddings_model
self.collection_name = collection_name
self._qdrant_db = None
@property
def qdrant_db(self):
if self._qdrant_db is None:
self._qdrant_db = Qdrant.from_documents(
self.documents,
OpenAIEmbeddings(model=self.embeddings_model),
url=os.getenv("QDRANT_URL"),
api_key=os.getenv("QDRANT_API_KEY"),
prefer_grpc=True,
collection_name=self.collection_name,
force_recreate=True,
)
return self._qdrant_db
class SparseVectorStore(ValidateQdrantClient):
"""Store sparse vectors in Qdrant vector database using SPLADE neural retrieval model."""
def __init__(
self,
documents: list[Document],
collection_name: str,
vector_name: str,
k: int = 4,
splade_model_id: str = "naver/splade-cocondenser-ensembledistil",
):
# Validate Qdrant client
super().__init__()
self.client = QdrantClient(
url=os.getenv("QDRANT_URL"),
api_key=os.getenv("QDRANT_API_KEY"),
) # TODO: prefer_grpc=True is not working
self.model_id = splade_model_id
self._tokenizer = None
self._model = None
self.collection_name = collection_name
self.vector_name = vector_name
self.k = k
self.sparse_retriever = self.create_sparse_retriever()
self.add_documents(documents)
@property
@cache
def tokenizer(self):
"""Initialize the tokenizer."""
if self._tokenizer is None:
self._tokenizer = AutoTokenizer.from_pretrained(self.model_id)
return self._tokenizer
@property
@cache
def model(self):
"""Initialize the SPLADE neural retrieval model."""
if self._model is None:
self._model = AutoModelForMaskedLM.from_pretrained(self.model_id)
return self._model
def sparse_encoder(self, text: str) -> tuple[list[int], list[float]]:
"""Encode the input text into a sparse vector."""
tokens = self.tokenizer(
text,
return_tensors="pt",
max_length=512,
padding="max_length",
truncation=True,
)
with torch.no_grad():
logits = self.model(**tokens).logits
relu_log = torch.log1p(torch.relu(logits))
weighted_log = relu_log * tokens.attention_mask.unsqueeze(-1)
max_val = torch.max(weighted_log, dim=1).values.squeeze()
indices = torch.nonzero(max_val, as_tuple=False).squeeze().cpu().numpy()
values = max_val[indices].cpu().numpy()
return indices.tolist(), values.tolist()
def create_sparse_retriever(self):
self.client.recreate_collection(
self.collection_name,
vectors_config={},
sparse_vectors_config={
self.vector_name: models.SparseVectorParams(
index=models.SparseIndexParams(
on_disk=False,
)
)
},
)
return QdrantSparseVectorRetriever(
client=self.client,
collection_name=self.collection_name,
sparse_vector_name=self.vector_name,
sparse_encoder=self.sparse_encoder,
k=self.k,
)
def add_documents(self, documents):
self.sparse_retriever.add_documents(documents)
def main():
data_dir = Path().resolve().parent / "data"
if not data_dir.exists():
sys.exit(f"The directory {data_dir} does not exist.")
processor = DataProcessor(data_dir)
print("Loading and cleaning Practitioners data...")
practitioners_dataset = processor.load_practitioners_data()
print("Loading Tall Tree data from json file...")
tall_tree_dataset = processor.load_tall_tree_data()
# Set OpenAI embeddings model
# TODO: Test new OpenAI text embeddings models
# text-embedding-3-large
# text-embedding-3-small
EMBEDDINGS_MODEL = "text-embedding-3-small"
# Store both datasets in Qdrant
print(f"Storing dense vectors in Qdrant using {EMBEDDINGS_MODEL}...")
practitioners_db = DenseVectorStore(
practitioners_dataset, EMBEDDINGS_MODEL, collection_name="practitioners_db"
).qdrant_db
tall_tree_db = DenseVectorStore(
tall_tree_dataset, EMBEDDINGS_MODEL, collection_name="tall_tree_db"
).qdrant_db
print(f"Storing sparse vectors in Qdrant using SPLADE neural retrieval model...")
practitioners_sparse_vector_db = SparseVectorStore(
documents=practitioners_dataset,
collection_name="practitioners_db_sparse_collection",
vector_name="sparse_vector",
k=15,
splade_model_id="naver/splade-cocondenser-ensembledistil",
)
if __name__ == "__main__":
main()
|