Spaces:
Running
Running
""" | |
/************************************************************************* | |
* | |
* CONFIDENTIAL | |
* __________________ | |
* | |
* Copyright (2023-2024) AI Labs, IronOne Technologies, LLC | |
* All Rights Reserved | |
* | |
* Author : Theekshana Samaradiwakara | |
* Description :Python Backend API to chat with private data | |
* CreatedDate : 14/11/2023 | |
* LastModifiedDate : 18/03/2024 | |
*************************************************************************/ | |
""" | |
import os | |
import logging | |
logger = logging.getLogger(__name__) | |
from dotenv import load_dotenv | |
load_dotenv() | |
verbose = os.environ.get('VERBOSE') | |
from reggpt.llms.llm import get_model | |
from langchain.chains import ConversationalRetrievalChain | |
# from conversationBufferWindowMemory import ConversationBufferWindowMemory | |
# from langchain.prompts import PromptTemplate | |
from langchain.chains import LLMChain | |
from reggpt.prompts.document_combine import document_combine_prompt | |
from reggpt.prompts.retrieval import retrieval_qa_chain_prompt | |
from reggpt.prompts.general import general_qa_chain_prompt | |
from reggpt.prompts.router import router_prompt | |
def get_qa_chain(model_type,retriever): | |
logger.info("creating qa_chain") | |
try: | |
qa_llm = get_model(model_type) | |
qa_chain = ConversationalRetrievalChain.from_llm( | |
llm=qa_llm, | |
chain_type="stuff", | |
retriever = retriever, | |
# retriever = self.retriever(search_kwargs={"k": target_source_chunks} | |
return_source_documents= True, | |
get_chat_history=lambda h : h, | |
combine_docs_chain_kwargs={ | |
"prompt": retrieval_qa_chain_prompt, | |
"document_prompt": document_combine_prompt, | |
}, | |
verbose=True, | |
# memory=memory, | |
) | |
logger.info("qa_chain created") | |
return qa_chain | |
except Exception as e: | |
msg=f"Error : {e}" | |
logger.exception(msg) | |
raise e | |
def get_general_qa_chain(model_type): | |
logger.info("creating general_qa_chain") | |
try: | |
general_qa_llm = get_model(model_type) | |
general_qa_chain = LLMChain(llm=general_qa_llm, prompt=general_qa_chain_prompt) | |
# general_qa_chain = general_qa_chain_prompt | general_qa_llm | |
logger.info("general_qa_chain created") | |
return general_qa_chain | |
except Exception as e: | |
msg=f"Error : {e}" | |
logger.exception(msg) | |
raise e | |
def get_router_chain(model_type): | |
logger.info("creating router_chain") | |
try: | |
router_llm = get_model(model_type) | |
router_chain = LLMChain(llm=router_llm, prompt=router_prompt) | |
# router_chain = router_prompt | router_llm | |
logger.info("router_chain created") | |
return router_chain | |
except Exception as e: | |
msg=f"Error : {e}" | |
logger.exception(msg) | |
raise e | |