Tuchuanhuhuhu commited on
Commit
e888600
·
1 Parent(s): 921af92

feat: Added support for custom models

Browse files
modules/config.py CHANGED
@@ -5,9 +5,11 @@ import logging
5
  import sys
6
  import commentjson as json
7
  import colorama
 
8
 
9
  from . import shared
10
  from . import presets
 
11
 
12
 
13
  __all__ = [
@@ -100,14 +102,25 @@ else:
100
  sensitive_id = config.get("sensitive_id", "")
101
  sensitive_id = os.environ.get("SENSITIVE_ID", sensitive_id)
102
 
 
 
 
 
 
 
 
 
 
 
 
103
  if "available_models" in config:
104
  presets.MODELS = config["available_models"]
105
- logging.info(f"已设置可用模型:{config['available_models']}")
106
 
107
  # 模型配置
108
  if "extra_models" in config:
109
  presets.MODELS.extend(config["extra_models"])
110
- logging.info(f"已添加额外的模型:{config['extra_models']}")
111
 
112
  HIDE_MY_KEY = config.get("hide_my_key", False)
113
 
 
5
  import sys
6
  import commentjson as json
7
  import colorama
8
+ from collections import defaultdict
9
 
10
  from . import shared
11
  from . import presets
12
+ from .presets import i18n
13
 
14
 
15
  __all__ = [
 
102
  sensitive_id = config.get("sensitive_id", "")
103
  sensitive_id = os.environ.get("SENSITIVE_ID", sensitive_id)
104
 
105
+ if "extra_model_metadata" in config:
106
+ presets.MODEL_METADATA.update(config["extra_model_metadata"])
107
+ logging.info(i18n("已添加 {extra_model_quantity} 个额外的模型元数据").format(extra_model_quantity=len(config["extra_model_metadata"])))
108
+
109
+ _model_metadata = {}
110
+ for k, v in presets.MODEL_METADATA.items():
111
+ temp_dict = presets.DEFAULT_METADATA.copy()
112
+ temp_dict.update(v)
113
+ _model_metadata[k] = temp_dict
114
+ presets.MODEL_METADATA = _model_metadata
115
+
116
  if "available_models" in config:
117
  presets.MODELS = config["available_models"]
118
+ logging.info(i18n("已设置可用模型:{available_models}").format(available_models=config["available_models"]))
119
 
120
  # 模型配置
121
  if "extra_models" in config:
122
  presets.MODELS.extend(config["extra_models"])
123
+ logging.info(i18n("已添加额外的模型:{extra_models}").format(extra_models=config["extra_models"]))
124
 
125
  HIDE_MY_KEY = config.get("hide_my_key", False)
126
 
modules/models/Claude.py CHANGED
@@ -11,7 +11,7 @@ class Claude_Client(BaseLLMModel):
11
  self.api_secret = api_secret
12
  if None in [self.api_secret]:
13
  raise Exception("请在配置文件或者环境变量中设置Claude的API Secret")
14
- self.claude_client = Anthropic(api_key=self.api_secret)
15
 
16
  def _get_claude_style_history(self):
17
  history = []
 
11
  self.api_secret = api_secret
12
  if None in [self.api_secret]:
13
  raise Exception("请在配置文件或者环境变量中设置Claude的API Secret")
14
+ self.claude_client = Anthropic(api_key=self.api_secret, base_url=self.api_host)
15
 
16
  def _get_claude_style_history(self):
17
  history = []
modules/models/DALLE3.py CHANGED
@@ -7,8 +7,11 @@ from ..config import retrieve_proxy, sensitive_id
7
 
8
  class OpenAI_DALLE3_Client(BaseLLMModel):
9
  def __init__(self, model_name, api_key, user_name="") -> None:
10
- super().__init__(model_name=model_name, user=user_name)
11
- self.api_key = api_key
 
 
 
12
  self._refresh_header()
13
 
14
  def _get_dalle3_prompt(self):
@@ -24,7 +27,7 @@ class OpenAI_DALLE3_Client(BaseLLMModel):
24
  "Authorization": f"Bearer {self.api_key}"
25
  }
26
  payload = {
27
- "model": "dall-e-3",
28
  "prompt": prompt,
29
  "n": 1,
30
  "size": "1024x1024",
@@ -35,13 +38,13 @@ class OpenAI_DALLE3_Client(BaseLLMModel):
35
  else:
36
  timeout = TIMEOUT_ALL
37
 
38
- if shared.state.images_completion_url != IMAGES_COMPLETION_URL:
39
- logging.debug(f"使用自定义API URL: {shared.state.images_completion_url}")
40
 
41
  with retrieve_proxy():
42
  try:
43
  response = requests.post(
44
- shared.state.images_completion_url,
45
  headers=headers,
46
  json=payload,
47
  stream=stream,
 
7
 
8
  class OpenAI_DALLE3_Client(BaseLLMModel):
9
  def __init__(self, model_name, api_key, user_name="") -> None:
10
+ super().__init__(model_name=model_name, user=user_name, config={"api_key": api_key})
11
+ if self.api_host is not None:
12
+ self.chat_completion_url, self.images_completion_url, self.openai_api_base, self.balance_api_url, self.usage_api_url = shared.format_openai_host(self.api_host)
13
+ else:
14
+ self.api_host, self.chat_completion_url, self.images_completion_url, self.openai_api_base, self.balance_api_url, self.usage_api_url = shared.state.api_host, shared.state.chat_completion_url, shared.state.images_completion_url, shared.state.openai_api_base, shared.state.balance_api_url, shared.state.usage_api_url
15
  self._refresh_header()
16
 
17
  def _get_dalle3_prompt(self):
 
27
  "Authorization": f"Bearer {self.api_key}"
28
  }
29
  payload = {
30
+ "model": self.model_name,
31
  "prompt": prompt,
32
  "n": 1,
33
  "size": "1024x1024",
 
38
  else:
39
  timeout = TIMEOUT_ALL
40
 
41
+ if self.images_completion_url != IMAGES_COMPLETION_URL:
42
+ logging.debug(f"使用自定义API URL: {self.images_completion_url}")
43
 
44
  with retrieve_proxy():
45
  try:
46
  response = requests.post(
47
+ self.images_completion_url,
48
  headers=headers,
49
  json=payload,
50
  stream=stream,
modules/models/GoogleGemini.py CHANGED
@@ -17,8 +17,7 @@ from .base_model import BaseLLMModel
17
 
18
  class GoogleGeminiClient(BaseLLMModel):
19
  def __init__(self, model_name, api_key, user_name="") -> None:
20
- super().__init__(model_name=model_name, user=user_name)
21
- self.api_key = api_key
22
  if "vision" in model_name.lower():
23
  self.multimodal = True
24
  else:
 
17
 
18
  class GoogleGeminiClient(BaseLLMModel):
19
  def __init__(self, model_name, api_key, user_name="") -> None:
20
+ super().__init__(model_name=model_name, user=user_name, config={"api_key": api_key})
 
21
  if "vision" in model_name.lower():
22
  self.multimodal = True
23
  else:
modules/models/GooglePaLM.py CHANGED
@@ -4,8 +4,7 @@ import google.generativeai as palm
4
 
5
  class Google_PaLM_Client(BaseLLMModel):
6
  def __init__(self, model_name, api_key, user_name="") -> None:
7
- super().__init__(model_name=model_name, user=user_name)
8
- self.api_key = api_key
9
 
10
  def _get_palm_style_input(self):
11
  new_history = []
@@ -20,7 +19,7 @@ class Google_PaLM_Client(BaseLLMModel):
20
  palm.configure(api_key=self.api_key)
21
  messages = self._get_palm_style_input()
22
  response = palm.chat(context=self.system_prompt, messages=messages,
23
- temperature=self.temperature, top_p=self.top_p)
24
  if response.last is not None:
25
  return response.last, len(response.last)
26
  else:
 
4
 
5
  class Google_PaLM_Client(BaseLLMModel):
6
  def __init__(self, model_name, api_key, user_name="") -> None:
7
+ super().__init__(model_name=model_name, user=user_name, config={"api_key": api_key})
 
8
 
9
  def _get_palm_style_input(self):
10
  new_history = []
 
19
  palm.configure(api_key=self.api_key)
20
  messages = self._get_palm_style_input()
21
  response = palm.chat(context=self.system_prompt, messages=messages,
22
+ temperature=self.temperature, top_p=self.top_p, model=self.model_name)
23
  if response.last is not None:
24
  return response.last, len(response.last)
25
  else:
modules/models/Groq.py CHANGED
@@ -18,10 +18,10 @@ from .base_model import BaseLLMModel
18
 
19
  class Groq_Client(BaseLLMModel):
20
  def __init__(self, model_name, api_key, user_name="") -> None:
21
- super().__init__(model_name=model_name, user=user_name)
22
- self.api_key = api_key
23
  self.client = Groq(
24
  api_key=os.environ.get("GROQ_API_KEY"),
 
25
  )
26
 
27
  def _get_groq_style_input(self):
 
18
 
19
  class Groq_Client(BaseLLMModel):
20
  def __init__(self, model_name, api_key, user_name="") -> None:
21
+ super().__init__(model_name=model_name, user=user_name, api_key=api_key)
 
22
  self.client = Groq(
23
  api_key=os.environ.get("GROQ_API_KEY"),
24
+ base_url=self.api_host,
25
  )
26
 
27
  def _get_groq_style_input(self):
modules/models/OpenAI.py DELETED
@@ -1,280 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import json
4
- import logging
5
- import traceback
6
-
7
- import colorama
8
- import requests
9
-
10
- from .. import shared
11
- from ..config import retrieve_proxy, sensitive_id, usage_limit
12
- from ..index_func import *
13
- from ..presets import *
14
- from ..utils import *
15
- from .base_model import BaseLLMModel
16
-
17
-
18
- class OpenAIClient(BaseLLMModel):
19
- def __init__(
20
- self,
21
- model_name,
22
- api_key,
23
- system_prompt=INITIAL_SYSTEM_PROMPT,
24
- temperature=1.0,
25
- top_p=1.0,
26
- user_name=""
27
- ) -> None:
28
- super().__init__(
29
- model_name=model_name,
30
- temperature=temperature,
31
- top_p=top_p,
32
- system_prompt=system_prompt,
33
- user=user_name
34
- )
35
- self.api_key = api_key
36
- self.need_api_key = True
37
- self._refresh_header()
38
-
39
- def get_answer_stream_iter(self):
40
- if not self.api_key:
41
- raise Exception(NO_APIKEY_MSG)
42
- response = self._get_response(stream=True)
43
- if response is not None:
44
- iter = self._decode_chat_response(response)
45
- partial_text = ""
46
- for i in iter:
47
- partial_text += i
48
- yield partial_text
49
- else:
50
- yield STANDARD_ERROR_MSG + GENERAL_ERROR_MSG
51
-
52
- def get_answer_at_once(self):
53
- if not self.api_key:
54
- raise Exception(NO_APIKEY_MSG)
55
- response = self._get_response()
56
- response = json.loads(response.text)
57
- content = response["choices"][0]["message"]["content"]
58
- total_token_count = response["usage"]["total_tokens"]
59
- return content, total_token_count
60
-
61
- def count_token(self, user_input):
62
- input_token_count = count_token(construct_user(user_input))
63
- if self.system_prompt is not None and len(self.all_token_counts) == 0:
64
- system_prompt_token_count = count_token(
65
- construct_system(self.system_prompt)
66
- )
67
- return input_token_count + system_prompt_token_count
68
- return input_token_count
69
-
70
- def billing_info(self):
71
- try:
72
- curr_time = datetime.datetime.now()
73
- last_day_of_month = get_last_day_of_month(
74
- curr_time).strftime("%Y-%m-%d")
75
- first_day_of_month = curr_time.replace(day=1).strftime("%Y-%m-%d")
76
- usage_url = f"{shared.state.usage_api_url}?start_date={first_day_of_month}&end_date={last_day_of_month}"
77
- try:
78
- usage_data = self._get_billing_data(usage_url)
79
- except Exception as e:
80
- # logging.error(f"获取API使用情况失败: " + str(e))
81
- if "Invalid authorization header" in str(e):
82
- return i18n("**获取API使用情况失败**,需在填写`config.json`中正确填写sensitive_id")
83
- elif "Incorrect API key provided: sess" in str(e):
84
- return i18n("**获取API使用情况失败**,sensitive_id错误或已过期")
85
- return i18n("**获取API使用情况失败**")
86
- # rounded_usage = "{:.5f}".format(usage_data["total_usage"] / 100)
87
- rounded_usage = round(usage_data["total_usage"] / 100, 5)
88
- usage_percent = round(usage_data["total_usage"] / usage_limit, 2)
89
- from ..webui import get_html
90
-
91
- # return i18n("**本月使用金额** ") + f"\u3000 ${rounded_usage}"
92
- return get_html("billing_info.html").format(
93
- label = i18n("本月使用金额"),
94
- usage_percent = usage_percent,
95
- rounded_usage = rounded_usage,
96
- usage_limit = usage_limit
97
- )
98
- except requests.exceptions.ConnectTimeout:
99
- status_text = (
100
- STANDARD_ERROR_MSG + CONNECTION_TIMEOUT_MSG + ERROR_RETRIEVE_MSG
101
- )
102
- return status_text
103
- except requests.exceptions.ReadTimeout:
104
- status_text = STANDARD_ERROR_MSG + READ_TIMEOUT_MSG + ERROR_RETRIEVE_MSG
105
- return status_text
106
- except Exception as e:
107
- import traceback
108
- traceback.print_exc()
109
- logging.error(i18n("获取API使用情况失败:") + str(e))
110
- return STANDARD_ERROR_MSG + ERROR_RETRIEVE_MSG
111
-
112
- @shared.state.switching_api_key # 在不开启多账号模式的时候,这个装饰器不会起作用
113
- def _get_response(self, stream=False):
114
- openai_api_key = self.api_key
115
- system_prompt = self.system_prompt
116
- history = self.history
117
- logging.debug(colorama.Fore.YELLOW +
118
- f"{history}" + colorama.Fore.RESET)
119
- headers = {
120
- "Content-Type": "application/json",
121
- "Authorization": f"Bearer {openai_api_key}",
122
- }
123
-
124
- if system_prompt is not None:
125
- history = [construct_system(system_prompt), *history]
126
-
127
- payload = {
128
- "model": self.model_name,
129
- "messages": history,
130
- "temperature": self.temperature,
131
- "top_p": self.top_p,
132
- "n": self.n_choices,
133
- "stream": stream,
134
- "presence_penalty": self.presence_penalty,
135
- "frequency_penalty": self.frequency_penalty,
136
- }
137
-
138
- if self.max_generation_token is not None:
139
- payload["max_tokens"] = self.max_generation_token
140
- if self.stop_sequence is not None:
141
- payload["stop"] = self.stop_sequence
142
- if self.logit_bias is not None:
143
- payload["logit_bias"] = self.encoded_logit_bias()
144
- if self.user_identifier:
145
- payload["user"] = self.user_identifier
146
-
147
- if stream:
148
- timeout = TIMEOUT_STREAMING
149
- else:
150
- timeout = TIMEOUT_ALL
151
-
152
- # 如果有自定义的api-host,使用自定义host发送请求,否则使用默认设置发送请求
153
- if shared.state.chat_completion_url != CHAT_COMPLETION_URL:
154
- logging.debug(f"使用自定义API URL: {shared.state.chat_completion_url}")
155
-
156
- with retrieve_proxy():
157
- try:
158
- response = requests.post(
159
- shared.state.chat_completion_url,
160
- headers=headers,
161
- json=payload,
162
- stream=stream,
163
- timeout=timeout,
164
- )
165
- except:
166
- traceback.print_exc()
167
- return None
168
- return response
169
-
170
- def _refresh_header(self):
171
- self.headers = {
172
- "Content-Type": "application/json",
173
- "Authorization": f"Bearer {sensitive_id}",
174
- }
175
-
176
-
177
- def _get_billing_data(self, billing_url):
178
- with retrieve_proxy():
179
- response = requests.get(
180
- billing_url,
181
- headers=self.headers,
182
- timeout=TIMEOUT_ALL,
183
- )
184
-
185
- if response.status_code == 200:
186
- data = response.json()
187
- return data
188
- else:
189
- raise Exception(
190
- f"API request failed with status code {response.status_code}: {response.text}"
191
- )
192
-
193
- def _decode_chat_response(self, response):
194
- error_msg = ""
195
- for chunk in response.iter_lines():
196
- if chunk:
197
- chunk = chunk.decode()
198
- chunk_length = len(chunk)
199
- try:
200
- chunk = json.loads(chunk[6:])
201
- except:
202
- print(i18n("JSON解析错误,收到的内容: ") + f"{chunk}")
203
- error_msg += chunk
204
- continue
205
- try:
206
- if chunk_length > 6 and "delta" in chunk["choices"][0]:
207
- if "finish_reason" in chunk["choices"][0]:
208
- finish_reason = chunk["choices"][0]["finish_reason"]
209
- else:
210
- finish_reason = chunk["finish_reason"]
211
- if finish_reason == "stop":
212
- break
213
- try:
214
- yield chunk["choices"][0]["delta"]["content"]
215
- except Exception as e:
216
- # logging.error(f"Error: {e}")
217
- continue
218
- except:
219
- print(f"ERROR: {chunk}")
220
- continue
221
- if error_msg and not error_msg=="data: [DONE]":
222
- raise Exception(error_msg)
223
-
224
- def set_key(self, new_access_key):
225
- ret = super().set_key(new_access_key)
226
- self._refresh_header()
227
- return ret
228
-
229
- def _single_query_at_once(self, history, temperature=1.0):
230
- timeout = TIMEOUT_ALL
231
- headers = {
232
- "Content-Type": "application/json",
233
- "Authorization": f"Bearer {self.api_key}",
234
- "temperature": f"{temperature}",
235
- }
236
- payload = {
237
- "model": self.model_name,
238
- "messages": history,
239
- }
240
- # 如果有自定义的api-host,使用自定义host发送请求,否则使用默认设置发送请求
241
- if shared.state.chat_completion_url != CHAT_COMPLETION_URL:
242
- logging.debug(f"使用自定义API URL: {shared.state.chat_completion_url}")
243
-
244
- with retrieve_proxy():
245
- response = requests.post(
246
- shared.state.chat_completion_url,
247
- headers=headers,
248
- json=payload,
249
- stream=False,
250
- timeout=timeout,
251
- )
252
-
253
- return response
254
-
255
-
256
- def auto_name_chat_history(self, name_chat_method, user_question, chatbot, single_turn_checkbox):
257
- if len(self.history) == 2 and not single_turn_checkbox and not hide_history_when_not_logged_in:
258
- user_question = self.history[0]["content"]
259
- if name_chat_method == i18n("模型自动总结(消耗tokens)"):
260
- ai_answer = self.history[1]["content"]
261
- try:
262
- history = [
263
- { "role": "system", "content": SUMMARY_CHAT_SYSTEM_PROMPT},
264
- { "role": "user", "content": f"Please write a title based on the following conversation:\n---\nUser: {user_question}\nAssistant: {ai_answer}"}
265
- ]
266
- response = self._single_query_at_once(history, temperature=0.0)
267
- response = json.loads(response.text)
268
- content = response["choices"][0]["message"]["content"]
269
- filename = replace_special_symbols(content) + ".json"
270
- except Exception as e:
271
- logging.info(f"自动命名失败。{e}")
272
- filename = replace_special_symbols(user_question)[:16] + ".json"
273
- return self.rename_chat_history(filename, chatbot)
274
- elif name_chat_method == i18n("第一条提问"):
275
- filename = replace_special_symbols(user_question)[:16] + ".json"
276
- return self.rename_chat_history(filename, chatbot)
277
- else:
278
- return gr.update()
279
- else:
280
- return gr.update()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modules/models/OpenAIInstruct.py CHANGED
@@ -8,8 +8,7 @@ from ..config import retrieve_proxy
8
 
9
  class OpenAI_Instruct_Client(BaseLLMModel):
10
  def __init__(self, model_name, api_key, user_name="") -> None:
11
- super().__init__(model_name=model_name, user=user_name)
12
- self.api_key = api_key
13
 
14
  def _get_instruct_style_input(self):
15
  return "\n\n".join([item["content"] for item in self.history])
 
8
 
9
  class OpenAI_Instruct_Client(BaseLLMModel):
10
  def __init__(self, model_name, api_key, user_name="") -> None:
11
+ super().__init__(model_name=model_name, user=user_name, config={"api_key": api_key})
 
12
 
13
  def _get_instruct_style_input(self):
14
  return "\n\n".join([item["content"] for item in self.history])
modules/models/OpenAIVision.py CHANGED
@@ -27,22 +27,19 @@ class OpenAIVisionClient(BaseLLMModel):
27
  self,
28
  model_name,
29
  api_key,
30
- system_prompt=INITIAL_SYSTEM_PROMPT,
31
- temperature=1.0,
32
- top_p=1.0,
33
  user_name=""
34
  ) -> None:
35
  super().__init__(
36
  model_name=model_name,
37
- temperature=temperature,
38
- top_p=top_p,
39
- system_prompt=system_prompt,
40
- user=user_name
41
  )
42
- self.image_token = 0
43
- self.api_key = api_key
44
- self.need_api_key = True
45
- self.max_generation_token = 4096
46
  self._refresh_header()
47
 
48
  def get_answer_stream_iter(self):
@@ -176,7 +173,7 @@ class OpenAIVisionClient(BaseLLMModel):
176
  "stream": stream,
177
  "presence_penalty": self.presence_penalty,
178
  "frequency_penalty": self.frequency_penalty,
179
- "max_tokens": 4096
180
  }
181
 
182
  if self.stop_sequence:
@@ -296,3 +293,29 @@ class OpenAIVisionClient(BaseLLMModel):
296
  )
297
 
298
  return response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  self,
28
  model_name,
29
  api_key,
 
 
 
30
  user_name=""
31
  ) -> None:
