from collections import defaultdict from contextlib import contextmanager import os import logging import sys import commentjson as json from . import shared from . import presets __all__ = [ "my_api_key", "sensitive_id", "authflag", "auth_list", "dockerflag", "retrieve_proxy", "advance_docs", "update_doc_config", "usage_limit", "multi_api_key", "server_name", "server_port", "share", "check_update", "latex_delimiters_set", "hide_history_when_not_logged_in", "default_chuanhu_assistant_model", "show_api_billing" ] # 添加一个统一的config文件,避免文件过多造成的疑惑(优先级最低) # 同时,也可以为后续支持自定义功能提供config的帮助 if os.path.exists("config.json"): with open("config.json", "r", encoding='utf-8') as f: config = json.load(f) else: config = {} def load_config_to_environ(key_list): global config for key in key_list: if key in config: os.environ[key.upper()] = os.environ.get(key.upper(), config[key]) lang_config = config.get("language", "auto") language = os.environ.get("LANGUAGE", lang_config) hide_history_when_not_logged_in = config.get( "hide_history_when_not_logged_in", False) check_update = config.get("check_update", True) show_api_billing = config.get("show_api_billing", False) show_api_billing = bool(os.environ.get("SHOW_API_BILLING", show_api_billing)) if os.path.exists("api_key.txt"): logging.info("检测到api_key.txt文件,正在进行迁移...") with open("api_key.txt", "r", encoding="utf-8") as f: config["openai_api_key"] = f.read().strip() os.rename("api_key.txt", "api_key(deprecated).txt") with open("config.json", "w", encoding='utf-8') as f: json.dump(config, f, indent=4, ensure_ascii=False) if os.path.exists("auth.json"): logging.info("检测到auth.json文件,正在进行迁移...") auth_list = [] with open("auth.json", "r", encoding='utf-8') as f: auth = json.load(f) for _ in auth: if auth[_]["username"] and auth[_]["password"]: auth_list.append((auth[_]["username"], auth[_]["password"])) else: logging.error("请检查auth.json文件中的用户名和密码!") sys.exit(1) config["users"] = auth_list os.rename("auth.json", "auth(deprecated).json") with open("config.json", "w", encoding='utf-8') as f: json.dump(config, f, indent=4, ensure_ascii=False) # 处理docker if we are running in Docker dockerflag = config.get("dockerflag", False) if os.environ.get("dockerrun") == "yes": dockerflag = True # 处理 api-key 以及 允许的用户列表 my_api_key = config.get("openai_api_key", "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") my_api_key = os.environ.get("OPENAI_API_KEY", my_api_key) if config.get("legacy_api_usage", False): sensitive_id = my_api_key else: sensitive_id = config.get("sensitive_id", "") sensitive_id = os.environ.get("SENSITIVE_ID", sensitive_id) # 模型配置 if "extra_models" in config: presets.MODELS.extend(config["extra_models"]) logging.info(f"已添加额外的模型:{config['extra_models']}") google_palm_api_key = config.get("google_palm_api_key", "") google_palm_api_key = os.environ.get( "GOOGLE_PALM_API_KEY", google_palm_api_key) os.environ["GOOGLE_PALM_API_KEY"] = google_palm_api_key xmchat_api_key = config.get("xmchat_api_key", "") os.environ["XMCHAT_API_KEY"] = xmchat_api_key minimax_api_key = config.get("minimax_api_key", "") os.environ["MINIMAX_API_KEY"] = minimax_api_key minimax_group_id = config.get("minimax_group_id", "") os.environ["MINIMAX_GROUP_ID"] = minimax_group_id midjourney_proxy_api_base = config.get("midjourney_proxy_api_base", "") os.environ["MIDJOURNEY_PROXY_API_BASE"] = midjourney_proxy_api_base midjourney_proxy_api_secret = config.get("midjourney_proxy_api_secret", "") os.environ["MIDJOURNEY_PROXY_API_SECRET"] = midjourney_proxy_api_secret midjourney_discord_proxy_url = config.get("midjourney_discord_proxy_url", "") os.environ["MIDJOURNEY_DISCORD_PROXY_URL"] = midjourney_discord_proxy_url midjourney_temp_folder = config.get("midjourney_temp_folder", "") os.environ["MIDJOURNEY_TEMP_FOLDER"] = midjourney_temp_folder load_config_to_environ(["openai_api_type", "azure_openai_api_key", "azure_openai_api_base_url", "azure_openai_api_version", "azure_deployment_name", "azure_embedding_deployment_name", "azure_embedding_model_name"]) usage_limit = os.environ.get("USAGE_LIMIT", config.get("usage_limit", 120)) # 多账户机制 multi_api_key = config.get("multi_api_key", False) # 是否开启多账户机制 if multi_api_key: api_key_list = config.get("api_key_list", []) if len(api_key_list) == 0: logging.error("多账号模式已开启,但api_key_list为空,请检查config.json") sys.exit(1) shared.state.set_api_key_queue(api_key_list) auth_list = config.get("users", []) # 实际上是使用者的列表 authflag = len(auth_list) > 0 # 是否开启认证的状态值,改为判断auth_list长度 # 处理自定义的api_host,优先读环境变量的配置,如果存在则自动装配 api_host = os.environ.get( "OPENAI_API_BASE", config.get("openai_api_base", None)) if api_host is not None: shared.state.set_api_host(api_host) os.environ["OPENAI_API_BASE"] = f"{api_host}/v1" logging.info(f"OpenAI API Base set to: {os.environ['OPENAI_API_BASE']}") default_chuanhu_assistant_model = config.get( "default_chuanhu_assistant_model", "gpt-3.5-turbo") for x in ["GOOGLE_CSE_ID", "GOOGLE_API_KEY", "WOLFRAM_ALPHA_APPID", "SERPAPI_API_KEY"]: if config.get(x, None) is not None: os.environ[x] = config[x] @contextmanager def retrieve_openai_api(api_key=None): old_api_key = os.environ.get("OPENAI_API_KEY", "") if api_key is None: os.environ["OPENAI_API_KEY"] = my_api_key yield my_api_key else: os.environ["OPENAI_API_KEY"] = api_key yield api_key os.environ["OPENAI_API_KEY"] = old_api_key # 处理代理: http_proxy = os.environ.get("HTTP_PROXY", "") https_proxy = os.environ.get("HTTPS_PROXY", "") http_proxy = config.get("http_proxy", http_proxy) https_proxy = config.get("https_proxy", https_proxy) # 重置系统变量,在不需要设置的时候不设置环境变量,以免引起全局代理报错 os.environ["HTTP_PROXY"] = "" os.environ["HTTPS_PROXY"] = "" local_embedding = config.get("local_embedding", False) # 是否使用本地embedding @contextmanager def retrieve_proxy(proxy=None): """ 1, 如果proxy = NONE,设置环境变量,并返回最新设置的代理 2,如果proxy != NONE,更新当前的代理配置,但是不更新环境变量 """ global http_proxy, https_proxy if proxy is not None: http_proxy = proxy https_proxy = proxy yield http_proxy, https_proxy else: old_var = os.environ["HTTP_PROXY"], os.environ["HTTPS_PROXY"] os.environ["HTTP_PROXY"] = http_proxy os.environ["HTTPS_PROXY"] = https_proxy yield http_proxy, https_proxy # return new proxy # return old proxy os.environ["HTTP_PROXY"], os.environ["HTTPS_PROXY"] = old_var # 处理latex options user_latex_option = config.get("latex_option", "default") if user_latex_option == "default": latex_delimiters_set = [ {"left": "$$", "right": "$$", "display": True}, {"left": "$", "right": "$", "display": False}, {"left": "\\(", "right": "\\)", "display": False}, {"left": "\\[", "right": "\\]", "display": True}, ] elif user_latex_option == "strict": latex_delimiters_set = [ {"left": "$$", "right": "$$", "display": True}, {"left": "\\(", "right": "\\)", "display": False}, {"left": "\\[", "right": "\\]", "display": True}, ] elif user_latex_option == "all": latex_delimiters_set = [ {"left": "$$", "right": "$$", "display": True}, {"left": "$", "right": "$", "display": False}, {"left": "\\(", "right": "\\)", "display": False}, {"left": "\\[", "right": "\\]", "display": True}, {"left": "\\begin{equation}", "right": "\\end{equation}", "display": True}, {"left": "\\begin{align}", "right": "\\end{align}", "display": True}, {"left": "\\begin{alignat}", "right": "\\end{alignat}", "display": True}, {"left": "\\begin{gather}", "right": "\\end{gather}", "display": True}, {"left": "\\begin{CD}", "right": "\\end{CD}", "display": True}, ] elif user_latex_option == "disabled": latex_delimiters_set = [] else: latex_delimiters_set = [ {"left": "$$", "right": "$$", "display": True}, {"left": "$", "right": "$", "display": False}, {"left": "\\(", "right": "\\)", "display": False}, {"left": "\\[", "right": "\\]", "display": True}, ] # 处理advance docs advance_docs = defaultdict(lambda: defaultdict(dict)) advance_docs.update(config.get("advance_docs", {})) def update_doc_config(two_column_pdf): global advance_docs advance_docs["pdf"]["two_column"] = two_column_pdf logging.info(f"更新后的文件参数为:{advance_docs}") # 处理gradio.launch参数 server_name = config.get("server_name", None) server_port = config.get("server_port", None) if server_name is None: if dockerflag: server_name = "0.0.0.0" else: server_name = "127.0.0.1" if server_port is None: if dockerflag: server_port = 7860 assert server_port is None or type(server_port) == int, "要求port设置为int类型" # 设置默认model default_model = config.get("default_model", "") try: presets.DEFAULT_MODEL = presets.MODELS.index(default_model) except ValueError: pass share = config.get("share", False) # avatar bot_avatar = config.get("bot_avatar", "default") user_avatar = config.get("user_avatar", "default")