File size: 4,234 Bytes
93bc171
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
395275a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137

"""
 /*************************************************************************
 * 
 * CONFIDENTIAL
 * __________________
 * 
 *  Copyright (2023-2024) AI Labs, IronOne Technologies, LLC
 *  All Rights Reserved
 *
 *  Author  : Theekshana Samaradiwakara
 *  Description :Python Backend API to chat with private data  
 *  CreatedDate : 19/03/2023
 *  LastModifiedDate : 19/03/2024
 *************************************************************************/
 """

"""
Ensemble retriever that ensemble the results of 
multiple retrievers by using weighted  Reciprocal Rank Fusion
"""

import logging
logger = logging.getLogger(__name__)

from reggpt.vectorstores.faissDb import load_FAISS_store

from langchain_community.retrievers import BM25Retriever
from langchain_community.document_loaders import PyPDFLoader
from langchain_community.document_loaders import DirectoryLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter

from langchain.schema import Document
from typing import Iterable
import json

def save_docs_to_jsonl(array:Iterable[Document], file_path:str)->None:
    with open(file_path, 'w') as jsonl_file:
        for doc in array:
            jsonl_file.write(doc.json() + '\n')

def load_docs_from_jsonl(file_path)->Iterable[Document]:
    array = []
    with open(file_path, 'r') as jsonl_file:
        for line in jsonl_file:
            data = json.loads(line)
            obj = Document(**data)
            array.append(obj)
    return array

def split_documents():
    chunk_size=2000
    chunk_overlap=100

    text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)

    years = [2013, 2014, 2015, 2016, 2017, 2018, 2019, 2020, 2021, 2022, 2023, 2024]
    docs_list=[]
    splits_list=[]


    for year in years:
        data_path= f"data/CBSL/{year}"
        logger.info(f"Loading year : {data_path}")

        documents = DirectoryLoader(data_path,  loader_cls=PyPDFLoader).load()

        for doc in documents:
            doc.metadata['year']=year
            logger.info(f"{doc.metadata['year']} : {doc.metadata['source']}" )
            docs_list.append(doc)

        texts = text_splitter.split_documents(documents)
        for text in texts:
            splits_list.append(text)

    splitted_texts_file='data/splitted_texts.jsonl'
    save_docs_to_jsonl(splits_list,splitted_texts_file)

from ensemble_retriever import EnsembleRetriever
from multi_query_retriever import MultiQueryRetriever

def load_faiss_retriever():
    try:
        vectorstore=load_FAISS_store()
        retriever = vectorstore.as_retriever(
            # search_type="mmr",
            search_kwargs={'k': 5, 'fetch_k': 10}
            )
        logger.info("FAISS Retriever loaded:")
        return retriever
    
    except Exception as e:
        logger.exception(e)
        raise e
    
def load_ensemble_retriever():
    try:
        # splitted_texts_file=os.path.dirname(os.path.abspath(__file__).join('/data/splitted_texts.jsonl'))
        splitted_texts_file='./data/splitted_texts.jsonl'
        sementic_k = 4
        bm25_k = 2
        splits_list = load_docs_from_jsonl(splitted_texts_file)

        bm25_retriever  = BM25Retriever.from_documents(splits_list)
        bm25_retriever.k = bm25_k

        faiss_vectorstore = load_FAISS_store()
        faiss_retriever  = faiss_vectorstore.as_retriever(search_kwargs={'k': sementic_k,})
        
        ensemble_retriever = EnsembleRetriever(retrievers=[bm25_retriever, faiss_retriever], weights=[0.5, 0.5])
        ensemble_retriever.top_k=4
        
        logger.info("EnsembleRetriever loaded:")
        return ensemble_retriever
    
    except Exception as e:
        logger.exception(e)
        raise e

from reggpt.llms.llm import get_model

def load_multi_query_retriever(multi_query_model_type):
    #multi query
    try:
        llm = get_model(multi_query_model_type)
        ensembleRetriever = load_ensemble_retriever()
        retriever  = MultiQueryRetriever.from_llm(
            retriever=ensembleRetriever,
            llm=llm
        )
        logger.info("MultiQueryRetriever loaded:")
        return retriever

    except Exception as e:
        logger.exception(e)
        raise e