32
  super().__init__(
33
  model_name=model_name,
34
+ user=user_name,
35
+ config={
36
+ "api_key": api_key
37
+ }
38
  )
39
+ if self.api_host is not None:
40
+ self.chat_completion_url, self.images_completion_url, self.openai_api_base, self.balance_api_url, self.usage_api_url = shared.format_openai_host(self.api_host)
41
+ else:
42
+ self.api_host, self.chat_completion_url, self.images_completion_url, self.openai_api_base, self.balance_api_url, self.usage_api_url = shared.state.api_host, shared.state.chat_completion_url, shared.state.images_completion_url, shared.state.openai_api_base, shared.state.balance_api_url, shared.state.usage_api_url
43
  self._refresh_header()
44
 
45
  def get_answer_stream_iter(self):
 
173
  "stream": stream,
174
  "presence_penalty": self.presence_penalty,
175
  "frequency_penalty": self.frequency_penalty,
176
+ "max_tokens": self.max_generation_token
177
  }
178
 
179
  if self.stop_sequence:
 
293
  )
294
 
295
  return response
296
+
297
+ def auto_name_chat_history(self, name_chat_method, user_question, chatbot, single_turn_checkbox):
298
+ if len(self.history) == 2 and not single_turn_checkbox and not hide_history_when_not_logged_in:
299
+ user_question = self.history[0]["content"]
300
+ if name_chat_method == i18n("模型自动总结(消耗tokens)"):
301
+ ai_answer = self.history[1]["content"]
302
+ try:
303
+ history = [
304
+ { "role": "system", "content": SUMMARY_CHAT_SYSTEM_PROMPT},
305
+ { "role": "user", "content": f"Please write a title based on the following conversation:\n---\nUser: {user_question}\nAssistant: {ai_answer}"}
306
+ ]
307
+ response = self._single_query_at_once(history, temperature=0.0)
308
+ response = json.loads(response.text)
309
+ content = response["choices"][0]["message"]["content"]
310
+ filename = replace_special_symbols(content) + ".json"
311
+ except Exception as e:
312
+ logging.info(f"自动命名失败。{e}")
313
+ filename = replace_special_symbols(user_question)[:16] + ".json"
314
+ return self.rename_chat_history(filename, chatbot)
315
+ elif name_chat_method == i18n("第一条提问"):
316
+ filename = replace_special_symbols(user_question)[:16] + ".json"
317
+ return self.rename_chat_history(filename, chatbot)
318
+ else:
319
+ return gr.update()
320
+ else:
321
+ return gr.update()
modules/models/XMChat.py CHANGED
@@ -26,6 +26,8 @@ class XMChat(BaseLLMModel):
26
  self.image_path = None
