Tuchuanhuhuhu commited on
Commit
2c3fb9f
·
1 Parent(s): 69f0c41

feature: 加入GPT4-Turbo和GPT4-Vision支持 #927 #929

Browse files
ChuanhuChatbot.py CHANGED
@@ -578,7 +578,7 @@ with gr.Blocks(theme=small_and_beautiful_theme) as demo:
578
 
579
  # submitBtn.click(auto_name_chat_history, [current_model, user_question, chatbot, user_name], [historySelectList], show_progress=False)
580
 
581
- index_files.change(handle_file_upload, [current_model, index_files, chatbot, language_select_dropdown], [
582
  index_files, chatbot, status_display])
583
  summarize_btn.click(handle_summarize_index, [
584
  current_model, index_files, chatbot, language_select_dropdown], [chatbot, status_display])
 
578
 
579
  # submitBtn.click(auto_name_chat_history, [current_model, user_question, chatbot, user_name], [historySelectList], show_progress=False)
580
 
581
+ index_files.upload(handle_file_upload, [current_model, index_files, chatbot, language_select_dropdown], [
582
  index_files, chatbot, status_display])
583
  summarize_btn.click(handle_summarize_index, [
584
  current_model, index_files, chatbot, language_select_dropdown], [chatbot, status_display])
modules/models/OpenAI.py CHANGED
@@ -26,7 +26,7 @@ class OpenAIClient(BaseLLMModel):
26
  user_name=""
27
  ) -> None:
28
  super().__init__(
29
- model_name=model_name,
30
  temperature=temperature,
31
  top_p=top_p,
32
  system_prompt=system_prompt,
 
26
  user_name=""
27
  ) -> None:
