Spaces:
Runtime error
Runtime error
from openai import OpenAI | |
from langchain_openai import ChatOpenAI | |
from langchain_community.chat_models import ChatOllama | |
from langchain_google_genai import ChatGoogleGenerativeAI | |
from langchain_groq import ChatGroq | |
try: | |
from .utils.db import load_api_key, load_openai_url, load_model_settings, load_groq_api_key, load_google_api_key | |
from .custom_callback import customcallback | |
except ImportError: | |
from utils.db import load_api_key, load_openai_url, load_model_settings, load_groq_api_key, load_google_api_key | |
from custom_callback import customcallback | |
the_callback = customcallback(strip_tokens=False, answer_prefix_tokens=["Answer"]) | |
def get_model(high_context=False): | |
the_model = load_model_settings() | |
the_api_key = load_api_key() | |
the_groq_api_key = load_groq_api_key() | |
the_google_api_key = load_google_api_key() | |
the_openai_url = load_openai_url() | |
def open_ai_base(high_context): | |
if the_openai_url == "default": | |
true_model = the_model | |
if high_context: | |
true_model = "gpt-4-turbo" | |
return {"model": true_model, "api_key": the_api_key, "max_retries":15, "streaming":True, "callbacks":[the_callback]} | |
else: | |
return {"model": the_model, "api_key": the_api_key, "max_retries":15, "streaming":True, "callbacks":[the_callback], "base_url": the_openai_url} | |
args_mapping = { | |
ChatOpenAI: open_ai_base(high_context=high_context), | |
ChatOllama: {"model": the_model}, | |
ChatGroq: {"temperature": 0, "model_name": the_model.replace("-groq", ""), "groq_api_key": the_openai_url}, | |
ChatGoogleGenerativeAI:{"model": the_model, "google_api_key": the_google_api_key} | |
} | |
model_mapping = { | |
# OpenAI | |
"gpt-4o": (ChatOpenAI, args_mapping[ChatOpenAI]), | |
"gpt-4-turbo": (ChatOpenAI, args_mapping[ChatOpenAI]), | |
"gpt-3.5": (ChatOpenAI, args_mapping[ChatOpenAI]), | |
"gpt-3.5-turbo": (ChatOpenAI, args_mapping[ChatOpenAI]), | |
# Google Generative AI - Llama | |
"llava": (ChatOllama, args_mapping[ChatOllama]), | |
"llama3": (ChatOllama, args_mapping[ChatOllama]), | |
"bakllava": (ChatOllama, args_mapping[ChatOllama]), | |
# Google Generative AI - Gemini | |
"gemini-pro": (ChatGoogleGenerativeAI, args_mapping[ChatGoogleGenerativeAI]), | |
# Groq | |
"mixtral-8x7b-groq": (ChatGroq, args_mapping[ChatGroq]) | |
} | |
model_class, args = model_mapping[the_model] | |
return model_class(**args) if model_class else None | |
def get_client(): | |
the_api_key = load_api_key() | |
the_openai_url = load_openai_url() | |
if the_openai_url == "default": | |
return OpenAI(api_key=the_api_key) | |
else: | |
return OpenAI(api_key=the_api_key, base_url=the_openai_url) | |