Spaces:
Sleeping
Sleeping
Tuchuanhuhuhu
commited on
Commit
·
e888600
1
Parent(s):
921af92
feat: Added support for custom models
Browse files- modules/config.py +15 -2
- modules/models/Claude.py +1 -1
- modules/models/DALLE3.py +9 -6
- modules/models/GoogleGemini.py +1 -2
- modules/models/GooglePaLM.py +2 -3
- modules/models/Groq.py +2 -2
- modules/models/OpenAI.py +0 -280
- modules/models/OpenAIInstruct.py +1 -2
- modules/models/OpenAIVision.py +35 -12
- modules/models/XMChat.py +2 -0
- modules/models/base_model.py +48 -50
- modules/models/models.py +5 -15
- modules/presets.py +41 -6
- modules/shared.py +18 -11
- modules/train_func.py +6 -0
modules/config.py
CHANGED
@@ -5,9 +5,11 @@ import logging
|
|
5 |
import sys
|
6 |
import commentjson as json
|
7 |
import colorama
|
|
|
8 |
|
9 |
from . import shared
|
10 |
from . import presets
|
|
|
11 |
|
12 |
|
13 |
__all__ = [
|
@@ -100,14 +102,25 @@ else:
|
|
100 |
sensitive_id = config.get("sensitive_id", "")
|
101 |
sensitive_id = os.environ.get("SENSITIVE_ID", sensitive_id)
|
102 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
if "available_models" in config:
|
104 |
presets.MODELS = config["available_models"]
|
105 |
-
logging.info(
|
106 |
|
107 |
# 模型配置
|
108 |
if "extra_models" in config:
|
109 |
presets.MODELS.extend(config["extra_models"])
|
110 |
-
logging.info(
|
111 |
|
112 |
HIDE_MY_KEY = config.get("hide_my_key", False)
|
113 |
|
|
|
5 |
import sys
|
6 |
import commentjson as json
|
7 |
import colorama
|
8 |
+
from collections import defaultdict
|
9 |
|
10 |
from . import shared
|
11 |
from . import presets
|
12 |
+
from .presets import i18n
|
13 |
|
14 |
|
15 |
__all__ = [
|
|
|
102 |
sensitive_id = config.get("sensitive_id", "")
|
103 |
sensitive_id = os.environ.get("SENSITIVE_ID", sensitive_id)
|
104 |
|
105 |
+
if "extra_model_metadata" in config:
|
106 |
+
presets.MODEL_METADATA.update(config["extra_model_metadata"])
|
107 |
+
logging.info(i18n("已添加 {extra_model_quantity} 个额外的模型元数据").format(extra_model_quantity=len(config["extra_model_metadata"])))
|
108 |
+
|
109 |
+
_model_metadata = {}
|
110 |
+
for k, v in presets.MODEL_METADATA.items():
|
111 |
+
temp_dict = presets.DEFAULT_METADATA.copy()
|
112 |
+
temp_dict.update(v)
|
113 |
+
_model_metadata[k] = temp_dict
|
114 |
+
presets.MODEL_METADATA = _model_metadata
|
115 |
+
|
116 |
if "available_models" in config:
|
117 |
presets.MODELS = config["available_models"]
|
118 |
+
logging.info(i18n("已设置可用模型:{available_models}").format(available_models=config["available_models"]))
|
119 |
|
120 |
# 模型配置
|
121 |
if "extra_models" in config:
|
122 |
presets.MODELS.extend(config["extra_models"])
|
123 |
+
logging.info(i18n("已添加额外的模型:{extra_models}").format(extra_models=config["extra_models"]))
|
124 |
|
125 |
HIDE_MY_KEY = config.get("hide_my_key", False)
|
126 |
|
modules/models/Claude.py
CHANGED
@@ -11,7 +11,7 @@ class Claude_Client(BaseLLMModel):
|
|
11 |
self.api_secret = api_secret
|
12 |
if None in [self.api_secret]:
|
13 |
raise Exception("请在配置文件或者环境变量中设置Claude的API Secret")
|
14 |
-
self.claude_client = Anthropic(api_key=self.api_secret)
|
15 |
|
16 |
def _get_claude_style_history(self):
|
17 |
history = []
|
|
|
11 |
self.api_secret = api_secret
|
12 |
if None in [self.api_secret]:
|
13 |
raise Exception("请在配置文件或者环境变量中设置Claude的API Secret")
|
14 |
+
self.claude_client = Anthropic(api_key=self.api_secret, base_url=self.api_host)
|
15 |
|
16 |
def _get_claude_style_history(self):
|
17 |
history = []
|
modules/models/DALLE3.py
CHANGED
@@ -7,8 +7,11 @@ from ..config import retrieve_proxy, sensitive_id
|
|
7 |
|
8 |
class OpenAI_DALLE3_Client(BaseLLMModel):
|
9 |
def __init__(self, model_name, api_key, user_name="") -> None:
|
10 |
-
super().__init__(model_name=model_name, user=user_name)
|
11 |
-
self.
|
|
|
|
|
|
|
12 |
self._refresh_header()
|
13 |
|
14 |
def _get_dalle3_prompt(self):
|
@@ -24,7 +27,7 @@ class OpenAI_DALLE3_Client(BaseLLMModel):
|
|
24 |
"Authorization": f"Bearer {self.api_key}"
|
25 |
}
|
26 |
payload = {
|
27 |
-
"model":
|
28 |
"prompt": prompt,
|
29 |
"n": 1,
|
30 |
"size": "1024x1024",
|
@@ -35,13 +38,13 @@ class OpenAI_DALLE3_Client(BaseLLMModel):
|
|
35 |
else:
|
36 |
timeout = TIMEOUT_ALL
|
37 |
|
38 |
-
if
|
39 |
-
logging.debug(f"使用自定义API URL: {
|
40 |
|
41 |
with retrieve_proxy():
|
42 |
try:
|
43 |
response = requests.post(
|
44 |
-
|
45 |
headers=headers,
|
46 |
json=payload,
|
47 |
stream=stream,
|
|
|
7 |
|
8 |
class OpenAI_DALLE3_Client(BaseLLMModel):
|
9 |
def __init__(self, model_name, api_key, user_name="") -> None:
|
10 |
+
super().__init__(model_name=model_name, user=user_name, config={"api_key": api_key})
|
11 |
+
if self.api_host is not None:
|
12 |
+
self.chat_completion_url, self.images_completion_url, self.openai_api_base, self.balance_api_url, self.usage_api_url = shared.format_openai_host(self.api_host)
|
13 |
+
else:
|
14 |
+
self.api_host, self.chat_completion_url, self.images_completion_url, self.openai_api_base, self.balance_api_url, self.usage_api_url = shared.state.api_host, shared.state.chat_completion_url, shared.state.images_completion_url, shared.state.openai_api_base, shared.state.balance_api_url, shared.state.usage_api_url
|
15 |
self._refresh_header()
|
16 |
|
17 |
def _get_dalle3_prompt(self):
|
|
|
27 |
"Authorization": f"Bearer {self.api_key}"
|
28 |
}
|
29 |
payload = {
|
30 |
+
"model": self.model_name,
|
31 |
"prompt": prompt,
|
32 |
"n": 1,
|
33 |
"size": "1024x1024",
|
|
|
38 |
else:
|
39 |
timeout = TIMEOUT_ALL
|
40 |
|
41 |
+
if self.images_completion_url != IMAGES_COMPLETION_URL:
|
42 |
+
logging.debug(f"使用自定义API URL: {self.images_completion_url}")
|
43 |
|
44 |
with retrieve_proxy():
|
45 |
try:
|
46 |
response = requests.post(
|
47 |
+
self.images_completion_url,
|
48 |
headers=headers,
|
49 |
json=payload,
|
50 |
stream=stream,
|
modules/models/GoogleGemini.py
CHANGED
@@ -17,8 +17,7 @@ from .base_model import BaseLLMModel
|
|
17 |
|
18 |
class GoogleGeminiClient(BaseLLMModel):
|
19 |
def __init__(self, model_name, api_key, user_name="") -> None:
|
20 |
-
super().__init__(model_name=model_name, user=user_name)
|
21 |
-
self.api_key = api_key
|
22 |
if "vision" in model_name.lower():
|
23 |
self.multimodal = True
|
24 |
else:
|
|
|
17 |
|
18 |
class GoogleGeminiClient(BaseLLMModel):
|
19 |
def __init__(self, model_name, api_key, user_name="") -> None:
|
20 |
+
super().__init__(model_name=model_name, user=user_name, config={"api_key": api_key})
|
|
|
21 |
if "vision" in model_name.lower():
|
22 |
self.multimodal = True
|
23 |
else:
|
modules/models/GooglePaLM.py
CHANGED
@@ -4,8 +4,7 @@ import google.generativeai as palm
|
|
4 |
|
5 |
class Google_PaLM_Client(BaseLLMModel):
|
6 |
def __init__(self, model_name, api_key, user_name="") -> None:
|
7 |
-
super().__init__(model_name=model_name, user=user_name)
|
8 |
-
self.api_key = api_key
|
9 |
|
10 |
def _get_palm_style_input(self):
|
11 |
new_history = []
|
@@ -20,7 +19,7 @@ class Google_PaLM_Client(BaseLLMModel):
|
|
20 |
palm.configure(api_key=self.api_key)
|
21 |
messages = self._get_palm_style_input()
|
22 |
response = palm.chat(context=self.system_prompt, messages=messages,
|
23 |
-
temperature=self.temperature, top_p=self.top_p)
|
24 |
if response.last is not None:
|
25 |
return response.last, len(response.last)
|
26 |
else:
|
|
|
4 |
|
5 |
class Google_PaLM_Client(BaseLLMModel):
|
6 |
def __init__(self, model_name, api_key, user_name="") -> None:
|
7 |
+
super().__init__(model_name=model_name, user=user_name, config={"api_key": api_key})
|
|
|
8 |
|
9 |
def _get_palm_style_input(self):
|
10 |
new_history = []
|
|
|
19 |
palm.configure(api_key=self.api_key)
|
20 |
messages = self._get_palm_style_input()
|
21 |
response = palm.chat(context=self.system_prompt, messages=messages,
|
22 |
+
temperature=self.temperature, top_p=self.top_p, model=self.model_name)
|
23 |
if response.last is not None:
|
24 |
return response.last, len(response.last)
|
25 |
else:
|
modules/models/Groq.py
CHANGED
@@ -18,10 +18,10 @@ from .base_model import BaseLLMModel
|
|
18 |
|
19 |
class Groq_Client(BaseLLMModel):
|
20 |
def __init__(self, model_name, api_key, user_name="") -> None:
|
21 |
-
super().__init__(model_name=model_name, user=user_name)
|
22 |
-
self.api_key = api_key
|
23 |
self.client = Groq(
|
24 |
api_key=os.environ.get("GROQ_API_KEY"),
|
|
|
25 |
)
|
26 |
|
27 |
def _get_groq_style_input(self):
|
|
|
18 |
|
19 |
class Groq_Client(BaseLLMModel):
|
20 |
def __init__(self, model_name, api_key, user_name="") -> None:
|
21 |
+
super().__init__(model_name=model_name, user=user_name, api_key=api_key)
|
|
|
22 |
self.client = Groq(
|
23 |
api_key=os.environ.get("GROQ_API_KEY"),
|
24 |
+
base_url=self.api_host,
|
25 |
)
|
26 |
|
27 |
def _get_groq_style_input(self):
|
modules/models/OpenAI.py
DELETED
@@ -1,280 +0,0 @@
|
|
1 |
-
from __future__ import annotations
|
2 |
-
|
3 |
-
import json
|
4 |
-
import logging
|
5 |
-
import traceback
|
6 |
-
|
7 |
-
import colorama
|
8 |
-
import requests
|
9 |
-
|
10 |
-
from .. import shared
|
11 |
-
from ..config import retrieve_proxy, sensitive_id, usage_limit
|
12 |
-
from ..index_func import *
|
13 |
-
from ..presets import *
|
14 |
-
from ..utils import *
|
15 |
-
from .base_model import BaseLLMModel
|
16 |
-
|
17 |
-
|
18 |
-
class OpenAIClient(BaseLLMModel):
|
19 |
-
def __init__(
|
20 |
-
self,
|
21 |
-
model_name,
|
22 |
-
api_key,
|
23 |
-
system_prompt=INITIAL_SYSTEM_PROMPT,
|
24 |
-
temperature=1.0,
|
25 |
-
top_p=1.0,
|
26 |
-
user_name=""
|
27 |
-
) -> None:
|
28 |
-
super().__init__(
|
29 |
-
model_name=model_name,
|
30 |
-
temperature=temperature,
|
31 |
-
top_p=top_p,
|
32 |
-
system_prompt=system_prompt,
|
33 |
-
user=user_name
|
34 |
-
)
|
35 |
-
self.api_key = api_key
|
36 |
-
self.need_api_key = True
|
37 |
-
self._refresh_header()
|
38 |
-
|
39 |
-
def get_answer_stream_iter(self):
|
40 |
-
if not self.api_key:
|
41 |
-
raise Exception(NO_APIKEY_MSG)
|
42 |
-
response = self._get_response(stream=True)
|
43 |
-
if response is not None:
|
44 |
-
iter = self._decode_chat_response(response)
|
45 |
-
partial_text = ""
|
46 |
-
for i in iter:
|
47 |
-
partial_text += i
|
48 |
-
yield partial_text
|
49 |
-
else:
|
50 |
-
yield STANDARD_ERROR_MSG + GENERAL_ERROR_MSG
|
51 |
-
|
52 |
-
def get_answer_at_once(self):
|
53 |
-
if not self.api_key:
|
54 |
-
raise Exception(NO_APIKEY_MSG)
|
55 |
-
response = self._get_response()
|
56 |
-
response = json.loads(response.text)
|
57 |
-
content = response["choices"][0]["message"]["content"]
|
58 |
-
total_token_count = response["usage"]["total_tokens"]
|
59 |
-
return content, total_token_count
|
60 |
-
|
61 |
-
def count_token(self, user_input):
|
62 |
-
input_token_count = count_token(construct_user(user_input))
|
63 |
-
if self.system_prompt is not None and len(self.all_token_counts) == 0:
|
64 |
-
system_prompt_token_count = count_token(
|
65 |
-
construct_system(self.system_prompt)
|
66 |
-
)
|
67 |
-
return input_token_count + system_prompt_token_count
|
68 |
-
return input_token_count
|
69 |
-
|
70 |
-
def billing_info(self):
|
71 |
-
try:
|
72 |
-
curr_time = datetime.datetime.now()
|
73 |
-
last_day_of_month = get_last_day_of_month(
|
74 |
-
curr_time).strftime("%Y-%m-%d")
|
75 |
-
first_day_of_month = curr_time.replace(day=1).strftime("%Y-%m-%d")
|
76 |
-
usage_url = f"{shared.state.usage_api_url}?start_date={first_day_of_month}&end_date={last_day_of_month}"
|
77 |
-
try:
|
78 |
-
usage_data = self._get_billing_data(usage_url)
|
79 |
-
except Exception as e:
|
80 |
-
# logging.error(f"获取API使用情况失败: " + str(e))
|
81 |
-
if "Invalid authorization header" in str(e):
|
82 |
-
return i18n("**获取API使用情况失败**,需在填写`config.json`中正确填写sensitive_id")
|
83 |
-
elif "Incorrect API key provided: sess" in str(e):
|
84 |
-
return i18n("**获取API使用情况失败**,sensitive_id错误或已过期")
|
85 |
-
return i18n("**获取API使用情况失败**")
|
86 |
-
# rounded_usage = "{:.5f}".format(usage_data["total_usage"] / 100)
|
87 |
-
rounded_usage = round(usage_data["total_usage"] / 100, 5)
|
88 |
-
usage_percent = round(usage_data["total_usage"] / usage_limit, 2)
|
89 |
-
from ..webui import get_html
|
90 |
-
|
91 |
-
# return i18n("**本月使用金额** ") + f"\u3000 ${rounded_usage}"
|
92 |
-
return get_html("billing_info.html").format(
|
93 |
-
label = i18n("本月使用金额"),
|
94 |
-
usage_percent = usage_percent,
|
95 |
-
rounded_usage = rounded_usage,
|
96 |
-
usage_limit = usage_limit
|
97 |
-
)
|
98 |
-
except requests.exceptions.ConnectTimeout:
|
99 |
-
status_text = (
|
100 |
-
STANDARD_ERROR_MSG + CONNECTION_TIMEOUT_MSG + ERROR_RETRIEVE_MSG
|
101 |
-
)
|
102 |
-
return status_text
|
103 |
-
except requests.exceptions.ReadTimeout:
|
104 |
-
status_text = STANDARD_ERROR_MSG + READ_TIMEOUT_MSG + ERROR_RETRIEVE_MSG
|
105 |
-
return status_text
|
106 |
-
except Exception as e:
|
107 |
-
import traceback
|
108 |
-
traceback.print_exc()
|
109 |
-
logging.error(i18n("获取API使用情况失败:") + str(e))
|
110 |
-
return STANDARD_ERROR_MSG + ERROR_RETRIEVE_MSG
|
111 |
-
|
112 |
-
@shared.state.switching_api_key # 在不开启多账号模式的时候,这个装饰器不会起作用
|
113 |
-
def _get_response(self, stream=False):
|
114 |
-
openai_api_key = self.api_key
|
115 |
-
system_prompt = self.system_prompt
|
116 |
-
history = self.history
|
117 |
-
logging.debug(colorama.Fore.YELLOW +
|
118 |
-
f"{history}" + colorama.Fore.RESET)
|
119 |
-
headers = {
|
120 |
-
"Content-Type": "application/json",
|
121 |
-
"Authorization": f"Bearer {openai_api_key}",
|
122 |
-
}
|
123 |
-
|
124 |
-
if system_prompt is not None:
|
125 |
-
history = [construct_system(system_prompt), *history]
|
126 |
-
|
127 |
-
payload = {
|
128 |
-
"model": self.model_name,
|
129 |
-
"messages": history,
|
130 |
-
"temperature": self.temperature,
|
131 |
-
"top_p": self.top_p,
|
132 |
-
"n": self.n_choices,
|
133 |
-
"stream": stream,
|
134 |
-
"presence_penalty": self.presence_penalty,
|
135 |
-
"frequency_penalty": self.frequency_penalty,
|
136 |
-
}
|
137 |
-
|
138 |
-
if self.max_generation_token is not None:
|
139 |
-
payload["max_tokens"] = self.max_generation_token
|
140 |
-
if self.stop_sequence is not None:
|
141 |
-
payload["stop"] = self.stop_sequence
|
142 |
-
if self.logit_bias is not None:
|
143 |
-
payload["logit_bias"] = self.encoded_logit_bias()
|
144 |
-
if self.user_identifier:
|
145 |
-
payload["user"] = self.user_identifier
|
146 |
-
|
147 |
-
if stream:
|
148 |
-
timeout = TIMEOUT_STREAMING
|
149 |
-
else:
|
150 |
-
timeout = TIMEOUT_ALL
|
151 |
-
|
152 |
-
# 如果有自定义的api-host,使用自定义host发送请求,否则使用默认设置发送请求
|
153 |
-
if shared.state.chat_completion_url != CHAT_COMPLETION_URL:
|
154 |
-
logging.debug(f"使用自定义API URL: {shared.state.chat_completion_url}")
|
155 |
-
|
156 |
-
with retrieve_proxy():
|
157 |
-
try:
|
158 |
-
response = requests.post(
|
159 |
-
shared.state.chat_completion_url,
|
160 |
-
headers=headers,
|
161 |
-
json=payload,
|
162 |
-
stream=stream,
|
163 |
-
timeout=timeout,
|
164 |
-
)
|
165 |
-
except:
|
166 |
-
traceback.print_exc()
|
167 |
-
return None
|
168 |
-
return response
|
169 |
-
|
170 |
-
def _refresh_header(self):
|
171 |
-
self.headers = {
|
172 |
-
"Content-Type": "application/json",
|
173 |
-
"Authorization": f"Bearer {sensitive_id}",
|
174 |
-
}
|
175 |
-
|
176 |
-
|
177 |
-
def _get_billing_data(self, billing_url):
|
178 |
-
with retrieve_proxy():
|
179 |
-
response = requests.get(
|
180 |
-
billing_url,
|
181 |
-
headers=self.headers,
|
182 |
-
timeout=TIMEOUT_ALL,
|
183 |
-
)
|
184 |
-
|
185 |
-
if response.status_code == 200:
|
186 |
-
data = response.json()
|
187 |
-
return data
|
188 |
-
else:
|
189 |
-
raise Exception(
|
190 |
-
f"API request failed with status code {response.status_code}: {response.text}"
|
191 |
-
)
|
192 |
-
|
193 |
-
def _decode_chat_response(self, response):
|
194 |
-
error_msg = ""
|
195 |
-
for chunk in response.iter_lines():
|
196 |
-
if chunk:
|
197 |
-
chunk = chunk.decode()
|
198 |
-
chunk_length = len(chunk)
|
199 |
-
try:
|
200 |
-
chunk = json.loads(chunk[6:])
|
201 |
-
except:
|
202 |
-
print(i18n("JSON解析错误,收到的内容: ") + f"{chunk}")
|
203 |
-
error_msg += chunk
|
204 |
-
continue
|
205 |
-
try:
|
206 |
-
if chunk_length > 6 and "delta" in chunk["choices"][0]:
|
207 |
-
if "finish_reason" in chunk["choices"][0]:
|
208 |
-
finish_reason = chunk["choices"][0]["finish_reason"]
|
209 |
-
else:
|
210 |
-
finish_reason = chunk["finish_reason"]
|
211 |
-
if finish_reason == "stop":
|
212 |
-
break
|
213 |
-
try:
|
214 |
-
yield chunk["choices"][0]["delta"]["content"]
|
215 |
-
except Exception as e:
|
216 |
-
# logging.error(f"Error: {e}")
|
217 |
-
continue
|
218 |
-
except:
|
219 |
-
print(f"ERROR: {chunk}")
|
220 |
-
continue
|
221 |
-
if error_msg and not error_msg=="data: [DONE]":
|
222 |
-
raise Exception(error_msg)
|
223 |
-
|
224 |
-
def set_key(self, new_access_key):
|
225 |
-
ret = super().set_key(new_access_key)
|
226 |
-
self._refresh_header()
|
227 |
-
return ret
|
228 |
-
|
229 |
-
def _single_query_at_once(self, history, temperature=1.0):
|
230 |
-
timeout = TIMEOUT_ALL
|
231 |
-
headers = {
|
232 |
-
"Content-Type": "application/json",
|
233 |
-
"Authorization": f"Bearer {self.api_key}",
|
234 |
-
"temperature": f"{temperature}",
|
235 |
-
}
|
236 |
-
payload = {
|
237 |
-
"model": self.model_name,
|
238 |
-
"messages": history,
|
239 |
-
}
|
240 |
-
# 如果有自定义的api-host,使用自定义host发送请求,否则使用默认设置发送请求
|
241 |
-
if shared.state.chat_completion_url != CHAT_COMPLETION_URL:
|
242 |
-
logging.debug(f"使用自定义API URL: {shared.state.chat_completion_url}")
|
243 |
-
|
244 |
-
with retrieve_proxy():
|
245 |
-
response = requests.post(
|
246 |
-
shared.state.chat_completion_url,
|
247 |
-
headers=headers,
|
248 |
-
json=payload,
|
249 |
-
stream=False,
|
250 |
-
timeout=timeout,
|
251 |
-
)
|
252 |
-
|
253 |
-
return response
|
254 |
-
|
255 |
-
|
256 |
-
def auto_name_chat_history(self, name_chat_method, user_question, chatbot, single_turn_checkbox):
|
257 |
-
if len(self.history) == 2 and not single_turn_checkbox and not hide_history_when_not_logged_in:
|
258 |
-
user_question = self.history[0]["content"]
|
259 |
-
if name_chat_method == i18n("模型自动总结(消耗tokens)"):
|
260 |
-
ai_answer = self.history[1]["content"]
|
261 |
-
try:
|
262 |
-
history = [
|
263 |
-
{ "role": "system", "content": SUMMARY_CHAT_SYSTEM_PROMPT},
|
264 |
-
{ "role": "user", "content": f"Please write a title based on the following conversation:\n---\nUser: {user_question}\nAssistant: {ai_answer}"}
|
265 |
-
]
|
266 |
-
response = self._single_query_at_once(history, temperature=0.0)
|
267 |
-
response = json.loads(response.text)
|
268 |
-
content = response["choices"][0]["message"]["content"]
|
269 |
-
filename = replace_special_symbols(content) + ".json"
|
270 |
-
except Exception as e:
|
271 |
-
logging.info(f"自动命名失败。{e}")
|
272 |
-
filename = replace_special_symbols(user_question)[:16] + ".json"
|
273 |
-
return self.rename_chat_history(filename, chatbot)
|
274 |
-
elif name_chat_method == i18n("第一条提问"):
|
275 |
-
filename = replace_special_symbols(user_question)[:16] + ".json"
|
276 |
-
return self.rename_chat_history(filename, chatbot)
|
277 |
-
else:
|
278 |
-
return gr.update()
|
279 |
-
else:
|
280 |
-
return gr.update()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
modules/models/OpenAIInstruct.py
CHANGED
@@ -8,8 +8,7 @@ from ..config import retrieve_proxy
|
|
8 |
|
9 |
class OpenAI_Instruct_Client(BaseLLMModel):
|
10 |
def __init__(self, model_name, api_key, user_name="") -> None:
|
11 |
-
super().__init__(model_name=model_name, user=user_name)
|
12 |
-
self.api_key = api_key
|
13 |
|
14 |
def _get_instruct_style_input(self):
|
15 |
return "\n\n".join([item["content"] for item in self.history])
|
|
|
8 |
|
9 |
class OpenAI_Instruct_Client(BaseLLMModel):
|
10 |
def __init__(self, model_name, api_key, user_name="") -> None:
|
11 |
+
super().__init__(model_name=model_name, user=user_name, config={"api_key": api_key})
|
|
|
12 |
|
13 |
def _get_instruct_style_input(self):
|
14 |
return "\n\n".join([item["content"] for item in self.history])
|
modules/models/OpenAIVision.py
CHANGED
@@ -27,22 +27,19 @@ class OpenAIVisionClient(BaseLLMModel):
|
|
27 |
self,
|
28 |
model_name,
|
29 |
api_key,
|
30 |
-
system_prompt=INITIAL_SYSTEM_PROMPT,
|
31 |
-
temperature=1.0,
|
32 |
-
top_p=1.0,
|
33 |
user_name=""
|
34 |
) -> None:
|
35 |
super().__init__(
|
36 |
model_name=model_name,
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
)
|
42 |
-
self.
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
self._refresh_header()
|
47 |
|
48 |
def get_answer_stream_iter(self):
|
@@ -176,7 +173,7 @@ class OpenAIVisionClient(BaseLLMModel):
|
|
176 |
"stream": stream,
|
177 |
"presence_penalty": self.presence_penalty,
|
178 |
"frequency_penalty": self.frequency_penalty,
|
179 |
-
"max_tokens":
|
180 |
}
|
181 |
|
182 |
if self.stop_sequence:
|
@@ -296,3 +293,29 @@ class OpenAIVisionClient(BaseLLMModel):
|
|
296 |
)
|
297 |
|
298 |
return response
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
self,
|
28 |
model_name,
|
29 |
api_key,
|
|
|
|
|
|
|
30 |
user_name=""
|
31 |
) -> None:
|
32 |
super().__init__(
|
33 |
model_name=model_name,
|
34 |
+
user=user_name,
|
35 |
+
config={
|
36 |
+
"api_key": api_key
|
37 |
+
}
|
38 |
)
|
39 |
+
if self.api_host is not None:
|
40 |
+
self.chat_completion_url, self.images_completion_url, self.openai_api_base, self.balance_api_url, self.usage_api_url = shared.format_openai_host(self.api_host)
|
41 |
+
else:
|
42 |
+
self.api_host, self.chat_completion_url, self.images_completion_url, self.openai_api_base, self.balance_api_url, self.usage_api_url = shared.state.api_host, shared.state.chat_completion_url, shared.state.images_completion_url, shared.state.openai_api_base, shared.state.balance_api_url, shared.state.usage_api_url
|
43 |
self._refresh_header()
|
44 |
|
45 |
def get_answer_stream_iter(self):
|
|
|
173 |
"stream": stream,
|
174 |
"presence_penalty": self.presence_penalty,
|
175 |
"frequency_penalty": self.frequency_penalty,
|
176 |
+
"max_tokens": self.max_generation_token
|
177 |
}
|
178 |
|
179 |
if self.stop_sequence:
|
|
|
293 |
)
|
294 |
|
295 |
return response
|
296 |
+
|
297 |
+
def auto_name_chat_history(self, name_chat_method, user_question, chatbot, single_turn_checkbox):
|
298 |
+
if len(self.history) == 2 and not single_turn_checkbox and not hide_history_when_not_logged_in:
|
299 |
+
user_question = self.history[0]["content"]
|
300 |
+
if name_chat_method == i18n("模型自动总结(消耗tokens)"):
|
301 |
+
ai_answer = self.history[1]["content"]
|
302 |
+
try:
|
303 |
+
history = [
|
304 |
+
{ "role": "system", "content": SUMMARY_CHAT_SYSTEM_PROMPT},
|
305 |
+
{ "role": "user", "content": f"Please write a title based on the following conversation:\n---\nUser: {user_question}\nAssistant: {ai_answer}"}
|
306 |
+
]
|
307 |
+
response = self._single_query_at_once(history, temperature=0.0)
|
308 |
+
response = json.loads(response.text)
|
309 |
+
content = response["choices"][0]["message"]["content"]
|
310 |
+
filename = replace_special_symbols(content) + ".json"
|
311 |
+
except Exception as e:
|
312 |
+
logging.info(f"自动命名失败。{e}")
|
313 |
+
filename = replace_special_symbols(user_question)[:16] + ".json"
|
314 |
+
return self.rename_chat_history(filename, chatbot)
|
315 |
+
elif name_chat_method == i18n("第一条提问"):
|
316 |
+
filename = replace_special_symbols(user_question)[:16] + ".json"
|
317 |
+
return self.rename_chat_history(filename, chatbot)
|
318 |
+
else:
|
319 |
+
return gr.update()
|
320 |
+
else:
|
321 |
+
return gr.update()
|
modules/models/XMChat.py
CHANGED
@@ -26,6 +26,8 @@ class XMChat(BaseLLMModel):
|
|
26 |
self.image_path = None
|
27 |
self.xm_history = []
|
28 |
self.url = "https://xmbot.net/web"
|
|
|
|
|
29 |
self.last_conv_id = None
|
30 |
|
31 |
def reset(self, remain_system_prompt=False):
|
|
|
26 |
self.image_path = None
|
27 |
self.xm_history = []
|
28 |
self.url = "https://xmbot.net/web"
|
29 |
+
if self.api_host is not None:
|
30 |
+
self.url = self.api_host
|
31 |
self.last_conv_id = None
|
32 |
|
33 |
def reset(self, remain_system_prompt=False):
|
modules/models/base_model.py
CHANGED
@@ -159,6 +159,14 @@ class ModelType(Enum):
|
|
159 |
|
160 |
@classmethod
|
161 |
def get_type(cls, model_name: str):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
162 |
model_type = None
|
163 |
model_name_lower = model_name.lower()
|
164 |
if "gpt" in model_name_lower:
|
@@ -249,66 +257,56 @@ class BaseLLMModel:
|
|
249 |
def __init__(
|
250 |
self,
|
251 |
model_name,
|
252 |
-
system_prompt=INITIAL_SYSTEM_PROMPT,
|
253 |
-
temperature=1.0,
|
254 |
-
top_p=1.0,
|
255 |
-
n_choices=1,
|
256 |
-
stop=[],
|
257 |
-
max_generation_token=None,
|
258 |
-
presence_penalty=0,
|
259 |
-
frequency_penalty=0,
|
260 |
-
logit_bias=None,
|
261 |
user="",
|
262 |
-
|
263 |
) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
264 |
self.history = []
|
265 |
self.all_token_counts = []
|
266 |
self.model_type = ModelType.get_type(model_name)
|
267 |
-
try:
|
268 |
-
self.model_name = MODEL_METADATA[model_name]["model_name"]
|
269 |
-
except:
|
270 |
-
self.model_name = model_name
|
271 |
-
try:
|
272 |
-
self.multimodal = MODEL_METADATA[model_name]["multimodal"]
|
273 |
-
except:
|
274 |
-
self.multimodal = False
|
275 |
-
if max_generation_token is None:
|
276 |
-
try:
|
277 |
-
max_generation_token = MODEL_METADATA[model_name]["max_generation"]
|
278 |
-
except:
|
279 |
-
pass
|
280 |
-
try:
|
281 |
-
self.token_upper_limit = MODEL_METADATA[model_name]["token_limit"]
|
282 |
-
except KeyError:
|
283 |
-
self.token_upper_limit = DEFAULT_TOKEN_LIMIT
|
284 |
-
self.interrupted = False
|
285 |
-
self.system_prompt = system_prompt
|
286 |
-
self.api_key = None
|
287 |
-
self.need_api_key = False
|
288 |
self.history_file_path = get_first_history_name(user)
|
289 |
self.user_name = user
|
290 |
self.chatbot = []
|
291 |
|
292 |
-
self.default_single_turn = single_turn
|
293 |
-
self.default_temperature = temperature
|
294 |
-
self.default_top_p = top_p
|
295 |
-
self.default_n_choices = n_choices
|
296 |
-
self.default_stop_sequence = stop
|
297 |
-
self.default_max_generation_token =
|
298 |
-
self.default_presence_penalty = presence_penalty
|
299 |
-
self.default_frequency_penalty = frequency_penalty
|
300 |
-
self.default_logit_bias = logit_bias
|
301 |
self.default_user_identifier = user
|
302 |
|
303 |
-
self.single_turn =
|
304 |
-
self.temperature =
|
305 |
-
self.top_p =
|
306 |
-
self.n_choices =
|
307 |
-
self.stop_sequence =
|
308 |
-
self.max_generation_token =
|
309 |
-
self.presence_penalty =
|
310 |
-
self.frequency_penalty =
|
311 |
-
self.logit_bias =
|
312 |
self.user_identifier = user
|
313 |
|
314 |
self.metadata = {}
|
@@ -1073,7 +1071,7 @@ class BaseLLMModel:
|
|
1073 |
self.reset()
|
1074 |
return (
|
1075 |
os.path.basename(self.history_file_path),
|
1076 |
-
|
1077 |
[],
|
1078 |
self.single_turn,
|
1079 |
self.temperature,
|
|
|
159 |
|
160 |
@classmethod
|
161 |
def get_type(cls, model_name: str):
|
162 |
+
# 1. get model type from model metadata (if exists)
|
163 |
+
model_type = MODEL_METADATA[model_name]["model_type"]
|
164 |
+
if model_type is not None:
|
165 |
+
for member in cls:
|
166 |
+
if member.name == model_type:
|
167 |
+
return member
|
168 |
+
|
169 |
+
# 2. infer model type from model name
|
170 |
model_type = None
|
171 |
model_name_lower = model_name.lower()
|
172 |
if "gpt" in model_name_lower:
|
|
|
257 |
def __init__(
|
258 |
self,
|
259 |
model_name,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
260 |
user="",
|
261 |
+
config=None,
|
262 |
) -> None:
|
263 |
+
|
264 |
+
if config is not None:
|
265 |
+
temp = MODEL_METADATA[model_name].copy()
|
266 |
+
keys_with_diff_values = {key: temp[key] for key in temp if key in DEFAULT_METADATA and temp[key] != DEFAULT_METADATA[key]}
|
267 |
+
config.update(keys_with_diff_values)
|
268 |
+
temp.update(config)
|
269 |
+
config = temp
|
270 |
+
else:
|
271 |
+
config = MODEL_METADATA[model_name]
|
272 |
+
|
273 |
+
self.model_name = config["model_name"]
|
274 |
+
self.multimodal = config["multimodal"]
|
275 |
+
self.description = config["description"]
|
276 |
+
self.token_upper_limit = config["token_limit"]
|
277 |
+
self.system_prompt = config["system"]
|
278 |
+
self.api_key = config["api_key"]
|
279 |
+
self.api_host = config["api_host"]
|
280 |
+
|
281 |
+
self.interrupted = False
|
282 |
+
self.need_api_key = self.api_key is not None
|
283 |
self.history = []
|
284 |
self.all_token_counts = []
|
285 |
self.model_type = ModelType.get_type(model_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
286 |
self.history_file_path = get_first_history_name(user)
|
287 |
self.user_name = user
|
288 |
self.chatbot = []
|
289 |
|
290 |
+
self.default_single_turn = config["single_turn"]
|
291 |
+
self.default_temperature = config["temperature"]
|
292 |
+
self.default_top_p = config["top_p"]
|
293 |
+
self.default_n_choices = config["n_choices"]
|
294 |
+
self.default_stop_sequence = config["stop"]
|
295 |
+
self.default_max_generation_token = config["max_generation"]
|
296 |
+
self.default_presence_penalty = config["presence_penalty"]
|
297 |
+
self.default_frequency_penalty = config["frequency_penalty"]
|
298 |
+
self.default_logit_bias = config["logit_bias"]
|
299 |
self.default_user_identifier = user
|
300 |
|
301 |
+
self.single_turn = self.default_single_turn
|
302 |
+
self.temperature = self.default_temperature
|
303 |
+
self.top_p = self.default_top_p
|
304 |
+
self.n_choices = self.default_n_choices
|
305 |
+
self.stop_sequence = self.default_stop_sequence
|
306 |
+
self.max_generation_token = self.default_max_generation_token
|
307 |
+
self.presence_penalty = self.default_presence_penalty
|
308 |
+
self.frequency_penalty = self.default_frequency_penalty
|
309 |
+
self.logit_bias = self.default_logit_bias
|
310 |
self.user_identifier = user
|
311 |
|
312 |
self.metadata = {}
|
|
|
1071 |
self.reset()
|
1072 |
return (
|
1073 |
os.path.basename(self.history_file_path),
|
1074 |
+
self.system_prompt,
|
1075 |
[],
|
1076 |
self.single_turn,
|
1077 |
self.temperature,
|
modules/models/models.py
CHANGED
@@ -35,28 +35,18 @@ def get_model(
|
|
35 |
model = original_model
|
36 |
chatbot = gr.Chatbot(label=model_name)
|
37 |
try:
|
38 |
-
if model_type == ModelType.OpenAI:
|
39 |
-
logging.info(f"正在加载OpenAI模型: {model_name}")
|
40 |
-
from .
|
41 |
access_key = os.environ.get("OPENAI_API_KEY", access_key)
|
42 |
-
model =
|
43 |
-
model_name=
|
44 |
-
api_key=access_key,
|
45 |
-
system_prompt=system_prompt,
|
46 |
-
user_name=user_name,
|
47 |
-
)
|
48 |
elif model_type == ModelType.OpenAIInstruct:
|
49 |
logging.info(f"正在加载OpenAI Instruct模型: {model_name}")
|
50 |
from .OpenAIInstruct import OpenAI_Instruct_Client
|
51 |
access_key = os.environ.get("OPENAI_API_KEY", access_key)
|
52 |
model = OpenAI_Instruct_Client(
|
53 |
model_name, api_key=access_key, user_name=user_name)
|
54 |
-
elif model_type == ModelType.OpenAIVision:
|
55 |
-
logging.info(f"正在加载OpenAI Vision模型: {model_name}")
|
56 |
-
from .OpenAIVision import OpenAIVisionClient
|
57 |
-
access_key = os.environ.get("OPENAI_API_KEY", access_key)
|
58 |
-
model = OpenAIVisionClient(
|
59 |
-
model_name, api_key=access_key, user_name=user_name)
|
60 |
elif model_type == ModelType.ChatGLM:
|
61 |
logging.info(f"正在加载ChatGLM模型: {model_name}")
|
62 |
from .ChatGLM import ChatGLM_Client
|
|
|
35 |
model = original_model
|
36 |
chatbot = gr.Chatbot(label=model_name)
|
37 |
try:
|
38 |
+
if model_type == ModelType.OpenAIVision or model_type == ModelType.OpenAI:
|
39 |
+
logging.info(f"正在加载 OpenAI 模型: {model_name}")
|
40 |
+
from .OpenAIVision import OpenAIVisionClient
|
41 |
access_key = os.environ.get("OPENAI_API_KEY", access_key)
|
42 |
+
model = OpenAIVisionClient(
|
43 |
+
model_name, api_key=access_key, user_name=user_name)
|
|
|
|
|
|
|
|
|
44 |
elif model_type == ModelType.OpenAIInstruct:
|
45 |
logging.info(f"正在加载OpenAI Instruct模型: {model_name}")
|
46 |
from .OpenAIInstruct import OpenAI_Instruct_Client
|
47 |
access_key = os.environ.get("OPENAI_API_KEY", access_key)
|
48 |
model = OpenAI_Instruct_Client(
|
49 |
model_name, api_key=access_key, user_name=user_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
elif model_type == ModelType.ChatGLM:
|
51 |
logging.info(f"正在加载ChatGLM模型: {model_name}")
|
52 |
from .ChatGLM import ChatGLM_Client
|
modules/presets.py
CHANGED
@@ -110,6 +110,29 @@ LOCAL_MODELS = [
|
|
110 |
"Qwen 14B"
|
111 |
]
|
112 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
# Additional metadata for online and local models
|
114 |
MODEL_METADATA = {
|
115 |
"Llama-2-7B":{
|
@@ -166,11 +189,8 @@ MODEL_METADATA = {
|
|
166 |
"GPT4 Vision": {
|
167 |
"model_name": "gpt-4-turbo",
|
168 |
"token_limit": 128000,
|
169 |
-
"multimodal": True
|
170 |
-
|
171 |
-
"Claude": {
|
172 |
-
"model_name": "Claude",
|
173 |
-
"token_limit": 4096,
|
174 |
},
|
175 |
"Claude 3 Haiku": {
|
176 |
"model_name": "claude-3-haiku-20240307",
|
@@ -190,6 +210,9 @@ MODEL_METADATA = {
|
|
190 |
"max_generation": 4096,
|
191 |
"multimodal": True
|
192 |
},
|
|
|
|
|
|
|
193 |
"ERNIE-Bot-turbo": {
|
194 |
"model_name": "ERNIE-Bot-turbo",
|
195 |
"token_limit": 1024,
|
@@ -243,7 +266,19 @@ MODEL_METADATA = {
|
|
243 |
"Groq Gemma 7B": {
|
244 |
"model_name": "gemma-7b-it",
|
245 |
"token_limit": 8192,
|
246 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
247 |
}
|
248 |
|
249 |
if os.environ.get('HIDE_LOCAL_MODELS', 'false') == 'true':
|
|
|
110 |
"Qwen 14B"
|
111 |
]
|
112 |
|
113 |
+
DEFAULT_METADATA = {
|
114 |
+
"repo_id": None, # HuggingFace repo id, used if this model is meant to be downloaded from HuggingFace then run locally
|
115 |
+
"model_name": None, # api model name, used if this model is meant to be used online
|
116 |
+
"filelist": None, # file list in the repo to download, now only support .gguf file
|
117 |
+
"description": None, # description of the model, displayed in the chat area when no message is present
|
118 |
+
"model_type": None, # model type, used to determine the model's behavior. If not set, the model type is inferred from the model name
|
119 |
+
"multimodal": False, # whether the model is multimodal
|
120 |
+
"api_host": None, # base url for the model's api
|
121 |
+
"api_key": None, # api key for the model's api
|
122 |
+
"system": INITIAL_SYSTEM_PROMPT, # system prompt for the model
|
123 |
+
"token_limit": 4096, # context window size
|
124 |
+
"single_turn": False, # whether the model is single turn
|
125 |
+
"temperature": 1.0,
|
126 |
+
"top_p": 1.0,
|
127 |
+
"n_choices": 1,
|
128 |
+
"stop": [],
|
129 |
+
"max_generation": None, # maximum token limit for a single generation
|
130 |
+
"presence_penalty": 0.0,
|
131 |
+
"frequency_penalty": 0.0,
|
132 |
+
"logit_bias": None,
|
133 |
+
"metadata": {} # additional metadata for the model
|
134 |
+
}
|
135 |
+
|
136 |
# Additional metadata for online and local models
|
137 |
MODEL_METADATA = {
|
138 |
"Llama-2-7B":{
|
|
|
189 |
"GPT4 Vision": {
|
190 |
"model_name": "gpt-4-turbo",
|
191 |
"token_limit": 128000,
|
192 |
+
"multimodal": True,
|
193 |
+
"max_generation": 4096,
|
|
|
|
|
|
|
194 |
},
|
195 |
"Claude 3 Haiku": {
|
196 |
"model_name": "claude-3-haiku-20240307",
|
|
|
210 |
"max_generation": 4096,
|
211 |
"multimodal": True
|
212 |
},
|
213 |
+
"川虎助理": {"model_name": "川虎助理"},
|
214 |
+
"川虎助理 Pro": {"model_name": "川虎助理 Pro"},
|
215 |
+
"DALL-E 3": {"model_name": "dall-e-3"},
|
216 |
"ERNIE-Bot-turbo": {
|
217 |
"model_name": "ERNIE-Bot-turbo",
|
218 |
"token_limit": 1024,
|
|
|
266 |
"Groq Gemma 7B": {
|
267 |
"model_name": "gemma-7b-it",
|
268 |
"token_limit": 8192,
|
269 |
+
},
|
270 |
+
"GooglePaLM": {"model_name": "models/chat-bison-001"},
|
271 |
+
"xmchat": {"model_name": "xmchat"},
|
272 |
+
"Azure OpenAI": {"model_name": "azure-openai"},
|
273 |
+
"yuanai-1.0-base_10B": {"model_name": "yuanai-1.0-base_10B"},
|
274 |
+
"yuanai-1.0-translate": {"model_name": "yuanai-1.0-translate"},
|
275 |
+
"yuanai-1.0-dialog": {"model_name": "yuanai-1.0-dialog"},
|
276 |
+
"yuanai-1.0-rhythm_poems": {"model_name": "yuanai-1.0-rhythm_poems"},
|
277 |
+
"minimax-abab5-chat": {"model_name": "minimax-abab5-chat"},
|
278 |
+
"midjourney": {"model_name": "midjourney"},
|
279 |
+
"讯飞星火大模型V3.0": {"model_name": "讯飞星火大模型V3.0"},
|
280 |
+
"讯飞星火大模型V2.0": {"model_name": "讯飞星火大模型V2.0"},
|
281 |
+
"讯飞星火大模型V1.5": {"model_name": "讯飞星火大模型V1.5"},
|
282 |
}
|
283 |
|
284 |
if os.environ.get('HIDE_LOCAL_MODELS', 'false') == 'true':
|
modules/shared.py
CHANGED
@@ -3,6 +3,19 @@ import os
|
|
3 |
import queue
|
4 |
import openai
|
5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
class State:
|
7 |
interrupted = False
|
8 |
multi_api_key = False
|
@@ -11,6 +24,7 @@ class State:
|
|
11 |
usage_api_url = USAGE_API_URL
|
12 |
openai_api_base = OPENAI_API_BASE
|
13 |
images_completion_url = IMAGES_COMPLETION_URL
|
|
|
14 |
|
15 |
def interrupt(self):
|
16 |
self.interrupted = True
|
@@ -19,23 +33,16 @@ class State:
|
|
19 |
self.interrupted = False
|
20 |
|
21 |
def set_api_host(self, api_host: str):
|
22 |
-
api_host = api_host
|
23 |
-
|
24 |
-
|
25 |
-
if api_host.endswith("/v1"):
|
26 |
-
api_host = api_host[:-3]
|
27 |
-
self.chat_completion_url = f"{api_host}/v1/chat/completions"
|
28 |
-
self.images_completion_url = f"{api_host}/v1/images/generations"
|
29 |
-
self.openai_api_base = f"{api_host}/v1"
|
30 |
-
self.balance_api_url = f"{api_host}/dashboard/billing/credit_grants"
|
31 |
-
self.usage_api_url = f"{api_host}/dashboard/billing/usage"
|
32 |
-
os.environ["OPENAI_API_BASE"] = api_host + "/v1"
|
33 |
|
34 |
def reset_api_host(self):
|
35 |
self.chat_completion_url = CHAT_COMPLETION_URL
|
36 |
self.images_completion_url = IMAGES_COMPLETION_URL
|
37 |
self.balance_api_url = BALANCE_API_URL
|
38 |
self.usage_api_url = USAGE_API_URL
|
|
|
39 |
os.environ["OPENAI_API_BASE"] = f"https://{API_HOST}"
|
40 |
return API_HOST
|
41 |
|
|
|
3 |
import queue
|
4 |
import openai
|
5 |
|
6 |
+
def format_openai_host(api_host: str):
|
7 |
+
api_host = api_host.rstrip("/")
|
8 |
+
if not api_host.startswith("http"):
|
9 |
+
api_host = f"https://{api_host}"
|
10 |
+
if api_host.endswith("/v1"):
|
11 |
+
api_host = api_host[:-3]
|
12 |
+
chat_completion_url = f"{api_host}/v1/chat/completions"
|
13 |
+
images_completion_url = f"{api_host}/v1/images/generations"
|
14 |
+
openai_api_base = f"{api_host}/v1"
|
15 |
+
balance_api_url = f"{api_host}/dashboard/billing/credit_grants"
|
16 |
+
usage_api_url = f"{api_host}/dashboard/billing/usage"
|
17 |
+
return chat_completion_url, images_completion_url, openai_api_base, balance_api_url, usage_api_url
|
18 |
+
|
19 |
class State:
|
20 |
interrupted = False
|
21 |
multi_api_key = False
|
|
|
24 |
usage_api_url = USAGE_API_URL
|
25 |
openai_api_base = OPENAI_API_BASE
|
26 |
images_completion_url = IMAGES_COMPLETION_URL
|
27 |
+
api_host = API_HOST
|
28 |
|
29 |
def interrupt(self):
|
30 |
self.interrupted = True
|
|
|
33 |
self.interrupted = False
|
34 |
|
35 |
def set_api_host(self, api_host: str):
|
36 |
+
self.api_host = api_host
|
37 |
+
self.chat_completion_url, self.images_completion_url, self.openai_api_base, self.balance_api_url, self.usage_api_url = format_openai_host(api_host)
|
38 |
+
os.environ["OPENAI_API_BASE"] = self.openai_api_base
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
|
40 |
def reset_api_host(self):
|
41 |
self.chat_completion_url = CHAT_COMPLETION_URL
|
42 |
self.images_completion_url = IMAGES_COMPLETION_URL
|
43 |
self.balance_api_url = BALANCE_API_URL
|
44 |
self.usage_api_url = USAGE_API_URL
|
45 |
+
self.api_host = API_HOST
|
46 |
os.environ["OPENAI_API_BASE"] = f"https://{API_HOST}"
|
47 |
return API_HOST
|
48 |
|
modules/train_func.py
CHANGED
@@ -144,6 +144,12 @@ def add_to_models():
|
|
144 |
data['extra_models'].append(i)
|
145 |
else:
|
146 |
data['extra_models'] = extra_models
|
|
|
|
|
|
|
|
|
|
|
|
|
147 |
with open('config.json', 'w') as f:
|
148 |
commentjson.dump(data, f, indent=4)
|
149 |
|
|
|
144 |
data['extra_models'].append(i)
|
145 |
else:
|
146 |
data['extra_models'] = extra_models
|
147 |
+
if 'extra_model_metadata' in data:
|
148 |
+
for i in extra_models:
|
149 |
+
if i not in data['extra_model_metadata']:
|
150 |
+
data['extra_model_metadata'][i] = {"model_name": i, "model_type": "OpenAIVision"}
|
151 |
+
else:
|
152 |
+
data['extra_model_metadata'] = {i: {"model_name": i, "model_type": "OpenAIVision"} for i in extra_models}
|
153 |
with open('config.json', 'w') as f:
|
154 |
commentjson.dump(data, f, indent=4)
|
155 |
|