import json from typing import List from langchain_core.documents.base import Document from langchain_pinecone import PineconeVectorStore from pinecone import Pinecone from langchain.retrievers import EnsembleRetriever from langchain.retrievers.bm25 import BM25Retriever from langchain_core.embeddings import Embeddings from langchain_openai import OpenAIEmbeddings from dotenv import load_dotenv load_dotenv() class DB: def __init__(self, index_name: str, embeddings: Embeddings) -> None: self.embeddings = embeddings self.pc_client = Pinecone() self.pc_index = self.pc_client.Index(name=index_name) self.vector_store = PineconeVectorStore( index_name=index_name, embedding=self.embeddings ) self.index_name = index_name self.keyword_retriever = None def insert(self, docs: List[Document], file_name_without_extension: str): try: self.vector_store.from_documents( documents=docs, embedding=self.embeddings, index_name=self.index_name, ids=[f"{file_name_without_extension}_{idx+1}" for idx in range(len(docs))] ) self.keyword_retriever = BM25Retriever.from_documents(docs) # with open("./curr_file.json", "w") as f: # json.dump({"current_filename":file_name_without_extension},f) return "Successfully inserted" except Exception as e: print(e) return f"Unable to add documents to DB: {e}" def remove(self, file_name_without_extension: str): try: print(file_name_without_extension) for ids in self.pc_index.list(prefix=file_name_without_extension): print(ids) tmp = self.pc_index.delete(ids=ids) print(tmp) return "Successfully deleted" except Exception as e: print(e) return f"Unable to delete {file_name_without_extension} from DB: {e}" def get_reriever(self, k=5): vector_store = self.vector_store.from_existing_index(self.index_name, self.embeddings) if self.keyword_retriever: vectorstore_retreiver = vector_store.as_retriever(search_kwargs={"k": k}) ensemble_retriever = EnsembleRetriever(retrievers=[vectorstore_retreiver, self.keyword_retriever], weights=[0.7, 0.3]) return ensemble_retriever return vector_store.as_retriever(search_kwargs={"k": k})