khulnasoft's picture
Upload 48 files
9b674e9 verified
from openai import OpenAI
from langchain_openai import ChatOpenAI
from langchain_community.chat_models import ChatOllama
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_groq import ChatGroq
try:
from .utils.db import load_api_key, load_openai_url, load_model_settings, load_groq_api_key, load_google_api_key
from .custom_callback import customcallback
except ImportError:
from utils.db import load_api_key, load_openai_url, load_model_settings, load_groq_api_key, load_google_api_key
from custom_callback import customcallback
the_callback = customcallback(strip_tokens=False, answer_prefix_tokens=["Answer"])
def get_model(high_context=False):
the_model = load_model_settings()
the_api_key = load_api_key()
the_groq_api_key = load_groq_api_key()
the_google_api_key = load_google_api_key()
the_openai_url = load_openai_url()
def open_ai_base(high_context):
if the_openai_url == "default":
true_model = the_model
if high_context:
true_model = "gpt-4-turbo"
return {"model": true_model, "api_key": the_api_key, "max_retries":15, "streaming":True, "callbacks":[the_callback]}
else:
return {"model": the_model, "api_key": the_api_key, "max_retries":15, "streaming":True, "callbacks":[the_callback], "base_url": the_openai_url}
args_mapping = {
ChatOpenAI: open_ai_base(high_context=high_context),
ChatOllama: {"model": the_model},
ChatGroq: {"temperature": 0, "model_name": the_model.replace("-groq", ""), "groq_api_key": the_openai_url},
ChatGoogleGenerativeAI:{"model": the_model, "google_api_key": the_google_api_key}
}
model_mapping = {
# OpenAI
"gpt-4o": (ChatOpenAI, args_mapping[ChatOpenAI]),
"gpt-4-turbo": (ChatOpenAI, args_mapping[ChatOpenAI]),
"gpt-3.5": (ChatOpenAI, args_mapping[ChatOpenAI]),
"gpt-3.5-turbo": (ChatOpenAI, args_mapping[ChatOpenAI]),
# Google Generative AI - Llama
"llava": (ChatOllama, args_mapping[ChatOllama]),
"llama3": (ChatOllama, args_mapping[ChatOllama]),
"bakllava": (ChatOllama, args_mapping[ChatOllama]),
# Google Generative AI - Gemini
"gemini-pro": (ChatGoogleGenerativeAI, args_mapping[ChatGoogleGenerativeAI]),
# Groq
"mixtral-8x7b-groq": (ChatGroq, args_mapping[ChatGroq])
}
model_class, args = model_mapping[the_model]
return model_class(**args) if model_class else None
def get_client():
the_api_key = load_api_key()
the_openai_url = load_openai_url()
if the_openai_url == "default":
return OpenAI(api_key=the_api_key)
else:
return OpenAI(api_key=the_api_key, base_url=the_openai_url)