28
  super().__init__(
29
+ model_name=MODEL_METADATA[model_name]["model_name"],
30
  temperature=temperature,
31
  top_p=top_p,
32
  system_prompt=system_prompt,
modules/models/OpenAIVision.py ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import logging
5
+ import traceback
6
+ import base64
7
+
8
+ import colorama
9
+ import requests
10
+ from io import BytesIO
11
+ import uuid
12
+
13
+ import requests
14
+ from PIL import Image
15
+
16
+ from .. import shared
17
+ from ..config import retrieve_proxy, sensitive_id, usage_limit
18
+ from ..index_func import *
19
+ from ..presets import *
20
+ from ..utils import *
21
+ from .base_model import BaseLLMModel
22
+
23
+
24
+ class OpenAIVisionClient(BaseLLMModel):
25
+ def __init__(
26
+ self,
27
+ model_name,
28
+ api_key,
29
+ system_prompt=INITIAL_SYSTEM_PROMPT,
30
+ temperature=1.0,
31
+ top_p=1.0,
32
+ user_name=""
33
+ ) -> None:
34
+ super().__init__(
35
+ model_name=MODEL_METADATA[model_name]["model_name"],
36
+ temperature=temperature,
37
+ top_p=top_p,
38
+ system_prompt=system_prompt,
39
+ user=user_name
40
+ )
41
+ self.api_key = api_key
42
+ self.need_api_key = True
43
+ self.max_generation_token = 4096
44
+ self.images = []
45
+ self._refresh_header()
46
+
47
+ def get_answer_stream_iter(self):
48
+ response = self._get_response(stream=True)
49
+ if response is not None:
50
+ iter = self._decode_chat_response(response)
51
+ partial_text = ""
52
+ for i in iter:
53
+ partial_text += i
54
+ yield partial_text
55
+ else:
56
+ yield STANDARD_ERROR_MSG + GENERAL_ERROR_MSG
57
+
58
+ def get_answer_at_once(self):
59
+ response = self._get_response()
60
+ response = json.loads(response.text)
61
+ content = response["choices"][0]["message"]["content"]
62
+ total_token_count = response["usage"]["total_tokens"]
63
+ return content, total_token_count
64
+
65
+ def try_read_image(self, filepath):
66
+ def is_image_file(filepath):
67
+ # 判断文件是否为图片
68
+ valid_image_extensions = [
69
+ ".jpg", ".jpeg", ".png", ".bmp", ".gif", ".tiff"]
70
+ file_extension = os.path.splitext(filepath)[1].lower()
71
+ return file_extension in valid_image_extensions
72
+ def image_to_base64(image_path):
73
+ # 打开并加载图片
74
+ img = Image.open(image_path)
75
+
76
+ # 获取图片的宽度和高度
77
+ width, height = img.size
78
+
79
+ # 计算压缩比例,以确保最长边小于4096像素
80
+ max_dimension = 2048
81
+ scale_ratio = min(max_dimension / width, max_dimension / height)
82
+
83
+ if scale_ratio < 1:
84
+ # 按压缩比例调整图片大小
85
+ new_width = int(width * scale_ratio)
86
+ new_height = int(height * scale_ratio)
87
+ img = img.resize((new_width, new_height), Image.ANTIALIAS)
88
+
89
+ # 将图片转换为jpg格式的二进制数据
90
+ buffer = BytesIO()
91
+ if img.mode == "RGBA":
92
+ img = img.convert("RGB")
93
+ img.save(buffer, format='JPEG')
94
+ binary_image = buffer.getvalue()
95
+
96
+ # 对二进制数据进行Base64编码
97
+ base64_image = base64.b64encode(binary_image).decode('utf-8')
98
+
99
+ return base64_image
100
+
101
+ if is_image_file(filepath):
102
+ logging.info(f"读取图片文件: {filepath}")
103
+ base64_image = image_to_base64(filepath)
104
+ self.images.append({
105
+ "path": filepath,
106
+ "base64": base64_image,
107
+ })
108
+
109
+ def handle_file_upload(self, files, chatbot, language):
110
+ """if the model accepts multi modal input, implement this function"""
111
+ if files:
112
+ for file in files:
113
+ if file.name:
114
+ self.try_read_image(file.name)
115
+ if self.images is not None:
116
+ chatbot = chatbot + [([image["path"] for image in self.images], None)]
117
+ return None, chatbot, None
118
+
119
+ def prepare_inputs(self, real_inputs, use_websearch, files, reply_language, chatbot):
120
+ fake_inputs = real_inputs
121
+ display_append = ""
122
+ limited_context = False
123
+ return limited_context, fake_inputs, display_append, real_inputs, chatbot
124
+
125
+
126
+ def count_token(self, user_input):
127
+ input_token_count = count_token(construct_user(user_input))
128
+ if self.system_prompt is not None and len(self.all_token_counts) == 0:
129
+ system_prompt_token_count = count_token(
130
+ construct_system(self.system_prompt)
131
+ )
132
+ return input_token_count + system_prompt_token_count
133
+ return input_token_count
134
+
135
+ def billing_info(self):
136
+ try:
137
+ curr_time = datetime.datetime.now()
138
+ last_day_of_month = get_last_day_of_month(
139
+ curr_time).strftime("%Y-%m-%d")
140
+ first_day_of_month = curr_time.replace(day=1).strftime("%Y-%m-%d")
141
+ usage_url = f"{shared.state.usage_api_url}?start_date={first_day_of_month}&end_date={last_day_of_month}"
142
+ try:
143
+ usage_data = self._get_billing_data(usage_url)
144
+ except Exception as e:
145
+ # logging.error(f"获取API使用情况失败: " + str(e))
146
+ if "Invalid authorization header" in str(e):
147
+ return i18n("**获取API使用情况失败**,需在填写`config.json`中正确填写sensitive_id")
148
+ elif "Incorrect API key provided: sess" in str(e):
149
+ return i18n("**获取API使用情况失败**,sensitive_id错误或已过期")
150
+ return i18n("**获取API使用情况失败**")
151
+ # rounded_usage = "{:.5f}".format(usage_data["total_usage"] / 100)
152
+ rounded_usage = round(usage_data["total_usage"] / 100, 5)
153
+ usage_percent = round(usage_data["total_usage"] / usage_limit, 2)
154
+ from ..webui import get_html
155
+
156
+ # return i18n("**本月使用金额** ") + f"\u3000 ${rounded_usage}"
157
+ return get_html("billing_info.html").format(
158
+ label = i18n("本月使用金额"),
159
+ usage_percent = usage_percent,
160
+ rounded_usage = rounded_usage,
161
+ usage_limit = usage_limit
162
+ )
163
+ except requests.exceptions.ConnectTimeout:
164
+ status_text = (
165
+ STANDARD_ERROR_MSG + CONNECTION_TIMEOUT_MSG + ERROR_RETRIEVE_MSG
166
+ )
167
+ return status_text
168
+ except requests.exceptions.ReadTimeout:
169
+ status_text = STANDARD_ERROR_MSG + READ_TIMEOUT_MSG + ERROR_RETRIEVE_MSG
170
+ return status_text
171
+ except Exception as e:
172
+ import traceback
173
+ traceback.print_exc()
174
+ logging.error(i18n("获取API使用情况失败:") + str(e))
175
+ return STANDARD_ERROR_MSG + ERROR_RETRIEVE_MSG
176
+
177
+ def set_token_upper_limit(self, new_upper_limit):
178
+ pass
179
+
180
+ @shared.state.switching_api_key # 在不开启多账号模式的时候,这个装饰器不会起作用
181
+ def _get_response(self, stream=False):
182
+ openai_api_key = self.api_key
183
+ system_prompt = self.system_prompt
184
+ history = self.history
185
+ if self.images:
186
+ self.history[-1]["content"] = [
187
+ {"type": "text", "text": self.history[-1]["content"]},
188
+ *[{"type": "image_url", "image_url": "data:image/jpeg;base64,"+image["base64"]} for image in self.images]
189
+ ]
190
+ self.images = []
191
+ logging.debug(colorama.Fore.YELLOW +
192
+ f"{history}" + colorama.Fore.RESET)
193
+ headers = {
194
+ "Content-Type": "application/json",
195
+ "Authorization": f"Bearer {openai_api_key}",
196
+ }
197
+
198
+ if system_prompt is not None:
199
+ history = [construct_system(system_prompt), *history]
200
+
201
+ payload = {
202
+ "model": self.model_name,
203
+ "messages": history,
204
+ "temperature": self.temperature,
205
+ "top_p": self.top_p,
206
+ "n": self.n_choices,
207
+ "stream": stream,
208
+ "presence_penalty": self.presence_penalty,
209
+ "frequency_penalty": self.frequency_penalty,
210
+ }
211
+
212
+ if self.max_generation_token is not None:
213
+ payload["max_tokens"] = self.max_generation_token
214
+ if self.stop_sequence is not None:
215
+ payload["stop"] = self.stop_sequence
216
+ if self.logit_bias is not None:
217
+ payload["logit_bias"] = self.logit_bias
218
+ if self.user_identifier:
219
+ payload["user"] = self.user_identifier
220
+
221
+ if stream:
222
+ timeout = TIMEOUT_STREAMING
223
+ else:
224
+ timeout = TIMEOUT_ALL
225
+
226
+ # 如果有自定义的api-host,使用自定义host发送请求,否则使用默认设置发送请求
227
+ if shared.state.chat_completion_url != CHAT_COMPLETION_URL:
228
+ logging.debug(f"使用自定义API URL: {shared.state.chat_completion_url}")
229
+
230
+ with retrieve_proxy():
231
+ try:
232
+ response = requests.post(
233
+ shared.state.chat_completion_url,
234
+ headers=headers,
235
+ json=payload,
236
+ stream=stream,
237
+ timeout=timeout,
238
+ )
239
+ except:
240
+ traceback.print_exc()
241
+ return None
242
+ return response
243
+
244
+ def _refresh_header(self):
245
+ self.headers = {
246
+ "Content-Type": "application/json",
247
+ "Authorization": f"Bearer {sensitive_id}",
248
+ }
249
+
250
+
251
+ def _get_billing_data(self, billing_url):
252
+ with retrieve_proxy():
253
+ response = requests.get(
254
+ billing_url,
255
+ headers=self.headers,
256
+ timeout=TIMEOUT_ALL,
257
+ )
258
+
259
+ if response.status_code == 200:
260
+ data = response.json()
261
+ return data
262
+ else:
263
+ raise Exception(
264
+ f"API request failed with status code {response.status_code}: {response.text}"
265
+ )
266
+
267
+ def _decode_chat_response(self, response):
268
+ error_msg = ""
269
+ for chunk in response.iter_lines():
270
+ if chunk:
271
+ chunk = chunk.decode()
272
+ chunk_length = len(chunk)
273
+ try:
274
+ chunk = json.loads(chunk[6:])
275
+ except:
276
+ print(i18n("JSON解析错误,收到的内容: ") + f"{chunk}")
277
+ error_msg += chunk
278
+ continue
279
+ try:
280
+ if chunk_length > 6 and "delta" in chunk["choices"][0]:
281
+ if "finish_details" in chunk["choices"][0]:
282
+ finish_reason = chunk["choices"][0]["finish_details"]
283
+ else:
284
+ finish_reason = chunk["finish_details"]
285
+ if finish_reason == "stop":
286
+ break
287
+ try:
288
+ yield chunk["choices"][0]["delta"]["content"]
289
+ except Exception as e:
290
+ # logging.error(f"Error: {e}")
291
+ continue
292
+ except:
293
+ traceback.print_exc()
294
+ print(f"ERROR: {chunk}")
295
+ continue
296
+ if error_msg and not error_msg=="data: [DONE]":
297
+ raise Exception(error_msg)
298
+
299
+ def set_key(self, new_access_key):
300
+ ret = super().set_key(new_access_key)
301
+ self._refresh_header()
302
+ return ret
303
+
304
+ def _single_query_at_once(self, history, temperature=1.0):
305
+ timeout = TIMEOUT_ALL
306
+ headers = {
307
+ "Content-Type": "application/json",
308
+ "Authorization": f"Bearer {self.api_key}",
309
+ "temperature": f"{temperature}",
310
+ }
311
+ payload = {
312
+ "model": self.model_name,
313
+ "messages": history,
314
+ }
315
+ # 如果有自定义的api-host,使用自定义host发送请求,否则使用默认设置发送请求
316
+ if shared.state.chat_completion_url != CHAT_COMPLETION_URL:
317
+ logging.debug(f"使用自定义API URL: {shared.state.chat_completion_url}")
318
+
319
+ with retrieve_proxy():
320
+ response = requests.post(
321
+ shared.state.chat_completion_url,
322
+ headers=headers,
323
+ json=payload,
324
+ stream=False,
325
+ timeout=timeout,
326
+ )
327
+
328
+ return response
modules/models/base_model.py CHANGED
@@ -147,6 +147,7 @@ class ModelType(Enum):
147
  OpenAIInstruct = 13
148
  Claude = 14
149
  Qwen = 15
 
150
 
151
  @classmethod
152
  def get_type(cls, model_name: str):
@@ -155,6 +156,8 @@ class ModelType(Enum):
155
  if "gpt" in model_name_lower:
156
  if "instruct" in model_name_lower:
157
  model_type = ModelType.OpenAIInstruct
 
 
158
  else:
159
  model_type = ModelType.OpenAI
160
  elif "chatglm" in model_name_lower:
@@ -210,7 +213,7 @@ class BaseLLMModel:
210
  self.model_name = model_name
211
  self.model_type = ModelType.get_type(model_name)
212
  try:
213
- self.token_upper_limit = MODEL_TOKEN_LIMIT[model_name]
214
  except KeyError:
215
  self.token_upper_limit = DEFAULT_TOKEN_LIMIT
216
  self.interrupted = False
@@ -353,10 +356,12 @@ class BaseLLMModel:
353
  return chatbot, status
354
 
355
  def prepare_inputs(self, real_inputs, use_websearch, files, reply_language, chatbot, load_from_cache_if_possible=True):
356
- fake_inputs = None
357
  display_append = []
358
  limited_context = False
359
- fake_inputs = real_inputs
 
 
 
360
  if files:
361
  from langchain.embeddings.huggingface import HuggingFaceEmbeddings
362
  from langchain.vectorstores.base import VectorStoreRetriever
@@ -372,24 +377,32 @@ class BaseLLMModel:
372
  "k": 6, "score_threshold": 0.5})
