Tuchuanhuhuhu commited on
Commit
8c04739
·
1 Parent(s): 4b9ef74

feat: Azure OpenAI API 支持 embedding

Browse files
config_example.json CHANGED
@@ -9,10 +9,13 @@
9
  "minimax_group_id": "", // 你的 MiniMax Group ID,用于 MiniMax 对话模型
10
 
11
  //== Azure ==
 
12
  "azure_openai_api_key": "", // 你的 Azure OpenAI API Key,用于 Azure OpenAI 对话模型
13
- "azure_api_base_url": "", // 你的 Azure Base URL
14
  "azure_openai_api_version": "2023-05-15", // 你的 Azure OpenAI API 版本
15
- "azure_deployment_name": "", // 你的 Azure DEPLOYMENT NAME
 
 
16
 
17
  //== 基础配置 ==
18
  "language": "auto", // 界面语言,可选"auto", "zh-CN", "en-US", "ja-JP", "ko-KR"
 
9
  "minimax_group_id": "", // 你的 MiniMax Group ID,用于 MiniMax 对话模型
10
 
11
  //== Azure ==
12
+ "openai_api_type": "openai", // 可选项:azure, openai
13
  "azure_openai_api_key": "", // 你的 Azure OpenAI API Key,用于 Azure OpenAI 对话模型
14
+ "azure_openai_api_base_url": "", // 你的 Azure Base URL
15
  "azure_openai_api_version": "2023-05-15", // 你的 Azure OpenAI API 版本
16
+ "azure_deployment_name": "", // 你的 Azure OpenAI Chat 模型 Deployment 名称
17
+ "azure_embedding_deployment_name": "", // 你的 Azure OpenAI Embedding 模型 Deployment 名称
18
+ "azure_embedding_model_name": "text-embedding-ada-002", // 你的 Azure OpenAI Embedding 模型名称
19
 
20
  //== 基础配置 ==
21
  "language": "auto", // 界面语言,可选"auto", "zh-CN", "en-US", "ja-JP", "ko-KR"
modules/config.py CHANGED
@@ -39,19 +39,22 @@ if os.path.exists("config.json"):
39
  else:
40
  config = {}
41
 
 
42
  def load_config_to_environ(key_list):
43
  global config
44
  for key in key_list:
45
  if key in config:
46
  os.environ[key.upper()] = os.environ.get(key.upper(), config[key])
47
 
 
48
  sensitive_id = config.get("sensitive_id", "")
49
  sensitive_id = os.environ.get("SENSITIVE_ID", sensitive_id)
50
 
51
  lang_config = config.get("language", "auto")
52
  language = os.environ.get("LANGUAGE", lang_config)
53
 
54
- hide_history_when_not_logged_in = config.get("hide_history_when_not_logged_in", False)
 
55
  check_update = config.get("check_update", True)
56
  show_api_billing = config.get("show_api_billing", False)
57
  show_api_billing = bool(os.environ.get("SHOW_API_BILLING", show_api_billing))
@@ -68,31 +71,32 @@ if os.path.exists("auth.json"):
68
  logging.info("检测到auth.json文件,正在进行迁移...")
69
  auth_list = []
70
  with open("auth.json", "r", encoding='utf-8') as f:
71
- auth = json.load(f)
72
- for _ in auth:
73
- if auth[_]["username"] and auth[_]["password"]:
74
- auth_list.append((auth[_]["username"], auth[_]["password"]))
75
- else:
76
- logging.error("请检查auth.json文件中的用户名和密码!")
77
- sys.exit(1)
78
  config["users"] = auth_list
79
  os.rename("auth.json", "auth(deprecated).json")
80
  with open("config.json", "w", encoding='utf-8') as f:
81
  json.dump(config, f, indent=4, ensure_ascii=False)
82
 
83
- ## 处理docker if we are running in Docker
84
  dockerflag = config.get("dockerflag", False)
85
  if os.environ.get("dockerrun") == "yes":
86
  dockerflag = True
87
 
88
- ## 处理 api-key 以及 允许的用户列表
89
  my_api_key = config.get("openai_api_key", "")
90
  my_api_key = os.environ.get("OPENAI_API_KEY", my_api_key)
91
  os.environ["OPENAI_API_KEY"] = my_api_key