27
  self.xm_history = []
28
  self.url = "https://xmbot.net/web"
 
 
29
  self.last_conv_id = None
30
 
31
  def reset(self, remain_system_prompt=False):
 
26
  self.image_path = None
27
  self.xm_history = []
28
  self.url = "https://xmbot.net/web"
29
+ if self.api_host is not None:
30
+ self.url = self.api_host
31
  self.last_conv_id = None
32
 
33
  def reset(self, remain_system_prompt=False):
modules/models/base_model.py CHANGED
@@ -159,6 +159,14 @@ class ModelType(Enum):
159
 
160
  @classmethod
161
  def get_type(cls, model_name: str):
 
 
 
 
 
 
 
 
162
  model_type = None
163
  model_name_lower = model_name.lower()
164
  if "gpt" in model_name_lower:
@@ -249,66 +257,56 @@ class BaseLLMModel:
249
  def __init__(
250
  self,
251
  model_name,
252
- system_prompt=INITIAL_SYSTEM_PROMPT,
253
- temperature=1.0,
254
- top_p=1.0,
255
- n_choices=1,
256
- stop=[],
257
- max_generation_token=None,
258
- presence_penalty=0,
259
- frequency_penalty=0,
260
- logit_bias=None,
261
  user="",
262
- single_turn=False,
263
  ) -> None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264
  self.history = []