373
  try:
374
  relevant_documents = retriever.get_relevant_documents(
375
- real_inputs)
376
  except AssertionError:
377
- return self.prepare_inputs(real_inputs, use_websearch, files, reply_language, chatbot, load_from_cache_if_possible=False)
378
  reference_results = [[d.page_content.strip("�"), os.path.basename(
379
  d.metadata["source"])] for d in relevant_documents]
380
  reference_results = add_source_numbers(reference_results)
381
  display_append = add_details(reference_results)
382
  display_append = "\n\n" + "".join(display_append)
383
- real_inputs = (
384
- replace_today(PROMPT_TEMPLATE)
385
- .replace("{query_str}", real_inputs)
386
- .replace("{context_str}", "\n\n".join(reference_results))
387
- .replace("{reply_language}", reply_language)
388
- )
 
 
 
 
 
 
 
 
389
  elif use_websearch:
390
  search_results = []
391
  with DDGS() as ddgs:
392
- ddgs_gen = ddgs.text(real_inputs, backend="lite")
393
  for r in islice(ddgs_gen, 10):
394
  search_results.append(r)
395
  reference_results = []
@@ -405,12 +418,20 @@ class BaseLLMModel:
405
  # display_append = "<ol>\n\n" + "".join(display_append) + "</ol>"
