hr_chatbot / retriever.py
Syed Junaid Iqbal
Upload 5 files
030d46c
raw
history blame
1.18 kB
import pickle
from langchain.retrievers import EnsembleRetriever
from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.embeddings import HuggingFaceEmbeddings
from transformers import AutoModel
def retriever():
# Embeddings
# Defign our Embedding Model
model_name = "jinaai/jina-embeddings-v2-base-en"
model_kwargs = {'device': 'cpu'}
encode_kwargs = {'normalize_embeddings': False, }
model = AutoModel.from_pretrained( model_name, trust_remote_code=True)
embeddings = HuggingFaceEmbeddings( model_name=model_name,
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs)
#to read bm25 object
with open('./bm25', 'rb') as file:
bm25_retriever = pickle.load(file)
bm25_retriever.k = 2
# Load FAISS
faiss_vectorstore = FAISS.load_local("./Vector_DB/", embeddings)
faiss_retriever = faiss_vectorstore.as_retriever(search_kwargs={"k": 1})
# initialize the ensemble retriever
return EnsembleRetriever( retrievers=[bm25_retriever, faiss_retriever], weights=[0.5, 0.5] )