umairahmad89
initial commit
67a91b0
import json
from typing import List
from langchain_core.documents.base import Document
from langchain_pinecone import PineconeVectorStore
from pinecone import Pinecone
from langchain.retrievers import EnsembleRetriever
from langchain.retrievers.bm25 import BM25Retriever
from langchain_core.embeddings import Embeddings
from langchain_openai import OpenAIEmbeddings
from dotenv import load_dotenv
load_dotenv()
class DB:
def __init__(self, index_name: str, embeddings: Embeddings) -> None:
self.embeddings = embeddings
self.pc_client = Pinecone()
self.pc_index = self.pc_client.Index(name=index_name)
self.vector_store = PineconeVectorStore(
index_name=index_name, embedding=self.embeddings
)
self.index_name = index_name
self.keyword_retriever = None
def insert(self, docs: List[Document], file_name_without_extension: str):
try:
self.vector_store.from_documents(
documents=docs,
embedding=self.embeddings,
index_name=self.index_name,
ids=[f"{file_name_without_extension}_{idx+1}" for idx in range(len(docs))]
)
self.keyword_retriever = BM25Retriever.from_documents(docs)
# with open("./curr_file.json", "w") as f:
# json.dump({"current_filename":file_name_without_extension},f)
return "Successfully inserted"
except Exception as e:
print(e)
return f"Unable to add documents to DB: {e}"
def remove(self, file_name_without_extension: str):
try:
print(file_name_without_extension)
for ids in self.pc_index.list(prefix=file_name_without_extension):
print(ids)
tmp = self.pc_index.delete(ids=ids)
print(tmp)
return "Successfully deleted"
except Exception as e:
print(e)
return f"Unable to delete {file_name_without_extension} from DB: {e}"
def get_reriever(self, k=5):
vector_store = self.vector_store.from_existing_index(self.index_name, self.embeddings)
if self.keyword_retriever:
vectorstore_retreiver = vector_store.as_retriever(search_kwargs={"k": k})
ensemble_retriever = EnsembleRetriever(retrievers=[vectorstore_retreiver,
self.keyword_retriever],
weights=[0.7, 0.3])
return ensemble_retriever
return vector_store.as_retriever(search_kwargs={"k": k})