406
  display_append = '<div class = "source-a">' + \
407
  "".join(display_append) + '</div>'
408
- real_inputs = (
409
- replace_today(WEBSEARCH_PTOMPT_TEMPLATE)
410
- .replace("{query}", real_inputs)
411
- .replace("{web_results}", "\n\n".join(reference_results))
412
- .replace("{reply_language}", reply_language)
413
- )
 
 
 
 
 
 
 
 
414
  else:
415
  display_append = ""
416
  return limited_context, fake_inputs, display_append, real_inputs, chatbot
@@ -427,12 +448,21 @@ class BaseLLMModel:
427
  ): # repetition_penalty, top_k
428
 
429
  status_text = "开始生成回答……"
430
- logging.info(
431
- "用户" + f"{self.user_identifier}" + "的输入为:" +
432
- colorama.Fore.BLUE + f"{inputs}" + colorama.Style.RESET_ALL
433
- )
 
 
 
 
 
 
434
  if should_check_token_count:
435
- yield chatbot + [(inputs, "")], status_text
 
 
 
436
  if reply_language == "跟随问题语言(不稳定)":
437
  reply_language = "the same language as the question, such as English, 中文, 日本語, Español, Français, or Deutsch."
438
 
@@ -447,25 +477,28 @@ class BaseLLMModel:
447
  ):
448
  status_text = STANDARD_ERROR_MSG + NO_APIKEY_MSG
449
  logging.info(status_text)
450
- chatbot.append((inputs, ""))
451
  if len(self.history) == 0:
452
- self.history.append(construct_user(inputs))
453
  self.history.append("")
454
  self.all_token_counts.append(0)
455
  else:
456
- self.history[-2] = construct_user(inputs)
457
- yield chatbot + [(inputs, "")], status_text
458
  return
459
- elif len(inputs.strip()) == 0:
460
  status_text = STANDARD_ERROR_MSG + NO_INPUT_MSG
461
  logging.info(status_text)
462
- yield chatbot + [(inputs, "")], status_text
463
  return
464
 
465
  if self.single_turn:
466
  self.history = []
467
  self.all_token_counts = []
468
- self.history.append(construct_user(inputs))
 
 
 
469
 
470
  try:
471
  if stream:
@@ -492,7 +525,7 @@ class BaseLLMModel:
492
  status_text = STANDARD_ERROR_MSG + beautify_err_msg(str(e))
493
  yield chatbot, status_text
494
 
495
- if len(self.history) > 1 and self.history[-1]["content"] != inputs:
496
  logging.info(
497
  "回答为:"
498
  + colorama.Fore.BLUE
@@ -702,6 +735,8 @@ class BaseLLMModel:
702
  def auto_name_chat_history(self, name_chat_method, user_question, chatbot, user_name, single_turn_checkbox):
703
  if len(self.history) == 2 and not single_turn_checkbox:
704
  user_question = self.history[0]["content"]
 
 
705
  filename = replace_special_symbols(user_question)[:16] + ".json"
706
  return self.rename_chat_history(filename, chatbot, user_name)
707
  else:
 
147
  OpenAIInstruct = 13
148
  Claude = 14
149
  Qwen = 15
150
+ OpenAIVision = 16
151
 
152
  @classmethod
153
  def get_type(cls, model_name: str):
 
156
  if "gpt" in model_name_lower:
157
  if "instruct" in model_name_lower:
158
  model_type = ModelType.OpenAIInstruct
159
+ elif "vision" in model_name_lower:
160
+ model_type = ModelType.OpenAIVision
161
  else:
162
  model_type = ModelType.OpenAI
163
  elif "chatglm" in model_name_lower:
 
213
  self.model_name = model_name
214
  self.model_type = ModelType.get_type(model_name)
215
  try:
216
+ self.token_upper_limit = MODEL_METADATA[model_name]["token_limit"]
217
  except KeyError:
218
  self.token_upper_limit = DEFAULT_TOKEN_LIMIT
219
  self.interrupted = False
 
356
  return chatbot, status
357
 
358
  def prepare_inputs(self, real_inputs, use_websearch, files, reply_language, chatbot, load_from_cache_if_possible=True):
 
359
  display_append = []
360
  limited_context = False
361
+ if type(real_inputs) == list:
362
+ fake_inputs = real_inputs[0]['text']
363
+ else:
364
+ fake_inputs = real_inputs
365
  if files:
366
  from langchain.embeddings.huggingface import HuggingFaceEmbeddings
367
  from langchain.vectorstores.base import VectorStoreRetriever
 
377
  "k": 6, "score_threshold": 0.5})
