import os import time import logging logger = logging.getLogger(__name__) from dotenv import load_dotenv from fastapi import HTTPException from reggpt.chains.llmChain import get_qa_chain, get_general_qa_chain, get_router_chain from reggpt.output_parsers.output_parser import general_qa_chain_output_parser, qa_chain_output_parser, out_of_domain_chain_parser from reggpt.configs.config import QA_MODEL_TYPE, GENERAL_QA_MODEL_TYPE, ROUTER_MODEL_TYPE, Multi_Query_MODEL_TYPE from reggpt.utils.retriever import load_faiss_retriever, load_ensemble_retriever, load_multi_query_retriever load_dotenv() verbose = os.environ.get('VERBOSE') qa_model_type=QA_MODEL_TYPE general_qa_model_type=GENERAL_QA_MODEL_TYPE router_model_type=ROUTER_MODEL_TYPE #"google/flan-t5-xxl" multi_query_model_type=Multi_Query_MODEL_TYPE #"google/flan-t5-xxl" # model_type="tiiuae/falcon-7b-instruct" # retriever=load_faiss_retriever() retriever=load_ensemble_retriever() # retriever=load_multi_query_retriever(multi_query_model_type) logger.info("retriever loaded:") qa_chain= get_qa_chain(qa_model_type,retriever) general_qa_chain= get_general_qa_chain(general_qa_model_type) router_chain= get_router_chain(router_model_type) def run_out_of_domain_chain(query): return out_of_domain_chain_parser(query)