File size: 2,875 Bytes
93bc171
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a624e23
93bc171
 
 
 
38be0ae
a624e23
 
 
 
93bc171
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38be0ae
 
93bc171
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38be0ae
 
93bc171
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
"""
 /*************************************************************************
 * 
 * CONFIDENTIAL
 * __________________
 * 
 *  Copyright (2023-2024) AI Labs, IronOne Technologies, LLC
 *  All Rights Reserved
 *
 *  Author  : Theekshana Samaradiwakara
 *  Description :Python Backend API to chat with private data  
 *  CreatedDate : 14/11/2023
 *  LastModifiedDate : 18/03/2024
 *************************************************************************/
 """

import os
import logging
logger = logging.getLogger(__name__)
from dotenv import load_dotenv

load_dotenv()

verbose = os.environ.get('VERBOSE')

from reggpt.llms.llm import get_model
from langchain.chains import  ConversationalRetrievalChain
# from conversationBufferWindowMemory import ConversationBufferWindowMemory

# from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
from reggpt.prompts.document_combine import document_combine_prompt
from reggpt.prompts.retrieval import retrieval_qa_chain_prompt
from reggpt.prompts.general import general_qa_chain_prompt
from reggpt.prompts.router import router_prompt


def get_qa_chain(model_type,retriever):
    logger.info("creating qa_chain")
    
    try:
        qa_llm = get_model(model_type)

        qa_chain = ConversationalRetrievalChain.from_llm(
            llm=qa_llm,
            chain_type="stuff",
            retriever = retriever, 
            # retriever = self.retriever(search_kwargs={"k": target_source_chunks}
            return_source_documents= True,
            get_chat_history=lambda h : h,
            combine_docs_chain_kwargs={
                "prompt": retrieval_qa_chain_prompt,
                "document_prompt": document_combine_prompt,
            },
            verbose=True,
            # memory=memory,
        )

        logger.info("qa_chain created")
        return qa_chain

    except Exception as e:
        msg=f"Error : {e}"
        logger.exception(msg)
        raise e


def get_general_qa_chain(model_type):
    logger.info("creating general_qa_chain")
    
    try:
        general_qa_llm = get_model(model_type)
        general_qa_chain = LLMChain(llm=general_qa_llm, prompt=general_qa_chain_prompt)
        # general_qa_chain = general_qa_chain_prompt | general_qa_llm

        logger.info("general_qa_chain created")
        return general_qa_chain

    except Exception as e:
        msg=f"Error : {e}"
        logger.exception(msg)
        raise e


def get_router_chain(model_type):
    logger.info("creating router_chain")
    
    try:
        router_llm = get_model(model_type)
        router_chain = LLMChain(llm=router_llm, prompt=router_prompt)
        # router_chain = router_prompt | router_llm

        logger.info("router_chain created")
        return router_chain

    except Exception as e:
        msg=f"Error : {e}"
        logger.exception(msg)
        raise e