378
  try:
379
  relevant_documents = retriever.get_relevant_documents(
380
+ fake_inputs)
381
  except AssertionError:
382
+ return self.prepare_inputs(fake_inputs, use_websearch, files, reply_language, chatbot, load_from_cache_if_possible=False)
383
  reference_results = [[d.page_content.strip("�"), os.path.basename(
384
  d.metadata["source"])] for d in relevant_documents]
385
  reference_results = add_source_numbers(reference_results)
386
  display_append = add_details(reference_results)
387
  display_append = "\n\n" + "".join(display_append)
388
+ if type(real_inputs) == list:
389
+ real_inputs[0]["text"] = (
390
+ replace_today(PROMPT_TEMPLATE)
391
+ .replace("{query_str}", fake_inputs)
392
+ .replace("{context_str}", "\n\n".join(reference_results))
393
+ .replace("{reply_language}", reply_language)
394
+ )
395
+ else:
396
+ real_inputs = (
397
+ replace_today(PROMPT_TEMPLATE)
398
+ .replace("{query_str}", real_inputs)
399
+ .replace("{context_str}", "\n\n".join(reference_results))
400
+ .replace("{reply_language}", reply_language)
401
+ )
402
  elif use_websearch:
403
  search_results = []
404
  with DDGS() as ddgs:
405
+ ddgs_gen = ddgs.text(fake_inputs, backend="lite")
406
  for r in islice(ddgs_gen, 10):
407
  search_results.append(r)
408
  reference_results = []
 
418
  # display_append = "<ol>\n\n" + "".join(display_append) + "</ol>"
419
  display_append = '<div class = "source-a">' + \
420
  "".join(display_append) + '</div>'
421
+ if type(real_inputs) == list:
422
+ real_inputs[0]["text"] = (
423
+ replace_today(WEBSEARCH_PTOMPT_TEMPLATE)
424
+ .replace("{query}", fake_inputs)
425
+ .replace("{web_results}", "\n\n".join(reference_results))
426
+ .replace("{reply_language}", reply_language)
427
+ )
428
+ else:
429
+ real_inputs = (
430
+ replace_today(WEBSEARCH_PTOMPT_TEMPLATE)
431
+ .replace("{query}", fake_inputs)
432
+ .replace("{web_results}", "\n\n".join(reference_results))
433
+ .replace("{reply_language}", reply_language)
434
+ )
435
  else:
436
  display_append = ""
437
  return limited_context, fake_inputs, display_append, real_inputs, chatbot
 
448
  ): # repetition_penalty, top_k
449
 
450
  status_text = "开始生成回答……"
451
+ if type(inputs) == list:
452
+ logging.info(
453
+ "用户" + f"{self.user_identifier}" + "的输入为:" +
454
+ colorama.Fore.BLUE + "(" + str(len(inputs)-1) + " images) " + f"{inputs[0]['text']}" + colorama.Style.RESET_ALL
455
+ )
456
+ else:
457
+ logging.info(
458
+ "用户" + f"{self.user_identifier}" + "的输入为:" +
459
+ colorama.Fore.BLUE + f"{inputs}" + colorama.Style.RESET_ALL
460
+ )
461
  if should_check_token_count:
462
+ if type(inputs) == list:
463
+ yield chatbot + [(inputs[0]['text'], "")], status_text
464
+ else:
465
+ yield chatbot + [(inputs, "")], status_text
466
  if reply_language == "跟随问题语言(不稳定)":
467
  reply_language = "the same language as the question, such as English, 中文, 日本語, Español, Français, or Deutsch."
468
 
 
477
  ):
478
  status_text = STANDARD_ERROR_MSG + NO_APIKEY_MSG
479
  logging.info(status_text)
480
+ chatbot.append((fake_inputs, ""))
481
  if len(self.history) == 0:
482
+ self.history.append(construct_user(fake_inputs))
483
  self.history.append("")
484
  self.all_token_counts.append(0)
485
  else:
486
+ self.history[-2] = construct_user(fake_inputs)
487
+ yield chatbot + [(fake_inputs, "")], status_text
488
  return
489
+ elif len(fake_inputs.strip()) == 0:
490
  status_text = STANDARD_ERROR_MSG + NO_INPUT_MSG
491
  logging.info(status_text)
492
+ yield chatbot + [(fake_inputs, "")], status_text
493
  return
494
 
495
  if self.single_turn:
496
  self.history = []
497
  self.all_token_counts = []
498
+ if type(inputs) == list:
499
+ self.history.append(inputs)
500
+ else:
501
+ self.history.append(construct_user(inputs))
502
 
503
  try:
504
  if stream:
 
525
  status_text = STANDARD_ERROR_MSG + beautify_err_msg(str(e))
526
  yield chatbot, status_text
527
 