92
  os.environ["OPENAI_EMBEDDING_API_KEY"] = my_api_key
93
 
94
  google_palm_api_key = config.get("google_palm_api_key", "")
95
- google_palm_api_key = os.environ.get("GOOGLE_PALM_API_KEY", google_palm_api_key)
 
96
  os.environ["GOOGLE_PALM_API_KEY"] = google_palm_api_key
97
 
98
  xmchat_api_key = config.get("xmchat_api_key", "")
@@ -103,13 +107,14 @@ os.environ["MINIMAX_API_KEY"] = minimax_api_key
103
  minimax_group_id = config.get("minimax_group_id", "")
104
  os.environ["MINIMAX_GROUP_ID"] = minimax_group_id
105
 
106
- load_config_to_environ(["azure_openai_api_key", "azure_api_base_url", "azure_openai_api_version", "azure_deployment_name"])
 
107
 
108
 
109
  usage_limit = os.environ.get("USAGE_LIMIT", config.get("usage_limit", 120))
110
 
111
- ## 多账户机制
112
- multi_api_key = config.get("multi_api_key", False) # 是否开启多账户机制
113
  if multi_api_key:
114
  api_key_list = config.get("api_key_list", [])
115
  if len(api_key_list) == 0:
@@ -117,23 +122,26 @@ if multi_api_key:
117
  sys.exit(1)
118
  shared.state.set_api_key_queue(api_key_list)
119
 
120
- auth_list = config.get("users", []) # 实际上是使用者的列表
121
  authflag = len(auth_list) > 0 # 是否开启认证的状态值,改为判断auth_list长度
122
 
123
  # 处理自定义的api_host,优先读环境变量的配置,如果存在则自动装配
124
- api_host = os.environ.get("OPENAI_API_BASE", config.get("openai_api_base", None))
 
125
  if api_host is not None:
126
  shared.state.set_api_host(api_host)
127
  os.environ["OPENAI_API_BASE"] = f"{api_host}/v1"
128
  logging.info(f"OpenAI API Base set to: {os.environ['OPENAI_API_BASE']}")
129
 
130
- default_chuanhu_assistant_model = config.get("default_chuanhu_assistant_model", "gpt-3.5-turbo")
 
131
  for x in ["GOOGLE_CSE_ID", "GOOGLE_API_KEY", "WOLFRAM_ALPHA_APPID", "SERPAPI_API_KEY"]:
132
  if config.get(x, None) is not None:
133
  os.environ[x] = config[x]
134
 
 
135
  @contextmanager
136
- def retrieve_openai_api(api_key = None):
137
  old_api_key = os.environ.get("OPENAI_API_KEY", "")
138
  if api_key is None:
139
  os.environ["OPENAI_API_KEY"] = my_api_key
@@ -143,14 +151,15 @@ def retrieve_openai_api(api_key = None):
143
  yield api_key
144
  os.environ["OPENAI_API_KEY"] = old_api_key
145
 
146
- ## 处理log
 
147
  log_level = config.get("log_level", "INFO")
148
  logging.basicConfig(
149
  level=log_level,
150
  format="%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s",
151
  )
152
 
153
- ## 处理代理:
154
  http_proxy = os.environ.get("HTTP_PROXY", "")
155
  https_proxy = os.environ.get("HTTPS_PROXY", "")
156
  http_proxy = config.get("http_proxy", http_proxy)
@@ -160,7 +169,8 @@ https_proxy = config.get("https_proxy", https_proxy)
160
  os.environ["HTTP_PROXY"] = ""
161
  os.environ["HTTPS_PROXY"] = ""
162
 
163
- local_embedding = config.get("local_embedding", False) # 是否使用本地embedding
 
164
 
165
  @contextmanager
166
  def retrieve_proxy(proxy=None):
@@ -177,12 +187,13 @@ def retrieve_proxy(proxy=None):
177
  old_var = os.environ["HTTP_PROXY"], os.environ["HTTPS_PROXY"]
178
  os.environ["HTTP_PROXY"] = http_proxy
179
  os.environ["HTTPS_PROXY"] = https_proxy
180
- yield http_proxy, https_proxy # return new proxy
181
 
182
  # return old proxy
183
  os.environ["HTTP_PROXY"], os.environ["HTTPS_PROXY"] = old_var
