Tuchuanhuhuhu commited on
Commit
e99bd71
·
1 Parent(s): fc2938f

feat: 加入gpt-3.5-turbo-instruct模型支持

Browse files
modules/models/OpenAI.py CHANGED
@@ -149,13 +149,13 @@ class OpenAIClient(BaseLLMModel):
149
  timeout = TIMEOUT_ALL
150
 
151
  # 如果有自定义的api-host,使用自定义host发送请求,否则使用默认设置发送请求
152
- if shared.state.completion_url != COMPLETION_URL:
153
- logging.debug(f"使用自定义API URL: {shared.state.completion_url}")
154
 
155
  with retrieve_proxy():
156
  try:
157
  response = requests.post(
158
- shared.state.completion_url,
159
  headers=headers,
160
  json=payload,
161
  stream=stream,
@@ -237,12 +237,12 @@ class OpenAIClient(BaseLLMModel):
237
  "messages": history,
238
  }
239
  # 如果有自定义的api-host,使用自定义host发送请求,否则使用默认设置发送请求
240
- if shared.state.completion_url != COMPLETION_URL:
241
- logging.debug(f"使用自定义API URL: {shared.state.completion_url}")
242
 
243
  with retrieve_proxy():
244
  response = requests.post(
245
- shared.state.completion_url,
246
  headers=headers,
247
  json=payload,
248
  stream=False,
 
149
  timeout = TIMEOUT_ALL
150
 
151
  # 如果有自定义的api-host,使用自定义host发送请求,否则使用默认设置发送请求
152
+ if shared.state.chat_completion_url != CHAT_COMPLETION_URL:
153
+ logging.debug(f"使用自定义API URL: {shared.state.chat_completion_url}")
154
 
155
  with retrieve_proxy():
156
  try:
157
  response = requests.post(
158
+ shared.state.chat_completion_url,
159
  headers=headers,
160
  json=payload,
161
  stream=stream,
 
237
  "messages": history,
238
  }
239
  # 如果有自定义的api-host,使用自定义host发送请求,否则使用默认设置发送请求
240
+ if shared.state.chat_completion_url != CHAT_COMPLETION_URL:
241
+ logging.debug(f"使用自定义API URL: {shared.state.chat_completion_url}")
242
 
243
  with retrieve_proxy():
244
  response = requests.post(
245
+ shared.state.chat_completion_url,
246
  headers=headers,
247
  json=payload,
248
  stream=False,
modules/models/OpenAIInstruct.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import openai
2
+ from .base_model import BaseLLMModel
3
+ from .. import shared
4
+ from ..config import retrieve_proxy
5
+
6
+
7
+ class OpenAI_Instruct_Client(BaseLLMModel):
8
+ def __init__(self, model_name, api_key, user_name="") -> None:
9
+ super().__init__(model_name=model_name, user=user_name)
10
+ self.api_key = api_key
11
+
12
+ def _get_instruct_style_input(self):
13
+ return "\n\n".join([item["content"] for item in self.history])
14
+
15
+ @shared.state.switching_api_key
16
+ def get_answer_at_once(self):
17
+ prompt = self._get_instruct_style_input()
18
+ with retrieve_proxy():
19
+ response = openai.Completion.create(
20
+ api_key=self.api_key,
21
+ api_base=shared.state.openai_api_base,
22
+ model=self.model_name,
23
+ prompt=prompt,
24
+ temperature=self.temperature,
25
+ top_p=self.top_p,
26
+ )
27
+ return response.choices[0].text.strip(), response.usage["total_tokens"]
modules/models/base_model.py CHANGED
@@ -144,13 +144,17 @@ class ModelType(Enum):
144
  LangchainChat = 10
145
  Midjourney = 11
146
  Spark = 12
 
147
 
148
  @classmethod
149
  def get_type(cls, model_name: str):
150
  model_type = None
151
  model_name_lower = model_name.lower()
152
  if "gpt" in model_name_lower:
153
- model_type = ModelType.OpenAI
 
 
 
154
  elif "chatglm" in model_name_lower:
155
  model_type = ModelType.ChatGLM
156
  elif "llama" in model_name_lower or "alpaca" in model_name_lower:
@@ -247,7 +251,7 @@ class BaseLLMModel:
247
 
248
  def billing_info(self):
249
  """get billing infomation, inplement if needed"""
250
- logging.warning("billing info not implemented, using default")
251
  return BILLING_NOT_APPLICABLE_MSG
252
 
253
  def count_token(self, user_input):
 
144
  LangchainChat = 10
145
  Midjourney = 11
146
  Spark = 12
147
+ OpenAIInstruct = 13
148
 
149
  @classmethod
150
  def get_type(cls, model_name: str):
151
  model_type = None
152
  model_name_lower = model_name.lower()
153
  if "gpt" in model_name_lower:
154
+ if "instruct" in model_name_lower:
155
+ model_type = ModelType.OpenAIInstruct
156
+ else:
157
+ model_type = ModelType.OpenAI
158
  elif "chatglm" in model_name_lower:
159
  model_type = ModelType.ChatGLM
160
  elif "llama" in model_name_lower or "alpaca" in model_name_lower:
 
251
 
252
  def billing_info(self):
253
  """get billing infomation, inplement if needed"""
254
+ # logging.warning("billing info not implemented, using default")
255
  return BILLING_NOT_APPLICABLE_MSG
256
 
257
  def count_token(self, user_input):
modules/models/models.py CHANGED
@@ -47,6 +47,12 @@ def get_model(
47
  top_p=top_p,
48
  user_name=user_name,
49
  )
 
 
 
 
 
 
50
  elif model_type == ModelType.ChatGLM:
51
  logging.info(f"正在加载ChatGLM模型: {model_name}")
52
  from .ChatGLM import ChatGLM_Client
 
47
  top_p=top_p,
48
  user_name=user_name,
49
  )
50
+ elif model_type == ModelType.OpenAIInstruct:
51
+ logging.info(f"正在加载OpenAI Instruct模型: {model_name}")
52
+ from .OpenAIInstruct import OpenAI_Instruct_Client
53
+ access_key = os.environ.get("OPENAI_API_KEY", access_key)
54
+ model = OpenAI_Instruct_Client(
55
+ model_name, api_key=access_key, user_name=user_name)
56
  elif model_type == ModelType.ChatGLM:
57
  logging.info(f"正在加载ChatGLM模型: {model_name}")
58
  from .ChatGLM import ChatGLM_Client
modules/presets.py CHANGED
@@ -14,7 +14,9 @@ LLAMA_INFERENCER = None
14
  # ChatGPT 设置
15
  INITIAL_SYSTEM_PROMPT = "You are a helpful assistant."
16
  API_HOST = "api.openai.com"
17
- COMPLETION_URL = "https://api.openai.com/v1/chat/completions"
 
 
18
  BALANCE_API_URL="https://api.openai.com/dashboard/billing/credit_grants"
19
  USAGE_API_URL="https://api.openai.com/dashboard/billing/usage"
20
  HISTORY_DIR = Path("history")
@@ -50,10 +52,11 @@ CHUANHU_DESCRIPTION = i18n("由Bilibili [土川虎虎虎](https://space.bilibili
50
 
51
  ONLINE_MODELS = [
52
  "gpt-3.5-turbo",
 
53
  "gpt-3.5-turbo-16k",
 
54
  "gpt-3.5-turbo-0301",
55
  "gpt-3.5-turbo-0613",
56
- "gpt-4",
57
  "gpt-4-0314",
58
  "gpt-4-0613",
59
  "gpt-4-32k",
 
14
  # ChatGPT 设置
15
  INITIAL_SYSTEM_PROMPT = "You are a helpful assistant."
16
  API_HOST = "api.openai.com"
17
+ OPENAI_API_BASE = "https://api.openai.com/v1"
18
+ CHAT_COMPLETION_URL = "https://api.openai.com/v1/chat/completions"
19
+ COMPLETION_URL = "https://api.openai.com/v1/completions"
20
  BALANCE_API_URL="https://api.openai.com/dashboard/billing/credit_grants"
21
  USAGE_API_URL="https://api.openai.com/dashboard/billing/usage"
22
  HISTORY_DIR = Path("history")
 
52
 
53
  ONLINE_MODELS = [
54
  "gpt-3.5-turbo",
55
+ "gpt-3.5-turbo-instruct",
56
  "gpt-3.5-turbo-16k",
57
+ "gpt-4",
58
  "gpt-3.5-turbo-0301",
59
  "gpt-3.5-turbo-0613",
 
60
  "gpt-4-0314",
61
  "gpt-4-0613",
62
  "gpt-4-32k",
modules/shared.py CHANGED
@@ -1,4 +1,4 @@
1
- from modules.presets import COMPLETION_URL, BALANCE_API_URL, USAGE_API_URL, API_HOST
2
  import os
3
  import queue
4
  import openai
@@ -6,9 +6,10 @@ import openai
6
  class State:
7
  interrupted = False
8
  multi_api_key = False
9
- completion_url = COMPLETION_URL
10
  balance_api_url = BALANCE_API_URL
11
  usage_api_url = USAGE_API_URL
 
12
 
13
  def interrupt(self):
14
  self.interrupted = True
@@ -22,13 +23,14 @@ class State:
22
  api_host = f"https://{api_host}"
23
  if api_host.endswith("/v1"):
24
  api_host = api_host[:-3]
25
- self.completion_url = f"{api_host}/v1/chat/completions"
 
26
  self.balance_api_url = f"{api_host}/dashboard/billing/credit_grants"
27
  self.usage_api_url = f"{api_host}/dashboard/billing/usage"
28
  os.environ["OPENAI_API_BASE"] = api_host
29
 
30
  def reset_api_host(self):
31
- self.completion_url = COMPLETION_URL
32
  self.balance_api_url = BALANCE_API_URL
33
  self.usage_api_url = USAGE_API_URL
34
  os.environ["OPENAI_API_BASE"] = f"https://{API_HOST}"
@@ -36,7 +38,7 @@ class State:
36
 
37
  def reset_all(self):
38
  self.interrupted = False
39
- self.completion_url = COMPLETION_URL
40
 
41
  def set_api_key_queue(self, api_key_list):
42
  self.multi_api_key = True
 
1
+ from modules.presets import CHAT_COMPLETION_URL, BALANCE_API_URL, USAGE_API_URL, API_HOST, OPENAI_API_BASE
2
  import os
3
  import queue
4
  import openai
 
6
  class State:
7
  interrupted = False
8
  multi_api_key = False
9
+ chat_completion_url = CHAT_COMPLETION_URL
10
  balance_api_url = BALANCE_API_URL
11
  usage_api_url = USAGE_API_URL
12
+ openai_api_base = OPENAI_API_BASE
13
 
14
  def interrupt(self):
15
  self.interrupted = True
 
23
  api_host = f"https://{api_host}"
24
  if api_host.endswith("/v1"):
25
  api_host = api_host[:-3]
26
+ self.chat_completion_url = f"{api_host}/v1/chat/completions"
27
+ self.openai_api_base = f"{api_host}/v1"
28
  self.balance_api_url = f"{api_host}/dashboard/billing/credit_grants"
29
  self.usage_api_url = f"{api_host}/dashboard/billing/usage"
30
  os.environ["OPENAI_API_BASE"] = api_host
31
 
32
  def reset_api_host(self):
33
+ self.chat_completion_url = CHAT_COMPLETION_URL
34
  self.balance_api_url = BALANCE_API_URL
35
  self.usage_api_url = USAGE_API_URL
36
  os.environ["OPENAI_API_BASE"] = f"https://{API_HOST}"
 
38
 
39
  def reset_all(self):
40
  self.interrupted = False
41
+ self.chat_completion_url = CHAT_COMPLETION_URL
42
 
43
  def set_api_key_queue(self, api_key_list):
44
  self.multi_api_key = True