Spaces:
Running
Running
""" | |
/************************************************************************* | |
* | |
* CONFIDENTIAL | |
* __________________ | |
* | |
* Copyright (2023-2024) AI Labs, IronOne Technologies, LLC | |
* All Rights Reserved | |
* | |
* Author : Theekshana Samaradiwakara | |
* Description :Python Backend API to chat with private data | |
* CreatedDate : 14/11/2023 | |
* LastModifiedDate : 18/03/2024 | |
*************************************************************************/ | |
""" | |
import os | |
import time | |
import logging | |
logger = logging.getLogger(__name__) | |
from dotenv import load_dotenv | |
from fastapi import HTTPException | |
from llmChain import get_qa_chain, get_general_qa_chain, get_router_chain | |
from output_parser import general_qa_chain_output_parser, qa_chain_output_parser, out_of_domain_chain_parser | |
from config import QA_MODEL_TYPE, GENERAL_QA_MODEL_TYPE, ROUTER_MODEL_TYPE, Multi_Query_MODEL_TYPE | |
from retriever import load_faiss_retriever, load_ensemble_retriever, load_multi_query_retriever | |
load_dotenv() | |
verbose = os.environ.get('VERBOSE') | |
qa_model_type=QA_MODEL_TYPE | |
general_qa_model_type=GENERAL_QA_MODEL_TYPE | |
router_model_type=ROUTER_MODEL_TYPE #"google/flan-t5-xxl" | |
multi_query_model_type=Multi_Query_MODEL_TYPE #"google/flan-t5-xxl" | |
# model_type="tiiuae/falcon-7b-instruct" | |
# retriever=load_faiss_retriever() | |
retriever=load_ensemble_retriever() | |
# retriever=load_multi_query_retriever(multi_query_model_type) | |
logger.info("retriever loaded:") | |
qa_chain= get_qa_chain(qa_model_type,retriever) | |
general_qa_chain= get_general_qa_chain(general_qa_model_type) | |
router_chain= get_router_chain(router_model_type) | |
def chain_selector(chain_type, query): | |
chain_type = chain_type.lower().strip() | |
logger.info(f"chain_selector : chain_type: {chain_type} Question: {query}") | |
if "greeting" in chain_type: | |
return run_general_qa_chain(query) | |
elif "other" in chain_type: | |
return run_out_of_domain_chain(query) | |
elif ("relevant" in chain_type) or ("not sure" in chain_type) : | |
return run_qa_chain(query) | |
else: | |
raise ValueError( | |
f"Received invalid type '{chain_type}'" | |
) | |
def run_agent(query): | |
try: | |
logger.info(f"run_agent : Question: {query}") | |
print(f"---------------- run_agent : Question: {query} ----------------") | |
# Get the answer from the chain | |
start = time.time() | |
chain_type = run_router_chain(query) | |
res = chain_selector(chain_type,query) | |
end = time.time() | |
# log the result | |
logger.error(f"---------------- Answer (took {round(end - start, 2)} s.) \n: {res}") | |
print(f" \n ---------------- Answer (took {round(end - start, 2)} s.): -------------- \n") | |
return res | |
except HTTPException as e: | |
print('HTTPException eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee') | |
print(e) | |
logger.exception(e) | |
raise e | |
except Exception as e: | |
print('Exception eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee') | |
print(e) | |
logger.exception(e) | |
raise e | |
def run_router_chain(query): | |
try: | |
logger.info(f"run_router_chain : Question: {query}") | |
# Get the answer from the chain | |
start = time.time() | |
chain_type = router_chain.invoke(query)['text'] | |
end = time.time() | |
# log the result | |
logger.info(f"Answer (took {round(end - start, 2)} s.) chain_type: {chain_type}") | |
return chain_type | |
except Exception as e: | |
logger.exception(e) | |
raise e | |
def run_qa_chain(query): | |
try: | |
logger.info(f"run_qa_chain : Question: {query}") | |
# Get the answer from the chain | |
start = time.time() | |
# res = qa_chain(query) | |
res = qa_chain.invoke({"question": query, "chat_history":""}) | |
# res = response | |
# answer, docs = res['result'],res['source_documents'] | |
end = time.time() | |
# log the result | |
logger.info(f"Answer (took {round(end - start, 2)} s.) \n: {res}") | |
return qa_chain_output_parser(res) | |
except Exception as e: | |
logger.exception(e) | |
raise e | |
def run_general_qa_chain(query): | |
try: | |
logger.info(f"run_general_qa_chain : Question: {query}") | |
# Get the answer from the chain | |
start = time.time() | |
res = general_qa_chain.invoke(query) | |
end = time.time() | |
# log the result | |
logger.info(f"Answer (took {round(end - start, 2)} s.) \n: {res}") | |
return general_qa_chain_output_parser(res) | |
except Exception as e: | |
logger.exception(e) | |
raise e | |
def run_out_of_domain_chain(query): | |
return out_of_domain_chain_parser(query) |