184
 
185
- ## 处理latex options
 
186
  user_latex_option = config.get("latex_option", "default")
187
  if user_latex_option == "default":
188
  latex_delimiters_set = [
@@ -219,16 +230,19 @@ else:
219
  {"left": "\\[", "right": "\\]", "display": True},
220
  ]
221
 
222
- ## 处理advance docs
223
  advance_docs = defaultdict(lambda: defaultdict(dict))
224
  advance_docs.update(config.get("advance_docs", {}))
 
 
225
  def update_doc_config(two_column_pdf):
226
  global advance_docs
227
  advance_docs["pdf"]["two_column"] = two_column_pdf
228
 
229
  logging.info(f"更新后的文件参数为:{advance_docs}")
230
 
231
- ## 处理gradio.launch参数
 
232
  server_name = config.get("server_name", None)
233
  server_port = config.get("server_port", None)
234
  if server_name is None:
 
39
  else:
40
  config = {}
41
 
42
+
43
  def load_config_to_environ(key_list):
44
  global config
45
  for key in key_list:
46
  if key in config:
47
  os.environ[key.upper()] = os.environ.get(key.upper(), config[key])
48
 
49
+
50
  sensitive_id = config.get("sensitive_id", "")
51
  sensitive_id = os.environ.get("SENSITIVE_ID", sensitive_id)
52
 
53
  lang_config = config.get("language", "auto")
54
  language = os.environ.get("LANGUAGE", lang_config)
55
 
56
+ hide_history_when_not_logged_in = config.get(
57
+ "hide_history_when_not_logged_in", False)
58
  check_update = config.get("check_update", True)
59
  show_api_billing = config.get("show_api_billing", False)
60
  show_api_billing = bool(os.environ.get("SHOW_API_BILLING", show_api_billing))
 
71
  logging.info("检测到auth.json文件,正在进行迁移...")
72
  auth_list = []
73
  with open("auth.json", "r", encoding='utf-8') as f:
74
+ auth = json.load(f)
75
+ for _ in auth:
76
+ if auth[_]["username"] and auth[_]["password"]:
77
+ auth_list.append((auth[_]["username"], auth[_]["password"]))
78
+ else:
79
+ logging.error("请检查auth.json文件中的用户名和密码!")
80
+ sys.exit(1)
81
  config["users"] = auth_list
82
  os.rename("auth.json", "auth(deprecated).json")
83
  with open("config.json", "w", encoding='utf-8') as f:
84
  json.dump(config, f, indent=4, ensure_ascii=False)
85
 
86
+ # 处理docker if we are running in Docker
87
  dockerflag = config.get("dockerflag", False)
88
  if os.environ.get("dockerrun") == "yes":
89
  dockerflag = True
90
 
91
+ # 处理 api-key 以及 允许的用户列表
92
  my_api_key = config.get("openai_api_key", "")
93
  my_api_key = os.environ.get("OPENAI_API_KEY", my_api_key)
94
  os.environ["OPENAI_API_KEY"] = my_api_key
95
  os.environ["OPENAI_EMBEDDING_API_KEY"] = my_api_key
96
 
97
  google_palm_api_key = config.get("google_palm_api_key", "")
98
+ google_palm_api_key = os.environ.get(
99
+ "GOOGLE_PALM_API_KEY", google_palm_api_key)
100
  os.environ["GOOGLE_PALM_API_KEY"] = google_palm_api_key
101
 
102
  xmchat_api_key = config.get("xmchat_api_key", "")
 
107
  minimax_group_id = config.get("minimax_group_id", "")
108
  os.environ["MINIMAX_GROUP_ID"] = minimax_group_id
109
 
110
+ load_config_to_environ(["openai_api_type", "azure_openai_api_key", "azure_openai_api_base_url",
111
+ "azure_openai_api_version", "azure_deployment_name", "azure_embedding_deployment_name", "azure_embedding_model_name"])
112
 
113
 
114
  usage_limit = os.environ.get("USAGE_LIMIT", config.get("usage_limit", 120))
115
 
116
+ # 多账户机制
117
+ multi_api_key = config.get("multi_api_key", False) # 是否开启多账户机制
118
  if multi_api_key:
119
  api_key_list = config.get("api_key_list", [])
