ai-virtual-assistant / utils /update_vector_database.py
yrobel-lima's picture
Upload 2 files
93d3140 verified
raw
history blame
No virus
8.66 kB
import json
import os
import sys
from functools import cache
from pathlib import Path
import torch
from langchain_community.retrievers import QdrantSparseVectorRetriever
from langchain_community.vectorstores import Qdrant
from langchain_core.documents import Document
from langchain_openai.embeddings import OpenAIEmbeddings
from qdrant_client import QdrantClient, models
from transformers import AutoModelForMaskedLM, AutoTokenizer
from data_processing import excel_to_dataframe
class DataProcessor:
def __init__(self, data_dir: Path):
self.data_dir = data_dir
def load_practitioners_data(self):
try:
df = excel_to_dataframe(self.data_dir)
practitioners_data = []
for idx, row in df.iterrows():
# I am using dot as a separator for text embeddings
content = ". ".join(f"{key}: {value}" for key, value in row.items())
doc = Document(page_content=content, metadata={"row": idx})
practitioners_data.append(doc)
return practitioners_data
except FileNotFoundError:
sys.exit(
"Directory or Excel file not found. Please check the path and try again."
)
def load_tall_tree_data(self):
# Check if the file has a .json extension
json_files = [
file for file in self.data_dir.iterdir() if file.suffix == ".json"
]
if not json_files:
raise FileNotFoundError("No JSON files found in the specified directory.")
if len(json_files) > 1:
raise ValueError(
"More than one JSON file found in the specified directory."
)
path = json_files[0]
data = self.load_json_file(path)
tall_tree_data = self.process_json_data(data)
return tall_tree_data
def load_json_file(self, path):
try:
with open(path, "r") as f:
data = json.load(f)
return data
except json.JSONDecodeError:
raise ValueError(f"The file {path} is not a valid JSON file.")
def process_json_data(self, data):
tall_tree_data = []
for idx, (key, value) in enumerate(data.items()):
content = f"{key}: {value}"
doc = Document(page_content=content, metadata={"row": idx})
tall_tree_data.append(doc)
return tall_tree_data
class ValidateQdrantClient:
"""Base class for retriever clients to ensure environment variables are set."""
def __init__(self):
self.validate_environment_variables()
def validate_environment_variables(self):
"""Check if the Qdrant environment variables are set."""
required_vars = ["QDRANT_API_KEY", "QDRANT_URL"]
missing_vars = [var for var in required_vars if not os.getenv(var)]
if missing_vars:
raise EnvironmentError(
f"Missing environment variable(s): {', '.join(missing_vars)}"
)
class DenseVectorStore(ValidateQdrantClient):
"""Store dense data in Qdrant vector database."""
TEXT_EMBEDDING_MODELS = [
"text-embedding-ada-002",
"text-embedding-3-small",
"text-embedding-3-large",
]
def __init__(
self,
documents: list[Document],
embeddings_model: str = "text-embedding-3-small",
collection_name: str = "practitioners_db",
):
super().__init__()
if embeddings_model not in self.TEXT_EMBEDDING_MODELS:
raise ValueError(
f"Invalid embeddings model: {embeddings_model}. Valid options are {', '.join(self.TEXT_EMBEDDING_MODELS)}."
)
self.documents = documents
self.embeddings_model = embeddings_model
self.collection_name = collection_name
self._qdrant_db = None
@property
def qdrant_db(self):
if self._qdrant_db is None:
self._qdrant_db = Qdrant.from_documents(
self.documents,
OpenAIEmbeddings(model=self.embeddings_model),
url=os.getenv("QDRANT_URL"),
api_key=os.getenv("QDRANT_API_KEY"),
prefer_grpc=True,
collection_name=self.collection_name,
force_recreate=True,
)
return self._qdrant_db
class SparseVectorStore(ValidateQdrantClient):
"""Store sparse vectors in Qdrant vector database using SPLADE neural retrieval model."""
def __init__(
self,
documents: list[Document],
collection_name: str,
vector_name: str,
k: int = 4,
splade_model_id: str = "naver/splade-cocondenser-ensembledistil",
):
# Validate Qdrant client
super().__init__()
self.client = QdrantClient(
url=os.getenv("QDRANT_URL"),
api_key=os.getenv("QDRANT_API_KEY"),
) # TODO: prefer_grpc=True is not working
self.model_id = splade_model_id
self._tokenizer = None
self._model = None
self.collection_name = collection_name
self.vector_name = vector_name
self.k = k
self.sparse_retriever = self.create_sparse_retriever()
self.add_documents(documents)
@property
@cache
def tokenizer(self):
"""Initialize the tokenizer."""
if self._tokenizer is None:
self._tokenizer = AutoTokenizer.from_pretrained(self.model_id)
return self._tokenizer
@property
@cache
def model(self):
"""Initialize the SPLADE neural retrieval model."""
if self._model is None:
self._model = AutoModelForMaskedLM.from_pretrained(self.model_id)
return self._model
def sparse_encoder(self, text: str) -> tuple[list[int], list[float]]:
"""Encode the input text into a sparse vector."""
tokens = self.tokenizer(
text,
return_tensors="pt",
max_length=512,
padding="max_length",
truncation=True,
)
with torch.no_grad():
logits = self.model(**tokens).logits
relu_log = torch.log1p(torch.relu(logits))
weighted_log = relu_log * tokens.attention_mask.unsqueeze(-1)
max_val = torch.max(weighted_log, dim=1).values.squeeze()
indices = torch.nonzero(max_val, as_tuple=False).squeeze().cpu().numpy()
values = max_val[indices].cpu().numpy()
return indices.tolist(), values.tolist()
def create_sparse_retriever(self):
self.client.recreate_collection(
self.collection_name,
vectors_config={},
sparse_vectors_config={
self.vector_name: models.SparseVectorParams(
index=models.SparseIndexParams(
on_disk=False,
)
)
},
)
return QdrantSparseVectorRetriever(
client=self.client,
collection_name=self.collection_name,
sparse_vector_name=self.vector_name,
sparse_encoder=self.sparse_encoder,
k=self.k,
)
def add_documents(self, documents):
self.sparse_retriever.add_documents(documents)
def main():
data_dir = Path().resolve().parent / "data"
if not data_dir.exists():
sys.exit(f"The directory {data_dir} does not exist.")
processor = DataProcessor(data_dir)
print("Loading and cleaning Practitioners data...")
practitioners_dataset = processor.load_practitioners_data()
print("Loading Tall Tree data from json file...")
tall_tree_dataset = processor.load_tall_tree_data()
# Set OpenAI embeddings model
# TODO: Test new OpenAI text embeddings models
# text-embedding-3-large
# text-embedding-3-small
EMBEDDINGS_MODEL = "text-embedding-3-small"
# Store both datasets in Qdrant
print(f"Storing dense vectors in Qdrant using {EMBEDDINGS_MODEL}...")
practitioners_db = DenseVectorStore(
practitioners_dataset, EMBEDDINGS_MODEL, collection_name="practitioners_db"
).qdrant_db
tall_tree_db = DenseVectorStore(
tall_tree_dataset, EMBEDDINGS_MODEL, collection_name="tall_tree_db"
).qdrant_db
print(f"Storing sparse vectors in Qdrant using SPLADE neural retrieval model...")
practitioners_sparse_vector_db = SparseVectorStore(
documents=practitioners_dataset,
collection_name="practitioners_db_sparse_collection",
vector_name="sparse_vector",
k=15,
splade_model_id="naver/splade-cocondenser-ensembledistil",
)
if __name__ == "__main__":
main()