ai-virtual-assistant / utils /update_vector_database.py
talltree's picture
Upload 2 files
fcae4fc verified
raw
history blame
No virus
8.06 kB
import json
import os
import sys
from functools import cache
from pathlib import Path
import torch
from langchain_community.retrievers import QdrantSparseVectorRetriever
from langchain_community.vectorstores import Qdrant
from langchain_core.documents import Document
from langchain_openai.embeddings import OpenAIEmbeddings
from qdrant_client import QdrantClient, models
from transformers import AutoModelForMaskedLM, AutoTokenizer
from data_processing import excel_to_dataframe
class DataProcessor:
def __init__(self, data_dir: Path):
self.data_dir = data_dir
def load_practitioners_data(self):
try:
df = excel_to_dataframe(self.data_dir)
practitioners_data = []
for idx, row in df.iterrows():
# I am using dot as a separator for text embeddings
content = '. '.join(
f"{key}: {value}" for key, value in row.items())
doc = Document(page_content=content, metadata={'row': idx})
practitioners_data.append(doc)
return practitioners_data
except FileNotFoundError:
sys.exit(
"Directory or Excel file not found. Please check the path and try again.")
def load_tall_tree_data(self):
# Check if the file has a .json extension
json_files = [file for file in self.data_dir.iterdir()
if file.suffix == '.json']
if not json_files:
raise FileNotFoundError(
"No JSON files found in the specified directory.")
if len(json_files) > 1:
raise ValueError(
"More than one JSON file found in the specified directory.")
path = json_files[0]
data = self.load_json_file(path)
tall_tree_data = self.process_json_data(data)
return tall_tree_data
def load_json_file(self, path):
try:
with open(path, 'r') as f:
data = json.load(f)
return data
except json.JSONDecodeError:
raise ValueError(f"The file {path} is not a valid JSON file.")
def process_json_data(self, data):
tall_tree_data = []
for idx, (key, value) in enumerate(data.items()):
content = f"{key}: {value}"
doc = Document(page_content=content, metadata={'row': idx})
tall_tree_data.append(doc)
return tall_tree_data
class DenseVectorStore:
"""Store dense data in Qdrant vector database."""
def __init__(self, documents: list[Document], embeddings: OpenAIEmbeddings, collection_name: str = 'practitioners_db'):
self.validate_environment_variables()
self.qdrant_db = Qdrant.from_documents(
documents,
embeddings,
url=os.getenv("QDRANT_URL"),
prefer_grpc=True,
api_key=os.getenv(
"QDRANT_API_KEY"),
collection_name=collection_name,
force_recreate=True)
def validate_environment_variables(self):
required_vars = ["QDRANT_API_KEY", "QDRANT_URL"]
for var in required_vars:
if not os.getenv(var):
raise EnvironmentError(f"Missing environment variable: {var}")
def get_db(self):
return self.qdrant_db
class SparseVectorStore:
"""Store sparse vectors in Qdrant vector database using SPLADE neural retrieval model."""
def __init__(self, documents: list[Document], collection_name: str, vector_name: str, k: int = 4, splade_model_id: str = "naver/splade-cocondenser-ensembledistil"):
self.validate_environment_variables()
self.client = QdrantClient(url=os.getenv(
"QDRANT_URL"), api_key=os.getenv("QDRANT_API_KEY"))
self.model_id = splade_model_id
self.tokenizer, self.model = self.set_tokenizer_config()
self.collection_name = collection_name
self.vector_name = vector_name
self.k = k
self.sparse_retriever = self.create_sparse_retriever()
self.add_documents(documents)
def validate_environment_variables(self):
required_vars = ["QDRANT_API_KEY", "QDRANT_URL"]
for var in required_vars:
if not os.getenv(var):
raise EnvironmentError(f"Missing environment variable: {var}")
@cache
def set_tokenizer_config(self):
"""Initialize the tokenizer and the SPLADE neural retrieval model.
See to https://huggingface.co./naver/splade-cocondenser-ensembledistil for more details.
"""
tokenizer = AutoTokenizer.from_pretrained(self.model_id)
model = AutoModelForMaskedLM.from_pretrained(self.model_id)
return tokenizer, model
@cache
def sparse_encoder(self, text: str) -> tuple[list[int], list[float]]:
"""This function encodes the input text into a sparse vector. The sparse_encoder is required for the QdrantSparseVectorRetriever.
Adapted from the Qdrant documentation: Computing the Sparse Vector code.
Args:
text (str): Text to encode
Returns:
tuple[list[int], list[float]]: Indices and values of the sparse vector
"""
tokens = self.tokenizer(
text, return_tensors="pt", max_length=512, padding="max_length", truncation=True)
output = self.model(**tokens)
logits, attention_mask = output.logits, tokens.attention_mask
relu_log = torch.log(1 + torch.relu(logits))
weighted_log = relu_log * attention_mask.unsqueeze(-1)
max_val, _ = torch.max(weighted_log, dim=1)
vec = max_val.squeeze()
indices = vec.nonzero().numpy().flatten()
values = vec.detach().numpy()[indices]
return indices.tolist(), values.tolist()
def create_sparse_retriever(self):
self.client.recreate_collection(
self.collection_name,
vectors_config={},
sparse_vectors_config={
self.vector_name: models.SparseVectorParams(
index=models.SparseIndexParams(
on_disk=False,
)
)
},
)
return QdrantSparseVectorRetriever(
client=self.client,
collection_name=self.collection_name,
sparse_vector_name=self.vector_name,
sparse_encoder=self.sparse_encoder,
k=self.k,
)
def add_documents(self, documents):
self.sparse_retriever.add_documents(documents)
def main():
data_dir = Path().resolve().parent / "data"
if not data_dir.exists():
sys.exit(f"The directory {data_dir} does not exist.")
processor = DataProcessor(data_dir)
print("Loading and cleaning Practitioners data...")
practitioners_dataset = processor.load_practitioners_data()
print("Loading Tall Tree data from json file...")
tall_tree_dataset = processor.load_tall_tree_data()
# Set OpenAI embeddings model
# TODO: Test new OpenAI text embeddings models
embeddings_model = "text-embedding-ada-002"
openai_embeddings = OpenAIEmbeddings(model=embeddings_model)
# Store both datasets in Qdrant
print(f"Storing dense vectors in Qdrant using {embeddings_model}...")
practitioners_db = DenseVectorStore(practitioners_dataset,
openai_embeddings,
collection_name="practitioners_db").get_db()
tall_tree_db = DenseVectorStore(tall_tree_dataset,
openai_embeddings,
collection_name="tall_tree_db").get_db()
print(f"Storing sparse vectors in Qdrant using SPLADE neural retrieval model...")
practitioners_sparse_vector_db = SparseVectorStore(
documents=practitioners_dataset,
collection_name="practitioners_db_sparse_collection",
vector_name="sparse_vector",
k=15,
splade_model_id="naver/splade-cocondenser-ensembledistil",
)
if __name__ == "__main__":
main()