120
  if len(api_key_list) == 0:
 
122
  sys.exit(1)
123
  shared.state.set_api_key_queue(api_key_list)
124
 
125
+ auth_list = config.get("users", []) # 实际上是使用者的列表
126
  authflag = len(auth_list) > 0 # 是否开启认证的状态值,改为判断auth_list长度
127
 
128
  # 处理自定义的api_host,优先读环境变量的配置,如果存在则自动装配
129
+ api_host = os.environ.get(
130
+ "OPENAI_API_BASE", config.get("openai_api_base", None))
131
  if api_host is not None:
132
  shared.state.set_api_host(api_host)
133
  os.environ["OPENAI_API_BASE"] = f"{api_host}/v1"
134
  logging.info(f"OpenAI API Base set to: {os.environ['OPENAI_API_BASE']}")
135
 
136
+ default_chuanhu_assistant_model = config.get(
137
+ "default_chuanhu_assistant_model", "gpt-3.5-turbo")
138
  for x in ["GOOGLE_CSE_ID", "GOOGLE_API_KEY", "WOLFRAM_ALPHA_APPID", "SERPAPI_API_KEY"]:
139
  if config.get(x, None) is not None:
140
  os.environ[x] = config[x]
141
 
142
+
143
  @contextmanager
144
+ def retrieve_openai_api(api_key=None):
145
  old_api_key = os.environ.get("OPENAI_API_KEY", "")
146
  if api_key is None:
147
  os.environ["OPENAI_API_KEY"] = my_api_key
 
151
  yield api_key
152
  os.environ["OPENAI_API_KEY"] = old_api_key
153
 
154
+
155
+ # 处理log
156
  log_level = config.get("log_level", "INFO")
157
  logging.basicConfig(
158
  level=log_level,
159
  format="%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s",
160
  )
161
 
162
+ # 处理代理:
163
  http_proxy = os.environ.get("HTTP_PROXY", "")
164
  https_proxy = os.environ.get("HTTPS_PROXY", "")
165
  http_proxy = config.get("http_proxy", http_proxy)
 
169
  os.environ["HTTP_PROXY"] = ""
170
  os.environ["HTTPS_PROXY"] = ""
171
 
172
+ local_embedding = config.get("local_embedding", False) # 是否使用本地embedding
173
+
174
 
175
  @contextmanager
176
  def retrieve_proxy(proxy=None):
 
187
  old_var = os.environ["HTTP_PROXY"], os.environ["HTTPS_PROXY"]
188
  os.environ["HTTP_PROXY"] = http_proxy
189
  os.environ["HTTPS_PROXY"] = https_proxy
190
+ yield http_proxy, https_proxy # return new proxy
191
 
192
  # return old proxy
193
  os.environ["HTTP_PROXY"], os.environ["HTTPS_PROXY"] = old_var
194
 
195
+
196
+ # 处理latex options
197
  user_latex_option = config.get("latex_option", "default")
198
  if user_latex_option == "default":
199
  latex_delimiters_set = [
 
230
  {"left": "\\[", "right": "\\]", "display": True},
231
  ]
232
 
233
+ # 处理advance docs
234
  advance_docs = defaultdict(lambda: defaultdict(dict))
235
  advance_docs.update(config.get("advance_docs", {}))
236
+
237
+
238
  def update_doc_config(two_column_pdf):
239
  global advance_docs
240
  advance_docs["pdf"]["two_column"] = two_column_pdf
241
 
242
  logging.info(f"更新后的文件参数为:{advance_docs}")
243
 
244
+
245
+ # 处理gradio.launch参数
246
  server_name = config.get("server_name", None)
247
  server_port = config.get("server_port", None)
248
  if server_name is None:
modules/index_func.py CHANGED
@@ -51,7 +51,8 @@ def get_documents(file_src):
51
  pdfReader = PyPDF2.PdfReader(pdfFileObj)
52
  for page in tqdm(pdfReader.pages):
53
  pdftext += page.extract_text()
54
- texts = [Document(page_content=pdftext, metadata={"source": filepath})]
 
55
  elif file_type == ".docx":
56
  logging.debug("Loading Word...")
57
  from langchain.document_loaders import UnstructuredWordDocumentLoader
@@ -72,7 +73,8 @@ def get_documents(file_src):
72
  text_list = excel_to_string(filepath)
