Spaces:
Running
Running
""" | |
/************************************************************************* | |
* | |
* CONFIDENTIAL | |
* __________________ | |
* | |
* Copyright (2023-2024) AI Labs, IronOne Technologies, LLC | |
* All Rights Reserved | |
* | |
* Author : Theekshana Samaradiwakara | |
* Description :Python Backend API to chat with private data | |
* CreatedDate : 19/03/2023 | |
* LastModifiedDate : 19/03/2024 | |
*************************************************************************/ | |
""" | |
""" | |
Ensemble retriever that ensemble the results of | |
multiple retrievers by using weighted Reciprocal Rank Fusion | |
""" | |
import logging | |
logger = logging.getLogger(__name__) | |
from reggpt.vectorstores.faissDb import load_FAISS_store | |
from langchain_community.retrievers import BM25Retriever | |
from langchain_community.document_loaders import PyPDFLoader | |
from langchain_community.document_loaders import DirectoryLoader | |
from langchain_text_splitters import RecursiveCharacterTextSplitter | |
from langchain.schema import Document | |
from typing import Iterable | |
import json | |
def save_docs_to_jsonl(array:Iterable[Document], file_path:str)->None: | |
with open(file_path, 'w') as jsonl_file: | |
for doc in array: | |
jsonl_file.write(doc.json() + '\n') | |
def load_docs_from_jsonl(file_path)->Iterable[Document]: | |
array = [] | |
with open(file_path, 'r') as jsonl_file: | |
for line in jsonl_file: | |
data = json.loads(line) | |
obj = Document(**data) | |
array.append(obj) | |
return array | |
def split_documents(): | |
chunk_size=2000 | |
chunk_overlap=100 | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap) | |
years = [2013, 2014, 2015, 2016, 2017, 2018, 2019, 2020, 2021, 2022, 2023, 2024] | |
docs_list=[] | |
splits_list=[] | |
for year in years: | |
data_path= f"data/CBSL/{year}" | |
logger.info(f"Loading year : {data_path}") | |
documents = DirectoryLoader(data_path, loader_cls=PyPDFLoader).load() | |
for doc in documents: | |
doc.metadata['year']=year | |
logger.info(f"{doc.metadata['year']} : {doc.metadata['source']}" ) | |
docs_list.append(doc) | |
texts = text_splitter.split_documents(documents) | |
for text in texts: | |
splits_list.append(text) | |
splitted_texts_file='data/splitted_texts.jsonl' | |
save_docs_to_jsonl(splits_list,splitted_texts_file) | |
from ensemble_retriever import EnsembleRetriever | |
from multi_query_retriever import MultiQueryRetriever | |
def load_faiss_retriever(): | |
try: | |
vectorstore=load_FAISS_store() | |
retriever = vectorstore.as_retriever( | |
# search_type="mmr", | |
search_kwargs={'k': 5, 'fetch_k': 10} | |
) | |
logger.info("FAISS Retriever loaded:") | |
return retriever | |
except Exception as e: | |
logger.exception(e) | |
raise e | |
def load_ensemble_retriever(): | |
try: | |
# splitted_texts_file=os.path.dirname(os.path.abspath(__file__).join('/data/splitted_texts.jsonl')) | |
splitted_texts_file='./data/splitted_texts.jsonl' | |
sementic_k = 4 | |
bm25_k = 2 | |
splits_list = load_docs_from_jsonl(splitted_texts_file) | |
bm25_retriever = BM25Retriever.from_documents(splits_list) | |
bm25_retriever.k = bm25_k | |
faiss_vectorstore = load_FAISS_store() | |
faiss_retriever = faiss_vectorstore.as_retriever(search_kwargs={'k': sementic_k,}) | |
ensemble_retriever = EnsembleRetriever(retrievers=[bm25_retriever, faiss_retriever], weights=[0.5, 0.5]) | |
ensemble_retriever.top_k=4 | |
logger.info("EnsembleRetriever loaded:") | |
return ensemble_retriever | |
except Exception as e: | |
logger.exception(e) | |
raise e | |
from reggpt.llms.llm import get_model | |
def load_multi_query_retriever(multi_query_model_type): | |
#multi query | |
try: | |
llm = get_model(multi_query_model_type) | |
ensembleRetriever = load_ensemble_retriever() | |
retriever = MultiQueryRetriever.from_llm( | |
retriever=ensembleRetriever, | |
llm=llm | |
) | |
logger.info("MultiQueryRetriever loaded:") | |
return retriever | |
except Exception as e: | |
logger.exception(e) | |
raise e |