265
  self.all_token_counts = []
266
  self.model_type = ModelType.get_type(model_name)
267
- try:
268
- self.model_name = MODEL_METADATA[model_name]["model_name"]
269
- except:
270
- self.model_name = model_name
271
- try:
272
- self.multimodal = MODEL_METADATA[model_name]["multimodal"]
273
- except:
274
- self.multimodal = False
275
- if max_generation_token is None:
276
- try:
277
- max_generation_token = MODEL_METADATA[model_name]["max_generation"]
278
- except:
279
- pass
280
- try:
281
- self.token_upper_limit = MODEL_METADATA[model_name]["token_limit"]
282
- except KeyError:
283
- self.token_upper_limit = DEFAULT_TOKEN_LIMIT
284
- self.interrupted = False
285
- self.system_prompt = system_prompt
286
- self.api_key = None
287
- self.need_api_key = False
288
  self.history_file_path = get_first_history_name(user)
289
  self.user_name = user
290
  self.chatbot = []
291
 
292
- self.default_single_turn = single_turn
293
- self.default_temperature = temperature
294
- self.default_top_p = top_p
295
- self.default_n_choices = n_choices
296
- self.default_stop_sequence = stop
297
- self.default_max_generation_token = max_generation_token
298
- self.default_presence_penalty = presence_penalty
299
- self.default_frequency_penalty = frequency_penalty
300
- self.default_logit_bias = logit_bias
301
  self.default_user_identifier = user
