File size: 1,282 Bytes
93bc171
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import os
import time
import logging
logger = logging.getLogger(__name__)
from dotenv import load_dotenv
from fastapi import HTTPException
from reggpt.chains.llmChain import get_qa_chain, get_general_qa_chain, get_router_chain
from reggpt.output_parsers.output_parser import general_qa_chain_output_parser, qa_chain_output_parser, out_of_domain_chain_parser

from reggpt.configs.config import QA_MODEL_TYPE, GENERAL_QA_MODEL_TYPE, ROUTER_MODEL_TYPE, Multi_Query_MODEL_TYPE
from reggpt.utils.retriever import load_faiss_retriever, load_ensemble_retriever, load_multi_query_retriever
load_dotenv()

verbose = os.environ.get('VERBOSE')

qa_model_type=QA_MODEL_TYPE
general_qa_model_type=GENERAL_QA_MODEL_TYPE
router_model_type=ROUTER_MODEL_TYPE #"google/flan-t5-xxl"
multi_query_model_type=Multi_Query_MODEL_TYPE #"google/flan-t5-xxl"
# model_type="tiiuae/falcon-7b-instruct"

# retriever=load_faiss_retriever()
retriever=load_ensemble_retriever()
# retriever=load_multi_query_retriever(multi_query_model_type)
logger.info("retriever loaded:")

qa_chain= get_qa_chain(qa_model_type,retriever)
general_qa_chain= get_general_qa_chain(general_qa_model_type)
router_chain= get_router_chain(router_model_type)
def run_out_of_domain_chain(query):
    return out_of_domain_chain_parser(query)