73
  texts = []
74
  for elem in text_list:
75
- texts.append(Document(page_content=elem, metadata={"source": filepath}))
 
76
  else:
77
  logging.debug("Loading text file...")
78
  from langchain.document_loaders import TextLoader
@@ -115,10 +117,16 @@ def construct_index(
115
  index_path = f"./index/{index_name}"
116
  if local_embedding:
117
  from langchain.embeddings.huggingface import HuggingFaceEmbeddings
118
- embeddings = HuggingFaceEmbeddings(model_name = "sentence-transformers/distiluse-base-multilingual-cased-v2")
 
119
  else:
120
  from langchain.embeddings import OpenAIEmbeddings
121
- embeddings = OpenAIEmbeddings(openai_api_base=os.environ.get("OPENAI_API_BASE", None), openai_api_key=os.environ.get("OPENAI_EMBEDDING_API_KEY", api_key))
 
 
 
 
 
122
  if os.path.exists(index_path):
123
  logging.info("找到了缓存的索引文件,加载中……")
124
  return FAISS.load_local(index_path, embeddings)
 
51
  pdfReader = PyPDF2.PdfReader(pdfFileObj)
52
  for page in tqdm(pdfReader.pages):
53
  pdftext += page.extract_text()
54
+ texts = [Document(page_content=pdftext,
55
+ metadata={"source": filepath})]
56
  elif file_type == ".docx":
57
  logging.debug("Loading Word...")
58
  from langchain.document_loaders import UnstructuredWordDocumentLoader
 
73
  text_list = excel_to_string(filepath)
74
  texts = []
75
  for elem in text_list:
76
+ texts.append(Document(page_content=elem,
77
+ metadata={"source": filepath}))
78
  else:
79
  logging.debug("Loading text file...")
80
  from langchain.document_loaders import TextLoader
 
117
  index_path = f"./index/{index_name}"
118
  if local_embedding:
119
  from langchain.embeddings.huggingface import HuggingFaceEmbeddings
120
+ embeddings = HuggingFaceEmbeddings(
121
+ model_name="sentence-transformers/distiluse-base-multilingual-cased-v2")
122
  else:
123
  from langchain.embeddings import OpenAIEmbeddings
124
+ if os.environ.get("OPENAI_API_TYPE", "openai") == "openai":
125
+ embeddings = OpenAIEmbeddings(openai_api_base=os.environ.get(
126
+ "OPENAI_API_BASE", None), openai_api_key=os.environ.get("OPENAI_EMBEDDING_API_KEY", api_key))
127
+ else:
128
+ embeddings = OpenAIEmbeddings(deployment=os.environ["AZURE_EMBEDDING_DEPLOYMENT_NAME"], openai_api_key=os.environ["AZURE_OPENAI_API_KEY"],
129
+ model=os.environ["AZURE_EMBEDDING_MODEL_NAME"], openai_api_base=os.environ["AZURE_OPENAI_API_BASE_URL"], openai_api_type="azure")
130
  if os.path.exists(index_path):
131
  logging.info("找到了缓存的索引文件,加载中……")
132
  return FAISS.load_local(index_path, embeddings)
modules/models/azure.py CHANGED
@@ -9,7 +9,7 @@ class Azure_OpenAI_Client(Base_Chat_Langchain_Client):
9
  def setup_model(self):
10
  # inplement this to setup the model then return it
11
  return AzureChatOpenAI(
12
- openai_api_base=os.environ["AZURE_API_BASE_URL"],
13
  openai_api_version=os.environ["AZURE_OPENAI_API_VERSION"],
14
  deployment_name=os.environ["AZURE_DEPLOYMENT_NAME"],
15
  openai_api_key=os.environ["AZURE_OPENAI_API_KEY"],
 
9
  def setup_model(self):
10
  # inplement this to setup the model then return it
11
  return AzureChatOpenAI(
12
+ openai_api_base=os.environ["AZURE_OPENAI_API_BASE_URL"],
13
  openai_api_version=os.environ["AZURE_OPENAI_API_VERSION"],
14
  deployment_name=os.environ["AZURE_DEPLOYMENT_NAME"],
15
  openai_api_key=os.environ["AZURE_OPENAI_API_KEY"],