theekshana's picture
moved app.python to main directory
a624e23
"""
/*************************************************************************
*
* CONFIDENTIAL
* __________________
*
* Copyright (2023-2024) AI Labs, IronOne Technologies, LLC
* All Rights Reserved
*
* Author : Theekshana Samaradiwakara
* Description :Python Backend API to chat with private data
* CreatedDate : 14/11/2023
* LastModifiedDate : 18/03/2024
*************************************************************************/
"""
import os
import logging
logger = logging.getLogger(__name__)
from dotenv import load_dotenv
load_dotenv()
verbose = os.environ.get('VERBOSE')
from reggpt.llms.llm import get_model
from langchain.chains import ConversationalRetrievalChain
# from conversationBufferWindowMemory import ConversationBufferWindowMemory
# from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
from reggpt.prompts.document_combine import document_combine_prompt
from reggpt.prompts.retrieval import retrieval_qa_chain_prompt
from reggpt.prompts.general import general_qa_chain_prompt
from reggpt.prompts.router import router_prompt
def get_qa_chain(model_type,retriever):
logger.info("creating qa_chain")
try:
qa_llm = get_model(model_type)
qa_chain = ConversationalRetrievalChain.from_llm(
llm=qa_llm,
chain_type="stuff",
retriever = retriever,
# retriever = self.retriever(search_kwargs={"k": target_source_chunks}
return_source_documents= True,
get_chat_history=lambda h : h,
combine_docs_chain_kwargs={
"prompt": retrieval_qa_chain_prompt,
"document_prompt": document_combine_prompt,
},
verbose=True,
# memory=memory,
)
logger.info("qa_chain created")
return qa_chain
except Exception as e:
msg=f"Error : {e}"
logger.exception(msg)
raise e
def get_general_qa_chain(model_type):
logger.info("creating general_qa_chain")
try:
general_qa_llm = get_model(model_type)
general_qa_chain = LLMChain(llm=general_qa_llm, prompt=general_qa_chain_prompt)
# general_qa_chain = general_qa_chain_prompt | general_qa_llm
logger.info("general_qa_chain created")
return general_qa_chain
except Exception as e:
msg=f"Error : {e}"
logger.exception(msg)
raise e
def get_router_chain(model_type):
logger.info("creating router_chain")
try:
router_llm = get_model(model_type)
router_chain = LLMChain(llm=router_llm, prompt=router_prompt)
# router_chain = router_prompt | router_llm
logger.info("router_chain created")
return router_chain
except Exception as e:
msg=f"Error : {e}"
logger.exception(msg)
raise e