528
+ if len(self.history) > 1 and self.history[-1]["content"] != fake_inputs:
529
  logging.info(
530
  "回答为:"
531
  + colorama.Fore.BLUE
 
735
  def auto_name_chat_history(self, name_chat_method, user_question, chatbot, user_name, single_turn_checkbox):
736
  if len(self.history) == 2 and not single_turn_checkbox:
737
  user_question = self.history[0]["content"]
738
+ if type(user_question) == list:
739
+ user_question = user_question[0]["text"]
740
  filename = replace_special_symbols(user_question)[:16] + ".json"
741
  return self.rename_chat_history(filename, chatbot, user_name)
742
  else:
modules/models/models.py CHANGED
@@ -53,6 +53,12 @@ def get_model(
53
  access_key = os.environ.get("OPENAI_API_KEY", access_key)
54
  model = OpenAI_Instruct_Client(
55
  model_name, api_key=access_key, user_name=user_name)
 
 
 
 
 
 
56
  elif model_type == ModelType.ChatGLM:
57
  logging.info(f"正在加载ChatGLM模型: {model_name}")
58
  from .ChatGLM import ChatGLM_Client
 
53
  access_key = os.environ.get("OPENAI_API_KEY", access_key)
54
  model = OpenAI_Instruct_Client(
55
  model_name, api_key=access_key, user_name=user_name)
56
+ elif model_type == ModelType.OpenAIVision:
57
+ logging.info(f"正在加载OpenAI Vision模型: {model_name}")
58
+ from .OpenAIVision import OpenAIVisionClient
59
+ access_key = os.environ.get("OPENAI_API_KEY", access_key)
60
+ model = OpenAIVisionClient(
61
+ model_name, api_key=access_key, user_name=user_name)
62
  elif model_type == ModelType.ChatGLM:
63
  logging.info(f"正在加载ChatGLM模型: {model_name}")
64
  from .ChatGLM import ChatGLM_Client
modules/overwrites.py CHANGED
@@ -44,32 +44,36 @@ def postprocess_chat_messages(
44
  ) -> str | dict | None:
45
  if chat_message is None:
46
  return None
47
- elif isinstance(chat_message, (tuple, list)):
48
- file_uri = chat_message[0]
49
- if utils.validate_url(file_uri):
50
- filepath = file_uri
51
- else:
52
- filepath = self.make_temp_copy_if_needed(file_uri)
53
-
54
- mime_type = client_utils.get_mimetype(filepath)
55
- return {
56
- "name": filepath,
57
- "mime_type": mime_type,
58
- "alt_text": chat_message[1] if len(chat_message) > 1 else None,
59
- "data": None, # These last two fields are filled in by the frontend
60
- "is_file": True,
61
- }
62
- elif isinstance(chat_message, str):
63
- # chat_message = inspect.cleandoc(chat_message)
64
- # escape html spaces
65
- # chat_message = chat_message.replace(" ", "&nbsp;")
66
- if role == "bot":
67
- chat_message = convert_bot_before_marked(chat_message)
68
- elif role == "user":
69
- chat_message = convert_user_before_marked(chat_message)
70
- return chat_message
71
  else:
72
- raise ValueError(f"Invalid message for Chatbot component: {chat_message}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
 
75
 
@@ -103,4 +107,3 @@ def BlockContext_init(self, *args, **kwargs):
103
 
104
  original_BlockContext_init = gr.blocks.BlockContext.__init__
105
  gr.blocks.BlockContext.__init__ = BlockContext_init
106
-
 
44
  ) -> str | dict | None:
45
  if chat_message is None:
46
  return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  else:
48
+ if isinstance(chat_message, (tuple, list)):
49
+ if len(chat_message) > 0 and "text" in chat_message[0]:
50
+ chat_message = chat_message[0]["text"]
51
+ else:
52
+ file_uri = chat_message[0]
53
+ if utils.validate_url(file_uri):
54
+ filepath = file_uri
55
+ else:
56
+ filepath = self.make_temp_copy_if_needed(file_uri)
57
+
58
+ mime_type = client_utils.get_mimetype(filepath)
59
+ return {
60
+ "name": filepath,
61
+ "mime_type": mime_type,
62
+ "alt_text": chat_message[1] if len(chat_message) > 1 else None,
63
+ "data": None, # These last two fields are filled in by the frontend
64
+ "is_file": True,
65
+ }
66
+ if isinstance(chat_message, str):
67
+ # chat_message = inspect.cleandoc(chat_message)
68
+ # escape html spaces
69
+ # chat_message = chat_message.replace(" ", "&nbsp;")
70
+ if role == "bot":
71
+ chat_message = convert_bot_before_marked(chat_message)
72
+ elif role == "user":
73
+ chat_message = convert_user_before_marked(chat_message)
74
+ return chat_message
75
+ else:
76
+ raise ValueError(f"Invalid message for Chatbot component: {chat_message}")
77
 
78
 
79
 
 
107
 
108
  original_BlockContext_init = gr.blocks.BlockContext.__init__
109
  gr.blocks.BlockContext.__init__ = BlockContext_init
 
modules/presets.py CHANGED
@@ -51,17 +51,15 @@ CHUANHU_DESCRIPTION = i18n("由Bilibili [土川虎虎虎](https://space.bilibili
51
 
52
 
53
  ONLINE_MODELS = [
54
- "gpt-3.5-turbo",
55
- "gpt-3.5-turbo-instruct",
56
- "gpt-3.5-turbo-16k",
57
- "gpt-4",
58
- "gpt-3.5-turbo-0301",
59
- "gpt-3.5-turbo-0613",
60
- "gpt-4-0314",
61
- "gpt-4-0613",
62
- "gpt-4-32k",
63
- "gpt-4-32k-0314",
64
- "gpt-4-32k-0613",
65
  "川虎助理",
66
  "川虎助理 Pro",
67
  "GooglePaLM",
@@ -92,7 +90,7 @@ LOCAL_MODELS = [
92
  "Qwen 14B"
93
  ]
94
 
95
- # Additional metadate for local models
96
  MODEL_METADATA = {
97
  "Llama-2-7B":{
98
  "repo_id": "TheBloke/Llama-2-7B-GGUF",
@@ -107,7 +105,47 @@ MODEL_METADATA = {
107
  },
108
  "Qwen 14B": {
109
  "repo_id": "Qwen/Qwen-14B-Chat-Int4",
110
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  }
112
 
113
  if os.environ.get('HIDE_LOCAL_MODELS', 'false') == 'true':
@@ -125,20 +163,6 @@ for dir_name in os.listdir("models"):
125
  if dir_name not in MODELS:
126
  MODELS.append(dir_name)
127
 
128
- MODEL_TOKEN_LIMIT = {
129
- "gpt-3.5-turbo": 4096,
130
- "gpt-3.5-turbo-16k": 16384,
131
- "gpt-3.5-turbo-0301": 4096,
132
- "gpt-3.5-turbo-0613": 4096,
133
- "gpt-4": 8192,
134
- "gpt-4-0314": 8192,
135
- "gpt-4-0613": 8192,
136
- "gpt-4-32k": 32768,
137
- "gpt-4-32k-0314": 32768,
138
- "gpt-4-32k-0613": 32768,
139
- "Claude": 4096
140
- }
141
-
142
  TOKEN_OFFSET = 1000 # 模型的token上限减去这个值,得到软上限。到达软上限之后,自动尝试减少token占用。
143
  DEFAULT_TOKEN_LIMIT = 3000 # 默认的token上限
144
  REDUCE_TOKEN_FACTOR = 0.5 # 与模型token上限想乘,得到目标token数。减少token占用时,将token占用减少到目标token数以下。
 
51
 
52
 
53
  ONLINE_MODELS = [
54
+ "GPT3.5 Turbo",
55
+ "GPT3.5 Turbo Instruct",
56
+ "GPT3.5 Turbo 16K",
57
+ "GPT3.5 Turbo 0301",
58
+ "GPT3.5 Turbo 0613",
59
+ "GPT4",
60
+ "GPT4 32K",
61
+ "GPT4 Turbo",
62
+ "GPT4 Vision",
 
 
63
  "川虎助理",
64
  "川虎助理 Pro",
65
  "GooglePaLM",
 
90
  "Qwen 14B"
91
  ]
92
 
93
+ # Additional metadata for online and local models
94
  MODEL_METADATA = {
95
  "Llama-2-7B":{
96
  "repo_id": "TheBloke/Llama-2-7B-GGUF",
 
105
  },
106
  "Qwen 14B": {
107
  "repo_id": "Qwen/Qwen-14B-Chat-Int4",
108
+ },
109
+ "GPT3.5 Turbo": {
110
+ "model_name": "gpt-3.5-turbo",
111
+ "token_limit": 4096,
112
+ },
113
+ "GPT3.5 Turbo Instruct": {
114
+ "model_name": "gpt-3.5-turbo-instruct",
115
+ "token_limit": 4096,
116
+ },
117
+ "GPT3.5 Turbo 16K": {
118
+ "model_name": "gpt-3.5-turbo-16k",
119
+ "token_limit": 16384,
120
+ },
121
+ "GPT3.5 Turbo 0301": {
122
+ "model_name": "gpt-3.5-turbo-0301",
123
+ "token_limit": 4096,
124
+ },
125
+ "GPT3.5 Turbo 0613": {
126
+ "model_name": "gpt-3.5-turbo-0613",
127
+ "token_limit": 4096,
128
+ },
129
+ "GPT4": {
130
+ "model_name": "gpt-4",
131
+ "token_limit": 8192,
132
+ },
133
+ "GPT4 32K": {
134
+ "model_name": "gpt-4-32k",
135
+ "token_limit": 32768,
136
+ },
137
+ "GPT4 Turbo": {
138
+ "model_name": "gpt-4-1106-preview",
139
+ "token_limit": 128000,
140
+ },
141
+ "GPT4 Vision": {
142
+ "model_name": "gpt-4-vision-preview",
143
+ "token_limit": 128000,
144
+ },
145
+ "Claude": {
146
+ "model_name": "Claude",
147
+ "token_limit": 4096,
148
+ },
149
  }
150
 
151
  if os.environ.get('HIDE_LOCAL_MODELS', 'false') == 'true':
 
163
  if dir_name not in MODELS:
164
  MODELS.append(dir_name)
165
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
  TOKEN_OFFSET = 1000 # 模型的token上限减去这个值,得到软上限。到达软上限之后,自动尝试减少token占用。
167
  DEFAULT_TOKEN_LIMIT = 3000 # 默认的token上限
168
  REDUCE_TOKEN_FACTOR = 0.5 # 与模型token上限想乘,得到目标token数。减少token占用时,将token占用减少到目标token数以下。
web_assets/javascript/ChuanhuChat.js CHANGED
@@ -45,7 +45,7 @@ let windowWidth = window.innerWidth; // 初始窗口宽度
45
 
46
  function addInit() {
47
  var needInit = {chatbotIndicator, uploaderIndicator};
48
-
49
  chatbotIndicator = gradioApp().querySelector('#chuanhu-chatbot > div.wrap');
50
  uploaderIndicator = gradioApp().querySelector('#upload-index-file > div.wrap');
51
  chatListIndicator = gradioApp().querySelector('#history-select-dropdown > div.wrap');
@@ -60,7 +60,7 @@ function addInit() {
60
  chatbotObserver.observe(chatbotIndicator, { attributes: true, childList: true, subtree: true });
61
  chatListObserver.observe(chatListIndicator, { attributes: true });
62
  setUploader();
63
-
64
  return true;
65
  }
66
 
@@ -124,7 +124,7 @@ function initialize() {
124
  // setHistroyPanel();
125
  // trainBody.classList.add('hide-body');
126
 
127
-
128
 
129
  return true;
130
  }
@@ -213,7 +213,7 @@ function checkModel() {
213
  checkXMChat();
214
  function checkGPT() {
215
  modelValue = model.value;
216
- if (modelValue.includes('gpt')) {
217
  gradioApp().querySelector('#header-btn-groups').classList.add('is-gpt');
218
  } else {
219
  gradioApp().querySelector('#header-btn-groups').classList.remove('is-gpt');
@@ -365,8 +365,8 @@ function chatbotContentChanged(attempt = 1, force = false) {
365
  }
366
  }, 200);
367
  }
368
-
369
-
370
  }, i === 0 ? 0 : 200);
371
  }
372
  // 理论上是不需要多次尝试执行的,可惜gradio的bug导致message可能没有渲染完毕,所以尝试500ms后再次执行
@@ -414,7 +414,7 @@ window.addEventListener('resize', ()=>{
414
  updateVH();
415
  windowWidth = window.innerWidth;
416
  setPopupBoxPosition();
417
- adjustSide();
418
  });
419
  window.addEventListener('orientationchange', (event) => {
420
  updateVH();
@@ -441,13 +441,13 @@ function makeML(str) {
441
  return l
442
  }
443
  let ChuanhuInfo = function () {
444
- /*
445
- ________ __ ________ __
446
  / ____/ /_ __ ______ _____ / /_ __ __ / ____/ /_ ____ _/ /_
447
  / / / __ \/ / / / __ `/ __ \/ __ \/ / / / / / / __ \/ __ `/ __/
448
- / /___/ / / / /_/ / /_/ / / / / / / / /_/ / / /___/ / / / /_/ / /_
449
- \____/_/ /_/\__,_/\__,_/_/ /_/_/ /_/\__,_/ \____/_/ /_/\__,_/\__/
450
-
451
  川虎Chat (Chuanhu Chat) - GUI for ChatGPT API and many LLMs
452
  */
453
  }
 
45
 
46
  function addInit() {
47
  var needInit = {chatbotIndicator, uploaderIndicator};
48
+
49
  chatbotIndicator = gradioApp().querySelector('#chuanhu-chatbot > div.wrap');
50
  uploaderIndicator = gradioApp().querySelector('#upload-index-file > div.wrap');
51
  chatListIndicator = gradioApp().querySelector('#history-select-dropdown > div.wrap');
 
60
  chatbotObserver.observe(chatbotIndicator, { attributes: true, childList: true, subtree: true });
61
  chatListObserver.observe(chatListIndicator, { attributes: true });
62
  setUploader();
63
+
64
  return true;
65
  }
66
 
 
124
  // setHistroyPanel();
125
  // trainBody.classList.add('hide-body');
126
 
127
+
128
 
129
  return true;
130
  }
 
213
  checkXMChat();
214
  function checkGPT() {
215
  modelValue = model.value;
216
+ if (modelValue.toLowerCase().includes('gpt')) {
217
  gradioApp().querySelector('#header-btn-groups').classList.add('is-gpt');
218
  } else {
219
  gradioApp().querySelector('#header-btn-groups').classList.remove('is-gpt');
 
365
  }
366
  }, 200);
367
  }
368
+
369
+
370
  }, i === 0 ? 0 : 200);
371
  }
372
  // 理论上是不需要多次尝试执行的,可惜gradio的bug导致message可能没有渲染完毕,所以尝试500ms后再次执行
 
414
  updateVH();
415
  windowWidth = window.innerWidth;
416
  setPopupBoxPosition();
417
+ adjustSide();
418
  });
419
  window.addEventListener('orientationchange', (event) => {
420
  updateVH();
 
441
  return l
442
  }
443
  let ChuanhuInfo = function () {
444
+ /*
445
+ ________ __ ________ __
446
  / ____/ /_ __ ______ _____ / /_ __ __ / ____/ /_ ____ _/ /_
447
  / / / __ \/ / / / __ `/ __ \/ __ \/ / / / / / / __ \/ __ `/ __/
448
+ / /___/ / / / /_/ / /_/ / / / / / / / /_/ / / /___/ / / / /_/ / /_
449
+ \____/_/ /_/\__,_/\__,_/_/ /_/_/ /_/\__,_/ \____/_/ /_/\__,_/\__/
450
+
451
  川虎Chat (Chuanhu Chat) - GUI for ChatGPT API and many LLMs
452
  */
453
  }