302
 
303
- self.single_turn = single_turn
304
- self.temperature = temperature
305
- self.top_p = top_p
306
- self.n_choices = n_choices
307
- self.stop_sequence = stop
308
- self.max_generation_token = max_generation_token
309
- self.presence_penalty = presence_penalty
310
- self.frequency_penalty = frequency_penalty
311
- self.logit_bias = logit_bias
312
  self.user_identifier = user
313
 
314
  self.metadata = {}
@@ -1073,7 +1071,7 @@ class BaseLLMModel:
1073
  self.reset()
1074
  return (
1075
  os.path.basename(self.history_file_path),
1076
- "",
1077
  [],
1078
  self.single_turn,
1079
  self.temperature,
 
159
 
160
  @classmethod
161
  def get_type(cls, model_name: str):
162
+ # 1. get model type from model metadata (if exists)
163
+ model_type = MODEL_METADATA[model_name]["model_type"]
164
+ if model_type is not None:
165
+ for member in cls:
166
+ if member.name == model_type:
167
+ return member
168
+
169
+ # 2. infer model type from model name
170
  model_type = None
171
  model_name_lower = model_name.lower()
172
  if "gpt" in model_name_lower:
 
257
  def __init__(
258
  self,
259
  model_name,
 
 
 
 
 
 
 
 
 
260
  user="",
261
+ config=None,
262
  ) -> None:
263
+
264
+ if config is not None:
265
+ temp = MODEL_METADATA[model_name].copy()
266
+ keys_with_diff_values = {key: temp[key] for key in temp if key in DEFAULT_METADATA and temp[key] != DEFAULT_METADATA[key]}
267
+ config.update(keys_with_diff_values)
268
+ temp.update(config)
269
+ config = temp
270
+ else:
271
+ config = MODEL_METADATA[model_name]
272
+
273
+ self.model_name = config["model_name"]
274
+ self.multimodal = config["multimodal"]
275
+ self.description = config["description"]
276
+ self.token_upper_limit = config["token_limit"]
277
+ self.system_prompt = config["system"]
278
+ self.api_key = config["api_key"]
279
+ self.api_host = config["api_host"]
280
+
281
+ self.interrupted = False
282
+ self.need_api_key = self.api_key is not None
283
  self.history = []
284
  self.all_token_counts = []
285
  self.model_type = ModelType.get_type(model_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
286
  self.history_file_path = get_first_history_name(user)
287
  self.user_name = user
288
  self.chatbot = []
289
 
290
+ self.default_single_turn = config["single_turn"]
291
+ self.default_temperature = config["temperature"]
292
+ self.default_top_p = config["top_p"]
293
+ self.default_n_choices = config["n_choices"]
294
+ self.default_stop_sequence = config["stop"]
295
+ self.default_max_generation_token = config["max_generation"]
296
+ self.default_presence_penalty = config["presence_penalty"]
297
+ self.default_frequency_penalty = config["frequency_penalty"]
298
+ self.default_logit_bias = config["logit_bias"]
299
  self.default_user_identifier = user
300
 
301
+ self.single_turn = self.default_single_turn
302
+ self.temperature = self.default_temperature
303
+ self.top_p = self.default_top_p
304
+ self.n_choices = self.default_n_choices
305
+ self.stop_sequence = self.default_stop_sequence
306
+ self.max_generation_token = self.default_max_generation_token
307
+ self.presence_penalty = self.default_presence_penalty
308
+ self.frequency_penalty = self.default_frequency_penalty
309
+ self.logit_bias = self.default_logit_bias
310
  self.user_identifier = user
311
 
312
  self.metadata = {}
 
1071
  self.reset()
1072
  return (
1073
  os.path.basename(self.history_file_path),
1074
+ self.system_prompt,
1075
  [],
1076
  self.single_turn,
1077
  self.temperature,
modules/models/models.py CHANGED
@@ -35,28 +35,18 @@ def get_model(
35
  model = original_model
36
  chatbot = gr.Chatbot(label=model_name)
37
  try:
38
- if model_type == ModelType.OpenAI:
39
- logging.info(f"正在加载OpenAI模型: {model_name}")
40
- from .OpenAI import OpenAIClient
41
  access_key = os.environ.get("OPENAI_API_KEY", access_key)
42
- model = OpenAIClient(
43
- model_name=model_name,
44
- api_key=access_key,
45
- system_prompt=system_prompt,
46
- user_name=user_name,
47
- )
48
  elif model_type == ModelType.OpenAIInstruct:
49
  logging.info(f"正在加载OpenAI Instruct模型: {model_name}")
50
  from .OpenAIInstruct import OpenAI_Instruct_Client
51
  access_key = os.environ.get("OPENAI_API_KEY", access_key)
52
  model = OpenAI_Instruct_Client(
53
  model_name, api_key=access_key, user_name=user_name)
54
- elif model_type == ModelType.OpenAIVision:
55
- logging.info(f"正在加载OpenAI Vision模型: {model_name}")
56
- from .OpenAIVision import OpenAIVisionClient
57
- access_key = os.environ.get("OPENAI_API_KEY", access_key)
58
- model = OpenAIVisionClient(
59
- model_name, api_key=access_key, user_name=user_name)
60
  elif model_type == ModelType.ChatGLM:
61
  logging.info(f"正在加载ChatGLM模型: {model_name}")
62
  from .ChatGLM import ChatGLM_Client
 
35
  model = original_model
36
  chatbot = gr.Chatbot(label=model_name)
37
  try:
38
+ if model_type == ModelType.OpenAIVision or model_type == ModelType.OpenAI:
39
+ logging.info(f"正在加载 OpenAI 模型: {model_name}")
40
+ from .OpenAIVision import OpenAIVisionClient
41
  access_key = os.environ.get("OPENAI_API_KEY", access_key)
42
+ model = OpenAIVisionClient(
43
+ model_name, api_key=access_key, user_name=user_name)
 
 
 
 
44
  elif model_type == ModelType.OpenAIInstruct:
45
  logging.info(f"正在加载OpenAI Instruct模型: {model_name}")
46
  from .OpenAIInstruct import OpenAI_Instruct_Client
47
  access_key = os.environ.get("OPENAI_API_KEY", access_key)
48
  model = OpenAI_Instruct_Client(
49
  model_name, api_key=access_key, user_name=user_name)
 
 
 
 
 
 
50
  elif model_type == ModelType.ChatGLM:
51
  logging.info(f"正在加载ChatGLM模型: {model_name}")
52
  from .ChatGLM import ChatGLM_Client
modules/presets.py CHANGED
@@ -110,6 +110,29 @@ LOCAL_MODELS = [
110
  "Qwen 14B"
111
  ]
112
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  # Additional metadata for online and local models
114
  MODEL_METADATA = {
115
  "Llama-2-7B":{
@@ -166,11 +189,8 @@ MODEL_METADATA = {
166
  "GPT4 Vision": {
167
  "model_name": "gpt-4-turbo",
168
  "token_limit": 128000,
169
- "multimodal": True
170
- },
171
- "Claude": {
172
- "model_name": "Claude",
173
- "token_limit": 4096,
174
  },
175
  "Claude 3 Haiku": {
176
  "model_name": "claude-3-haiku-20240307",
@@ -190,6 +210,9 @@ MODEL_METADATA = {
190
  "max_generation": 4096,
191
  "multimodal": True
192
  },
 
 
 
193
  "ERNIE-Bot-turbo": {
194
  "model_name": "ERNIE-Bot-turbo",
195
  "token_limit": 1024,
@@ -243,7 +266,19 @@ MODEL_METADATA = {
243
  "Groq Gemma 7B": {
244
  "model_name": "gemma-7b-it",
245
  "token_limit": 8192,
246
- }
 
 
 
 
 
 
 
 
 
 
 
 
247
  }
248
 
249
  if os.environ.get('HIDE_LOCAL_MODELS', 'false') == 'true':
 
110
  "Qwen 14B"
111
  ]
112
 
113
+ DEFAULT_METADATA = {
114
+ "repo_id": None, # HuggingFace repo id, used if this model is meant to be downloaded from HuggingFace then run locally
115
+ "model_name": None, # api model name, used if this model is meant to be used online
116
+ "filelist": None, # file list in the repo to download, now only support .gguf file
117
+ "description": None, # description of the model, displayed in the chat area when no message is present
118
+ "model_type": None, # model type, used to determine the model's behavior. If not set, the model type is inferred from the model name
119
+ "multimodal": False, # whether the model is multimodal
120
+ "api_host": None, # base url for the model's api
121
+ "api_key": None, # api key for the model's api
122
+ "system": INITIAL_SYSTEM_PROMPT, # system prompt for the model
123
+ "token_limit": 4096, # context window size
124
+ "single_turn": False, # whether the model is single turn
125
+ "temperature": 1.0,
126
+ "top_p": 1.0,
127
+ "n_choices": 1,
128
+ "stop": [],
129
+ "max_generation": None, # maximum token limit for a single generation
130
+ "presence_penalty": 0.0,
131
+ "frequency_penalty": 0.0,
132
+ "logit_bias": None,
133
+ "metadata": {} # additional metadata for the model
134
+ }
135
+
136
  # Additional metadata for online and local models
137
  MODEL_METADATA = {
138
  "Llama-2-7B":{
 
189
  "GPT4 Vision": {
190
  "model_name": "gpt-4-turbo",
191
  "token_limit": 128000,
192
+ "multimodal": True,
193
+ "max_generation": 4096,
 
 
 
194
  },
195
  "Claude 3 Haiku": {
196
  "model_name": "claude-3-haiku-20240307",
 
210
  "max_generation": 4096,
211
  "multimodal": True
212
  },
213
+ "川虎助理": {"model_name": "川虎助理"},
214
+ "川虎助理 Pro": {"model_name": "川虎助理 Pro"},
215
+ "DALL-E 3": {"model_name": "dall-e-3"},
216
  "ERNIE-Bot-turbo": {
217
  "model_name": "ERNIE-Bot-turbo",
218
  "token_limit": 1024,
 
266
  "Groq Gemma 7B": {
267
  "model_name": "gemma-7b-it",
268
  "token_limit": 8192,
269
+ },
270
+ "GooglePaLM": {"model_name": "models/chat-bison-001"},
271
+ "xmchat": {"model_name": "xmchat"},
272
+ "Azure OpenAI": {"model_name": "azure-openai"},
273
+ "yuanai-1.0-base_10B": {"model_name": "yuanai-1.0-base_10B"},
274
+ "yuanai-1.0-translate": {"model_name": "yuanai-1.0-translate"},
275
+ "yuanai-1.0-dialog": {"model_name": "yuanai-1.0-dialog"},
276
+ "yuanai-1.0-rhythm_poems": {"model_name": "yuanai-1.0-rhythm_poems"},
277
+ "minimax-abab5-chat": {"model_name": "minimax-abab5-chat"},
278
+ "midjourney": {"model_name": "midjourney"},
279
+ "讯飞星火大模型V3.0": {"model_name": "讯飞星火大模型V3.0"},
280
+ "讯飞星火大模型V2.0": {"model_name": "讯飞星火大模型V2.0"},
281
+ "讯飞星火大模型V1.5": {"model_name": "讯飞星火大模型V1.5"},
282
  }
283
 
284
  if os.environ.get('HIDE_LOCAL_MODELS', 'false') == 'true':
modules/shared.py CHANGED
@@ -3,6 +3,19 @@ import os
3
  import queue
4
  import openai
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  class State:
7
  interrupted = False
8
  multi_api_key = False
@@ -11,6 +24,7 @@ class State:
11
  usage_api_url = USAGE_API_URL
12
  openai_api_base = OPENAI_API_BASE
13
  images_completion_url = IMAGES_COMPLETION_URL
 
14
 
15
  def interrupt(self):
16
  self.interrupted = True
@@ -19,23 +33,16 @@ class State:
19
  self.interrupted = False
20
 
21
  def set_api_host(self, api_host: str):
22
- api_host = api_host.rstrip("/")
23
- if not api_host.startswith("http"):
24
- api_host = f"https://{api_host}"
25
- if api_host.endswith("/v1"):
26
- api_host = api_host[:-3]
27
- self.chat_completion_url = f"{api_host}/v1/chat/completions"
28
- self.images_completion_url = f"{api_host}/v1/images/generations"
29
- self.openai_api_base = f"{api_host}/v1"
30
- self.balance_api_url = f"{api_host}/dashboard/billing/credit_grants"
31
- self.usage_api_url = f"{api_host}/dashboard/billing/usage"
32
- os.environ["OPENAI_API_BASE"] = api_host + "/v1"
33
 
34
  def reset_api_host(self):
35
  self.chat_completion_url = CHAT_COMPLETION_URL
36
  self.images_completion_url = IMAGES_COMPLETION_URL
37
  self.balance_api_url = BALANCE_API_URL
38
  self.usage_api_url = USAGE_API_URL
 
39
  os.environ["OPENAI_API_BASE"] = f"https://{API_HOST}"
40
  return API_HOST
41
 
 
3
  import queue
4
  import openai
5
 
6
+ def format_openai_host(api_host: str):
7
+ api_host = api_host.rstrip("/")
8
+ if not api_host.startswith("http"):
9
+ api_host = f"https://{api_host}"
10
+ if api_host.endswith("/v1"):
11
+ api_host = api_host[:-3]
12
+ chat_completion_url = f"{api_host}/v1/chat/completions"
13
+ images_completion_url = f"{api_host}/v1/images/generations"
14
+ openai_api_base = f"{api_host}/v1"
15
+ balance_api_url = f"{api_host}/dashboard/billing/credit_grants"
16
+ usage_api_url = f"{api_host}/dashboard/billing/usage"
17
+ return chat_completion_url, images_completion_url, openai_api_base, balance_api_url, usage_api_url
18
+
19
  class State:
20
  interrupted = False
21
  multi_api_key = False
 
24
  usage_api_url = USAGE_API_URL
25
  openai_api_base = OPENAI_API_BASE
26
  images_completion_url = IMAGES_COMPLETION_URL
27
+ api_host = API_HOST
28
 
29
  def interrupt(self):
30
  self.interrupted = True
 
33
  self.interrupted = False
34
 
35
  def set_api_host(self, api_host: str):
36
+ self.api_host = api_host
37
+ self.chat_completion_url, self.images_completion_url, self.openai_api_base, self.balance_api_url, self.usage_api_url = format_openai_host(api_host)
38
+ os.environ["OPENAI_API_BASE"] = self.openai_api_base
 
 
 
 
 
 
 
 
39
 
40
  def reset_api_host(self):
41
  self.chat_completion_url = CHAT_COMPLETION_URL
42
  self.images_completion_url = IMAGES_COMPLETION_URL
43
  self.balance_api_url = BALANCE_API_URL
44
  self.usage_api_url = USAGE_API_URL
45
+ self.api_host = API_HOST
46
  os.environ["OPENAI_API_BASE"] = f"https://{API_HOST}"
47
  return API_HOST
48
 
modules/train_func.py CHANGED
@@ -144,6 +144,12 @@ def add_to_models():
144
  data['extra_models'].append(i)
145
  else:
146
  data['extra_models'] = extra_models
 
 
 
 
 
 
147
  with open('config.json', 'w') as f:
148
  commentjson.dump(data, f, indent=4)
149
 
 
144
  data['extra_models'].append(i)
145
  else:
146
  data['extra_models'] = extra_models
147
+ if 'extra_model_metadata' in data:
148
+ for i in extra_models:
149
+ if i not in data['extra_model_metadata']:
150
+ data['extra_model_metadata'][i] = {"model_name": i, "model_type": "OpenAIVision"}
151
+ else:
152
+ data['extra_model_metadata'] = {i: {"model_name": i, "model_type": "OpenAIVision"} for i in extra_models}
153
  with open('config.json', 'w') as f:
154
  commentjson.dump(data, f, indent=4)
155