RegGPT-Back-End / qaPipeline.py
theekshana's picture
Upload 30 files
395275a verified
raw
history blame
4.9 kB
"""
/*************************************************************************
*
* CONFIDENTIAL
* __________________
*
* Copyright (2023-2024) AI Labs, IronOne Technologies, LLC
* All Rights Reserved
*
* Author : Theekshana Samaradiwakara
* Description :Python Backend API to chat with private data
* CreatedDate : 14/11/2023
* LastModifiedDate : 18/03/2024
*************************************************************************/
"""
import os
import time
import logging
logger = logging.getLogger(__name__)
from dotenv import load_dotenv
from fastapi import HTTPException
from llmChain import get_qa_chain, get_general_qa_chain, get_router_chain
from output_parser import general_qa_chain_output_parser, qa_chain_output_parser, out_of_domain_chain_parser
from config import QA_MODEL_TYPE, GENERAL_QA_MODEL_TYPE, ROUTER_MODEL_TYPE, Multi_Query_MODEL_TYPE
from retriever import load_faiss_retriever, load_ensemble_retriever, load_multi_query_retriever
load_dotenv()
verbose = os.environ.get('VERBOSE')
qa_model_type=QA_MODEL_TYPE
general_qa_model_type=GENERAL_QA_MODEL_TYPE
router_model_type=ROUTER_MODEL_TYPE #"google/flan-t5-xxl"
multi_query_model_type=Multi_Query_MODEL_TYPE #"google/flan-t5-xxl"
# model_type="tiiuae/falcon-7b-instruct"
# retriever=load_faiss_retriever()
retriever=load_ensemble_retriever()
# retriever=load_multi_query_retriever(multi_query_model_type)
logger.info("retriever loaded:")
qa_chain= get_qa_chain(qa_model_type,retriever)
general_qa_chain= get_general_qa_chain(general_qa_model_type)
router_chain= get_router_chain(router_model_type)
def chain_selector(chain_type, query):
chain_type = chain_type.lower().strip()
logger.info(f"chain_selector : chain_type: {chain_type} Question: {query}")
if "greeting" in chain_type:
return run_general_qa_chain(query)
elif "other" in chain_type:
return run_out_of_domain_chain(query)
elif ("relevant" in chain_type) or ("not sure" in chain_type) :
return run_qa_chain(query)
else:
raise ValueError(
f"Received invalid type '{chain_type}'"
)
def run_agent(query):
try:
logger.info(f"run_agent : Question: {query}")
print(f"---------------- run_agent : Question: {query} ----------------")
# Get the answer from the chain
start = time.time()
chain_type = run_router_chain(query)
res = chain_selector(chain_type,query)
end = time.time()
# log the result
logger.error(f"---------------- Answer (took {round(end - start, 2)} s.) \n: {res}")
print(f" \n ---------------- Answer (took {round(end - start, 2)} s.): -------------- \n")
return res
except HTTPException as e:
print('HTTPException eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee')
print(e)
logger.exception(e)
raise e
except Exception as e:
print('Exception eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee')
print(e)
logger.exception(e)
raise e
def run_router_chain(query):
try:
logger.info(f"run_router_chain : Question: {query}")
# Get the answer from the chain
start = time.time()
chain_type = router_chain.invoke(query)['text']
end = time.time()
# log the result
logger.info(f"Answer (took {round(end - start, 2)} s.) chain_type: {chain_type}")
return chain_type
except Exception as e:
logger.exception(e)
raise e
def run_qa_chain(query):
try:
logger.info(f"run_qa_chain : Question: {query}")
# Get the answer from the chain
start = time.time()
# res = qa_chain(query)
res = qa_chain.invoke({"question": query, "chat_history":""})
# res = response
# answer, docs = res['result'],res['source_documents']
end = time.time()
# log the result
logger.info(f"Answer (took {round(end - start, 2)} s.) \n: {res}")
return qa_chain_output_parser(res)
except Exception as e:
logger.exception(e)
raise e
def run_general_qa_chain(query):
try:
logger.info(f"run_general_qa_chain : Question: {query}")
# Get the answer from the chain
start = time.time()
res = general_qa_chain.invoke(query)
end = time.time()
# log the result
logger.info(f"Answer (took {round(end - start, 2)} s.) \n: {res}")
return general_qa_chain_output_parser(res)
except Exception as e:
logger.exception(e)
raise e
def run_out_of_domain_chain(query):
return out_of_domain_chain_parser(query)