theekshana's picture
moved app.python to main directory
a624e23
raw
history blame
4.29 kB
"""
/*************************************************************************
*
* CONFIDENTIAL
* __________________
*
* Copyright (2023-2024) AI Labs, IronOne Technologies, LLC
* All Rights Reserved
*
* Author : Theekshana Samaradiwakara
* Description :Python Backend API to chat with private data
* CreatedDate : 19/03/2023
* LastModifiedDate : 19/03/2024
*************************************************************************/
"""
"""
Ensemble retriever that ensemble the results of
multiple retrievers by using weighted Reciprocal Rank Fusion
"""
import logging
logger = logging.getLogger(__name__)
from reggpt.vectorstores.faissDb import load_FAISS_store
from langchain_community.retrievers import BM25Retriever
from langchain_community.document_loaders import PyPDFLoader
from langchain_community.document_loaders import DirectoryLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain.schema import Document
from typing import Iterable
import json
def save_docs_to_jsonl(array:Iterable[Document], file_path:str)->None:
with open(file_path, 'w') as jsonl_file:
for doc in array:
jsonl_file.write(doc.json() + '\n')
def load_docs_from_jsonl(file_path)->Iterable[Document]:
array = []
with open(file_path, 'r') as jsonl_file:
for line in jsonl_file:
data = json.loads(line)
obj = Document(**data)
array.append(obj)
return array
def split_documents():
chunk_size=2000
chunk_overlap=100
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
years = [2013, 2014, 2015, 2016, 2017, 2018, 2019, 2020, 2021, 2022, 2023, 2024]
docs_list=[]
splits_list=[]
for year in years:
data_path= f"data/CBSL/{year}"
logger.info(f"Loading year : {data_path}")
documents = DirectoryLoader(data_path, loader_cls=PyPDFLoader).load()
for doc in documents:
doc.metadata['year']=year
logger.info(f"{doc.metadata['year']} : {doc.metadata['source']}" )
docs_list.append(doc)
texts = text_splitter.split_documents(documents)
for text in texts:
splits_list.append(text)
splitted_texts_file='../reggpt/data/splitted_texts.jsonl'
save_docs_to_jsonl(splits_list,splitted_texts_file)
from reggpt.retriever.ensemble_retriever import EnsembleRetriever
from reggpt.retriever.multi_query_retriever import MultiQueryRetriever
def load_faiss_retriever():
try:
vectorstore=load_FAISS_store()
retriever = vectorstore.as_retriever(
# search_type="mmr",
search_kwargs={'k': 5, 'fetch_k': 10}
)
logger.info("FAISS Retriever loaded:")
return retriever
except Exception as e:
logger.exception(e)
raise e
def load_ensemble_retriever():
try:
# splitted_texts_file=os.path.dirname(os.path.abspath(__file__).join('/data/splitted_texts.jsonl'))
splitted_texts_file='./reggpt/data/splitted_texts.jsonl'
sementic_k = 4
bm25_k = 2
splits_list = load_docs_from_jsonl(splitted_texts_file)
bm25_retriever = BM25Retriever.from_documents(splits_list)
bm25_retriever.k = bm25_k
faiss_vectorstore = load_FAISS_store()
faiss_retriever = faiss_vectorstore.as_retriever(search_kwargs={'k': sementic_k,})
ensemble_retriever = EnsembleRetriever(retrievers=[bm25_retriever, faiss_retriever], weights=[0.5, 0.5])
ensemble_retriever.top_k=4
logger.info("EnsembleRetriever loaded:")
return ensemble_retriever
except Exception as e:
logger.exception(e)
raise e
from reggpt.llms.llm import get_model
def load_multi_query_retriever(multi_query_model_type):
#multi query
try:
llm = get_model(multi_query_model_type)
ensembleRetriever = load_ensemble_retriever()
retriever = MultiQueryRetriever.from_llm(
retriever=ensembleRetriever,
llm=llm
)
logger.info("MultiQueryRetriever loaded:")
return retriever
except Exception as e:
logger.exception(e)
raise e